<a href="https://colab.research.google.com/github/kanishqvijay/DSA0328-Natural-Language-Processing/blob/main/Program-12.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
class State:
    def __init__(self, rule, dot, start, end=0, children=None):
        self.rule = (rule[0], tuple(rule[1]))
        self.dot = dot
        self.start = start
        self.end = end
        self.children = children if children is not None else []

    def __eq__(self, other):
        return (self.rule, self.dot, self.start, self.end) == (other.rule, other.dot, other.start, other.end)

    def __hash__(self):
        return hash((self.rule, self.dot, self.start, self.end))

    def next(self):
        return self.rule[1][self.dot] if self.dot < len(self.rule[1]) else None

    def complete(self):
        return self.dot >= len(self.rule[1])

class TreeNode:
    def __init__(self, symbol, children=None):
        self.symbol = symbol
        self.children = children if children is not None else []

    def __str__(self):
        return self.pretty_print()

    def pretty_print(self, level=0):
        indent = "  " * level
        result = f"{indent}{self.symbol}\n"
        for child in self.children:
            result += child.pretty_print(level + 1)
        return result

class EarleyParser:
    def __init__(self, grammar):
        self.grammar = {k: [tuple(rhs) for rhs in v] for k, v in grammar.items()}

    def parse(self, tokens, start='S'):
        chart = [set() for _ in range(len(tokens) + 1)]
        chart[0].add(State((start, self.grammar[start][0]), 0, 0))
        for i in range(len(tokens) + 1):
            while True:
                size = len(chart[i])
                self.predict(chart, i)
                if i < len(tokens):
                    self.scan(chart, i, tokens)
                self.complete(chart, i)
                if size == len(chart[i]):
                    break

        for state in chart[len(tokens)]:
            if state.rule[0] == start and state.complete() and state.start == 0:
                return self.build_tree(state, chart, tokens)
        return None

    def predict(self, chart, i):
        for state in list(chart[i]):
            next_sym = state.next()
            if next_sym in self.grammar:
                for rhs in self.grammar[next_sym]:
                    chart[i].add(State((next_sym, rhs), 0, i))

    def scan(self, chart, i, tokens):
        for state in list(chart[i]):
            next_sym = state.next()
            if next_sym == tokens[i]:
                new_state = State(state.rule, state.dot + 1, state.start, i + 1)
                new_state.children = state.children + [TreeNode(tokens[i])]
                chart[i + 1].add(new_state)

    def complete(self, chart, i):
        for completed in list(chart[i]):
            if completed.complete():
                for state in chart[completed.start]:
                    if state.next() == completed.rule[0]:
                        new_state = State(state.rule, state.dot + 1, state.start, i)
                        new_children = state.children + [TreeNode(completed.rule[0], completed.children)]
                        new_state.children = new_children
                        chart[i].add(new_state)

    def build_tree(self, state, chart, tokens):
        return TreeNode(state.rule[0], state.children)

def print_colored_tree(node, level=0):
    """Print parse tree with ANSI color codes"""
    colors = {
        'nonterminal': '\033[94m',
        'terminal': '\033[92m',
        'reset': '\033[0m'
    }

    indent = "  " * level

    color = colors['terminal'] if len(node.children) == 0 else colors['nonterminal']
    print(f"{indent}{color}{node.symbol}{colors['reset']}")

    for child in node.children:
        print_colored_tree(child, level + 1)


if __name__ == "__main__":
    grammar = {
        'S': [['NP', 'VP']],
        'NP': [['Det', 'N']],
        'VP': [['V', 'NP']],
        'Det': [['the']],
        'N': [['cat']],
        'V': [['saw']]
    }

    parser = EarleyParser(grammar)
    tokens = ['the', 'cat', 'saw', 'the', 'cat']
    tree = parser.parse(tokens)

    if tree:
        print("\nParse Tree:")
        print_colored_tree(tree)
    else:
        print("Input rejected")


Parse Tree:
[94mS[0m
  [94mNP[0m
    [94mDet[0m
      [92mthe[0m
    [94mN[0m
      [92mcat[0m
  [94mVP[0m
    [94mV[0m
      [92msaw[0m
    [94mNP[0m
      [94mDet[0m
        [92mthe[0m
      [94mN[0m
        [92mcat[0m
