# Homework 5: Transition-Based Dependency Parser

**Due April 6, 2020 at 11:59pm**

In this homework, you will be implementing parts of a transition-based dependency parser.

**Before beginning, please switch your Colab session to a GPU runtime** 

Go to Runtime > Change runtime type > Hardware accelerator > GPU

## ALSO, REMEMBER TO UPLOAD THE DATASET!

Click the Files icon > Upload > Upload `train.projective.short.conll` and `dev.projective.conll` that you have downloaded from bCourses:Files/HW_5

### Setup

In [0]:
import sys
import re
import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim

In [86]:
# if this cell prints "Running on cpu", you must switch runtime environments
# go to Runtime > Change runtime type > Hardware accelerator > GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on {}".format(device))

Running on cuda


### Download pretrained word embeddings

In this assignment, we will still be using [GloVe](https://nlp.stanford.edu/projects/glove/) pretrained word embeddings.

**Note**: this section will take *several minutes*, since the embedding files are large. Files in Colab may be cached between sessions, so you may or may not need to redownload the files each time you reconnect. 

In [87]:
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip glove*.zip

--2020-04-05 18:43:58--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2020-04-05 18:43:59--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2020-04-05 18:43:59--  http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: ‘glove.6B.zip’


2020-0

### Question 1. Checking for Projectivity
In this question, you are supposed to implement the `is_projective` function below.
* A tree structure is said to be [projective](https://en.wikipedia.org/wiki/Discontinuity_(linguistics)) if there are no crossing dependency edges and/or projection lines. 
* The function should take a sentence as input and returns True if and only if the tree is projective.

In [0]:
def is_projective(toks):
    """
    params: toks is a list of (idd, tok, pos, head, lab) for a sentence
    return True if and only if the sentence has a projective dependency tree
    """

    # Implement your code below
    
    ##################
    graph = {}
    for i in range(len(toks)+1):
          graph[i] = []
    for elem in toks:
      graph.setdefault(elem[3],[]).append(elem[0])

    visited = []

    def dfs(visited, graph, node):
      if node not in visited:
        visited.append(node)
        for neighbour in graph[node]:
          dfs(visited, graph, neighbour)
    
    v = []
    for i in range(len(toks)+1):
      dfs(visited, graph, i)
      v.append(visited)
      visited = []

    for elem in toks:
      dep = elem[0]
      head = elem[3]
      if (head - dep > 1):
        nodes = np.arange(dep+1, head)  
        if (np.any([node not in v[head] for node in nodes])):
          return False
      if (dep - head > 1):
        nodes = np.arange(head+1, dep)  
        if (np.any([node not in v[head] for node in nodes])):
          return False
    return True
    ##################


In [3]:
def sanity_check_is_projective():
    """
    Sanity check for the function is_projective()
    """
    # "From the AP comes this story:" should be projective
    proj_toks = [(1, 'From', 'IN', 3, 'case'), 
                  (2, 'the', 'DT', 3, 'det'), 
                  (3, 'AP', 'NNP', 4, 'obl'), 
                  (4, 'comes', 'VBZ', 0, 'root'), 
                  (5, 'this', 'DT', 6, 'det'), 
                  (6, 'story', 'NN', 4, 'nsubj'), 
                  (7, ':', ':', 4, 'punct')]
    assert is_projective(proj_toks) == True
    
    # "I saw a man today who is tall" should not be projective
    non_proj_toks = [(1, 'I', 'PRP', 2, 'nsubj'), 
                      (2, 'saw', 'VBD', 0, 'root'), 
                      (3, 'a', 'DT', 4, 'det'), 
                      (4, "man", 'NN', 2, 'obj'), 
                      (5, 'today', 'NN', 2, 'nmod'), 
                      (6, 'who', 'WP', 8, 'nsubj'), 
                      (7, 'is', 'VBZ', 8, 'cop'), 
                      (8, 'tall', 'JJ', 4, 'acl:relcl')]
    assert is_projective(non_proj_toks) == False
    print("Congrats! You have passed the basic sanity check of is_projective().")
    
sanity_check_is_projective()    

Congrats! You have passed the basic sanity check of is_projective().


### Question 2.a.
Implement the first helper function `perform_shift` to achieve the SHIFT operation.
* The SHIFT Operation removes the word from the front of the input buffer and pushes it onto stack.

In [0]:
def perform_shift(wbuffer, stack, arcs,
                  configurations, gold_transitions):
    """
    perform the SHIFT operation
    """

    # Implement your code below
    # your code should:
    # 1. append the latest configuration to configurations
    # 2. append the latest action to gold_transitions
    # 3. update wbuffer, stack and arcs accordingly
    # hint: note that the order of operations matters
    # as we want to capture the configurations and transition rules
    # before making changes to the stack, wbuffer and arcs
    
    ##################
    w = wbuffer[:]
    s = stack[:]
    a = arcs[:]

    configurations.append((w,s,a))
    gold_transitions.append('SHIFT')
    n = wbuffer.pop()
    stack.append(n)
    ##################
    

In [5]:
def sanity_check_perform_shift():
    """
    Sanity check for the function perform_shift()
    """    
    # Before perform SHIFT
    wbuffer = [3, 2, 1]
    stack = [0]
    arcs = []
    configurations = []
    gold_transitions = []

    # Perform SHIFT
    perform_shift(wbuffer, stack, arcs, configurations, gold_transitions)

    # After perform SHIFT
    assert wbuffer == [3, 2], "The result for wbuffer is not correct"
    assert stack == [0, 1], "The result for stack is not correct"
    assert arcs == [], "The result for arcs is not correct"
    assert configurations == [([3, 2, 1], [0], [])], "The result for configurations is not correct"
    assert gold_transitions == ['SHIFT'], "The result for gold_transitions is not correct"
    print("Cool! You have passed the basic sanity check of perform_shift().")
    
sanity_check_perform_shift()    

Cool! You have passed the basic sanity check of perform_shift().


### Question 2.b.
Implement the second helper function `perform_arc` to achieve the ARC operation.

* LEFT-ARC (label): assert relation between head at $stack_1$ and dependent at $stack_2$: remove $stack_2$
* RIGHT-ARC (label): assert relation between head at $stack_2$ and dependent at $stack_1$; remove $stack_1$ 

In [0]:
def perform_arc(direction, dep_label,
                wbuffer, stack, arcs,
                configurations, gold_transitions):
    """
    params:
        - direction: {"LEFT", "RIGHT"}
        - dep_label: label for the dependency relations
    Perform LEFTARC_ and RIGHTARC_ operations
    """

    # Implement your code below
    # your code should:
    # 1. append the latest configuration to configurations
    # 2. append the latest action to gold_transitions
    # 3. update wbuffer, stack and arcs accordingly
    # hint: note that the order of operations matters
    # as we want to capture the configurations and transition rules
    # before making changes to the stack, wbuffer and arcs

    ##################
    w = wbuffer[:]
    s = stack[:]
    a = arcs[:]

    configurations.append((w,s,a))
    gold_transitions.append(direction+'ARC_'+dep_label)
    if direction == 'RIGHT':
      n = stack.pop()
      arcs.append((dep_label, stack[len(stack)-1],n))
    elif direction == 'LEFT':
      n = stack.pop(len(stack)-2)
      arcs.append((dep_label, stack[len(stack)-1], n))
    ##################


In [17]:
def sanity_check_perform_arc():
    """
    Sanity check for the function perform_arc()
    """
    # Before perform ARC
    direction = 'RIGHT'
    dep_label = 'punct'
    wbuffer = [5, 4, 3]
    stack = [0, 1, 2]
    arcs = []
    configurations = [([5, 4, 3, 2, 1], [0], []), 
                      ([5, 4, 3, 2], [0, 1], [])]
    gold_transitions = ['SHIFT', 'SHIFT']

    # Perform ARC
    perform_arc(direction, dep_label, wbuffer, stack, arcs, configurations, gold_transitions)

    # After perform ARC
    assert wbuffer == [5, 4, 3], "The result for wbuffer is not correct"
    assert stack == [0, 1], "The result for stack is not correct"
    assert arcs == [('punct', 1, 2)], "The result for arcs is not correct"
    assert configurations == [([5, 4, 3, 2, 1], [0], []), 
                              ([5, 4, 3, 2], [0, 1], []), 
                              ([5, 4, 3], [0, 1, 2], [])], \
            "The result for configurations is not correct"
    assert gold_transitions == ['SHIFT', 'SHIFT', 'RIGHTARC_punct'], "The result for gold_transitions is not correct"
    print("You have passed the basic sanity check of perform_arc().")

sanity_check_perform_arc()    

You have passed the basic sanity check of perform_arc().


### Question 2.c.
Now, since we have implemented the helper functions, let's use them to complete `tree_to_actions`.

`tree_to_actions` takes wbuffer, stack, arcs and deps as input, returns configuration of the parser and action for the parser.

In [0]:
def tree_to_actions(wbuffer, stack, arcs, deps):
    """
    params:
    wbuffer: a list of word indices; the top of buffer is at the end of the list
    stack: a list of word indices; the top of buffer is at the end of the list
    arcs: a list of (label, head, dependent) tuples

    Given wbuffer, stack, arcs and deps
    Return configurations and gold_transitions (actions)
    """

    # configurations:
    # A list of tuples of lists
    # [(wbuffer1, stack1, arcs1), (wbuffer2, stack2, arcs2), ...]
    # Keeps tracks of the states at each step
    configurations=[]

    # gold_transitions:
    # A list of action strings, e.g ["SHIFT", "LEFTARC_nsubj"]
    # Keeps tracks of the actions at each step
    gold_transitions=[]

    # Implement your code below
    # hint:
    # 1. configurations[i] and gold_transitions[i] should
    # correspond to the states of the wbuffer, stack, arcs
    # (before the action was taken) and action to take at step i
    # 2. you should call perform_shift and perform_arc in your code
    

    ##################
    configurations.append((wbuffer,stack,arcs))
    gold_transitions.append('SHIFT')
    perform_shift(wbuffer, stack, arcs,configurations, gold_transitions)

    while (len(stack)>1 or len(wbuffer)!=0):
      stack_1 = stack[len(stack)-1]
      stack_2 = stack[len(stack)-2]
      if (len(stack)>1 and stack_1 in deps.keys() and (stack_1,stack_2) in deps[stack_1]):
        direction = 'LEFT'
        dep_label = deps[stack_1][(stack_1,stack_2)]
        configurations.append((wbuffer,stack,arcs))
        gold_transitions.append(direction+'ARC_'+dep_label)
        perform_arc(direction, dep_label, wbuffer, stack, arcs, configurations, gold_transitions)
      elif (len(stack)>1 and stack_2 in deps.keys() and (stack_2,stack_1) in deps[stack_2] and all([(stack_1, w) not in deps.get(stack_1, []) for w in wbuffer])):
        direction = 'RIGHT'
        dep_label = deps[stack_2][(stack_2,stack_1)]
        configurations.append((wbuffer,stack,arcs))
        gold_transitions.append(direction+'ARC_'+dep_label)
        perform_arc(direction, dep_label, wbuffer, stack, arcs, configurations, gold_transitions)
      elif (len(wbuffer)!=0):
        configurations.append((wbuffer,stack,arcs))
        gold_transitions.append('SHIFT')
        perform_shift(wbuffer, stack, arcs,configurations, gold_transitions)
    return configurations, gold_transitions
    ##################
    

In [108]:
def sanity_check_tree_to_actions():
    """
    Sanity check for the function tree_to_actions()
    """
    # Before tree_to_actions 
    wbuffer = [9, 8, 7, 6, 5, 4, 3, 2, 1]
    stack = [0]
    arcs = []
    deps = {5: {(5, 9): 'punct', (5, 8): 'obl', (5, 4): 'advmod', (5, 3): 'aux:pass', (5, 2): 'nsubj:pass'},
            8: {(8, 7): 'det', (8, 6): 'case'}, 0: {(0, 5): 'root'}, 2: {(2, 1): 'nmod:poss'}}
    
    tree_to_actions(wbuffer, stack, arcs, deps)
    # After tree_to_actions
    assert wbuffer == [], "The result for wbuffer is not correct"
    assert stack == [0], "The result for stack is not correct"
    assert arcs == [('nmod:poss', 2, 1), ('advmod', 5, 4), ('aux:pass', 5, 3), ('nsubj:pass', 5, 2), 
                     ('det', 8, 7), ('case', 8, 6), ('obl', 5, 8), ('punct', 5, 9), ('root', 0, 5)], \
                        "The result for arcs is not correct"
    assert deps == {5: {(5, 9): 'punct', (5, 8): 'obl', (5, 4): 'advmod', (5, 3): 'aux:pass', (5, 2): 'nsubj:pass'}, 
                     8: {(8, 7): 'det', (8, 6): 'case'}, 0: {(0, 5): 'root'}, 2: {(2, 1): 'nmod:poss'}}, \
                    "The result for deps is not correct"
    print("You have passed the basic sanity check of tree_to_actions()! One more function to go.")   
     
sanity_check_tree_to_actions()    

You have passed the basic sanity check of tree_to_actions()! One more function to go.


### Question 3. Tree Parsing with Predictions
Implement action_to_tree, which will update the dependency tree based on the action predictions.
* Don't forget to use `isvalid` to check the validity of the possible actions!

In [0]:
def isvalid(stack, wbuffer, action):
    """
    Helper function that returns True only if an action is
    legal given the current states of the stack and wbuffer
    """
    if action == "SHIFT" and len(wbuffer) > 0:
        return True
    if action.startswith("RIGHTARC") and len(stack) > 1 and stack[-1] != 0:
        return True
    if action.startswith("LEFTARC") and len(stack) > 1 and stack[-2] != 0:
        return True

    return False

In [0]:
def action_to_tree(tree, predictions, wbuffer, stack, arcs, reverse_labels):
    """
    params:
    tree:
    a dictionary of dependency relations (head, dep_label)
        {
            child1: (head1, dep_lebel1),
            child2: (head2, dep_label2), ...
        }

    predictions:
    a numpy column vector of probabilities for different dependency labels
    as ordered by the variable reverse_labels
    predictions.shape = (1, total number of dependency labels)

    wbuffer: a list of word indices; top of buffer is at the end of the list
    stack: a list of word indices; top of stack is at the end of the list
    arcs: a list of (label, head, dependent) tuples

    """

    # Implement your code below
    # hint:
    # 1. the predictions contains the probability distribution for all
    # possible actions for a single step, and you should choose one
    # and update the tree only once
    # 2. some actions predicted are not going to be valid
    # (e.g., shifting if nothing is on the buffer)
    # so sort probs and keep going until you find one that is valid.
    
    ##################
    idxs = (-predictions).argsort()[:3]
    for idx in idxs[0]:
      if (isvalid(stack, wbuffer, reverse_labels[idx])):
        if (reverse_labels[idx] == "SHIFT"):
          n = wbuffer.pop()
          stack.append(n)
          break
        if (reverse_labels[idx].startswith("RIGHTARC")):
          n = stack.pop()
          arcs.append((reverse_labels[idx][9:], stack[len(stack)-1],n))
          tree[n] = (stack[len(stack)-1],reverse_labels[idx][9:])
          break
        elif (reverse_labels[idx].startswith("LEFTARC")):
          n = stack.pop(len(stack)-2)
          arcs.append((reverse_labels[idx][8:], stack[len(stack)-1], n))
          tree[n] = (stack[len(stack)-1],reverse_labels[idx][8:])
          break
    ##################
      

In [110]:
def sanity_check_action_to_tree():
    """
    Sanity check for the function action_to_tree()
    """    
    # Before action
    tree = {}
    predictions = np.array([[ 8.904456  ,  2.1306312 , -0.6716528 , -0.37662476, -0.01239625,-3.3660867 , -2.1345713 ,  1.4581618 , 
                             -0.1688145 , -0.61321   , 0.40860286, -2.7569351 , -0.69548404, -0.7809651 ,  0.7595304 ,-2.770731  , 
                             -0.97373027, -2.70085   , -0.26645675, -1.2353135 ,-1.4289687 , -1.3272284 , -2.4956157 , -1.0178847 , 
                             -1.7484616 , 1.7610879 ,  0.301237  , -0.71727145, -1.9370077 , -1.3722429 , 0.9516849 , -2.6749346 , 
                             -1.4604743 , -1.6903474 , -2.5261753 ,-0.88417345, -0.50328434, -0.21296862, -3.4296887 , -3.3282495 ,
                             -4.300956  , -2.12365   , -3.3637137 , -5.570282  , -3.8983932 ,-3.0985348 , -5.818429  , -1.5155774 , 
                             -3.4247532 , -2.7098398 ,-4.799152  , -4.020282  , -3.5505116 , -2.7114115 , -4.1488724 ,-4.7484784 , 
                             -4.0955606 , -2.994336  , -4.9744525 , -4.3390574 ,-2.782462  , -4.615161  , -4.6250424 , -4.4105268 , 
                             -4.856515  ,-3.5684056 , -4.6808653 , -4.882898  , -4.3673973 , -5.379696  ]])
    
    reverse_labels = ['SHIFT', 'RIGHTARC_punct', 'RIGHTARC_flat', 'LEFTARC_amod', 'LEFTARC_nsubj', 'LEFTARC_det', 'RIGHTARC_appos', 'RIGHTARC_obj', 'LEFTARC_case', 'RIGHTARC_nmod', 'RIGHTARC_obl', 'RIGHTARC_parataxis', 'RIGHTARC_root', 'LEFTARC_aux', 'LEFTARC_punct', 'RIGHTARC_iobj', 'LEFTARC_mark', 'RIGHTARC_acl', 'RIGHTARC_compound:prt', 'LEFTARC_nummod', 'RIGHTARC_ccomp', 'LEFTARC_aux:pass', 'LEFTARC_nsubj:pass', 'LEFTARC_compound', 'LEFTARC_nmod:poss', 'LEFTARC_cc', 'RIGHTARC_conj', 'LEFTARC_advmod', 'RIGHTARC_xcomp', 'LEFTARC_advcl', 'RIGHTARC_advmod', 'RIGHTARC_acl:relcl', 'RIGHTARC_advcl', 'LEFTARC_expl', 'RIGHTARC_nsubj', 'LEFTARC_obl', 'LEFTARC_cop', 'RIGHTARC_fixed', 'RIGHTARC_nummod', 'LEFTARC_det:predet', 'RIGHTARC_obl:npmod', 'RIGHTARC_obl:tmod', 'LEFTARC_obl:tmod', 'RIGHTARC_nmod:tmod', 'RIGHTARC_amod', 'LEFTARC_csubj', 'LEFTARC_csubj:pass', 'RIGHTARC_case', 'RIGHTARC_det', 'LEFTARC_obj', 'LEFTARC_nmod:tmod', 'LEFTARC_nmod', 'RIGHTARC_cop', 'RIGHTARC_expl', 'RIGHTARC_aux', 'RIGHTARC_vocative', 'RIGHTARC_csubj', 'LEFTARC_obl:npmod', 'RIGHTARC_nmod:npmod', 'RIGHTARC_list', 'LEFTARC_ccomp', 'LEFTARC_discourse', 'LEFTARC_parataxis', 'LEFTARC_xcomp', 'RIGHTARC_csubj:pass', 'LEFTARC_cc:preconj', 'RIGHTARC_flat:foreign', 'RIGHTARC_compound', 'LEFTARC_acl:relcl', 'RIGHTARC_discourse']
    wbuffer = [4,3,2,1]
    stack = [0]
    arcs = []

    # Perform action
    action_to_tree(tree, predictions, wbuffer, stack, arcs, reverse_labels)

    # After action (the action is SHIFT for this step)
    assert not tree, "The tree should be {} after the SHIFT"
    assert wbuffer == [4,3,2], "wbuffer should be [4,3,2] after the SHIFT"
    assert stack == [0, 1], "stack should be [0, 1] after the SHIFT"
    assert arcs == [], "arcs should be [] after the SHIFT"
    print("You have passed the basic sanity check of action_to_tree()!")
    
sanity_check_action_to_tree()    

You have passed the basic sanity check of action_to_tree()!


### Implemented for you
Now since you have the configuration $x$ and action $y$, we can now train a supervised model to predict an action $y$ given a configuration $x$. We are using a simplified version model of [A Fast and Accurate Dependency Parser using Neural Networks](https://nlp.stanford.edu/pubs/emnlp2014-depparser.pdf).

* This model is alreadly implemented for you, please `train` the model, and report the evaluation and test results by calling the function `evaluate` and `test`

In [0]:
# ============================================================
# THE FOLLOWING CODE IS PROVIDED
# ============================================================
def get_oracle(toks):
    """
    Return pairs of configurations + gold transitions (actions)
    from training data
    configuration = a list of tuple of:
        - buffer (top of buffer is at the end of the list)
        - stack (top of buffer is at the end of the list)
        - arcs (a list of (label, head, dependent) tuples)
    gold transitions = a list of actions, e.g. SHIFT
    """

    stack = [] # stack
    arcs = [] # existing list of arcs
    wbuffer = [] # input buffer

    # deps is a dictionary of head: dependency relations, where
    # dependency relations is a dictionary of the (head, child): label
    # deps = {head1:{
    #               (head1, child1):dependency_label1,
    #               (head1, child2):dependency_label2
    #              }
    #         head2:{
    #               (head2, child3):dependency_label3,
    #               (head2, child4):dependency_label4
    #              }
    #         }
    deps = {}

    # ROOT
    stack.append(0)

    # initialize variables
    for position in reversed(toks):
        (idd, _, _, head, lab) = position

        dep = (head, idd)
        if head not in deps:
            deps[head] = {}
        deps[head][dep] = lab

        wbuffer.append(idd)

    # configurations:
    # A list of (wbuffer, stack, arcs)
    # Keeps tracks of the states at each step
    # gold_transitions:
    # A list of action strings ["SHIFT", "LEFTARC_nsubj"]
    # Keeps tracks of the actions at each step
    configurations, gold_transitions = tree_to_actions(wbuffer, stack, arcs, deps)
    return configurations, gold_transitions

def featurize_configuration(configuration, tokens, postags, vocab, pos_vocab):

    def get_id(word, vocab):
        word=word.lower()
        if word in vocab:
            return vocab[word]
        return vocab["<unk>"]

    """
    Given configurations of the stack, input buffer and arcs,
    words of the sentence and POS tags of the words,
    return some features

    The current features are the word ID and postag ID at the 
    first three positions of the stack and buffer.
    """

    wbuffer, stack, arcs = configuration

    word_features=[]
    pos_features=[]

    if len(stack) > 0: 
        word_features.append(get_id(tokens[stack[-1]], vocab))
        pos_features.append(get_id(postags[stack[-1]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    if len(stack) > 1: 
        word_features.append(get_id(tokens[stack[-2]], vocab))
        pos_features.append(get_id(postags[stack[-2]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    if len(stack) > 2: 
        word_features.append(get_id(tokens[stack[-3]], vocab))
        pos_features.append(get_id(postags[stack[-3]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    if len(wbuffer) > 0: 
        word_features.append(get_id(tokens[wbuffer[-1]], vocab))
        pos_features.append(get_id(postags[wbuffer[-1]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))
       
    if len(wbuffer) > 1: 
        word_features.append(get_id(tokens[wbuffer[-2]], vocab))
        pos_features.append(get_id(postags[wbuffer[-2]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    if len(wbuffer) > 2: 
        word_features.append(get_id(tokens[wbuffer[-3]], vocab))
        pos_features.append(get_id(postags[wbuffer[-3]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    return word_features, pos_features


def get_oracles(filename, vocab, tag_vocab):
    """
    Get configurations, gold_transitions from all sentences
    """
    with open(filename) as f:
        toks, tokens, postags = [], {}, {}
        tokens[0] = "<ROOT>"
        postags[0] = "<ROOT>"

        # a list of all features for each transition step
        word_feats = []
        pos_feats = []
        # a list of labels, e.g. SHIFT, LEFTARC_DEP_LABEL, RIGHTARC_DEP_LABEL
        labels = []

        for line in f:
            cols = line.rstrip().split("\t")
            
            if len(cols) < 2: # at the end of each sentence
                if len(toks) > 0:
                    if is_projective(toks): # only use projective trees
                        # get all configurations and gold standard transitions
                        configurations, gold_transitions = get_oracle(toks)
                        
                        for i in range(len(configurations)):
                            word_feat, pos_feat = featurize_configuration(configurations[i], tokens, postags, vocab, tag_vocab)
                            label = gold_transitions[i]
                            word_feats.append(word_feat)
                            pos_feats.append(pos_feat)
                            labels.append(label)

                    # reset vars for the next sentence
                    toks, tokens, postags = [], {}, {}
                    tokens[0] = "<ROOT>"
                    postags[0] = "<ROOT>"
                    
                continue

            if cols[0].startswith("#"):
                continue

            # construct the tuple for each word in the sentence
            # for each word in the sentence
            # idd: index of a word in a sentence, starting from 1
            # tok: the word itself
            # pos: pos tag for that word
            # head: parent of the dependency
            # lab: dependency relation label
            idd, tok, pos, head, lab = int(cols[0]), cols[1], cols[4], int(cols[6]), cols[7]
            toks.append((idd, tok, pos, head, lab))

            # feature for training to predict the gold transition
            tokens[idd], postags[idd] = tok, pos

        return word_feats, pos_feats, labels

def load_embeddings(filename):
    # 0 idx is for padding
    # 1 idx is for <UNK>
    # 2 idx is for <NONE>
    # 3 idx is for <ROOT>

    # get the embedding size from the first embedding
    vocab_size=4
    with open(filename, encoding="utf-8") as file:
        for idx, line in enumerate(file):
            if idx == 0:
                word_embedding_dim=len(line.rstrip().split(" "))-1
            vocab_size+=1
        

    vocab={"<pad>":0, "<unk>":1, "<none>":2, "<root>":3}
    print("word_embedding_dim: %s, vocab size: %s" % (word_embedding_dim, vocab_size))

    embeddings=np.zeros((vocab_size, word_embedding_dim))

    with open(filename, encoding="utf-8") as file:
        for idx,line in enumerate(file):

            if idx + 4 >= vocab_size:
                break

            cols=line.rstrip().split(" ")
            val=np.array(cols[1:])
            word=cols[0]
            embeddings[idx+4]=val
            vocab[word]=idx+4

    return torch.FloatTensor(embeddings), vocab

class ShiftReduceParser(nn.Module):

    def __init__(self, embeddings, hidden_dim, tagset_size, num_pos_tags, pos_embedding_dim):
        super(ShiftReduceParser, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_labels=tagset_size

        _, embedding_dim = embeddings.shape

        self.input_size=embedding_dim*6 + pos_embedding_dim*6
        
        self.dropout_layer = nn.Dropout(p=0.25)

        self.word_embeddings = nn.Embedding.from_pretrained(embeddings)
        self.pos_embeddings = nn.Embedding(num_pos_tags, pos_embedding_dim)
        self.tanh = nn.Tanh()
        self.W1 = nn.Linear(self.input_size, self.hidden_dim)
        self.W2 = nn.Linear(self.hidden_dim, self.num_labels)

    def forward(self, words, pos_tags, Y=None):
        
        words=words.to(device)
        pos_tags=pos_tags.to(device)

        if Y is not None:
            Y=Y.to(device)

        word_embeds = self.word_embeddings(words)
        postag_embeds = self.pos_embeddings(pos_tags)

        embeds=torch.cat((word_embeds, postag_embeds), 2)

        embeds=embeds.view(-1, self.input_size)

        embeds=self.dropout_layer(embeds)

        hidden = self.W1(embeds)
        hidden = self.tanh(hidden)
        logits = self.W2(hidden)

        if Y is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), Y.view(-1))
            return loss
        else:
            return logits

def get_batches(W, P, Y, batch_size):
    batch_W=[]
    batch_P=[]
    batch_Y=[]

    i=0
    while i < len(W):
        batch_W.append(torch.LongTensor(W[i:i+batch_size]))
        batch_P.append(torch.LongTensor(P[i:i+batch_size]))
        batch_Y.append(torch.LongTensor(Y[i:i+batch_size]))
        i+=batch_size  

    return batch_W, batch_P, batch_Y

def train(word_feats, pos_feats, labels, embeddings, vocab, postag_vocab, label_vocab):
    """

    Train transition-based parser to predict next action (labels)
    given current configuration (featurized by word_feats and pos_feats)
    Return the classifier trained using Chen and Manning (2014), "A Fast 
    and Accurate Dependency Parser using Neural Networks"

    """

    # dimensionality of linear layer
    HIDDEN_DIM=100
    # dimensionality of POS embeddings
    POS_EMBEDDING_SIZE=50

    # batch size for training
    BATCH_SIZE=32

    # number of epochs to train for
    NUM_EPOCHS=10

    # learning rate for Adam optimizer
    LEARNING_RATE=0.001

    num_labels=[]
    for i, y in enumerate(labels):
        num_labels.append(label_vocab[y])

    batch_W, batch_P, batch_Y = get_batches(word_feats, pos_feats, num_labels, BATCH_SIZE)

    model = ShiftReduceParser(embeddings, HIDDEN_DIM, len(label_vocab), len(postag_vocab), POS_EMBEDDING_SIZE)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(NUM_EPOCHS):
        model.train()

        bigloss=0.
        for b in range(len(batch_W)):
            model.zero_grad()

            loss = model.forward(batch_W[b], batch_P[b], Y=batch_Y[b])
            bigloss+=loss.item()

            loss.backward()
            optimizer.step()

        print("loss: ", bigloss)


    return model


def parse(toks, model, vocab, tag_vocab, reverse_labels):
    """
    parse sentence with trained model and return correctness measure
    """
    tokens, postags = {}, {}
    tokens[0] = "<ROOT>"
    postags[0] = "<ROOT>"

    wbuffer, stack, arcs = [], [], []
    stack.append(0)

    for position in reversed(toks):

        (idd, tok, pos, head, lab) = position
        tokens[idd] = tok
        postags[idd] = pos

        # update buffer
        wbuffer.append(idd)

    tree = {}
    while len(wbuffer) >= 0:
        if len(wbuffer) == 0 and len(stack) == 0: break
        if len(wbuffer) == 0 and len(stack) == 1 and stack[0] == 0: break

        word_feats, pos_feats = (featurize_configuration((wbuffer, stack, arcs), tokens, postags, vocab, tag_vocab))

       
        predictions=model.forward(torch.LongTensor([word_feats]), torch.LongTensor([pos_feats]))

        predictions=predictions.detach().cpu().numpy()

        # your function will be called here
        action_to_tree(tree, predictions, wbuffer, stack, arcs, reverse_labels)

    return tree

def parse_and_evaluate(toks, model, vocab, tag_vocab, reverse_labels):
    """
    parse sentence with trained model and return correctness measure
    """

    heads, labels = {}, {}

    for position in reversed(toks):
        (idd, tok, pos, head, lab) = position

        # keep track of gold standards for performance evaluation
        heads[idd], labels[idd] = head, lab

    tree = parse(toks, model, vocab, tag_vocab, reverse_labels)

    # correct_unlabeled: total number of correct (head, child) dependencies
    # correct_labeled: total number of correctly *labeled* dependencies
    correct_unlabeled, correct_labeled, total = 0, 0, 0

    for child in tree:
        (head, label) = tree[child]
        if head == heads[child]:
            correct_unlabeled += 1
            if label == labels[child]: correct_labeled += 1
        total += 1

    return [correct_unlabeled, correct_labeled, total]

def get_label_vocab(labels):
    tag_vocab={}
    num_labels=[]
    for i, y in enumerate(labels):
        if y not in tag_vocab:
            tag_vocab[y]=len(tag_vocab)
        num_labels.append(tag_vocab[y])

    reverse_labels=[None]*len(tag_vocab)
    for y in tag_vocab:
        reverse_labels[tag_vocab[y]]=y

    return tag_vocab, reverse_labels


def get_pos_tag_vocab(filename):
    tag_vocab={"<none>":0, "<unk>":1}
    with open(filename) as file:
        for line in file:
            cols=line.rstrip().split("\t")
            if len(cols) < 3:
                continue
            pos=cols[4].lower()
            if pos not in tag_vocab:
                tag_vocab[pos]=len(tag_vocab)
    return tag_vocab

def test(model, vocab, tag_vocab, reverse_labels):
    """
    Evaluate the performance of a parser against gold standard
    """

    model.eval()

    toks=["I", "bought", "a", "book"]
    pos=["NNP", "VBD", "DT", "NN"]

    data=[]
    # put it in format parser expects
    for i, tok in enumerate(toks):
        data.append((i+1, tok, pos[i], "_", "_"))

    tree=parse(data, model, vocab, tag_vocab, reverse_labels)

    for child in sorted(tree.keys()):
        (head, label) = tree[child]
        headStr="<ROOT>"
        if head > 0: # child and head indexes start at 1; 0 denotes the <ROOT>
            headStr=toks[head-1]

        print("(%s %s) -> (%s %s) %s" % (child, toks[child-1], head, headStr, label))
  


def evaluate(filename, model, vocab, tag_vocab, reverse_labels):
    """
    Evaluate the performance of a parser against gold standard
    """

    model.eval()

    with open(filename) as f:
        toks=[]
        totals = np.zeros(3)
        for line in f:
            cols=line.rstrip().split("\t")

            if len(cols) < 2: # end of a sentence
                if len(toks) > 0:
                    if is_projective(toks):
                        tots = np.array(parse_and_evaluate(toks, model, vocab, tag_vocab, reverse_labels))
                        totals += tots
                        
                    toks = []
                continue

            if cols[0].startswith("#"):
                continue

            idd, tok, pos, head, lab = int(cols[0]), cols[1], cols[4], int(cols[6]), cols[7]
            toks.append((idd, tok, pos, head, lab))
        
        print ("UAS: %.3f, LAS:%.3f" % (totals[0]/totals[2], totals[1]/totals[2]))

### Train and evaluate the model

- NOTICE: Because you are not implementing the model or the training process, You will **NOT** be graded based on the performance of the model!

- You are only graded based on the correctness of each of the implemented functions.

- If all the required functions are implemented correctly, you should expect a UAS in a range of [0.64, 0.67], a LAS in a range of [0.56, 0.59] without changing the parameters of the neural model or the whole training process.

In [111]:
embeddingsFile = "glove.6B.50d.txt"
trainFile = "train.projective.short.conll"
devFile = "dev.projective.conll"

embeddings, vocab=load_embeddings(embeddingsFile)
pos_tag_vocab=get_pos_tag_vocab(trainFile)
word_feats, pos_feats, labels = get_oracles(trainFile, vocab, pos_tag_vocab)

label_vocab, reverse_labels=get_label_vocab(labels)

word_embedding_dim: 50, vocab size: 400004


In [112]:
model = train(word_feats, pos_feats, labels, embeddings, vocab, pos_tag_vocab, label_vocab)
evaluate(devFile, model, vocab, pos_tag_vocab, reverse_labels)
test(model, vocab, pos_tag_vocab, reverse_labels)

loss:  2097.9754111766815
loss:  1803.350136935711
loss:  1734.3115628361702
loss:  1701.0472575426102
loss:  1673.9843604564667
loss:  1657.0750904083252
loss:  1640.5775406360626
loss:  1628.0692977309227
loss:  1618.8904757499695
loss:  1612.0517721772194
UAS: 0.660, LAS:0.582
(1 I) -> (2 bought) nsubj
(2 bought) -> (0 <ROOT>) root
(3 a) -> (4 book) det
(4 book) -> (2 bought) obj
