In [1]:
import numpy as np
import pandas as pd


with open('en_ewt-ud-train.conllu') as f:
    lines = f.readlines()
    
with open('en_ewt-ud-dev.conllu') as f:
    lines_dev = f.readlines()


In [40]:
def read_syntax_tree(lines):
    sentences = []
    heads = []
    deprels = []
    
    count_word = False
    for l in lines:
        if l[:6] == '# text':
            count_word = True
            sentence = []
            head = []
            deprel = []
        elif count_word:
            if l[0] != '\n':
                cells = l.split('\t')
                
                # skip inferred words not part of the original sentence
                if '.1' in cells[0]:
                    continue
                    
                sentence.append(cells[1])
                head.append(int(cells[6]))
                deprel.append(cells[7])
            else:
                count_word = False
                sentences.append(sentence)
                heads.append(head)
                deprels.append(deprel)
        else:
            continue
    return pd.DataFrame({'sentence': sentences, 'head': heads, 'deprel': deprels})

df = read_syntax_tree(lines)
            

In [206]:
class TwoStack:
    def __init__(self):
        self.stack = []
        
    def add(self, x):
        self.stack.append(x)
        
    def view1(self):
        if len(self.stack) == 0:
            return None
        return self.stack[-1]
    
    def view2(self):
        if len(self.stack) <= 1:
            return None
        return self.stack[-2]
    
    def pop1(self):
        if len(self.stack) == 0:
            return None
        return self.stack.pop(-1)
        
    def pop2(self):
        if len(self.stack) <= 1:
            return None
        return self.stack.pop(-2)

    def __bool__(self):
        return len(self.stack) > 0

    def __len__(self):
        return len(self.stack)
    
    def __str__(self):
        return self.stack.__str__()


class RelationGraph:
    def __init__(self, n):
        self.mat = np.zeros((n + 1, n + 1))
        self.dep_to_idx = {}
        self.idx_to_dep = {}
        self.true_graph = False
        
    def set_true_labels(self, head, deprel):
        # this is temorary: eventually have a universal lookup
        # from training set for consistency

        # lookup for deprel entry to matrix value
        self.dep_to_idx = {
            dep: i for i, dep in enumerate(set(deprel))
        }
        # lookup for matrix value to deprel type
        self.idx_to_dep = {
            v: k for k, v in self.dep_to_idx.items()
        }

        for child, parent in enumerate(head):
            # self.mat[parent, child + 1] = self.dep_to_idx[deprel[child]]
            self.mat[parent, child + 1] = 1
            
        self.true_graph = True
    
    def add_relation(self, from_idx, to_idx, deprel=1):
        if self.true_graph:
            raise Exception('Cannot modify true labeled graph')
        self.mat[from_idx, to_idx] = deprel
    
    def contains(self, from_idx, to_idx):
        return self.mat[from_idx, to_idx] > 0
    
    def get_children(self, idx):
        return list(np.where(self.mat[idx] > 0)[0])
    
    def get_parent(self, idx):
        return np.where(self.mat[:, idx] > 0)[0][0]
    
    def __str__(self):
        out = '{'
        for a, b in zip(*np.where(self.mat > 0)):
            out += f'{a}->{b},'
        out += '}'
        return out


In [208]:
def word_map(sentence):
    idx_to_word = {0: 'root'}
    for i, word in enumerate(sentence):
        idx_to_word[i + 1] = word
    return idx_to_word

word_map(sentence)

{0: 'root', 1: 'book', 2: 'me', 3: 'the', 4: 'morning', 5: 'flight'}

In [209]:
import copy

In [210]:
# example from SLP3 textbook chapter 14

sentence = ['book', 'me', 'the', 'morning', 'flight']
# sentence = ['book', 'the', 'flight', 'through', 'houston']

head = [0, 1, 5, 5, 1]

deprel = [1, 1, 1, 1, 1]

target = ['shift', 'shift', 'right-arc', 'shift', 'shift', 'shift', 'left-arc', 'left-arc', 'right-arc', 'right-arc']

In [244]:

