In [None]:
import os

import nltk
import pcfg
import pyrootutils
from nltk import CFG, Nonterminal, Production

from formal_gym import grammar as fg_grammar

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

grammar_path = PROJECT_ROOT / "data" / "sample_trim_20241022141559.cfg"
# grammar_path = PROJECT_ROOT / "data" / "sample_raw_20241022141532.cfg"

In [None]:
dual_grammar = fg_grammar.ContextFreeGrammar.from_file(grammar_path)

print(dual_grammar.as_cfg)

print(dual_grammar.as_pcfg)

In [None]:
dual_grammar.generate(sep=" ", max_depth=1000)

In [None]:
grammar = fg_grammar.Grammar.from_grammar(grammar_path)
cfg_grammar = grammar.grammar_obj

cfg_grammar.productions()[0]

In [None]:
rules = {}
for prod in cfg_grammar.productions():
    lhs, rhs = prod.lhs(), prod.rhs()
    rules.setdefault(lhs, []).append(rhs)

pcfg_rules = [
    (lhs, rhs, 1 / len(rules[lhs]))
    for lhs, rhs_list in rules.items()
    for rhs in rhs_list
]

pcfg_productions = [
    f"{lhs} -> {' '.join(str(sym) for sym in rhs)} [{prob:0.5f}]"
    for lhs, rhs, prob in pcfg_rules
]

pcfg_grammar_str = "\n".join(pcfg_productions)
pcfg_grammar = pcfg.PCFG.fromstring(pcfg_grammar_str)

In [None]:
type(pcfg_grammar)

In [None]:
for s in grammar.generate(3, sep=" "):
    print(s)

In [None]:
cfg_productions = [Production(p.lhs(), p.rhs()) for p in pcfg_grammar.productions()]
cfg = CFG(Nonterminal(pcfg_grammar.start()), cfg_productions)

In [None]:
print(cfg)

In [None]:
print(cfg_grammar)

In [None]:
type(pcfg_grammar.productions()[0])

In [None]:
import random
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple

from nltk import PCFG, Nonterminal, Production


class PCFGSampler:
    def __init__(self, grammar: PCFG, max_depth: int = 50):
        self.grammar = grammar
        self.max_depth = max_depth
        # Cache productions by left-hand side
        self.productions_by_lhs = defaultdict(list)
        self.probs_by_lhs = defaultdict(list)

        # Precompute productions and probabilities
        for prod in grammar.productions():
            self.productions_by_lhs[prod.lhs()].append(prod)
            self.probs_by_lhs[prod.lhs()].append(prod.prob())

        # Find non-terminals that can derive terminals
        self.can_terminate = self._find_terminating_nts()

    def _find_terminating_nts(self) -> Set[Nonterminal]:
        """Find all non-terminals that can eventually derive only terminals."""
        can_terminate = set()

        # First pass: find non-terminals that directly derive terminals
        for prod in self.grammar.productions():
            if all(not isinstance(sym, Nonterminal) for sym in prod.rhs()):
                can_terminate.add(prod.lhs())

        # Fixed point iteration until no more non-terminals are added
        changed = True
        while changed:
            changed = False
            for prod in self.grammar.productions():
                if prod.lhs() not in can_terminate:
                    if all(
                        not isinstance(sym, Nonterminal) or sym in can_terminate
                        for sym in prod.rhs()
                    ):
                        can_terminate.add(prod.lhs())
                        changed = True

        return can_terminate

    def _choose_production(self, lhs: Nonterminal, depth: int) -> Optional[Production]:
        """Choose a production for the given LHS, considering depth constraints."""
        productions = self.productions_by_lhs[lhs]
        probs = self.probs_by_lhs[lhs]

        if depth >= self.max_depth:
            # At max depth, only consider productions that can lead to termination
            valid_prods = [
                (p, prob)
                for p, prob in zip(productions, probs)
                if all(
                    not isinstance(sym, Nonterminal) or sym in self.can_terminate
                    for sym in p.rhs()
                )
            ]

            if not valid_prods:
                return None

            # Normalize probabilities of valid productions
            total_prob = sum(prob for _, prob in valid_prods)
            valid_prods = [(p, prob / total_prob) for p, prob in valid_prods]

            return random.choices(
                [p for p, _ in valid_prods], weights=[prob for _, prob in valid_prods]
            )[0]

        return random.choices(productions, weights=probs)[0]

    def sample(self, start: Optional[Nonterminal] = None) -> Optional[List[str]]:
        """Generate a random sample from the grammar."""
        if start is None:
            start = self.grammar.start()

        def _sample_recursive(symbol: Nonterminal, depth: int) -> Optional[List[str]]:
            if depth > self.max_depth:
                return None

            if not isinstance(symbol, Nonterminal):
                return [str(symbol)]

            production = self._choose_production(symbol, depth)
            if production is None:
                return None

            result = []
            for sym in production.rhs():
                if isinstance(sym, Nonterminal):
                    subsample = _sample_recursive(sym, depth + 1)
                    if subsample is None:
                        return None
                    result.extend(subsample)
                else:
                    result.append(str(sym))

            return result

        result = _sample_recursive(start, 0)
        return (
            result if result is not None else self.sample(start)
        )  # Try again if sampling failed

In [None]:
grammar = PCFG.fromstring("""
    S -> NP VP [1.0]
    NP -> Det N [0.5] | NP PP [0.5]
    VP -> V NP [0.6] | VP PP [0.4]
    PP -> P NP [1.0]
    Det -> 'the' [1.0]
    N -> 'cat' [0.4] | 'dog' [0.6]
    V -> 'saw' [1.0]
    P -> 'with' [1.0]
""")

In [None]:
grammar

In [None]:
sampler = PCFGSampler(grammar)
result = sampler.sample()
print(" ".join(result))

In [None]:
result