In [3]:
import copy
import ast

import torch
import torch.nn as nn

import numpy as np
from datasets import load_dataset, get_dataset_config_names
from transformers import BertTokenizer, BertModel

import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [69]:
torch.random.manual_seed(34)

<torch._C.Generator at 0x28f7eebb530>

In [4]:
name = "universal_dependencies"
ud_config = get_dataset_config_names(name)
ud_ewt_train = load_dataset(name, 'en_ewt', split="train")
ud_ewt_dev = load_dataset(name, 'en_ewt', split="validation")
ud_ewt_test = load_dataset(name, 'en_ewt', split="test")

In [95]:
#ud_ewt_train[34] is good default
ud_ewt_train[34]['tokens']
ud_ewt_train[34]['upos'] #Part of speech tag as an integer index of ["NOUN","PUNCT","ADP","NUM","SYM","SCONJ","ADJ","PART","DET","CCONJ","PROPN","PRON","X","_","ADV","INTJ","VERB","AUX"]
upos_map = ["NOUN","PUNCT","ADP","NUM","SYM","SCONJ","ADJ","PART","DET","CCONJ","PROPN","PRON","X","_","ADV","INTJ","VERB","AUX"]
ud_ewt_train[34]['xpos'] #Other POS tag, might be more accurate?
ud_ewt_train[34]['deprel'] #arc labels by themselves. https://universaldependencies.org/docs/u/dep/index.html
ud_ewt_train[34]['deps'] #arc labels and the index they interact with. NOTE: index starts at one, ROOT is assumed
ud_ewt_train[34]

