In [2]:
from collections import defaultdict

import pandas as pd

In [39]:
loose_token_type = tuple[str, str] | tuple[str]
class CykParser:
    def __init__(self, rules: list[tuple[str, loose_token_type]], probs: dict[str, dict[loose_token_type, float]]):
        self.rules = rules
        self.probs = probs
        self.binary_rules = [r for r in self.rules if len(r[1]) == 2]
    
    def parse(self, tokens: list[str]):
        n = len(tokens)
        CYK = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0)))
        PTR  = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        
        for s in range(n):
            current_token = (tokens[s],)
            for A, rhs in self.rules:
                if rhs == current_token:
                    CYK[s][s+1][A] = self.probs[A][current_token]
                    PTR[s][s+1][A].append(current_token)
        
        for l in range(2, n+1):
            for start in range(0, n-l+1):
                end = start + l
                for split in range(start+1, end):
                    for A, (B,C) in self.binary_rules:
                        p = CYK[start][split][B] * CYK[split][end][C] * self.probs[A][(B, C)]
                        if p > CYK[start][end][A]:
                            CYK[start][end][A] = p
                            PTR[start][end][A].append(([start, split, B], [split, end, C]))
        return CYK, PTR

In [44]:
data = [
    ('S', ('NP', 'VP'), 0.4),
    ('S', ('V', 'NP'), 0.4),
    ('S', ('time',), 0.1),
    ('S', ('flies',), 0.1),
    ('NP', ('N', 'N'), 1.0),
    ('VP', ('V', 'NP'), 0.5),
    ('VP', ('time', ), 0.3),
    ('VP', ('flies',), 0.2),
    ('N', ('time',), 0.4),
    ('N', ('flies',), 0.4),
    ('N', ('fruit',), 0.2),
    ('V', ('time', ), 0.6),
    ('V', ('flies',), 0.4),
]
rules = [(lhs, rhs) for lhs, rhs, prob in data]
probs = defaultdict(lambda: defaultdict(float))
for lhs, rhs, prob in data:
    probs[lhs][rhs] = prob


In [51]:
tokens = ['time', 'fruit', 'flies']
token_length = len(tokens)
parser = CykParser(rules, probs)
CYK, PTR = parser.parse(tokens)
CYK[0][token_length]

defaultdict(<function __main__.CykParser.parse.<locals>.<lambda>.<locals>.<lambda>.<locals>.<lambda>()>,
            {'S': 0.019200000000000005, 'NP': 0, 'VP': 0.024000000000000004})

In [53]:
for i in range(token_length):
    for j in range(token_length):
        print((i, j+1), dict(CYK[i][j+1]))

(0, 1) {'S': 0.1, 'VP': 0.3, 'N': 0.4, 'V': 0.6, 'NP': 0}
(0, 2) {'S': 0, 'NP': 0.08000000000000002, 'VP': 0, 'V': 0, 'N': 0}
(0, 3) {'S': 0.019200000000000005, 'NP': 0, 'VP': 0.024000000000000004}
(1, 1) {}
(1, 2) {'N': 0.2, 'VP': 0, 'NP': 0, 'V': 0}
(1, 3) {'S': 0, 'NP': 0.08000000000000002, 'VP': 0, 'N': 0}
(2, 1) {}
(2, 2) {}
(2, 3) {'S': 0.1, 'VP': 0.2, 'N': 0.4, 'V': 0.4, 'NP': 0}


In [55]:
df_data = [
    [dict(CYK[i][j+1]) for j in range(token_length)]
    for i in range(token_length)
]
df_data

[[{'S': 0.1, 'VP': 0.3, 'N': 0.4, 'V': 0.6, 'NP': 0},
  {'S': 0, 'NP': 0.08000000000000002, 'VP': 0, 'V': 0, 'N': 0},
  {'S': 0.019200000000000005, 'NP': 0, 'VP': 0.024000000000000004}],
 [{},
  {'N': 0.2, 'VP': 0, 'NP': 0, 'V': 0},
  {'S': 0, 'NP': 0.08000000000000002, 'VP': 0, 'N': 0}],
 [{}, {}, {'S': 0.1, 'VP': 0.2, 'N': 0.4, 'V': 0.4, 'NP': 0}]]

In [57]:
pd.DataFrame(df_data, columns=tokens, index=tokens)

Unnamed: 0,time,fruit,flies
time,"{'S': 0.1, 'VP': 0.3, 'N': 0.4, 'V': 0.6, 'NP'...","{'S': 0, 'NP': 0.08000000000000002, 'VP': 0, '...","{'S': 0.019200000000000005, 'NP': 0, 'VP': 0.0..."
fruit,{},"{'N': 0.2, 'VP': 0, 'NP': 0, 'V': 0}","{'S': 0, 'NP': 0.08000000000000002, 'VP': 0, '..."
flies,{},{},"{'S': 0.1, 'VP': 0.2, 'N': 0.4, 'V': 0.4, 'NP'..."
