In [1]:
# imports
import itertools
from tabulate import tabulate
import math
from collections import defaultdict
import numpy as np
from binarytree import Node

In [2]:
# Class for grammar rule
class Rule:
    def __init__(self, left, right, prob = 0):
        self.left = left
        self.right = right
        self.prob = prob

In [3]:
# Class grammer that holds the rule
class Grammar:
    def __init__(self, rules):
        self.grammar = {}
        self.reversed_grammar = {}
        for rule in rules:
            self.grammar[rule.left] = self.grammar.get(rule.left, set()).union(set([rule.right]))
            #self.reversed_grammar[rule.right] = self.reversed_grammar.get(rule.right, set()).union(str(rule.left) + "##" + str(rule.prob))
            if rule.right not in self.reversed_grammar:    
                self.reversed_grammar[rule.right] = defaultdict(int)
            self.reversed_grammar[rule.right][str(rule.left) + "##" + str(rule.prob)] += 1

In [4]:
# Custom Class for binary tree. It inherites the 'Node' class of binarytree package. This is useful to visualise a tree data structure
class TreeNode(Node):
    def __init__(self, value, log_prob, left, right, word=None):
        value = value if word is None else f"{value}-{word}"
        super().__init__(value, left, right)
        
        # Calculate the log_score using left, right tree
        left_score = 0 if self.left is None else self.left.score
        right_score = 0 if self.right is None else self.right.score
        self.score = log_prob + left_score + right_score
        
        # Calculate probability
        self.prob = math.exp(self.score)

    def __repr__(self):
        return str(round(self.prob, 4))
    

# Calculates a marginal probability of trees
def marginalize_trees(trees):
    if not trees:
        return 0
    prob = sum([tree.prob for tree in trees])
    return prob

In [5]:
class CKYParser(Grammar):
    # initialize parser
    def __init__(self, rules):
        super().__init__(rules)

    # cky parser algorithm
    def parse(self, words):
        # initialize table
        table = [[""] * i + [set()] * (len(words) - i) for i in range(len(words))]
        
        # iterate of the columns
        for j in range(len(words)):
            cell = set()
            for word_tag in self.reversed_grammar.get(words[j], {}):
                w, p = word_tag.split('##')
                cell.add(w)
            
            # initialize diagonal element
            table[j][j] = cell
            
            # fill up the rows from bottom to up
            for i in reversed(range(j)):
                cell = set()
                for k in range(i, j):
                    rows = table[i][k]
                    cols = table[k + 1][j]
                    permutations = list(itertools.product(rows, cols))
                    
                    # iterate on each permutation
                    for left, right in permutations:
                        check = f"{left} {right}"
                        if check in self.reversed_grammar.keys():
                            cell = cell.union([k.split('##')[0] for k in self.reversed_grammar[check]])
                # update cell value
                table[i][j] = cell if len(cell)>0 else '\u03A6'
        return table
    
    # weighted cky parser algorithm
    def weigheted_parse(self, words):
        # initialize table
        table = [[""] * i + [set()] * (len(words) - i) for i in range(len(words))]
        tree_list = list()
        
        # iterate of the columns
        for j in range(len(words)):
            cell = set()
            for word_tag in self.reversed_grammar.get(words[j], {}):
                w, p = word_tag.split('##')
                p = float(p)
                node = TreeNode(w, np.log(p), left=None, right=None, word=words[j])
                cell.add((w, node))
            
            # initialize diagonal element
            table[j][j] = cell
            
            # fill up the rows from bottom to up
            for i in reversed(range(j)):
                cell = set()
                for k in range(i, j):
                    rows = table[i][k]
                    cols = table[k + 1][j]
                    permutations = list(itertools.product(rows, cols))
                    
                    # iterate on each permutation
                    for left, right in permutations:
                        l_const, l_node = left
                        r_const, r_node = right
                        check = f"{l_const} {r_const}"
                        if check in self.reversed_grammar.keys():
                            for k in self.reversed_grammar[check]:
                                w, p = k.split('##')
                                p = float(p)
                                new_node = TreeNode(w, np.log(p), left=l_node, right=r_node)
                                cell.add((w, new_node))
                table[i][j] = cell
        
        # Only add trees that have the root node as the start symbol(S)        
        for tag, root in table[0][-1]:
            if tag == "S":
                tree_list.append(root)

        return table, tree_list

In [6]:
sentence = "British left waffles on Falklands"
tokens = sentence.split(" ")
rules = [Rule("S", "NP VP"), Rule("NP", "JJ NP"), Rule("VP", "VP NP"), Rule("VP", "VP PP"), Rule("PP", "P NP"), Rule("NP", "British"), Rule("JJ", "British"), Rule("NP", "left"), Rule("VP", "left"), Rule("NP", "waffles"), Rule("VP", "waffles"), Rule("P", "on"), Rule("NP", "Falklands"),]
cky_parser = CKYParser(rules)

In [7]:
table = cky_parser.parse(tokens)
print("CKY Table:")
print()
print(tabulate(table, headers=tokens, showindex="always"))

CKY Table:

    British       left          waffles       on     Falklands
--  ------------  ------------  ------------  -----  -----------
 0  {'JJ', 'NP'}  {'S', 'NP'}   {'S'}         Φ      {'S'}
 1                {'NP', 'VP'}  {'S', 'VP'}   Φ      {'S', 'VP'}
 2                              {'NP', 'VP'}  Φ      {'VP'}
 3                                            {'P'}  {'PP'}
 4                                                   {'NP'}


In [8]:
sentence = "astronomers saw stars with ears"
tokens = sentence.split(" ")
rules = [Rule("S", "NP VP", 1.0), Rule("PP", "P NP", 1.0), Rule("VP", "V NP", 0.7), Rule("VP", "VP PP", 0.3), Rule("P", "with", 1.0), Rule("V", "saw", 1.0), Rule("NP", "NP PP", 0.4), Rule("NP", "astronomers", 0.4), Rule("NP", "ears", 0.18), Rule("NP", "saw", 0.04), Rule("NP", "stars", 0.18), Rule("NP", "telescopes", 0.1)]
cky_parser = CKYParser(rules)

In [9]:
table, parse_trees = cky_parser.weigheted_parse(tokens)

print("Weighted CKY Table:")
print()
print(tabulate(table, headers=tokens, showindex="always"))
print()

# best tree
best_tree = max(parse_trees, key=lambda node: node.score)
print("Most Probable Parse Tree:")
print()
print('Score : ', best_tree.prob)
print(best_tree)
print()

# Marginalize over all trees
print(f"Probability of Sentence marginalized over the trees: {marginalize_trees(parse_trees)}")

Weighted CKY Table:

    astronomers    saw                         stars            with          ears
--  -------------  --------------------------  ---------------  ------------  --------------------------------
 0  {('NP', 0.4)}  set()                       {('S', 0.0504)}  set()         {('S', 0.0036), ('S', 0.0027)}
 1                 {('V', 1.0), ('NP', 0.04)}  {('VP', 0.126)}  set()         {('VP', 0.0068), ('VP', 0.0091)}
 2                                             {('NP', 0.18)}   set()         {('NP', 0.013)}
 3                                                              {('P', 1.0)}  {('PP', 0.18)}
 4                                                                            {('NP', 0.18)}

Most Probable Parse Tree:

Score :  0.0036288000000000015

        _______S______
       /              \
NP-astronomers       __VP_________
                    /             \
                 V-saw         ____NP_______
                              /             \
                