In [1]:
import numpy as np
import pandas as pd

%load_ext autoreload
%autoreload 2

with open('./data/en_ewt-ud-train.conllu') as f:
    lines = f.readlines()
    
with open('./data/en_ewt-ud-dev.conllu') as f:
    lines_dev = f.readlines()


In [2]:
from dataset import read_syntax_tree

df = read_syntax_tree(lines)

In [3]:
df.head()

Unnamed: 0,sentence,pos,head,deprel
0,"[Al, -, Zaman, :, American, forces, killed, Sh...","[NNP, HYPH, NNP, :, JJ, NNS, VBD, NNP, NNP, NN...","[0, 1, 1, 1, 6, 7, 1, 7, 8, 8, 8, 8, 8, 15, 8,...","[root, punct, flat, punct, amod, nsubj, parata..."
1,"[[, This, killing, of, a, respected, cleric, w...","[-LRB-, DT, NN, IN, DT, JJ, NN, MD, VB, VBG, P...","[10, 3, 10, 7, 7, 7, 3, 10, 10, 0, 10, 10, 14,...","[punct, det, nsubj, case, det, amod, nmod, aux..."
2,"[DPA, :, Iraqi, authorities, announced, that, ...","[NNP, :, JJ, NNS, VBD, IN, PRP, VBD, VBN, RP, ...","[0, 1, 4, 5, 1, 9, 9, 9, 5, 9, 13, 13, 9, 13, ...","[root, punct, amod, nsubj, parataxis, mark, ns..."
3,"[Two, of, them, were, being, run, by, 2, offic...","[CD, IN, PRP, VBD, VBG, VBN, IN, CD, NNS, IN, ...","[6, 3, 1, 6, 6, 0, 9, 9, 6, 12, 12, 9, 15, 15,...","[nsubj:pass, case, nmod, aux, aux:pass, root, ..."
4,"[The, MoI, in, Iraq, is, equivalent, to, the, ...","[DT, NNP, IN, NNP, VBZ, JJ, IN, DT, NNP, NNP, ...","[2, 6, 4, 2, 6, 0, 10, 10, 10, 6, 6, 17, 15, 1...","[det, nsubj, case, nmod, cop, root, case, det,..."


In [4]:
from data_structures import Vocab, TwoStack, RelationGraph


In [6]:
word_vocab = Vocab(min_freq=10).build(df['sentence'].tolist())
pos_vocab = Vocab(min_freq=1).build(df['pos'].tolist())
deprel_labels = Vocab(min_freq=1).build(df['deprel'].tolist())

In [71]:
import torch.nn as nn

emb = nn.Embedding(len(word_vocab), 2)
emb(torch.tensor([[1, 2, 3], [1, 1, 3]]))


tensor([[[ 0.0291,  0.3924],
         [-0.6424,  0.3659],
         [ 0.0710,  0.1849]],

        [[ 0.0291,  0.3924],
         [ 0.0291,  0.3924],
         [ 0.0710,  0.1849]]], grad_fn=<EmbeddingBackward0>)

In [24]:
def word_map(sentence):
    idx_to_word = {0: 'root'}
    for i, word in enumerate(sentence):
        idx_to_word[i + 1] = word
    return idx_to_word

# word_map(sentence)

In [103]:
import copy

In [5]:
# example from SLP3 textbook chapter 14

sentence = ['book', 'me', 'the', 'morning', 'flight']
# sentence = ['book', 'the', 'flight', 'through', 'houston']

head = [0, 1, 5, 5, 1]

deprel = [1, 2, 3, 4, 5]

target = ['shift', 'shift', 'right-arc', 'shift', 'shift', 'shift', 'left-arc', 'left-arc', 'right-arc', 'right-arc']

In [15]:
from utils import unravel_sentence
from data_structures import *
import copy

In [16]:
# # df.iloc[4]['sentence'], df.iloc[4]['head'], df.iloc[4]['deprel']
# df.iloc[0]['head']