def unravel_sentence(sentence, head, deprel):

    true_graph = RelationGraph(len(sentence))
    true_graph.set_true_labels(head, deprel)

    stack = TwoStack()
    word_list = [x + 1 for x in range(len(sentence))]
    graph = RelationGraph(len(sentence))

    idx_to_word = word_map(sentence)

    train_states = []
    train_targets = []

    # add root and first token
    stack.add(0)

    while stack:
        state = (copy.deepcopy(stack), copy.copy(word_list), copy.deepcopy(graph))

        print(stack, word_list, graph)

        if len(stack) == 1:
            if len(word_list) > 0:
                stack.add(word_list.pop(0))
                target = 'shift'
            else:
                stack.pop1()
                target = 'done'
        else:
            # check for left-arc
            if true_graph.contains(stack.view1(), stack.view2()):
                graph.add_relation(stack.view1(), stack.view2())
                stack.pop2()
                target = 'left-arc'

            elif true_graph.contains(stack.view2(), stack.view1()) and \
                    all([graph.contains(stack.view1(), child) for child in true_graph.get_children(stack.view1())]):
                graph.add_relation(stack.view2(), stack.view1())
                stack.pop1()
                target = 'right-arc'

            else:
                stack.add(word_list.pop(0))
                target = 'shift'

        train_states.append(state)
        train_targets.append(target)


In [245]:
unravel_sentence(sentence, head, deprel)

[0] [1, 2, 3, 4, 5] {}
[0, 1] [2, 3, 4, 5] {}
[0, 1, 2] [3, 4, 5] {}
[0, 1] [3, 4, 5] {1->2,}
[0, 1, 3] [4, 5] {1->2,}
[0, 1, 3, 4] [5] {1->2,}
[0, 1, 3, 4, 5] [] {1->2,}
[0, 1, 3, 5] [] {1->2,5->4,}
[0, 1, 5] [] {1->2,5->3,5->4,}
[0, 1] [] {1->2,1->5,5->3,5->4,}
[0] [] {0->1,1->2,1->5,5->3,5->4,}


In [246]:
def view_states(sentence, states, targets):
    idx_to_word = word_map(sentence)
    for state, target in zip(states, targets):
        stack, buffer, graph = state
        stack_str = [idx_to_word[x] for x in stack.stack]
        buffer_str = [idx_to_word[x] for x in buffer]
        
        print(stack_str, buffer_str, target)


In [247]:
view_states(sentence, train_states, train_targets)

['root'] ['book', 'me', 'the', 'morning', 'flight'] shift
['root', 'book'] ['me', 'the', 'morning', 'flight'] shift
['root', 'book', 'me'] ['the', 'morning', 'flight'] right-arc
['root', 'book'] ['the', 'morning', 'flight'] shift
['root', 'book', 'the'] ['morning', 'flight'] shift
['root', 'book', 'the', 'morning'] ['flight'] shift
['root', 'book', 'the', 'morning', 'flight'] [] left-arc
['root', 'book', 'the', 'flight'] [] left-arc
['root', 'book', 'flight'] [] right-arc
['root', 'book'] [] right-arc
['root'] [] done


In [11]:
class Oracle:
    def __init__(self, context):
        self.context = context
    
        self.seq = ['shift', 'shift', 'shift', 'arc-right', 'arc-left', 'shift']
        self.idx = 0
    
    def query(self, stack):
        action = self.seq[self.idx]
        self.idx += 1
        return action


In [12]:
oracle = Oracle(None)


In [None]:
# # def apply_action
# #word_list
# #stack

# parse = [] # or a nxn adj matrix


# if action == 'shift':
#     token = word_list.pop()
#     stack.add(token)
    
# elif action == 'left-arc':
#     parent = stack.view1()
#     child = stack.pop2()

# elif action == 'right-arc':
#     parent = stack.view2()
#     child = stack.pop1()
    
#     # parent -> child
    
    

In [None]:
# class LabeledRelation:
    
#     if relation == 'left-arc'
#         ('left-arc', 'nobj')
#     if relation == 'shift':
#         ('shift', None)