In [None]:
"""
COMS W4705 - Natural Language Processing - Spring 2023
Homework 2 - Parsing with Probabilistic Context Free Grammars 
Daniel Bauer
"""
import math
import sys
from collections import defaultdict
import itertools
import numpy as np
from grammar import Pcfg

In [None]:
### Use the following two functions to check the format of your data structures in part 3 ###
def check_table_format(table):
    """
    Return true if the backpointer table object is formatted correctly.
    Otherwise return False and print an error.  
    """
    if not isinstance(table, dict): 
        sys.stderr.write("Backpointer table is not a dict.\n")
        return False
    for split in table: 
        if not isinstance(split, tuple) and len(split) ==2 and \
          isinstance(split[0], int)  and isinstance(split[1], int):
            sys.stderr.write("Keys of the backpointer table must be tuples (i,j) representing spans.\n")
            return False
        if not isinstance(table[split], dict):
            sys.stderr.write("Value of backpointer table (for each span) is not a dict.\n")
            return False
        for nt in table[split]:
            if not isinstance(nt, str): 
                sys.stderr.write("Keys of the inner dictionary (for each span) must be strings representing nonterminals.\n")
                return False
            bps = table[split][nt]
            if isinstance(bps, str): # Leaf nodes may be strings
                continue 
            if not isinstance(bps, tuple):
                sys.stderr.write("Values of the inner dictionary (for each span and nonterminal) must be a pair ((i,k,A),(k,j,B)) of backpointers. Incorrect type: {}\n".format(bps))
                return False
            if len(bps) != 2:
                sys.stderr.write("Values of the inner dictionary (for each span and nonterminal) must be a pair ((i,k,A),(k,j,B)) of backpointers. Found more than two backpointers: {}\n".format(bps))
                return False
            for bp in bps: 
                if not isinstance(bp, tuple) or len(bp)!=3:
                    sys.stderr.write("Values of the inner dictionary (for each span and nonterminal) must be a pair ((i,k,A),(k,j,B)) of backpointers. Backpointer has length != 3.\n".format(bp))
                    return False
                if not (isinstance(bp[0], str) and isinstance(bp[1], int) and isinstance(bp[2], int)):
                    print(bp)
                    sys.stderr.write("Values of the inner dictionary (for each span and nonterminal) must be a pair ((i,k,A),(k,j,B)) of backpointers. Backpointer has incorrect type.\n".format(bp))
                    return False
    return True

In [None]:
def check_probs_format(table):
    """
    Return true if the probability table object is formatted correctly.
    Otherwise return False and print an error.  
    """
    if not isinstance(table, dict): 
        sys.stderr.write("Probability table is not a dict.\n")
        return False
    for split in table: 
        if not isinstance(split, tuple) and len(split) ==2 and isinstance(split[0], int) and isinstance(split[1], int):
            sys.stderr.write("Keys of the probability must be tuples (i,j) representing spans.\n")
            return False
        if not isinstance(table[split], dict):
            sys.stderr.write("Value of probability table (for each span) is not a dict.\n")
            return False
        for nt in table[split]:
            if not isinstance(nt, str): 
                sys.stderr.write("Keys of the inner dictionary (for each span) must be strings representing nonterminals.\n")
                return False
            prob = table[split][nt]
            if not isinstance(prob, float):
                sys.stderr.write("Values of the inner dictionary (for each span and nonterminal) must be a float.{}\n".format(prob))
                return False
            if prob > 0:
                sys.stderr.write("Log probability may not be > 0.  {}\n".format(prob))
                return False
    return True

In [None]:
class CkyParser(object):
    """
    A CKY parser.
    """

    def __init__(self, grammar): 
        """
        Initialize a new parser instance from a grammar. 
        """
        self.grammar = grammar

    def is_in_language(self,tokens):
        """
        Membership checking. Parse the input tokens and return True if 
        the sentence is in the language described by the grammar. Otherwise
        return False
        """
        table = {}
        length = len(tokens)

        # Getting all the tokens and storing in the diagonal
        for i in range(length):
            table[(i,i)] = []
            for item in self.grammar.rhs_to_rules[(tokens[i],)]:
                table[(i,i)].append(item[0])

        # Filling the upper half of the matrix
        for row in range(1,length):
            for i in range(length):
                if i+row < length:
                    table[(i,i+row)] = []

                    j = i
                    while j< i+row:
                        if table.get((i,j)) != [] and table.get((j+1,i+row)) != []:
                            for a in table[(i,j)]:
                                for b in table[(j+1,i+row)]:
                                    if self.grammar.rhs_to_rules.get((a, b)) != None :
                                        for item in self.grammar.rhs_to_rules.get((a, b)):
                                            table[(i, i+row)].append(item[0])
                        j+=1
        if self.grammar.startsymbol in table[0,length-1]:
            return True
        else:
            return False
       
    def parse_with_backpointers(self, tokens):
        """
        Parse the input tokens and return a parse table and a probability table.
        """
        #table= None
        #probs = None
        table = {}
        probs = {}
        length = len(tokens)

        for i in range(length):
            table[(i,i+1)] = {}
            probs[(i,i+1)] = {}
            for item in self.grammar.rhs_to_rules[(tokens[i],)]:
                table[(i,i+1)][item[0]]=tokens[i]
                probs[(i,i+1)][item[0]]=np.log(item[-1])

        for row in range(2,length+1):
            for i in range(length):
                if i+row <= length:
                    table[(i,i+row)] = {}
                    probs[(i,i+row)] = {}
                    j = i+1
                    while j < i+row:
                        if table.get((i,j)) != {} and table.get((j,i+row)) != {}:
                            for a in table[(i,j)]:
                                for b in table[(j,i+row)]:
                                    if self.grammar.rhs_to_rules.get((a, b)) != None :
                                        for item in self.grammar.rhs_to_rules.get((a, b)):
                                            try:
                                                score = probs[(i, i + row)][item[0]]
                                                new_score = probs[(i,j)][a] + probs[(j,i+row)][b] +np.log(item[-1])
                                                if new_score > score:
                                                    probs[(i, i + row)][item[0]] = new_score
                                                    table[(i,i+row)][item[0]] = ((a,i,j),(b,j,i+row))
                                            except:
                                                table[(i, i + row)][item[0]] = ((a, i, j), (b, j, i + row))
                                                probs[(i, i + row)][item[0]] = probs[(i, j)][a] + probs[(j, i + row)][b] + np.log(item[-1])

                        j+=1

        return table, probs

In [None]:
def get_tree(chart, i,j,nt): 
    """
    Return the parse-tree rooted in non-terminal nt and covering span i,j.
    """
    table = chart
    item = table[(i,j)][nt]
    if type(item) == tuple:
        return (nt,get_tree(table,item[0][1],item[0][2],item[0][0]),get_tree(table,item[1][1],item[1][2],item[1][0]))
    else:
        return (nt, item)

In [None]:
if __name__ == "__main__":
    
    with open('test2.pcfg','r') as grammar_file:
        grammar = Pcfg(grammar_file) 
        parser = CkyParser(grammar)
        toks =['flights', 'from', 'miami','to','cleveland' ,'.']
        print(parser.is_in_language(toks))
        table,probs = parser.parse_with_backpointers(toks)
        print(table)
        print(check_table_format(table))
        print(check_probs_format(probs))
        #print(get_tree(table, 0, len(toks), grammar.startsymbol))