In [17]:
def unravel_sentence(sentence, head, deprel):

    true_graph = RelationGraph(len(sentence))
    true_graph.set_true_labels(head, deprel)

    stack = TwoStack()
    word_list = Buffer(len(sentence))
    graph = RelationGraph(len(sentence))

    train_states = []
    train_targets = []

    # add root and first token
    stack.add(0)

    while stack:
        state = (copy.deepcopy(stack), copy.deepcopy(word_list), copy.deepcopy(graph))

        if len(stack) == 1:
            if len(word_list) > 0:
                stack.add(word_list.pop(0))
                target = 'shift'
            else:
                stack.pop1()
                target = 'done'
        else:
            # check for left-arc
            if true_graph.contains(stack.view1(), stack.view2()):
                graph.add_relation(stack.view1(), stack.view2())
                stack.pop2()
                target = 'left-arc'

            elif true_graph.contains(stack.view2(), stack.view1()) and \
                    all([graph.contains(stack.view1(), child) for child in true_graph.get_children(stack.view1())]):
                graph.add_relation(stack.view2(), stack.view1())
                stack.pop1()
                target = 'right-arc'

            else:
                stack.add(word_list.pop(0))
                target = 'shift'

        train_states.append(state)
        train_targets.append(target)

    return train_states, train_targets, graph, true_graph

train_states, train_targets, graph, true_graph = unravel_sentence(sentence, head, deprel)

In [20]:
graph.mat

array([[0., 1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 1., 0.]])

In [22]:
head

[0, 1, 5, 5, 1]

In [21]:
true_graph.mat

array([[0., 1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 1., 0.]])

In [24]:
import torch
np.where(graph.mat > 0)

(array([0, 1, 1, 5, 5]), array([1, 2, 5, 3, 4]))

In [27]:
torch.nonzero(torch.tensor(graph.mat))

tensor([[0, 1],
        [1, 2],
        [1, 5],
        [5, 3],
        [5, 4]])

In [43]:
mat = torch.tensor(graph.mat)
preds = torch.zeros(len(head))
for ix in range(len(preds)):
    nonzero = torch.where(mat[:, ix+1] != 0)[0]
    if len(nonzero) == 0:
        preds[ix] = -1
    else:
        preds[ix] = nonzero.item()
    # print(torch.where(mat[:, ix+1] > 0)[0].item())

In [46]:
preds == torch.tensor(head)

tensor([True, True, True, True, True])

In [19]:
word_vocab.transform(sentence)

[707, 1391, 1970, 1432, 1082]

In [17]:
train_targets

['shift',
 'shift',
 'right-arc',
 'shift',
 'shift',
 'shift',
 'left-arc',
 'left-arc',
 'right-arc',
 'right-arc',
 'done']

In [101]:

def unravel_sentence(sentence, head, deprel, verbose=False):

    true_graph = RelationGraph(len(sentence))
    true_graph.set_true_labels(head, deprel)

    stack = TwoStack()
    word_list = [x + 1 for x in range(len(sentence))]
    graph = RelationGraph(len(sentence))

    idx_to_word = word_map(sentence)

    train_states = []
    train_targets = []

    # add root and first token
    stack.add(0)

    while stack:
        state = (copy.deepcopy(stack), copy.copy(word_list), copy.deepcopy(graph))
        
        if verbose:
            print(stack, word_list, graph)

        if len(stack) == 1:
            if len(word_list) > 0:
                stack.add(word_list.pop(0))
                target = 'shift'
            else:
                stack.pop1()
                target = 'done'
        else:
            # check for left-arc
            if true_graph.contains(stack.view1(), stack.view2()):
                graph.add_relation(stack.view1(), stack.view2())
                stack.pop2()
                target = 'left-arc'

            elif true_graph.contains(stack.view2(), stack.view1()) and \
                    all([graph.contains(stack.view1(), child) for child in true_graph.get_children(stack.view1())]):
                graph.add_relation(stack.view2(), stack.view1())
                stack.pop1()
                target = 'right-arc'

            else:
                stack.add(word_list.pop(0))
                target = 'shift'

        train_states.append(state)
        train_targets.append(target)
    return train_states, train_targets, graph, true_graph


