In [1]:
import numpy as np
import pandas as pd


with open('en_ewt-ud-train.conllu') as f:
    lines = f.readlines()
    
with open('en_ewt-ud-dev.conllu') as f:
    lines_dev = f.readlines()


In [40]:
def read_syntax_tree(lines):
    sentences = []
    heads = []
    deprels = []
    
    count_word = False
    for l in lines:
        if l[:6] == '# text':
            count_word = True
            sentence = []
            head = []
            deprel = []
        elif count_word:
            if l[0] != '\n':
                cells = l.split('\t')
                
                # skip inferred words that are not part of the original sentence
                if '.1' in cells[0]:
                    continue
                    
                sentence.append(cells[1])
                head.append(int(cells[6]))
                deprel.append(cells[7])
            else:
                count_word = False
                sentences.append(sentence)
                heads.append(head)
                deprels.append(deprel)
        else:
            continue
    return pd.DataFrame({'sentence': sentences, 'head': heads, 'deprel': deprels})

df = read_syntax_tree(lines)
            

In [101]:
class TwoStack:
    def __init__(self):
        self.stack = ['root']
        
    def add(self, x):
        self.stack.append(x)
        
    def view1(self):
        if len(self.stack) == 0:
            return None
        return self.stack[-1]
    
    def view2(self):
        if len(self.stack) <= 1:
            return None
        return self.stack[-2]
    
    def pop1(self):
        if len(self.stack) == 0:
            return None
        return self.stack.pop(-1)
        
    def pop2(self):
        if len(self.stack) <= 1:
            return None
        return self.stack.pop(-2)

    def __bool__(self):
        return len(self.stack) > 0


class RelationGraph:
    def __init__(self, sentence, head, deprel):
        self.sentence = sentence
        self.head = head
        self.deprel = deprel
        
        # lookup for deprel entry to matrix value
        self.dep_to_idx = {
            dep: i for i, dep in enumerate(set(deprel))
        }
        # lookup for matrix value to deprel type
        self.idx_to_dep = {
            v: k for k, v in self.dep_to_idx.items()
        }

        n = len(sentence)
        self.mat = np.zeros((n + 1, n + 1))        # 0th entry is root
        for child, parent in enumerate(head):
            self.mat[parent, child + 1] = self.dep_to_idx[deprel[child]]     
    
    def contains(self, from_idx, to_idx):
        return self.mat[from_idx, to_idx] > 0
    
    def get_children(self, idx):
        return list(np.where(self.mat[idx] > 0)[0])
    
    def get_parent(self, idx):
        return np.where(self.mat[:, idx] > 0)[0][0]
    
    def get_word(self, idx):
        return self.sentence[idx]
    

In [81]:
rg = RelationGraph(df.iloc[2]['sentence'], df.iloc[2]['head'], df.iloc[2]['deprel'])

In [82]:
rg.dep_to_idx

{'acl': 0,
 'aux': 1,
 'ccomp': 2,
 'case': 3,
 'amod': 4,
 'mark': 5,
 'obj': 6,
 'root': 7,
 'obl': 8,
 'parataxis': 9,
 'nsubj': 10,
 'punct': 11,
 'nummod': 12,
 'compound:prt': 13}

In [83]:
df.iloc[2]['head']

[0, 1, 4, 5, 1, 9, 9, 9, 5, 9, 13, 13, 9, 13, 16, 14, 1]

In [97]:
rg.get_children(5)
rg.get_word(0)

'DPA'

In [None]:
stack = TwoStack()
word_list = df.iloc[0]['sentence']
graph = RelationGraph(df.iloc[0]['sentence'], df.iloc[0]['head'], df.iloc[0]['deprel'])

# state = (stack, word_list, graph)

# unravel sentence

# make train set

# sentence    [Al, -, Zaman, :, American, forces, killed, Sh...
# head        [0, 1, 1, 1, 6, 7, 1, 7, 8, 8, 8, 8, 8, 15, 8,...
# deprel      [root, punct, flat, punct, amod, nsubj, parata...
import copy

train_states = []
train_targets = []


# while stack:
stack.add(word_list.pop())

state = (copy.copy(stack), word_list, graph)


# check for left-arc
if graph.contains(stack.view1(), stack.view2()):
    train_targets.append('left-arc')
    
    
# check for right-arc
    
# else shift



In [8]:
tmp = df.iloc[0]
tmp

sentence    [Al, -, Zaman, :, American, forces, killed, Sh...
head        [0, 1, 1, 1, 6, 7, 1, 7, 8, 8, 8, 8, 8, 15, 8,...
deprel      [root, punct, flat, punct, amod, nsubj, parata...
Name: 0, dtype: object

In [10]:
# tmp['sentence']

In [11]:
class Oracle:
    def __init__(self, context):
        self.context = context
    
        self.seq = ['shift', 'shift', 'shift', 'arc-right', 'arc-left', 'shift']
        self.idx = 0
    
    def query(self, stack):
        action = self.seq[self.idx]
        self.idx += 1
        return action


In [12]:
oracle = Oracle(None)


In [None]:
# def apply_action
#word_list
#stack

parse = [] # or a nxn adj matrix


if action == 'shift':
    token = word_list.pop()
    stack.add(token)
    
elif action == 'left-arc':
    parent = stack.view1()
    child = stack.pop2()

elif action == 'right-arc':
    parent = stack.view2()
    child = stack.pop1()
    
    # parent -> child
    
    

In [None]:
class RunState:
    self.stack 
    
    self.list
    
    self.parse
    
    # def __init__

In [None]:
class LabeledRelation:
    
    if relation == 'left-arc'
        ('left-arc', 'nobj')
    if relation == 'shift':
        ('shift', None)