In [2]:
import os
from itertools import product
from mrg_utils import *
import nltk.corpus
import random

# DATA PROCESSING

In [3]:
def parsedSents(wsjDir):
    reader = nltk.corpus.BracketParseCorpusReader(wsjDir, ".*/.*\.mrg")
    return reader.parsed_sents()


wsj = "LDC99T42-20211111T235121Z-001/LDC99T42/treebank_3/parsed/mrg/wsj"
wsjdev = "LDC99T42-20211111T235121Z-001/LDC99T42/treebank_3/parsed/mrg/dev"
wsjs = "LDC99T42-20211111T235121Z-001/LDC99T42/treebank_3/parsed/mrg/small"

train = parsedSents(wsj)
dev = parsedSents(wsjdev)
s = parsedSents(wsjs)

In [10]:
def generate(trees):
    CFG = []
    for tree in trees:
        tree.chomsky_normal_form()
        CFG = CFG + tree.productions()
    cfg = list(set(CFG))
    return cfg,CFG    

In [5]:
#writing cfgs into file for further use
def writeinto(cfg,CFG,sample):
    fnamecfg = 'cfg_'+sample
    fnameCFG = 'CFG_'+sample
    with open(fnamecfg, 'w') as filehandle:
        for listitem in cfg:
            filehandle.write('%s\n' % listitem)
        
    with open(fnameCFG, 'w') as filehandle:
        for listitem in CFG:
            filehandle.write('%s\n' % listitem)


In [6]:
cfgsamp,CFGsamp = generate(s)  #generate sample data
writeinto(cfgsamp,CFGsamp,'sample')

In [7]:
cfgdev,CFGdev = generate(dev) #generate dev data
writeinto(cfgdev,CFGdev,'dev')

In [11]:
cfgtrain,CFGtrain = generate(train) #generate train data
writeinto(cfgtrain,CFGtrain,'train')

In [187]:
with open('UD_English-EWT/en_ewt-ud-test.txt', encoding="mbcs") as f1:
    sents1 = f1.read()
test=sents1.split('\n\n')

# INSIDE OUTSIDE

In [18]:
def inside(wo, ur, br, non_terminals):  
    in_prob = []
    temp = []
    for i in range(0, len(wo)):
        for j in range(0, len(wo)):
            temp.append({})
        in_prob.append(temp)
        temp = []
    
    for i in range(0, len(wo)):
        for unary_rule in ur:
            if unary_rule[1] == wo[i]:
                in_prob[i][i][unary_rule[0]] = unary_rule[2]
                
   
    j = 1
    while j < len(wo):
        for i in range(0, len(wo)):
            if (i + j) < len(wo):
                for nt_left in non_terminals:
                    for nt_right in non_terminals:
                        for binary_rule in br:
                            if binary_rule[1] == nt_left and binary_rule[2] == nt_right:
                                sProb = 0
                                for d in range(i, i + j):
                                    if nt_left in in_prob[i][d].keys() and nt_right in \
                                    in_prob[d + 1][i + j].keys():
                                        sProb += binary_rule[3] * in_prob[i][d][nt_left] * in_prob[d + 1][i +j][nt_right]
                                        
                                if sProb > 0:
                                    if binary_rule[0] in in_prob[i][i + j].keys():
                                        in_prob[i][i + j][binary_rule[0]] += sProb
                                    else:
                                        in_prob[i][i + j][binary_rule[0]] = sProb                            
        j += 1
        

    return in_prob

In [19]:
def outside(words, inside_probs, binary_rules, nts):   
    outside_probs = []
    new = []
    for i in range(0, len(words)):
        for j in range(0, len(words)):
            new.append({})
        outside_probs.append(new)
        new = []
    
    outside_probs[0][len(words) - 1]['S'] = 1.0

    j = len(words) - 1 
    while j >= 0:
        for i in range(0, len(words)):
            if (i + j) < len(words):
                # check rules to the right
                for nt_start in nts:
                    for nt_right in nts:
                        for binary_rule in binary_rules:
                            if binary_rule[0] == nt_start and binary_rule[2] == nt_right and binary_rule[1] != nt_right:
                                sum_prob = 0
                                for e in range(i + j + 1, len(words)):
                                    if nt_start in outside_probs[i][e].keys() and nt_right in inside_probs[i + j + 1][e].keys():
                                        sum_prob += binary_rule[3] * outside_probs[i][e][nt_start] * inside_probs[i + j + 1][e][nt_right]
                                if sum_prob > 0:
                                    if binary_rule[1] in outside_probs[i][i + j].keys():
                                        outside_probs[i][i + j][binary_rule[1]] += sum_prob
                                    else:
                                        outside_probs[i][i + j][binary_rule[1]] = sum_prob
                                        
                # check rules above
                for nt_start in nts:
                    for nt_left in nts:
                        for binary_rule in binary_rules:
                            if binary_rule[0] == nt_start and binary_rule[1] \
                            == nt_left:
                                sum_prob = 0
                                # add up probabilities corresponding to the 
                                # rule
                                for e in range(0, i):
                                    if nt_start in \
                                    outside_probs[e][i + j].keys() and \
                                    nt_left in inside_probs[e][i - 1].keys():
                                        sum_prob += binary_rule[3] * \
                                        outside_probs[e][i + j][nt_start] * \
                                        inside_probs[e][i - 1][nt_left]
                                if sum_prob > 0:
                                    if binary_rule[2] in \
                                    outside_probs[i][i + j].keys():
                                        outside_probs[i][i + j] \
                                        [binary_rule[2]] += sum_prob
                                    else:
                                        outside_probs[i][i + j] \
                                        [binary_rule[2]] = sum_prob                        
        
        j -= 1
        
    return outside_probs