In [245]:
unravel_sentence(sentence, head, deprel)

[0] [1, 2, 3, 4, 5] {}
[0, 1] [2, 3, 4, 5] {}
[0, 1, 2] [3, 4, 5] {}
[0, 1] [3, 4, 5] {1->2,}
[0, 1, 3] [4, 5] {1->2,}
[0, 1, 3, 4] [5] {1->2,}
[0, 1, 3, 4, 5] [] {1->2,}
[0, 1, 3, 5] [] {1->2,5->4,}
[0, 1, 5] [] {1->2,5->3,5->4,}
[0, 1] [] {1->2,1->5,5->3,5->4,}
[0] [] {0->1,1->2,1->5,5->3,5->4,}


In [269]:
train_states, train_targets, graph, true_graph = unravel_sentence(df.iloc[4]['sentence'], df.iloc[4]['head'], df.iloc[4]['deprel'])

[0] [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36] {}
[0, 1] [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36] {}
[0, 1, 2] [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36] {}
[0, 2] [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36] {2->1,}
[0, 2, 3] [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36] {2->1,}
[0, 2, 3, 4] [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36] {2->1,}
[0, 2, 4] [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36] {2

IndexError: pop from empty list

In [259]:
print(true_graph)

{0->1,1->2,1->3,1->4,1->7,1->29,6->5,7->6,7->8,7->18,8->9,8->10,8->11,8->12,8->13,8->15,15->14,18->16,18->17,18->21,21->19,21->20,21->23,21->24,21->28,23->22,28->25,28->26,28->27,}


In [270]:
df.shape

(12543, 3)

In [258]:
print(graph)

{0->1,1->2,1->3,1->4,1->7,1->29,6->5,7->6,7->8,7->18,8->9,8->10,8->11,8->12,8->13,8->15,15->14,18->16,18->17,18->21,21->19,21->20,21->23,21->24,21->28,23->22,28->25,28->26,28->27,}


In [246]:
def view_states(sentence, states, targets):
    idx_to_word = word_map(sentence)
    for state, target in zip(states, targets):
        stack, buffer, graph = state
        stack_str = [idx_to_word[x] for x in stack.stack]
        buffer_str = [idx_to_word[x] for x in buffer]
        
        print(stack_str, buffer_str, target)


In [261]:
view_states(df.iloc[0]['sentence'], train_states, train_targets)

['root'] ['Al', '-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.'] shift
['root', 'Al'] ['-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.'] shift
['root', 'Al', '-'] ['Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.'] right-arc
['root', 'Al'] ['Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.'] shift
['root', 'Al', 'Zaman'] [':', 'American',

In [11]:
class Oracle:
    def __init__(self, context):
        self.context = context
    
        self.seq = ['shift', 'shift', 'shift', 'arc-right', 'arc-left', 'shift']
        self.idx = 0
    
    def query(self, stack):
        action = self.seq[self.idx]
        self.idx += 1
        return action


In [12]:
oracle = Oracle(None)


In [None]:
# # def apply_action
# #word_list
# #stack

# parse = [] # or a nxn adj matrix


# if action == 'shift':
#     token = word_list.pop()
#     stack.add(token)
    
# elif action == 'left-arc':
#     parent = stack.view1()
#     child = stack.pop2()

# elif action == 'right-arc':
#     parent = stack.view2()
#     child = stack.pop1()
    
#     # parent -> child
    
    

In [None]:
# class LabeledRelation:
    
#     if relation == 'left-arc'
#         ('left-arc', 'nobj')
#     if relation == 'shift':
#         ('shift', None)