In [None]:
import os
import random

import lark
import nltk
import pandas as pd
import pyrootutils
import seaborn as sns
import tqdm.auto as tqdm

from formal_gym import grammar as fg_grammar

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

# grammar_name = "grammar_20250312172959_597104"
grammar_name = "grammar_20250319112222_631725"
grammar_path = PROJECT_ROOT / "data" / "grammars" / grammar_name / f"{grammar_name}.cfg"

In [None]:
grammar = fg_grammar.Grammar.from_file(grammar_path)
grammar.as_cfg.productions()

In [None]:
print(grammar.terminals)

import random

new_sample = " ".join(random.choices(list(grammar.terminals), k=10))
print(new_sample)

print(grammar.test_sample(new_sample))

In [None]:
def cfg_to_lark(grammar: fg_grammar.Grammar) -> str:
    out_str = ""
    for production in grammar.as_cfg.productions():
        lhs = production.lhs()
        if str(lhs) == "S":
            lhs = "start"
        rhs_pre = production.rhs()
        rhs = []
        for r in rhs_pre:
            if isinstance(r, str):
                rhs.append(f'"{r}"')
            elif isinstance(r, nltk.grammar.Nonterminal):
                rhs.append(f"{r}")
        rhs = " ".join(sym for sym in rhs)
        out_str += f"{lhs} : {rhs}\n"
    return out_str


from collections import defaultdict


def convert_cfg_to_ebnf(grammar: fg_grammar.Grammar) -> str:
    rules = defaultdict(set)

    for production in grammar.as_cfg.productions():
        rules[production.lhs()].add(production.rhs())

    lark_rules = []
    for lhs, rhs_set in rules.items():
        rhs_rules = []
        for rhs in rhs_set:
            rhs_syms = []
            for s in rhs:
                if isinstance(s, str):
                    rhs_syms.append(f'"{s}"')
                elif isinstance(s, nltk.grammar.Nonterminal):
                    rhs_syms.append(f"{str(s).lower()}")
            rhs_string = " ".join(rhs_syms)
            rhs_rules.append(rhs_string)
        print(rhs_rules)
        lark_rhs = " | ".join(s for s in list(rhs_rules))
        if str(lhs) == "S":
            lhs = "start"
        lark_rules.append(f"{str(lhs).lower()} : {lark_rhs}")

    g_dir = "%import common.WS_INLINE\n%ignore WS_INLINE"
    return "\n".join(lark_rules) + "\n" + g_dir


lark_g = convert_cfg_to_ebnf(grammar)
print(lark_g)

In [None]:
lark_parser = lark.Lark(lark_g, ambiguity="explicit")
lark_parser.__dict__

In [None]:
chart_parser = nltk.ChartParser(grammar.as_cfg)
recdescent_parser = nltk.RecursiveDescentParser(grammar.as_cfg)
shift_reduce_parser = nltk.ShiftReduceParser(grammar.as_cfg)

In [None]:
# positive_sample = "t0 t0 t0 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3 t3".split(" ")
sample_with_parse = grammar.generate_tree()
positive_sample = sample_with_parse["string"].split(" ")
print(sample_with_parse)

In [None]:
parse_tree = lark_parser.parse(sample_with_parse["string"])
len(parse_tree.children)

In [None]:
# def parse_with_lark(sample: str):
#     try:
#         return lark_parser.parse(sample)
#     except lark.exceptions.LarkError as e:
#         # print(f"Error parsing sample: {sample}")
#         # print(e)
#         return None

# parse_with_lark("t0 t0 t0 t3")
# parse_with_lark("t3")

In [None]:
from lark import Token, Tree


def clean_tree(tree, in_string: str) -> str:
    clean_parses = []

    for parse_tree in tree.children:
        labels = in_string.split()  # tokens in original order

        def recurse(node):
            children = " ".join(recurse(child) for child in node.children)
            if children == "":
                children = labels.pop(0)
            node_label = node.data.upper()
            if node_label == "START":
                node_label = "S"
            return f"({node_label} {children})"

        clean_parses.append(recurse(parse_tree))

    return clean_parses


print(clean_tree(parse_tree, sample_with_parse["string"]))

In [None]:
c_parses = chart_parser.parse(positive_sample)
for p in c_parses:
    print(p)

In [None]:
sr_parses = shift_reduce_parser.parse(positive_sample)

for p in sr_parses:
    print(p)

In [None]:
rd_parses = recdescent_parser.parse(positive_sample)

for p in rd_parses:
    print(p)

In [None]:
c_parses = chart_parser.parse(positive_sample)

n = 0
for p in c_parses:
    n += 1
    print(p)

print(n)