# The HDP-PCFG

In [42]:
from rule import Rule
from cfg import WCFG, read_grammar_rules
from parser import cky
from earley import earley
from symbol import make_symbol, is_nonterminal, is_terminal
import numpy as np
from collections import defaultdict
from InsideOutside import inside, outside, inside_outside, EM, plot_EM
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def generate_sample(grammar, items=('[E]',)):
    """
    Given a grammar returns a sentence from it using
    the probabilities specfied in the grammar.
    :param items: call the function with (start,) where 
                  start is the start symbol of the grammar
    :returns: a sentence from the language as a list
    """
    frags = []
    for item in items:
        if is_nonterminal(item):
            productions = grammar.get(item)
            ps = [production.prob for production in productions]
            random_index = np.argmax(np.random.multinomial(1, ps, size=1))
            prod = productions[random_index]
            frags.extend(generate_sample(grammar, items=prod.rhs))
        else:
            frags.append(item)
    return frags

def generate_corpus(grammar, n, start=('[E]',)):
    """
    Generates a corpus using the grammar
    :param n: size of the corpus
    :params: same a s generate corpus
    :returns: a corpus in the form of a list
    """
    return [generate_sample(grammar, items=start) for i in range(n)]

def initialize(grammar, alpha=20.0):
    """
    Takes a grammar and returns that grammar with 
    the probabilities replaced by random probabilities
    generated from a Dirichlet distribution.
    :param: alpha is the Dirichlet concentration parameter
    """
    init_grammar = WCFG()
    for nonterminal in grammar.nonterminals:
        rules = grammar.get(nonterminal)
        init_prob = np.random.dirichlet(len(rules)*[alpha])
        for i, rule in enumerate(rules):
            init_grammar.add(Rule(rule.lhs, rule.rhs, init_prob[i]))
    return init_grammar

In [14]:
def counting(forest, start):  # acyclic hypergraph
    N = dict()
    
    def get_count(symbol):
        w = N.get(symbol, None)
        if w is not None:
            return w
        incoming = forest.get(symbol, set())
        if len(incoming) == 0:  # terminals have already been handled, this must be a nonterminal dead end
            N[symbol] = w
            return 0
        w = 0
        for rule in incoming:
            k = 1
            for child in rule.rhs:
                k *= get_count(child)
            w += k
        N[symbol] = w
        return w
    
    # handles terminals
    for sym in forest.terminals:
        N[sym] = 1
    # handles nonterminals
    #for sym in forest.nonterminals:
    #    get_inside(sym)
    get_count(start)
        
    return N

In [45]:
import matplotlib.pyplot as plt
%matplotlib inline

def get_instances(rule, forest):
    """
    Given a rule
    
    A -> B C 
    
    get_instances collects all instances of rules of the form
    
    [A:i-j] -> [B:i-k] [C:k-j] 
    
    from the forest and returns them in a list.
    """
    instances = []
    for r in forest:
        if r.lhs[1] == rule.lhs[1] and len(r.rhs)==len(rule.rhs):
            test = []
            for i in range(len(r.rhs)):
                try:
                    # if for example r.rhs[i] = '[A]' 
                    v = r.rhs[i][1] == rule.rhs[i][1]
                    test.append(v)
                except IndexError:
                    # if for example r.rhs[i] = '*' 
                    v = r.rhs[i][0] == rule.rhs[i][0]
                    test.append(v)
            if np.all(test):
                instances.append(r)
    return instances

def inside_outside(training_sents, grammar, start_sym='[E]'):
    f = defaultdict(float)
    for sent in training_sents:
        forest = cky(grammar, sent)
        goal = make_symbol(start_sym, 0, len(sent))
        I = inside(forest, goal)
        O = outside(forest, goal, I)     
        for rule in grammar:
            w = 0.0
            for instance in get_instances(rule, forest):
                k = rule.prob
                k *= O[instance.lhs]
                for child in instance.rhs:
                    try:
                        alpha = I[child]
                    except KeyError:
                        # same solution as in outside
                        alpha = 0.0
                    k *= alpha
                w += k
            f[rule] += w/I[goal]
    return f


In [3]:
hdp_pcfg = WCFG(read_grammar_rules(open('examples/hdp-pcfg-grammar', 'r')))
print hdp_pcfg
hdp_corpus = generate_corpus(hdp_pcfg, 10, start=('[S]',))
print hdp_corpus