In [106]:
def print_rules(u_rules, b_rules, output_file):   
    with open(output_file, 'w'):
        pass
    unary_rules, binary_rules = u_rules, b_rules
    for binary_rule in binary_rules:
        if binary_rule[-1] >= 0.0:
            with open(output_file, 'a+') as o:
                o.write(' '.join([str(binary_rule[0]), '->', str(binary_rule[1]), str(binary_rule[2]), str(binary_rule[3]), '\n']))
    for unary_rule in unary_rules:
        if unary_rule[-1] >= 0.0:
            with open(output_file, 'a+') as o:
                o.write(' '.join([str(unary_rule[0]), '->', str(unary_rule[1]), str(unary_rule[2]), '\n']))

In [20]:
def train_iterate(words, inside_probs, outside_probs, binary_rules):

    updated_rules = []

    for binary_rule in binary_rules:
        numerator = 0
        for i in range(0, len(words)):
            for j in range(i + 1, len(words)):
                if binary_rule[0] in outside_probs[i][j].keys():
                        inside_sum = 0
                        for d in range(i, j):
                            if binary_rule[1] in inside_probs[i][d].keys() and binary_rule[2] in inside_probs[d + 1][j].keys():
                                inside_sum += inside_probs[i][d][binary_rule[1]] * inside_probs[d + 1][j][binary_rule[2]]                                
                        outside_sum = outside_probs[i][j][binary_rule[0]] * binary_rule[3]
                        numerator += inside_sum * outside_sum
        
        denominator = 0
        for i in range(0, len(words)):
            for j in range(i, len(words)):
                if binary_rule[0] in outside_probs[i][j].keys() and binary_rule[0] in inside_probs[i][j].keys():
                    denominator += outside_probs[i][j][binary_rule[0]] * inside_probs[i][j][binary_rule[0]] 
               
        try:
            new_prob = numerator / denominator
        except ZeroDivisionError:
            new_prob =  0.0
        
        if new_prob == 0.0:
            new_prob = binary_rule[-1]
        updated_rules.append((binary_rule[0], binary_rule[1], binary_rule[2],new_prob))
        
    return updated_rules

In [21]:
def improvement(old, new):
    improv = 0

    for i in range(0, len(old)):
        if abs(old[i][-1] - new[i][-1]) > improv:
            improv = abs(old[i][-1] - new[i][-1])
    
    return improv

In [24]:
def training(unary, binary, non_terminals, i, t):
    if t == 'dev':
        with open('UD_English-EWT/en_ewt-ud-dev.txt', encoding="mbcs") as f:
            sents1 = f.read()
    elif t == 'train':
        with open('UD_English-EWT/en_ewt-ud-train.txt', encoding="mbcs") as f1:
            sents1 = f1.read()
    senon_terminals=senon_terminals1.split('\n\n')        
    iterations = 0
    ud_rules = binary
    print('Training ' + str(i) + '...\n')
    for sent in senon_terminals:
        words = sent.split()
        inside_probs = inside(words, unary, ud_rules, non_terminals)
        outside_probs = outside(words, inside_probs, ud_rules, non_terminals)
        ud_rules = train_iterate(words, inside_probs, outside_probs, ud_rules)
        
        
    iterations += 1
    print('Iteration', iterations)
    print_rules(unary, binary, 'log/' + str(iterations) + '.log')
    

    impr = improvement(binary, ud_rules)
    threshold = 1e-04
    while impr >= threshold:
        temp_u = []
        for ud_rule in ud_rules:
            if ud_rule[-1] != 0.0:
                temp_u.append(ud_rule)
        ud_rules = temp_u
        binary = ud_rules
        
        for sent in senon_terminals:
            words = sent.split()
            
            inside_probs = inside(words, unary, ud_rules, non_terminals)
            outside_probs = outside(words, inside_probs, ud_rules, non_terminals)
            ud_rules = train_iterate(words, inside_probs, outside_probs, ud_rules)
            
    
            
        iterations += 1
        impr = check_improvement(binary, ud_rules)
        
    print('terminated')
    print_rules(unary, binary, 'output_' + str(i) + '.txt')
    return ud_rules