{'idx': 'weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0035',
 'text': 'The situation in Iraq is only going to get better this way.',
 'tokens': ['The',
  'situation',
  'in',
  'Iraq',
  'is',
  'only',
  'going',
  'to',
  'get',
  'better',
  'this',
  'way',
  '.'],
 'lemmas': ['the',
  'situation',
  'in',
  'Iraq',
  'be',
  'only',
  'go',
  'to',
  'get',
  'better',
  'this',
  'way',
  '.'],
 'upos': [8, 0, 2, 10, 17, 14, 16, 7, 16, 6, 8, 0, 1],
 'xpos': ['DT',
  'NN',
  'IN',
  'NNP',
  'VBZ',
  'RB',
  'VBG',
  'TO',
  'VB',
  'JJR',
  'DT',
  'NN',
  '.'],
 'feats': ["{'Definite': 'Def', 'PronType': 'Art'}",
  "{'Number': 'Sing'}",
  'None',
  "{'Number': 'Sing'}",
  "{'Mood': 'Ind', 'Number': 'Sing', 'Person': '3', 'Tense': 'Pres', 'VerbForm': 'Fin'}",
  'None',
  "{'Tense': 'Pres', 'VerbForm': 'Part'}",
  'None',
  "{'VerbForm': 'Inf'}",
  "{'Degree': 'Cmp'}",
  "{'Number': 'Sing', 'PronType': 'Dem'}",
  "{'Number': 'Sing'}",
  'None'],
 'head': ['2', '

In [156]:
#test sentence from exmaple in chen and manning
test_sent = {'text': 'He has really good control.',
             'tokens': ['He', 'has', 'good', 'control', '.'],
             'upos': [11, 16, 6, 0, 1], #probably need to use xpos,
             'xpos': ['PRP', 'VBZ', 'JJ', 'NN', '.'],
             'deprel': ['nsubj', 'root', 'amod', 'dobj', 'punct'],
             'deps': ["[('nsubj', 2)]", "[('root', 0)]", "[('amod', 4)]", "[('dobj', 2)]", "[('punct', 2)]"] #these are strings for some reason
             }

In [241]:
#LEFT-ARC(l): adds an arc s1 → s2 with label l and removes s2 from the stack. Precondition: |s| ≥ 2. from Chen and Manning
def left_arc(stack, buffer, arcs, dep, print_output=True):
    if len(stack) < 2:
        print("[@] LEFT-ARC called incorrectly, check stack size")
        # return
    elif stack[-2] == "[ROOT]":
        print("[@] LEFT-ARC called incorrectly, tried to add depepndency to ROOT")
    else:
        arcs.append((dep, (stack[-1],stack.pop(-2))))
    if print_output:
        print("Stack: ", stack, end=" | ")
        print("Buffer: ", buffer, end=" | ")
        print("Arcs:", arcs)

In [242]:
#RIGHT-ARC(l): adds an arc s2 → s1 with label l and removes s1 from the stack. Precondition: |s| ≥ 2. from Chen and Manning
def right_arc(stack, buffer, arcs, dep, print_output=True):
    if len(stack) < 2:
        print("[@] RIGHT-ARC called incorrectly, check stack size")
        # return
    else:
        arcs.append((dep, (stack[-2],stack.pop(-1))))
    if print_output:
        print("Stack: ", stack, end=" | ")
        print("Buffer: ", buffer, end=" | ")
        print("Arcs:", arcs)
    

In [243]:
# SHIFT: moves b1 from the buffer to the stack. Precondition: |b| ≥ 1. from Chen and Manning
def shift(stack, buffer, arcs, print_output=True):
    if len(buffer) < 1:
        print("[@] SHIFT called incorrectly, check buffer size")
        # return
    else:
        stack.append(buffer.pop(0))
    if print_output:
        print("Stack: ", stack, end=" | ")
        print("Buffer: ", buffer, end=" | ")
        print("Arcs:", arcs)

In [244]:
#Goal is to predict the correct transitions at each step aka predict the deps and timing of deps correctly
#Sample of "correct" transitions for "He has good control."
test_stack = ["[ROOT]"]
test_buffer = copy.deepcopy(test_sent['tokens'])
# buffer = [(test_sent['tokens'][i], test_sent['tokens'][i]) for i in range(len(test_sent['tokens']))] #include xpos in buffer in the future
test_arcs = []
shift(test_stack, test_buffer, test_arcs)
shift(test_stack, test_buffer, test_arcs)
left_arc(test_stack, test_buffer, test_arcs, 'nsubj')
shift(test_stack, test_buffer, test_arcs)
shift(test_stack, test_buffer, test_arcs)
left_arc(test_stack, test_buffer, test_arcs, 'amod')
right_arc(test_stack, test_buffer, test_arcs, 'dobj')
shift(test_stack, test_buffer, test_arcs)
right_arc(test_stack, test_buffer, test_arcs, 'punct')
right_arc(test_stack, test_buffer, test_arcs, 'root')

Stack:  ['[ROOT]', 'He'] | Buffer:  ['has', 'good', 'control', '.'] | Arcs: []
Stack:  ['[ROOT]', 'He', 'has'] | Buffer:  ['good', 'control', '.'] | Arcs: []
Stack:  ['[ROOT]', 'has'] | Buffer:  ['good', 'control', '.'] | Arcs: [('nsubj', ('has', 'He'))]
Stack:  ['[ROOT]', 'has', 'good'] | Buffer:  ['control', '.'] | Arcs: [('nsubj', ('has', 'He'))]
Stack:  ['[ROOT]', 'has', 'good', 'control'] | Buffer:  ['.'] | Arcs: [('nsubj', ('has', 'He'))]
Stack:  ['[ROOT]', 'has', 'control'] | Buffer:  ['.'] | Arcs: [('nsubj', ('has', 'He')), ('amod', ('control', 'good'))]
Stack:  ['[ROOT]', 'has'] | Buffer:  ['.'] | Arcs: [('nsubj', ('has', 'He')), ('amod', ('control', 'good')), ('dobj', ('has', 'control'))]
Stack:  ['[ROOT]', 'has', '.'] | Buffer:  [] | Arcs: [('nsubj', ('has', 'He')), ('amod', ('control', 'good')), ('dobj', ('has', 'control'))]
Stack:  ['[ROOT]', 'has'] | Buffer:  [] | Arcs: [('nsubj', ('has', 'He')), ('amod', ('control', 'good')), ('dobj', ('has', 'control')), ('punct', ('has

In [245]:
test_arcs

[('nsubj', ('has', 'He')),
 ('amod', ('control', 'good')),
 ('dobj', ('has', 'control')),
 ('punct', ('has', '.')),
 ('root', ('[ROOT]', 'has'))]

In [246]:
#compare to arcs in UD format
# ["[('nsubj', 2)]", "[('root', 0)]", "[('amod', 4)]", "[('dobj', 2)]", "[('punct', 2)]"]
valid_arcs = []
for i in range(len(test_sent['deps'])):
    curr_dep = ast.literal_eval(test_sent['deps'][i])[0]
    valid_arcs.append((curr_dep[0],(((["[ROOT]"] + test_sent['tokens'])[curr_dep[1]]), test_sent['tokens'][i])))
valid_arcs

[('nsubj', ('has', 'He')),
 ('root', ('[ROOT]', 'has')),
 ('amod', ('control', 'good')),
 ('dobj', ('has', 'control')),
 ('punct', ('has', '.'))]

In [247]:
test_sent2 = ud_ewt_train[34]
valid_arcs2 = []
for i in range(len(test_sent2['deps'])):
    curr_dep = ast.literal_eval(test_sent2['deps'][i])[0]
    valid_arcs2.append((test_sent2['deprel'][i],((((["[ROOT]"] + test_sent2['tokens'])[curr_dep[1]]), (["ROOT"] + test_sent2['xpos'])[curr_dep[1]]), (test_sent2['tokens'][i], test_sent2['xpos'][i])))) #use deprel for the arc name. ignoring enhanced depependencies for now
valid_arcs2

[('det', (('situation', 'NN'), ('The', 'DT'))),
 ('nsubj', (('going', 'VBG'), ('situation', 'NN'))),
 ('case', (('Iraq', 'NNP'), ('in', 'IN'))),
 ('nmod', (('situation', 'NN'), ('Iraq', 'NNP'))),
 ('aux', (('going', 'VBG'), ('is', 'VBZ'))),
 ('advmod', (('going', 'VBG'), ('only', 'RB'))),
 ('root', (('[ROOT]', 'ROOT'), ('going', 'VBG'))),
 ('mark', (('get', 'VB'), ('to', 'TO'))),
 ('xcomp', (('going', 'VBG'), ('get', 'VB'))),
 ('xcomp', (('get', 'VB'), ('better', 'JJR'))),
 ('det', (('way', 'NN'), ('this', 'DT'))),
 ('obj', (('better', 'JJR'), ('way', 'NN'))),
 ('punct', (('going', 'VBG'), ('.', '.')))]

In [248]:
all([i in valid_arcs for i in test_arcs])

True

In [265]:
#To generate training data, need to reverse the valid arcs into operations which produce them
#The "training oracle" as described in slp3 which determines which transition to do. return a list of transitions
#sentence is UD format sentence
#Stack/Buffer format: List of tuple with arc label and tuple of (word, POS) tuples in parent, child order
#[ ( arc-label, ( (parent-word, parent-POS), (child-word, child-POS) ) ), ... ]
def training_oracle(sentence):
    stack = [("[ROOT]", "ROOT")]
    # buffer = copy.deepcopy(sentence['tokens'])
    buffer = [(sentence['tokens'][i], sentence['xpos'][i]) for i in range(len(sentence['tokens']))]
    arcs = []
    transitions = [] #what we return
    labeled_arcs = []
    for i in range(len(sentence['deps'])):
        curr_dep = ast.literal_eval(sentence['deps'][i])[0]
        labeled_arcs.append((sentence['deprel'][i],((((["[ROOT]"] + sentence['tokens'])[curr_dep[1]]), (["ROOT"] + sentence['xpos'])[curr_dep[1]]), (sentence['tokens'][i], sentence['xpos'][i]))))
    labeled_arcs_copy = copy.deepcopy(labeled_arcs)
    unlabeled_arcs = [i[1] for i in labeled_arcs]
    # print(labeled_arcs_copy)
    for i in range(2*len(sentence['tokens'])): #2N transitions
        if len(stack) >= 2:
            # print([j[0] for j in unlabeled_arcs])
            if (stack[-1],stack[-2]) in unlabeled_arcs:
                arc_label = labeled_arcs.pop(unlabeled_arcs.index((stack[-1],stack[-2])))[0]
                transitions.append("left-arc " + arc_label)
                unlabeled_arcs.remove((stack[-1],stack[-2]))
                left_arc(stack, buffer, arcs, arc_label)
            elif (stack[-2],stack[-1]) in unlabeled_arcs and stack[-1] not in [j[0] for j in unlabeled_arcs]: #all of the dependents of the word at the top of the stack must already be assign before right arc,
                arc_label = labeled_arcs.pop(unlabeled_arcs.index((stack[-2],stack[-1])))[0]
                transitions.append("right-arc " + arc_label)
                unlabeled_arcs.remove((stack[-2],stack[-1]))
                right_arc(stack, buffer, arcs, arc_label)
        if len(transitions) <= i: #if neither arc transition has been done do shift
            transitions.append("shift")
            shift(stack, buffer, arcs)
    print("All generated arcs are in original deps list:", all([i in labeled_arcs_copy for i in arcs]))
    return transitions

In [266]:
training_oracle(test_sent)

Stack:  [('[ROOT]', 'ROOT'), ('He', 'PRP')] | Buffer:  [('has', 'VBZ'), ('good', 'JJ'), ('control', 'NN'), ('.', '.')] | Arcs: []
Stack:  [('[ROOT]', 'ROOT'), ('He', 'PRP'), ('has', 'VBZ')] | Buffer:  [('good', 'JJ'), ('control', 'NN'), ('.', '.')] | Arcs: []
Stack:  [('[ROOT]', 'ROOT'), ('has', 'VBZ')] | Buffer:  [('good', 'JJ'), ('control', 'NN'), ('.', '.')] | Arcs: [('nsubj', (('has', 'VBZ'), ('He', 'PRP')))]
Stack:  [('[ROOT]', 'ROOT'), ('has', 'VBZ'), ('good', 'JJ')] | Buffer:  [('control', 'NN'), ('.', '.')] | Arcs: [('nsubj', (('has', 'VBZ'), ('He', 'PRP')))]
Stack:  [('[ROOT]', 'ROOT'), ('has', 'VBZ'), ('good', 'JJ'), ('control', 'NN')] | Buffer:  [('.', '.')] | Arcs: [('nsubj', (('has', 'VBZ'), ('He', 'PRP')))]
Stack:  [('[ROOT]', 'ROOT'), ('has', 'VBZ'), ('control', 'NN')] | Buffer:  [('.', '.')] | Arcs: [('nsubj', (('has', 'VBZ'), ('He', 'PRP'))), ('amod', (('control', 'NN'), ('good', 'JJ')))]
Stack:  [('[ROOT]', 'ROOT'), ('has', 'VBZ')] | Buffer:  [('.', '.')] | Arcs: [('n

['shift',
 'shift',
 'left-arc nsubj',
 'shift',
 'shift',
 'left-arc amod',
 'right-arc dobj',
 'shift',
 'right-arc punct',
 'right-arc root']

In [267]:
training_oracle(ud_ewt_train[34])

Stack:  [('[ROOT]', 'ROOT'), ('The', 'DT')] | Buffer:  [('situation', 'NN'), ('in', 'IN'), ('Iraq', 'NNP'), ('is', 'VBZ'), ('only', 'RB'), ('going', 'VBG'), ('to', 'TO'), ('get', 'VB'), ('better', 'JJR'), ('this', 'DT'), ('way', 'NN'), ('.', '.')] | Arcs: []
Stack:  [('[ROOT]', 'ROOT'), ('The', 'DT'), ('situation', 'NN')] | Buffer:  [('in', 'IN'), ('Iraq', 'NNP'), ('is', 'VBZ'), ('only', 'RB'), ('going', 'VBG'), ('to', 'TO'), ('get', 'VB'), ('better', 'JJR'), ('this', 'DT'), ('way', 'NN'), ('.', '.')] | Arcs: []
Stack:  [('[ROOT]', 'ROOT'), ('situation', 'NN')] | Buffer:  [('in', 'IN'), ('Iraq', 'NNP'), ('is', 'VBZ'), ('only', 'RB'), ('going', 'VBG'), ('to', 'TO'), ('get', 'VB'), ('better', 'JJR'), ('this', 'DT'), ('way', 'NN'), ('.', '.')] | Arcs: [('det', (('situation', 'NN'), ('The', 'DT')))]
Stack:  [('[ROOT]', 'ROOT'), ('situation', 'NN'), ('in', 'IN')] | Buffer:  [('Iraq', 'NNP'), ('is', 'VBZ'), ('only', 'RB'), ('going', 'VBG'), ('to', 'TO'), ('get', 'VB'), ('better', 'JJR'), ('t

['shift',
 'shift',
 'left-arc det',
 'shift',
 'shift',
 'left-arc case',
 'right-arc nmod',
 'shift',
 'shift',
 'shift',
 'left-arc advmod',
 'left-arc aux',
 'left-arc nsubj',
 'shift',
 'shift',
 'left-arc mark',
 'shift',
 'shift',
 'shift',
 'left-arc det',
 'right-arc obj',
 'right-arc xcomp',
 'right-arc xcomp',
 'shift',
 'right-arc punct',
 'right-arc root']

In [268]:
#Featureization. We use sets of elements Sw St and Sl as described in 3.1 of Chen and Manning which can be combined to create features
def featurize_configuration(stack, buffer, arcs):
    S_w = {'s1' : 'NULL', 's2': 'NULL', 's3': 'NULL', 'b1': 'NULL', 'b2': 'NULL', 'b3': 'NULL',
           'lc1s1' : 'NULL', 'lc2s1' : 'NULL', 'lc1s2' : 'NULL', 'lc2s2' : 'NULL', 'rc1s1' : 'NULL', 'rc2s1' : 'NULL', 'rc1s2' : 'NULL', 'rc2s2' : 'NULL', #lc1s1 is leftmost child of s1, lc2s1 is second leftmost
           'lc1lc1s1': 'NULL', 'lc1lc1s2': 'NULL', 'rc1rc1s1': 'NULL', 'rc1rc1s2': 'NULL'} #lc1lc1 is leftmost child of leftmost children
    S_t = {'s1' : 'NULL', 's2': 'NULL', 's3': 'NULL', 'b1': 'NULL', 'b2': 'NULL', 'b3': 'NULL',
           'lc1s1' : 'NULL', 'lc2s1' : 'NULL', 'lc1s2' : 'NULL', 'lc2s2' : 'NULL', 'rc1s1' : 'NULL', 'rc2s1' : 'NULL', 'rc1s2' : 'NULL', 'rc2s2' : 'NULL',
           'lc1lc1s1': 'NULL', 'lc1lc1s2': 'NULL', 'rc1rc1s1': 'NULL', 'rc1rc1s2': 'NULL'}
    S_l = {'lc1s1' : 'NULL', 'lc2s1' : 'NULL', 'lc1s2' : 'NULL', 'lc2s2' : 'NULL', 'rc1s1' : 'NULL', 'rc2s1' : 'NULL', 'rc1s2' : 'NULL', 'rc2s2' : 'NULL',
           'lc1lc1s1': 'NULL', 'lc1lc1s2': 'NULL', 'rc1rc1s1': 'NULL', 'rc1rc1s2': 'NULL'}
    
    #s1 =============================================================
    if len(stack) >= 1:
       S_w['s1'] = stack[-1][0]
       S_t['s1'] = stack[-1][1]
       parent_list = [j[0] for j in [i[1] for i in arcs]]
       arcs_copy = copy.deepcopy(arcs)
       
       #first leftmost child
      #  print(parent_list)
       if stack[-1] in parent_list:
           #arcs are added sequentially, so the lowest index child is leftmost/highest index child is rightmost
           idx = parent_list.index(stack[-1])
           lc1_arc = arcs_copy[idx]
          #  print(lc1_arc)
           S_w['lc1s1'] = lc1_arc[1][1][0]
           S_t['lc1s1'] = lc1_arc[1][1][1]
           S_l['lc1s1'] = lc1_arc[0]
           parent_list.pop(idx)
           arcs_copy.pop(idx)
           #leftmost child of leftmost child
           if lc1_arc[1][1] in parent_list:
              idx = parent_list.index(lc1_arc[1][1])
              lc1lc1_arc = arcs_copy[idx]
              S_w['lc1lc1s1'] = lc1lc1_arc[1][1][0]
              S_t['lc1lc1s1'] = lc1lc1_arc[1][1][1]
              S_l['lc1lc1s1'] = lc1lc1_arc[0]
              parent_list.pop(idx)
              arcs_copy.pop(idx)
      #  print(S_w)
       #first rightmost child       
       if stack[-1] in parent_list:
           #arcs are added sequentially, so the lowest index child is leftmost/highest index child is rightmost
           idx = len(parent_list) - parent_list[-1::-1].index(stack[-1]) - 1 #index of last occurence list slicing nonsense
           rc1_arc = arcs_copy[idx] 
           S_w['rc1s1'] = rc1_arc[1][1][0]
           S_t['rc1s1'] = rc1_arc[1][1][1]
           S_l['rc1s1'] = rc1_arc[0]
           parent_list.pop(idx)
           arcs_copy.pop(idx)
          #  print(parent_list)
           #rightmost child of rightmost child
           if rc1_arc[1][1] in parent_list:
              idx = len(parent_list) - parent_list[-1::-1].index(rc1_arc[1][1]) - 1
              rc1rc1_arc = arcs_copy[idx]
              S_w['rc1rc1s1'] = rc1rc1_arc[1][1][0]
              S_t['rc1rc1s1'] = rc1rc1_arc[1][1][1]
              S_l['rc1rc1s1'] = rc1rc1_arc[0]
              parent_list.pop(idx)
              arcs_copy.pop(idx)
       
       #second leftmost child
       if stack[-1] in parent_list:
           #arcs are added sequentially, so the lowest index child is leftmost/highest index child is rightmost
           idx = parent_list.index(stack[-1])
           lc2_arc = arcs_copy[idx]
           S_w['lc2s1'] = lc2_arc[1][1][0]
           S_t['lc2s1'] = lc2_arc[1][1][1]
           S_l['lc2s1'] = lc2_arc[0]
           parent_list.pop(idx)
           arcs_copy.pop(idx)
       
       #second rightmost child       
       if stack[-1] in parent_list:
           #arcs are added sequentially, so the lowest index child is leftmost/highest index child is rightmost
           idx = len(parent_list) - parent_list[-1::-1].index(stack[-1]) - 1 #index of last occurence list slicing nonsense
           rc2_arc = arcs_copy[idx] 
           S_w['rc2s1'] = rc2_arc[1][1][0]
           S_t['rc2s1'] = rc2_arc[1][1][1]
           S_l['rc2s1'] = rc2_arc[0]
           parent_list.pop(idx)
           arcs_copy.pop(idx)
       
       #s2 =================================================================
       if len(stack) >= 2:       
         S_w['s2'] = stack[-2][0]
         S_t['s2'] = stack[-2][1]
           
         #first leftmost child
         if stack[-2] in parent_list:
           #arcs are added sequentially, so the lowest index child is leftmost/highest index child is rightmost
           idx = parent_list.index(stack[-2])
           lc1_arc = arcs_copy[idx]
       #     print(lc1_arc)
           S_w['lc1s2'] = lc1_arc[1][1][0]
           S_t['lc1s2'] = lc1_arc[1][1][1]
           S_l['lc1s2'] = lc1_arc[0]
           parent_list.pop(idx)
           arcs_copy.pop(idx)
           #leftmost child of leftmost child
           if lc1_arc[1][1] in parent_list:
              idx = parent_list.index(lc1_arc[1][1])
              lc1lc1_arc = arcs_copy[idx]
              S_w['lc1lc1s2'] = lc1lc1_arc[1][1][0]
              S_t['lc1lc1s2'] = lc1lc1_arc[1][1][1]
              S_l['lc1lc1s2'] = lc1lc1_arc[0]
              parent_list.pop(idx)
              arcs_copy.pop(idx)
       
         #first rightmost child       
         if stack[-2] in parent_list:
           #arcs are added sequentially, so the lowest index child is leftmost/highest index child is rightmost
           idx = len(parent_list) - parent_list[-1::-1].index(stack[-2]) - 1 #index of last occurence list slicing nonsense
           rc1_arc = arcs_copy[idx] 
           S_w['rc1s2'] = rc1_arc[1][1][0]
           S_t['rc1s2'] = rc1_arc[1][1][1]
           S_l['rc1s2'] = rc1_arc[0]
           parent_list.pop(idx)
           arcs_copy.pop(idx)
           #rightmost child of rightmost child
           if rc1_arc[1][1] in parent_list:
              idx = len(parent_list) - parent_list[-1::-1].index(rc1_arc[1][1]) - 1
              rc1rc1_arc = arcs_copy[idx]
              S_w['rc1rc1s2'] = rc1rc1_arc[1][1][0]
              S_t['rc1rc1s2'] = rc1rc1_arc[1][1][1]
              S_l['rc1rc1s2'] = rc1rc1_arc[0]
              parent_list.pop(idx)
              arcs_copy.pop(idx)
       
         #second leftmost child
         if stack[-2] in parent_list:
           #arcs are added sequentially, so the lowest index child is leftmost/highest index child is rightmost
           idx = parent_list.index(stack[-2])
           lc2_arc = arcs_copy[idx]
           S_w['lc2s2'] = lc2_arc[1][1][0]
           S_t['lc2s2'] = lc2_arc[1][1][1]
           S_l['lc2s2'] = lc2_arc[0]
           parent_list.pop(idx)
           arcs_copy.pop(idx)
       
         #second rightmost child       
         if stack[-2] in parent_list:
           #arcs are added sequentially, so the lowest index child is leftmost/highest index child is rightmost
           idx = len(parent_list) - parent_list[-1::-1].index(stack[-2]) - 1 #index of last occurence list slicing nonsense
           rc2_arc = arcs_copy[idx] 
           S_w['rc2s2'] = rc2_arc[1][1][0]
           S_t['rc2s2'] = rc2_arc[1][1][1]
           S_l['rc2s2'] = rc2_arc[0]
           parent_list.pop(idx)
           arcs_copy.pop(idx)
           
         if len(stack) >= 3:
           S_w['s3'] = stack[-3][0]
           S_t['s3'] = stack[-3][1]

    if len(buffer) >= 1:
       S_w['b1'] = buffer[-1][0]
       S_t['b1'] = buffer[-1][1]
       if len(buffer) >= 2:       
           S_w['b2'] = buffer[-2][0]
           S_t['b2'] = buffer[-2][1]
           if len(buffer) >= 3:
              S_w['b3'] = buffer[-3][0]
              S_t['b3'] = buffer[-3][1]
    
    return S_w, S_t, S_l

In [269]:
#Testing the feature list
test_stack = [("[ROOT]", "NULL")]
test_buffer = [(test_sent['tokens'][i], test_sent['xpos'][i]) for i in range(len(test_sent['tokens']))]
test_arcs = []
shift(test_stack, test_buffer, test_arcs, False)
shift(test_stack, test_buffer, test_arcs, False)
left_arc(test_stack, test_buffer, test_arcs, 'nsubj', False)
shift(test_stack, test_buffer, test_arcs, False)
shift(test_stack, test_buffer, test_arcs, False)
left_arc(test_stack, test_buffer, test_arcs, 'amod', False)
right_arc(test_stack, test_buffer, test_arcs, 'dobj', False)
shift(test_stack, test_buffer, test_arcs, False)
right_arc(test_stack, test_buffer, test_arcs, 'punct')
print(featurize_configuration(test_stack, test_buffer, test_arcs))
# right_arc(test_stack, test_buffer, test_arcs, 'root')

Stack:  [('[ROOT]', 'NULL'), ('has', 'VBZ')] | Buffer:  [] | Arcs: [('nsubj', (('has', 'VBZ'), ('He', 'PRP'))), ('amod', (('control', 'NN'), ('good', 'JJ'))), ('dobj', (('has', 'VBZ'), ('control', 'NN'))), ('punct', (('has', 'VBZ'), ('.', '.')))]
({'s1': 'has', 's2': '[ROOT]', 's3': 'NULL', 'b1': 'NULL', 'b2': 'NULL', 'b3': 'NULL', 'lc1s1': 'He', 'lc2s1': 'control', 'lc1s2': 'NULL', 'lc2s2': 'NULL', 'rc1s1': '.', 'rc2s1': 'NULL', 'rc1s2': 'NULL', 'rc2s2': 'NULL', 'lc1lc1s1': 'NULL', 'lc1lc1s2': 'NULL', 'rc1rc1s1': 'NULL', 'rc1rc1s2': 'NULL'}, {'s1': 'VBZ', 's2': 'NULL', 's3': 'NULL', 'b1': 'NULL', 'b2': 'NULL', 'b3': 'NULL', 'lc1s1': 'PRP', 'lc2s1': 'NN', 'lc1s2': 'NULL', 'lc2s2': 'NULL', 'rc1s1': '.', 'rc2s1': 'NULL', 'rc1s2': 'NULL', 'rc2s2': 'NULL', 'lc1lc1s1': 'NULL', 'lc1lc1s2': 'NULL', 'rc1rc1s1': 'NULL', 'rc1rc1s2': 'NULL'}, {'lc1s1': 'nsubj', 'lc2s1': 'dobj', 'lc1s2': 'NULL', 'lc2s2': 'NULL', 'rc1s1': 'punct', 'rc2s1': 'NULL', 'rc1s2': 'NULL', 'rc2s2': 'NULL', 'lc1lc1s1': 'NULL

In [271]:
test_stack3 = [('[ROOT]', 'ROOT'), ('going', 'VBG')] 
test_buffer3 = [('.', '.')]
test_arcs3 = [('det', (('situation', 'NN'), ('The', 'DT'))), ('case', (('Iraq', 'NNP'), ('in', 'IN'))), ('nmod', (('situation', 'NN'), ('Iraq', 'NNP'))), ('advmod', (('going', 'VBG'), ('only', 'RB'))), ('aux', (('going', 'VBG'), ('is', 'VBZ'))), ('nsubj', (('going', 'VBG'), ('situation', 'NN'))), ('mark', (('get', 'VB'), ('to', 'TO'))), ('det', (('way', 'NN'), ('this', 'DT'))), ('obj', (('better', 'JJR'), ('way', 'NN'))), ('xcomp', (('get', 'VB'), ('better', 'JJR'))), ('xcomp', (('going', 'VBG'), ('get', 'VB')))]
print(featurize_configuration(test_stack3, test_buffer3, test_arcs3))

({'s1': 'going', 's2': '[ROOT]', 's3': 'NULL', 'b1': '.', 'b2': 'NULL', 'b3': 'NULL', 'lc1s1': 'only', 'lc2s1': 'is', 'lc1s2': 'NULL', 'lc2s2': 'NULL', 'rc1s1': 'get', 'rc2s1': 'situation', 'rc1s2': 'NULL', 'rc2s2': 'NULL', 'lc1lc1s1': 'NULL', 'lc1lc1s2': 'NULL', 'rc1rc1s1': 'better', 'rc1rc1s2': 'NULL'}, {'s1': 'VBG', 's2': 'ROOT', 's3': 'NULL', 'b1': '.', 'b2': 'NULL', 'b3': 'NULL', 'lc1s1': 'RB', 'lc2s1': 'VBZ', 'lc1s2': 'NULL', 'lc2s2': 'NULL', 'rc1s1': 'VB', 'rc2s1': 'NN', 'rc1s2': 'NULL', 'rc2s2': 'NULL', 'lc1lc1s1': 'NULL', 'lc1lc1s2': 'NULL', 'rc1rc1s1': 'JJR', 'rc1rc1s2': 'NULL'}, {'lc1s1': 'advmod', 'lc2s1': 'aux', 'lc1s2': 'NULL', 'lc2s2': 'NULL', 'rc1s1': 'xcomp', 'rc2s1': 'nsubj', 'rc1s2': 'NULL', 'rc2s2': 'NULL', 'lc1lc1s1': 'NULL', 'lc1lc1s2': 'NULL', 'rc1rc1s1': 'xcomp', 'rc1rc1s2': 'NULL'})


In [162]:
#get frequency list of training vocab
train_vocab_en = {}
for i in ud_ewt_train:
    for j in i['tokens']:
        if j not in train_vocab_en:
            train_vocab_en[j] = 1
        else:
            train_vocab_en[j] += 1
train_vocab_en = dict(sorted(train_vocab_en.items(), key=lambda item: item[1], reverse=True))

In [216]:
#Use pretrained word embeddings from wikipedia2vec. I recognize that the vocabulary is needlessly large, but why not
file = open("enwiki_20180420_win10_100d.txt", "r", encoding='utf-8')
#nn.embedding uses indexes to map words, so there will be a dictionary of {word:index} and a dictionary of {index:embeddings}
#the _all dicts contain all of the embeddings as initialized, useful for referencing only the prtrained embedding values. tokens not in vocab won't be modified anyways
word_indexes_en = {}
word_indexes_all_en = {}
word_embeddings_en = {} 
word_embeddings_all_en = {}
#word_embeddings_en = nn.Embedding(num_embeddings=4530030, embedding_dim=100)
skip_first = True
has_next=True #use a while with boolean so I can continue through UnicodeDecodeErros and skip weird characters which arent in the vocab anyways
current_index_all = 0
current_index = 0
while has_next:
    if not skip_first:
        try:
            line = file.readline()
            if line != "":
                if not ("ENTITY/" in line and "_" in line): #We dont care about entities // We can try adding single token entities since most of them are proper nouns, can differnciate by case
                    current_line = line.split()
                    #print(current_line)
                    if current_line != []:
                        if "ENTITY/" in current_line[0]:
                            current_token = current_line[0][current_line[0].index("/")+1:]
                        else:
                            current_token = current_line[0]
                        #print(current_token)
                        current_embeds = []
                        if len(current_line) == 101: #some words in the pretrained embeddings have a duplication with a space in them (ex. 8_Flora vs 8 Flora, Channel_4 vs Channel 4). Ignore the ones with a space
                            for i in current_line[1:]:
                                    current_embeds.append(float(i))
                            word_indexes_all_en[current_token] = current_index_all
                            word_embeddings_all_en[current_index_all] = torch.FloatTensor(current_embeds)
                            if current_token in train_vocab_en:
                                word_indexes_en[current_token] = current_index
                                word_embeddings_en[current_index] = torch.FloatTensor(current_embeds)
                                current_index += 1
                            current_index_all += 1
            else:
                has_next=False
        except UnicodeDecodeError:
            continue
    else: #first line is skipped
        file.readline()
        skip_first = False
file.close()
current_index

13654

In [217]:
#why does this happen???
count = 0
for i in word_indexes_en:
    if count >= 11445 and count <= 11448:
        print(count, i, word_indexes_en[i])
    count += 1
print("")
magic_11447 = '2015'
print([i for i in word_indexes_en].index(magic_11447), magic_11447, word_indexes_en[magic_11447])

11445 friendlier 11445
11446 Confirmation 11446
11447 Butcher 11448
11448 chairpersons 11449

225 2015 11447


In [218]:
#randomly initialize non intialized embeddings
#uses current index and current index all from previous block
non_init_word_indexes = {}
non_init_word_embeds = {}
non_init_word_embeds_all = {}
train_vocab_en["NULL"] = 0 #add NULL token
for i in train_vocab_en:
    if i not in word_indexes_en:
        if i not in non_init_word_indexes:
            non_init_word_indexes[i] = current_index
            init_value = (-0.01 - 0.01) * torch.rand(100) + 0.01
            non_init_word_embeds[current_index] = init_value
            non_init_word_embeds_all[current_index_all] = init_value
            current_index += 1 #since its a dict, out indexes can be negative probably
            current_index_all += 1
word_embeddings_en.update(non_init_word_embeds)
word_embeddings_all_en.update(non_init_word_embeds_all)
current_index

20326

In [219]:
word_embeddings_en[word_indexes_en['the']]

tensor([-5.8400e-02,  1.3640e-01,  2.0190e-01, -1.0000e-04,  2.5780e-01,
         5.8900e-02, -3.3300e-01,  8.5000e-02,  4.0000e-04,  9.3100e-02,
         2.1940e-01,  2.1300e-01,  2.7970e-01,  1.6510e-01,  1.2840e-01,
         3.9200e-01,  2.2390e-01, -1.2000e-01,  8.8700e-02,  5.3600e-02,
         1.0650e-01,  5.2000e-03,  3.7510e-01, -4.2400e-02,  2.3600e-02,
         3.2150e-01,  2.0360e-01,  1.8630e-01, -4.3200e-02,  2.0160e-01,
         2.7880e-01,  1.1990e-01, -3.5000e-03, -4.0400e-02,  3.5560e-01,
         1.2460e-01, -1.3600e-01, -5.6900e-02,  6.0600e-02,  2.7280e-01,
         9.6900e-02,  4.6000e-02, -3.1650e-01,  4.2000e-02, -4.2800e-02,
         2.5300e-02,  4.3100e-02,  2.2600e-02, -8.8500e-02, -2.5040e-01,
        -7.8100e-02, -2.4600e-02, -1.8000e-02,  3.9900e-02, -1.1280e-01,
         2.4850e-01, -1.3300e-02,  8.5500e-02,  9.0000e-03,  1.9590e-01,
         2.5100e-02, -1.6200e-02,  2.3690e-01, -7.2000e-03, -1.3850e-01,
        -1.9370e-01,  1.1420e-01,  1.6320e-01, -8.0

In [239]:
word_embeds_tensor_en = torch.zeros(20326, 100)
for i in word_embeddings_en:
    word_embeds_tensor_en[i] = word_embeddings_en[i]
word_embeds_tensor_en.size()
#init_embeds_tensor = torch.stack(tuple([word_embeddings_en[i] for i in word_embeddings_en]))
#init_embeds_tensor = torch.tensor(dtype=torch.long)
#for i in range(len(word_embeddings_en)):
#    init_embeds_tensor.stack
#word_embeds_en = nn.Embedding(len(word_embeddings_en), 100)
#lookup_tensor = torch.tensor([word_indexes_en["the"]], dtype=torch.long)
#word_embeds_en(lookup_tensor)
word_embeds_tensor_en[0]

tensor([-5.8400e-02,  1.3640e-01,  2.0190e-01, -1.0000e-04,  2.5780e-01,
         5.8900e-02, -3.3300e-01,  8.5000e-02,  4.0000e-04,  9.3100e-02,
         2.1940e-01,  2.1300e-01,  2.7970e-01,  1.6510e-01,  1.2840e-01,
         3.9200e-01,  2.2390e-01, -1.2000e-01,  8.8700e-02,  5.3600e-02,
         1.0650e-01,  5.2000e-03,  3.7510e-01, -4.2400e-02,  2.3600e-02,
         3.2150e-01,  2.0360e-01,  1.8630e-01, -4.3200e-02,  2.0160e-01,
         2.7880e-01,  1.1990e-01, -3.5000e-03, -4.0400e-02,  3.5560e-01,
         1.2460e-01, -1.3600e-01, -5.6900e-02,  6.0600e-02,  2.7280e-01,
         9.6900e-02,  4.6000e-02, -3.1650e-01,  4.2000e-02, -4.2800e-02,
         2.5300e-02,  4.3100e-02,  2.2600e-02, -8.8500e-02, -2.5040e-01,
        -7.8100e-02, -2.4600e-02, -1.8000e-02,  3.9900e-02, -1.1280e-01,
         2.4850e-01, -1.3300e-02,  8.5500e-02,  9.0000e-03,  1.9590e-01,
         2.5100e-02, -1.6200e-02,  2.3690e-01, -7.2000e-03, -1.3850e-01,
        -1.9370e-01,  1.1420e-01,  1.6320e-01, -8.0

In [240]:
#get frequency list of training xpos
train_xpos_en = {}
for i in ud_ewt_train:
    for j in i['xpos']:
        if j not in train_xpos_en:
            train_xpos_en[j] = 1
        else:
            train_xpos_en[j] += 1
train_xpos_en = dict(sorted(train_xpos_en.items(), key=lambda item: item[1], reverse=True))
#train_xpos_en

In [241]:
#get frequency list of training upos
train_upos_en = {}
for i in ud_ewt_train:
    for j in i['upos']:
        if upos_map[j] not in train_upos_en:
            train_upos_en[upos_map[j]] = 1
        else:
            train_upos_en[upos_map[j]] += 1
train_upos_en = dict(sorted(train_upos_en.items(), key=lambda item: item[1], reverse=True))
#train_upos_en

In [242]:
#get frequency list of training deps
train_deps_en = {}
for i in ud_ewt_train:
    for j in i['deprel']:
        if j not in train_deps_en:
            train_deps_en[j] = 1
        else:
            train_deps_en[j] += 1
train_deps_en = dict(sorted(train_deps_en.items(), key=lambda item: item[1], reverse=True))
#train_deps_en

In [243]:
#randomly initialize pos and label embeddings (Chen and Manning 3.2) 
#xpos_tags is full list of arcs in conllu documentation. also has NULL appended
xpos_tags = ['CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', '$', ':', ',', '.', "``", '’’', '#', '-LRB-', '-RRB-', 'HYPH', 'NFP', 'SYM', 'PUNC', '_', 'NULL']
index = 0
xpos_indexes = {}
for j in set([i for i in train_xpos_en]) | set(xpos_tags): #combine with all xpos found in training for more complete list of xpos
    xpos_indexes[j] = index
    index += 1
xpos_embeddings_en = {}
for i in xpos_indexes:
    xpos_embeddings_en[xpos_indexes[i]] = (-0.01 - 0.01) * torch.rand(100) + 0.01
    
xpos_embeds_tensor_en = torch.zeros(len(xpos_indexes), 100)
for i in xpos_embeddings_en:
    xpos_embeds_tensor_en[i] = xpos_embeddings_en[i]
xpos_embeds_tensor_en.size()

torch.Size([56, 100])

In [244]:
#same as xpos but for arc labels
arc_labels = ['nsubj', 'nsubj:pass', 'nsubj:outer', 'obj', 'iobj', 'csubj', 'csubj:pass', 'csubj:outer', 'ccomp', 'xcomp', 'obl', 'obl:npmod', 'obl:tmod', 'advcl', 'advcl:relcl', 'advmod', 'vocative', 'discourse', 'expl', 'aux', 'aux:pass', 'cop', 'mark', 'nummod', 'appos', 'nmod', 'nmod:npmod', 'nmod:tmod', 'nmod:poss', 'acl', 'acl:relcl', 'amod', 'det', 'det:predet', 'compound', 'compound:prt', 'fixed', 'flat', 'flat:foreign', 'goeswith', 'conj', 'cc', 'cc:preconj', 'case', 'list', 'dislocated', 'parataxis', 'orphan', 'reparandum', 'root', 'punct', 'dep', 'NULL']
#arc_labels is generic list of arcs. also has NULL appended
index = 0
arc_indexes = {}
for j in set([i for i in train_deps_en]) | set(arc_labels):
    arc_indexes[j] = index
    index += 1
arc_embeddings_en = {}
for i in arc_indexes:
    arc_embeddings_en[arc_indexes[i]] = (-0.01 - 0.01) * torch.rand(100) + 0.01
    
arc_embeds_tensor_en = torch.zeros(len(arc_indexes), 100)
for i in arc_embeddings_en:
    arc_embeds_tensor_en[i] = arc_embeddings_en[i]
arc_embeds_tensor_en.size()

torch.Size([54, 100])