[NN] -> mouse (0.33)
[NN] -> cat (0.33)
[NN] -> dog (0.34)
[JJ] -> big (0.5)
[JJ] -> black (0.5)
[DT] -> the (0.5)
[DT] -> a (0.5)
[NPBAR] -> [JJ] [NN] (0.5)
[NPBAR] -> [JJ] [NPBAR] (0.5)
[VP] -> [VB] [NP] (1.0)
[S] -> [NP] [VP] (1.0)
[VB] -> chased (0.5)
[VB] -> ate (0.5)
[NP] -> [DT] [NN] (0.5)
[NP] -> [DT] [NPBAR] (0.5)
[['a', 'cat', 'chased', 'a', 'big', 'cat'], ['the', 'big', 'mouse', 'chased', 'a', 'dog'], ['the', 'big', 'mouse', 'ate', 'a', 'big', 'black', 'mouse'], ['a', 'mouse', 'chased', 'the', 'dog'], ['a', 'big', 'black', 'dog', 'chased', 'a', 'dog'], ['the', 'mouse', 'chased', 'the', 'black', 'dog'], ['the', 'dog', 'ate', 'the', 'big', 'mouse'], ['the', 'mouse', 'ate', 'a', 'big', 'big', 'black', 'big', 'cat'], ['the', 'big', 'black', 'big', 'dog', 'chased', 'the', 'big', 'black', 'dog'], ['the', 'black', 'cat', 'ate', 'a', 'black', 'big', 'dog']]


In [26]:
sentence = hdp_corpus[1]
for sentence in hdp_corpus:
    cky_forest = cky(hdp_pcfg, sentence)
#     print cky_forest
    goal = make_symbol('[S]', 0, len(sentence))
    print counting(cky_forest, goal)[goal]

1
1
1
1
1
1
1
1
1
1


In [50]:
print hdp_pcfg
f = inside_outside(hdp_corpus, hdp_pcfg, start_sym='[S]')
print f

s = 0
for sent in hdp_corpus:
    print sent
    s += sum(map(lambda x : x=='chased', sent))
print s