In [61]:
def read(cfg):
    binary_rules, unary_rules, nts, ts = [], [], [], []
    nts = list(map(lambda x : x.lhs(), cfg))
    i = 0
    for line in cfg:
        l = [line.lhs()]
        for e in list(line.rhs()):
            l.append(e)
        if len(l) >= 3:
            binary_rules.append((l[0], l[1], l[2], 0.0))
        elif len(l) >= 2:
            unary_rules.append((l[0], l[1], 0.0))
    
    return unary, binary, nts

In [57]:
unary, binary, nonTerminals=read(cfgtrain)
rules = training(unary, binary, nonTerminals, 0,'train')

# CYK

In [158]:
class Node:
    def __init__(self, left, right1, right2=None, prob=0.5):
        self.left = left
        self.right1 = right1
        self.right2 = right2
        self.prob = prob
    def __repr__(self):
        return self.left

In [159]:
def read_input(sentence):
    with open(sentence, encoding="mbcs") as f1:
        sents1 = f1.read()
    sents=sents1.split('\n\n') 
    return sents[0].split()   

In [160]:
def read_grammar(grammar):
    with open(grammar) as fr2:
        lines = fr2.readlines()
        g=[]
        for each in lines:
            s = each.replace('->','')
            g.append(s.split())
    return g    

In [161]:
def parse(text,grammar):
        length = len(text)
        parse_table = [[[] for i in range(length)] for j in range(length)]
        for j, word in enumerate(text):
            # go through every column, from left to right
            for rule in grammar:
                # fill the terminal word cell
                if f"'{word}'" == rule[1]:
                    parse_table[j][j].append(Node(rule[0], word, prob=rule[-1]))
            # go through every row, from bottom to top
            for i in range(j-1, -1, -1):
                for k in range(i, j):
                    right1_cell = parse_table[i][k]  # cell left
                    right2_cell = parse_table[k+1][j]  # cell beneath
                    for rule in grammar:
                        right1_node = [n for n in right1_cell if n.left == rule[1]]
                        if right1_node:
                            right2_node = [n for n in right2_cell if n.left == rule[2]]
                            parse_table[i][j].extend(
                                [Node(rule[0], right1, right2, rule[-1]) for right1 in right1_node for right2 in right2_node])
        return parse_table                  

In [177]:
def get_parses(grammar,parse_table):
    start_symbol = grammar[0][0]
    final_nodes = [n for n in parse_table[0][-1] if n.left == start_symbol]
    if final_nodes:
        write_trees = [parseString(node) for node in final_nodes]
        poss_trees = [poss_tree(node) for node in final_nodes]
        maximum = poss_trees.index(max(poss_trees))
        print(write_trees[maximum], round(poss_trees[maximum], 6))
        return(write_trees[maximum], round(poss_trees[maximum], 6))
    else:
        print("Not Valid")
        return("Not Valid")

In [184]:
def parseString(node):
    k = node
    if k.right2 is None:
        return f"({k.left} {k.right1})"
    return f"({k.left} {generate_tree(k.right1)} {generate_tree(k.right2)})"

In [185]:
def poss_tree(node):
    k = node
    if k.right2 is None:
        p = float(k.prob)
        return p
    return float(k.prob) * poss_tree(k.right1) * poss_tree(k.right2)


In [205]:
def cyk(text):
    grammar_path = './output.txt'
    text_path = './str1.txt'
    grammar=read_grammar(grammar_path)
    input_text = text.split()
    matrix=parse(input_text,grammar)
    return get_parses(grammar,matrix)

In [207]:
test_parses = []
for each in test:
    cyk(each)

1
Not Valid
1
Not Valid
1
Not Valid
1


KeyboardInterrupt: 

In [210]:
grammar_path = './output.txt'
text_path = './str1.txt'
grammar=read_grammar(grammar_path)
input_text = read_input(text_path)
matrix=parse(input_text,grammar)
get_parses(grammar,matrix)

Not Valid


In [209]:
for i in matrix:
    print(i)

[[Verb, VP, S], [], [VP, VP1, VP2, S], [], [VP, VP1, VP2, S, VP, VP, S, S]]
[[], [Det], [NP], [], [NP]]
[[], [], [Noun, Nominal, NP], [], [Nominal, NP]]
[[], [], [], [Preposition], [PP]]
[[], [], [], [], [Proper-Noun, NP]]


# EVALUATE