[NN] -> mouse (0.33)
[NN] -> cat (0.33)
[NN] -> dog (0.34)
[JJ] -> big (0.5)
[JJ] -> black (0.5)
[DT] -> the (0.5)
[DT] -> a (0.5)
[NPBAR] -> [JJ] [NN] (0.5)
[NPBAR] -> [JJ] [NPBAR] (0.5)
[VP] -> [VB] [NP] (1.0)
[S] -> [NP] [VP] (1.0)
[VB] -> chased (0.5)
[VB] -> ate (0.5)
[NP] -> [DT] [NN] (0.5)
[NP] -> [DT] [NPBAR] (0.5)
defaultdict(<type 'float'>, {[NN] -> mouse (0.33): 15.73529411764706, [DT] -> the (0.5): 11.0, [VP] -> [VB] [NP] (1.0): 10.0, [JJ] -> black (0.5): 8.0, [NN] -> cat (0.33): 4.0, [NPBAR] -> [JJ] [NPBAR] (0.5): 21.0, [NN] -> dog (0.34): 16.212121212121215, [S] -> [NP] [VP] (1.0): 10.0, [JJ] -> big (0.5): 13.0, [NPBAR] -> [JJ] [NN] (0.5): 21.0, [DT] -> a (0.5): 9.0, [VB] -> chased (0.5): 6.0, [NP] -> [DT] [NN] (0.5): 20.0, [VB] -> ate (0.5): 4.0, [NP] -> [DT] [NPBAR] (0.5): 20.0})
['a', 'cat', 'chased', 'a', 'big', 'cat']
['the', 'big', 'mouse', 'chased', 'a', 'dog']
['the', 'big', 'mouse', 'ate', 'a', 'big', 'black', 'mouse']
['a', 'mouse', 'chased', 'the', 'dog']
['a',

In [33]:
synthetic = WCFG(read_grammar_rules(open('examples/synthetic', 'r')))
print synthetic
synthetic_corpus = generate_corpus(synthetic, 10, start=('[S]',))
print synthetic_corpus
# print sum(map(lambda (x,y): x==y, synthetic_corpus))

[S] -> [X1] [X1] (0.25)
[S] -> [X2] [X2] (0.25)
[S] -> [X3] [X3] (0.25)
[S] -> [X4] [X4] (0.25)
[X3] -> a3 (0.25)
[X3] -> b3 (0.25)
[X3] -> c3 (0.25)
[X3] -> d3 (0.25)
[X2] -> a2 (0.25)
[X2] -> b2 (0.25)
[X2] -> c2 (0.25)
[X2] -> d2 (0.25)
[X4] -> a4 (0.25)
[X4] -> b4 (0.25)
[X4] -> c4 (0.25)
[X4] -> d4 (0.25)
[X1] -> a1 (0.25)
[X1] -> b1 (0.25)
[X1] -> c1 (0.25)
[X1] -> d1 (0.25)
[['a4', 'a4'], ['d1', 'c1'], ['b4', 'd4'], ['c2', 'd2'], ['c4', 'd4'], ['d4', 'c4'], ['d4', 'c4'], ['b2', 'b2'], ['c3', 'd3'], ['c3', 'c3']]


In [29]:
sentence = synthetic_corpus[1]
cky_forest = cky(synthetic, sentence)
print cky_forest

[X3:0-1] -> c3 (0.25)
[X3:1-2] -> b3 (0.25)
[S:0-2] -> [X3:0-1] [X3:1-2] (0.25)


## Update $q(\phi)$ ('M-step')

Needed: corpus, inside-outside

In [119]:
# expected_counts = inside_outside(hdp_corpus, hdp_pcfg, start_sym='[S]')
from itertools import product

grammar = hdp_pcfg
corpus = hdp_corpus

sentence = hdp_corpus[1]
print sentence

expected_counts1 = inside_outside([sentence], grammar, start_sym='[S]')

expected_counts = inside_outside(corpus, grammar, start_sym='[S]')
normalized_counts = {rule: count/float(len(corpus)) for rule,count in expected_counts.iteritems()}

# print expected_counts1
# print expected_counts
# print normalized_counts


sigma = grammar.terminals
S = grammar.nonterminals
S2 = list(product(S,S))
T = ['E', 'B']

# print sigma
# print S
# print S2

a_E = 0.5
a_B = 0.5
a_T = 0.5

alpha_E = {w: a_E for w in sigma}
alpha_B = {N: a_B for N in S2}
alpha_T = {t: a_T for t in T}

gamma_E = {z: {w: a_E for w in sigma} for z in S}
gamma_B = {z: {N: a_B for N in S2} for z in S}
gamma_T = {z: {t: a_T for t in T} for z in S}


for rule, count in normalized_counts.iteritems():
    if len(rule.rhs) > 1:
        gamma_B[rule.lhs][rule.rhs] += count
        gamma_T[rule.lhs]['B'] += count
    else:
        gamma_E[rule.lhs][rule.rhs[0]] += count
        gamma_T[rule.lhs]['E'] += count

# print '\n'
# print gamma_E
# print '\n'
# print gamma_B
# print '\n'
# print gamma_T




['the', 'big', 'mouse', 'chased', 'a', 'dog']


## Multinomial updates ('E-step')

In [126]:
from scipy.special import digamma
from math import exp
# for each nonterminal z a dictionary with keys in sigma
W_E = {z: {w: a_E for w in sigma} for z in S}
W_B = {z: {N: a_B for N in S2} for z in S}
W_T = {z: {t: a_T for t in T} for z in S}

for z in S:
    E_total = 0.0
    for w in sigma:
        v = gamma_E[z][w]
        W_E[z][w] = exp(digamma(v))
        E_total += v
    for w in sigma:
        W_E[z][w] *= exp(digamma(E_total))
    
    B_total = 0.0
    for N in S2:
        v = gamma_B[z][N]
        W_B[z][N] = exp(digamma(v))
        B_total += v
    for N in S2:
        W_B[z][N] *= exp(digamma(B_total))
    
    T_total = 0.0
    for t in T:
        v = gamma_T[z][t]
        W_T[z][t] = exp(digamma(v))
        T_total += v
    for t in T:
        W_T[z][t] *= exp(digamma(T_total))


In [151]:
sentence =  hdp_corpus[6]
print sentence

def get_base_symbol(symbol):
    """
    Returns the span from a symbol
    E.g. input [NN:2-3] returns [NN]
    """
    return symbol.split(':')[0]+']'

forest = cky(grammar, sentence)
print forest

new_grammar = {}
for rule in forest:
    q_z = 1.0
    if len(rule.rhs) > 1:
        # we've found a B
        z = get_base_symbol(rule.lhs)
        c1 = get_base_symbol(rule.rhs[0])
        c2 = get_base_symbol(rule.rhs[1])
        q_z *= W_B[z][(c1,c2)] * W_T[z]['B']
    else:
        # we've found an E
        z = get_base_symbol(rule.lhs)
        w = rule.rhs[0]
        q_z *= W_E[z][w] * W_T[z]['E']
print q_z




['the', 'dog', 'ate', 'the', 'big', 'mouse']
[NN:1-2] -> dog (0.34)
[DT:0-1] -> the (0.5)
[S:0-6] -> [NP:0-2] [VP:2-6] (1.0)
[NP:3-6] -> [DT:3-4] [NPBAR:4-6] (0.5)
[VB:2-3] -> ate (0.5)
[VP:2-6] -> [VB:2-3] [NP:3-6] (1.0)
[NN:5-6] -> mouse (0.33)
[NPBAR:4-6] -> [JJ:4-5] [NN:5-6] (0.5)
[NP:0-2] -> [DT:0-1] [NN:1-2] (0.5)
[JJ:4-5] -> big (0.5)
[DT:3-4] -> the (0.5)
45.013780626
