<a href="https://colab.research.google.com/github/dbamman/nlp23/blob/main/HW4/HW4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 4: Transition-Based Dependency Parser

**Due Tuesday March 14th, 2023 at 11:59pm**

In this homework, you will be implementing components of a transition-based dependency parser.

Before diving into any code, please read through the associated [PDF](https://github.com/dbamman/nlp23/blob/main/HW4/HW4.pdf) for an overview of the assignment and specific instructions on how to submit.

As in prior homeworks, please don't remove the `BEGIN / END SOLUTION` flags, and type your logic *entirely* within them (including any helper functions).

The graded deliverables for this assignment are found in Questions: [**1, 2.a, 2.b, 2.c, 3**]

The *optional* Bonus sections do not contain graded deliverables. They allow you to further explore how a transition-based parser can be configured and utilized for a real-world task.

## Question 1. Checking for Projectivity
In this question, you are supposed to implement the `is_projective` function below.
* A tree structure is said to be [projective](https://en.wikipedia.org/wiki/Discontinuity_(linguistics)) if there are no crossing dependency edges and/or projection lines. 
* The function should take a sentence as input and return `True` if and only if the tree is projective.

In [None]:
"""
objective: Determine if a list of 5-element tuples represents a sentence with 'projective' dependency tree

params: tokens is a list of (idd, tok, pos, head, lab) for a sentence
return True if and only if the sentence has a projective dependency tree
"""
# instructions
# there is more than one way to carry this out correctly. 
# We recommend the following procedure:
# Part 1
# Iterate over the tuples and store their relevant components a dictionary
# with the tuples' id ("idd") as the key and its head ("head") as the value
# this dictionary should thus contain dependency relations in the form {idd: head}
# Part 2
# For every dependency relation in that dictionary, extract the head
# establish the left and right bounds for your dependency-path search, with
# the left bound being: 1 + the smaller of (head, idd) and 
# the right bound, conversely, being the bigger of (head, idd)
# Part 3
# Check if every word in the tree is reachable by following the path of dependencies
# complete the reachable method

# hints
# -the left bound gets a 1 added to it because you don't need to check for crossing edges 
#   when you have a token that's directly besides its head
# - a token with '0' value for its head denotes that it is the root of the sentence
# -if any word in the tree is not reachable, this means there is non-projectivity
#   which is when your method needs to return False.
# -otherwise, return True!
# -you are welcome to remove any and all of our starter code and create your own logic for this
# if you do so, please *do not* remove the BEGIN / END Solution flags!
def is_projective(tokens):

    # BEGIN SOLUTION
    def reachable(i, head, heads): # populate this (Part 3)
        pass
    
    heads = {} # populate this (Part 1)

    for dep_to_check in heads: # populate this (Part 2)
        
        for i in range(left + 1, right):
            if not reachable(i, head, heads):
              return False
            
    return True
    # END SOLUTION

### Check `is_projective`

In [None]:
"""
objective: Sanity check the is_projective() function
"""
def sanity_check_is_projective():

    # setup for checking projectivity --
    # to test the method, input a projective sentence from the CONLL dataset
    # e.g. "From the AP comes this story:" should be projective
    proj_toks = [(1, 'From', 'IN', 3, 'case'), 
                  (2, 'the', 'DT', 3, 'det'), 
                  (3, 'AP', 'NNP', 4, 'obl'), 
                  (4, 'comes', 'VBZ', 0, 'root'), 
                  (5, 'this', 'DT', 6, 'det'), 
                  (6, 'story', 'NN', 4, 'nsubj'), 
                  (7, ':', ':', 4, 'punct')]
    
    assert is_projective(proj_toks) == True
    
    # test the method by inputting a non-projective sentence
    # "I saw a man today who is tall" should not be projective
    non_proj_toks = [(1, 'I', 'PRP', 2, 'nsubj'), 
                      (2, 'saw', 'VBD', 0, 'root'), 
                      (3, 'a', 'DT', 4, 'det'), 
                      (4, "man", 'NN', 2, 'obj'), 
                      (5, 'today', 'NN', 2, 'nmod'), 
                      (6, 'who', 'WP', 8, 'nsubj'), 
                      (7, 'is', 'VBZ', 8, 'cop'), 
                      (8, 'tall', 'JJ', 4, 'acl:relcl')]
    assert is_projective(non_proj_toks) == False
    print("You have cleared the sanity check for is_projective()!")
    
sanity_check_is_projective()    

## Question 2.a Shift Operation
Implement the first helper function `perform_shift` to achieve the SHIFT operation.
* The SHIFT Operation removes the word from the front of the input buffer and pushes it onto stack.

In [None]:
"""
objective: perform the SHIFT operation

params:
    - wbuffer: input buffer
    - stack: what we're buliding our parse on
    - arcs: our dependency parse for a single token (containing the dependency relation, head-token ID, its own token ID)
    - configurations: state of the parse at a given decision point
    - gold_transitions: the operations we're applying at each decision point
"""

# instructions
# 1. update configurations
# 2. update gold_transitions
# 3. remove word from front of buffer and push it onto stack
# hints
# -this can be completed in just a few lines.
# -we have provided the code for updating configurations, so you just need to
# update gold_transitions, then the stack and wbuffer accordingly
# -since this is a SHIFT operation, arcs does not come into play because no head-dependent relationships are being asserted
# -the .pop() operation removes an item at the given index from a list
# defaulting to the last element in the list when no index is provided

def perform_shift(wbuffer, stack, arcs, configurations, gold_transitions):    
    
    # BEGIN SOLUTION
    # update configurations (Part 1)
    configurations.append((list(wbuffer), list(stack), list(arcs)))

    # update gold_transitions (Part 2)

    # remove word from front of buffer and push it onto stack (Part 3)
    
    # END SOLUTION

### Check `perform_shift`

In [None]:
"""
objective: Sanity check the perform_shift() function
"""  

def sanity_check_perform_shift():

    # setup for performing SHIFT
    wbuffer = [3, 2, 1]
    stack = [0]
    arcs = []
    configurations = []
    gold_transitions = []

    # Perform SHIFT by invoking student-function
    perform_shift(wbuffer, stack, arcs, configurations, gold_transitions)

    # validate outputs after performing SHIFT
    assert wbuffer == [3, 2], "The result for wbuffer is not correct"
    assert stack == [0, 1], "The result for stack is not correct"
    assert arcs == [], "The result for arcs is not correct"
    assert configurations == [([3, 2, 1], [0], [])], "The result for configurations is not correct"
    assert gold_transitions == ['SHIFT'], "The result for gold_transitions is not correct"
    print("You cleared the sanity check for perform_shift().")
    
sanity_check_perform_shift()    

## Question 2.b ARC Operations
Implement the second helper function `perform_arc` to achieve the ARC operation.

* LEFT-ARC (label): assert relation between head at $stack_1$ and dependent at $stack_2$: remove $stack_2$
* RIGHT-ARC (label): assert relation between head at $stack_2$ and dependent at $stack_1$; remove $stack_1$ 

Your addition to `gold_transitions` should be formatted as follows (case-sensitive): DIRECTION+ARC+_+dependency_label

so if you are performing a right arc and `dep_label` is `punct` the output of this function is:

`RIGHTARC_punct`

In [None]:
"""
objective: perform LEFTARC_ and RIGHTARC_ operations

params:
    - direction: {"LEFT", "RIGHT"}
    - dep_label: label for the dependency relations
    - wbuffer: input buffer
    - stack: what we're buliding our parse on
    - arcs: our dependency parse for a single token (containing the dependency relation, head-token ID, its own token ID)
    - configurations: state of the parse at a given decision point
    - gold_transitions: the operations we're applying at each decision point
"""

# instructions
# 1. update configurations
# 2. update transitions
# 3. use the first 2 elements on the stack to create a (head, child) pairing, based on the type of shift we're doing
# 4. update arcs
# 5. update the stack
# hints
# -updating configurations is identical to the perform_shift function
# -be sure the information you're inserting into gold_transitions is formatted properly (see above)
# -be sure that the orientation of the (head,child) pairing is correct

def perform_arc(direction, dep_label, wbuffer, stack, arcs, configurations, gold_transitions):

    # BEGIN SOLUTION
    # update configurations (Part 1)
    
    # update transitions (Part 2)

    # setup head and child items (Part 3)

    # update arcs (Part 4)

    # update stack (Part 5)
    pass
    # END SOLUTION

### check `perform_arc`

In [None]:
"""
Sanity check for the function perform_arc()
"""
def sanity_check_perform_arc():

    # setup for perfomring perform ARC
    direction = 'RIGHT'
    dep_label = 'punct'
    wbuffer = [5, 4, 3]
    stack = [0, 1, 2]
    arcs = []
    configurations = [([5, 4, 3, 2, 1], [0], []), 
                      ([5, 4, 3, 2], [0, 1], [])]
    gold_transitions = ['SHIFT', 'SHIFT']

    # Perform ARC by invoking student-function 
    perform_arc(direction, dep_label, wbuffer, stack, arcs, configurations, gold_transitions)

    # Validate outputs after performng ARC
    assert wbuffer == [5, 4, 3], "The result for wbuffer is not correct"
    assert stack == [0, 1], "The result for stack is not correct"
    assert arcs == [('punct', 1, 2)], "The result for arcs is not correct"
    assert configurations == [([5, 4, 3, 2, 1], [0], []), 
                              ([5, 4, 3, 2], [0, 1], []), 
                              ([5, 4, 3], [0, 1, 2], [])], \
            "The result for configurations is not correct"
    assert gold_transitions == ['SHIFT', 'SHIFT', 'RIGHTARC_punct'], "The result for gold_transitions is not correct"
    print("You cleared the sanity check for perform_arc()!")

sanity_check_perform_arc()    

## Question 2.c Tree to Actions
Now, since we have implemented the helper functions, let's use them to complete `tree_to_actions`.

`tree_to_actions` takes `wbuffer`, `stack`, `arcs` and `deps` as input, then returns the configuration of the parser and action for the parser (`configurations` and `gold_transitions` respectively).

Parts of this method have been filled in for you -- your job is to fill in how the right arc transitions should be invoked.

In [None]:
"""
objective: based on inputs, return the correct configurations and actions for the parser.

params:
wbuffer: a list of word indices; the top of buffer is at the end of the list
stack: a list of word indices; the top of buffer is at the end of the list
arcs: a list of (label, head, dependent) tuples
deps: the existing dependency tree

Return configurations and gold_transitions
"""

# instructions
# Initialize return lists
# Check for simple cases (that we've translated the tree entirely, or that we need to get more words onto the stack)
# Check for left-arc, if yes carry out
# Check for right-arc, elif yes, carry out: your job is to fill in this portion of the code!
# else, perform a shift
# hints
# -checking against conditions for the right arc operation is fairly similar to the left arc check
#  with the order of stack1/stack2 swapped
# -to check if all the dependents of s1 have been assigned, you can
#  look for the presence of its children in dep_arcs
# -don't forget to update dep_arc after you called your perform_arc function

def tree_to_actions(wbuffer, stack, arcs, deps):

    # Initialize return lists
    
    # A list of 3-element tuples of lists
    # [(wbuffer1, stack1, arcs1), (wbuffer2, stack2, arcs2), ...]
    # Keeps track of the states at each step
    configurations=[]

    # gold_transitions:
    # A list of action strings, e.g ["SHIFT", "LEFTARC_nsubj"]
    # Keeps track of the actions at each step
    gold_transitions=[]

    # Keeps track of the dependents of each word in the tree
    # that have already been assigned
    dep_arcs = {}
    
    while len(wbuffer) >= 0:
        # Check for base-cases

        # firstly, check if we have translated all the tree to transition instructions
        if len(wbuffer) == 0 and len(stack) == 1 and stack[0] == 0:
            return configurations, gold_transitions

        # also, if there are fewer than 2 words on the stack
        # and more than 0 left on the buffer, 
        # we need to perform a shift operation
        if len(stack) < 2 and len(wbuffer) > 0:
            # shift operations
            perform_shift(wbuffer, stack, arcs, configurations, gold_transitions)
            continue

        # grab s1 and s2
        stack1 = stack[-1]
        stack2 = stack[-2]

        # check against conditions for left arc operation
        if stack1 in deps and (stack1, stack2) in deps[stack1]:
            # perform left arc
            perform_arc("LEFT", deps[stack1][(stack1, stack2)], wbuffer, stack, arcs, configurations, gold_transitions)
            # update dep_arcs
            dep_arcs[stack2] = 1

        # BEGIN SOLUTION
  
        # END SOLUTION

        # perform shift
        else:
            perform_shift(wbuffer, stack, arcs, configurations, gold_transitions)
    
    return configurations, gold_transitions

### check `tree_to_actions`

In [None]:
"""
objective: sanity check the tree_to_actions() function
"""
def sanity_check_tree_to_actions():

    # Setup for invoking tree_to_actions 
    wbuffer = [9, 8, 7, 6, 5, 4, 3, 2, 1]
    stack = [0]
    arcs = []
    deps = {5: {(5, 9): 'punct', (5, 8): 'obl', (5, 4): 'advmod', (5, 3): 'aux:pass', (5, 2): 'nsubj:pass'},
            8: {(8, 7): 'det', (8, 6): 'case'}, 0: {(0, 5): 'root'}, 2: {(2, 1): 'nmod:poss'}}

    tree_to_actions(wbuffer, stack, arcs, deps)

    # After tree_to_actions
    assert wbuffer == [], "The result for wbuffer is not correct"
    assert stack == [0], "The result for stack is not correct"
    assert arcs == [('nmod:poss', 2, 1), ('advmod', 5, 4), ('aux:pass', 5, 3), ('nsubj:pass', 5, 2), 
                     ('det', 8, 7), ('case', 8, 6), ('obl', 5, 8), ('punct', 5, 9), ('root', 0, 5)], \
                        "The result for arcs is not correct"
    assert deps == {5: {(5, 9): 'punct', (5, 8): 'obl', (5, 4): 'advmod', (5, 3): 'aux:pass', (5, 2): 'nsubj:pass'}, 
                     8: {(8, 7): 'det', (8, 6): 'case'}, 0: {(0, 5): 'root'}, 2: {(2, 1): 'nmod:poss'}}, \
                    "The result for deps is not correct"
    print("You cleared the sanity check for tree_to_actions()!")   
     
sanity_check_tree_to_actions()    

## Question 3. Dependency grammar vs. Phrase-structure grammar [written]

In the space below, please explain some **key differences** between the phrase-structure grammars and dependency grammars (~150 words). 

We aren't looking for an exhaustive list, rather, a sufficiently detailed explanation about the situational advantages and major differences that each of these different formalisms offers. See the assignment writeup for tips on getting started and directions to head for further information.

### Q3 response

**type q3 response here**

## Bonus: Tree Parsing with Predictions

**Note**: nothing in this section, or the subsequent "bonus" section, is graded. 

As a follow-up, the `action_to_tree` method will update the dependency tree based on the action predictions.


In [None]:
import re
import numpy as np

In [None]:
def is_valid(stack, wbuffer, action):

    if action == "SHIFT" and len(wbuffer) > 0:
        return True
    if action.startswith("RIGHTARC") and len(stack) > 1 and stack[-1] != 0:
        return True
    if action.startswith("LEFTARC") and len(stack) > 1 and stack[-2] != 0:
        return True

    return False

In [None]:
def action_to_tree(tree, predictions, wbuffer, stack, arcs, reverse_labels):

    sorted_probs = np.argsort(-predictions, kind='quicksort')[0]

    for i in range(len(sorted_probs)):
        action = reverse_labels[sorted_probs[i]]
        if is_valid(stack, wbuffer, action):
            if action == "SHIFT":
                stack.append(wbuffer.pop())

            elif len(stack) > 1 and action.startswith("RIGHTARC"):
                tree[stack[-1]] = (stack[-2], re.sub("RIGHTARC_", "", action))
                stack.pop()

            elif len(stack) > 1 and action.startswith("LEFTARC"):
                tree[stack[-2]] = (stack[-1], re.sub("LEFTARC_", "", action))
                stack.pop(-2)

            break


## Bonus: Neural Dependency Parser
Now since you have the configuration $x$ and action $y$, we can now train a supervised model to predict an action $y$ given a configuration $x$. We are using a simplified version of the model from [A Fast and Accurate Dependency Parser using Neural Networks](https://nlp.stanford.edu/pubs/emnlp2014-depparser.pdf).

* This model is alreadly implemented for you, please `train` the model, and report the evaluation and test results by calling the function `evaluate` and `test`
* To run the code for this section, you'll need to import pytorch + related libraries, switch your runtime to GPU, and pull in GLoVE embeddings -- this is provided for you in the next few cells as well.

#### Setup

In [None]:
import sys

import torch.nn as nn
import torch
import torch.optim as optim

In [None]:
# if this cell prints "Running on cpu", switch runtime environments
# go to Runtime > Change runtime type > Hardware accelerator > GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on {}".format(device))

Again, we will still be using [GloVe](https://nlp.stanford.edu/projects/glove/) pretrained word embeddings.

In [None]:
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip glove*.zip

Next, pull in Universal Dependency data.

In [None]:
!wget https://raw.githubusercontent.com/dbamman/nlp23/main/HW4/train.projective.short.conll
!wget https://raw.githubusercontent.com/dbamman/nlp23/main/HW4/dev.projective.conll

#### Class Definition

In [None]:
"""
Return pairs of configurations + gold transitions (actions)
from training data
configuration = a list of tuple of:
    - buffer (top of buffer is at the end of the list)
    - stack (top of buffer is at the end of the list)
    - arcs (a list of (label, head, dependent) tuples)
gold transitions = a list of actions, e.g. SHIFT
"""
def get_oracle(toks):

    stack = [] # stack
    arcs = [] # existing list of arcs
    wbuffer = [] # input buffer

    # deps is a dictionary of head: dependency relations, where
    # dependency relations is a dictionary of the (head, child): label
    # deps = {head1:{
    #               (head1, child1):dependency_label1,
    #               (head1, child2):dependency_label2
    #              }
    #         head2:{
    #               (head2, child3):dependency_label3,
    #               (head2, child4):dependency_label4
    #              }
    #         }
    deps = {}

    # ROOT
    stack.append(0)

    # initialize variables
    for position in reversed(toks):
        (idd, _, _, head, lab) = position

        dep = (head, idd)
        if head not in deps:
            deps[head] = {}
        deps[head][dep] = lab

        wbuffer.append(idd)

    # configurations:
    # A list of (wbuffer, stack, arcs)
    # Keeps tracks of the states at each step
    # gold_transitions:
    # A list of action strings ["SHIFT", "LEFTARC_nsubj"]
    # Keeps tracks of the actions at each step
    configurations, gold_transitions = tree_to_actions(wbuffer, stack, arcs, deps)
    return configurations, gold_transitions

def featurize_configuration(configuration, tokens, postags, vocab, pos_vocab):
    
    """
    Given configurations of the stack, input buffer and arcs,
    words of the sentence and POS tags of the words,
    return some features

    The current features are the word ID and postag ID at the 
    first three positions of the stack and buffer.
    """

    def get_id(word, vocab):
        word=word.lower()
        if word in vocab:
            return vocab[word]
        return vocab["<unk>"]

    wbuffer, stack, arcs = configuration

    word_features=[]
    pos_features=[]

    if len(stack) > 0: 
        word_features.append(get_id(tokens[stack[-1]], vocab))
        pos_features.append(get_id(postags[stack[-1]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    if len(stack) > 1: 
        word_features.append(get_id(tokens[stack[-2]], vocab))
        pos_features.append(get_id(postags[stack[-2]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    if len(stack) > 2: 
        word_features.append(get_id(tokens[stack[-3]], vocab))
        pos_features.append(get_id(postags[stack[-3]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    if len(wbuffer) > 0: 
        word_features.append(get_id(tokens[wbuffer[-1]], vocab))
        pos_features.append(get_id(postags[wbuffer[-1]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))
       
    if len(wbuffer) > 1: 
        word_features.append(get_id(tokens[wbuffer[-2]], vocab))
        pos_features.append(get_id(postags[wbuffer[-2]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    if len(wbuffer) > 2: 
        word_features.append(get_id(tokens[wbuffer[-3]], vocab))
        pos_features.append(get_id(postags[wbuffer[-3]], pos_vocab))
    else: 
        word_features.append(get_id("<NONE>", vocab))
        pos_features.append(get_id("<NONE>", pos_vocab))

    return word_features, pos_features

"""
Get configurations, gold_transitions from all sentences
"""
def get_oracles(filename, vocab, tag_vocab):

    with open(filename) as f:
        toks, tokens, postags = [], {}, {}
        tokens[0] = "<ROOT>"
        postags[0] = "<ROOT>"

        # a list of all features for each transition step
        word_feats = []
        pos_feats = []
        # a list of labels, e.g. SHIFT, LEFTARC_DEP_LABEL, RIGHTARC_DEP_LABEL
        labels = []

        for line in f:
            cols = line.rstrip().split("\t")
            
            if len(cols) < 2: # at the end of each sentence
                if len(toks) > 0:
                    if is_projective(toks): # only use projective trees
                        # get all configurations and gold standard transitions
                        configurations, gold_transitions = get_oracle(toks)
                        
                        for i in range(len(configurations)):
                            word_feat, pos_feat = featurize_configuration(configurations[i], tokens, postags, vocab, tag_vocab)
                            label = gold_transitions[i]
                            word_feats.append(word_feat)
                            pos_feats.append(pos_feat)
                            labels.append(label)

                    # reset vars for the next sentence
                    toks, tokens, postags = [], {}, {}
                    tokens[0] = "<ROOT>"
                    postags[0] = "<ROOT>"
                    
                continue

            if cols[0].startswith("#"):
                continue

            # construct the tuple for each word in the sentence
            # for each word in the sentence
            # idd: index of a word in a sentence, starting from 1
            # tok: the word itself
            # pos: pos tag for that word
            # head: parent of the dependency
            # lab: dependency relation label
            idd, tok, pos, head, lab = int(cols[0]), cols[1], cols[4], int(cols[6]), cols[7]
            toks.append((idd, tok, pos, head, lab))

            # feature for training to predict the gold transition
            tokens[idd], postags[idd] = tok, pos

        return word_feats, pos_feats, labels

def load_embeddings(filename):
    # 0 idx is for padding
    # 1 idx is for <UNK>
    # 2 idx is for <NONE>
    # 3 idx is for <ROOT>

    # get the embedding size from the first embedding
    vocab_size=4
    with open(filename, encoding="utf-8") as file:
        for idx, line in enumerate(file):
            if idx == 0:
                word_embedding_dim=len(line.rstrip().split(" "))-1
            vocab_size+=1

    vocab={"<pad>":0, "<unk>":1, "<none>":2, "<root>":3}
    print(f"word_embedding_dim: {word_embedding_dim}, vocab_size: {vocab_size}")

    embeddings=np.zeros((vocab_size, word_embedding_dim))

    with open(filename, encoding="utf-8") as file:
        for idx,line in enumerate(file):

            if idx + 4 >= vocab_size:
                break

            cols=line.rstrip().split(" ")
            val=np.array(cols[1:])
            word=cols[0]
            embeddings[idx+4]=val
            vocab[word]=idx+4

    return torch.FloatTensor(embeddings), vocab

class ShiftReduceParser(nn.Module):

    def __init__(self, embeddings, hidden_dim, tagset_size, num_pos_tags, pos_embedding_dim):
        super(ShiftReduceParser, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_labels=tagset_size

        _, embedding_dim = embeddings.shape

        self.input_size=embedding_dim*6 + pos_embedding_dim*6
        
        self.dropout_layer = nn.Dropout(p=0.25)

        self.word_embeddings = nn.Embedding.from_pretrained(embeddings)
        self.pos_embeddings = nn.Embedding(num_pos_tags, pos_embedding_dim)
        self.tanh = nn.Tanh()
        self.W1 = nn.Linear(self.input_size, self.hidden_dim)
        self.W2 = nn.Linear(self.hidden_dim, self.num_labels)

    def forward(self, words, pos_tags, Y=None):
        
        words=words.to(device)
        pos_tags=pos_tags.to(device)

        if Y is not None:
            Y=Y.to(device)

        word_embeds = self.word_embeddings(words)
        postag_embeds = self.pos_embeddings(pos_tags)

        embeds=torch.cat((word_embeds, postag_embeds), 2)

        embeds=embeds.view(-1, self.input_size)

        embeds=self.dropout_layer(embeds)

        hidden = self.W1(embeds)
        hidden = self.tanh(hidden)
        logits = self.W2(hidden)

        if Y is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), Y.view(-1))
            return loss
        else:
            return logits

def get_batches(W, P, Y, batch_size):
    batch_W=[]
    batch_P=[]
    batch_Y=[]

    i=0
    while i < len(W):
        batch_W.append(torch.LongTensor(W[i:i+batch_size]))
        batch_P.append(torch.LongTensor(P[i:i+batch_size]))
        batch_Y.append(torch.LongTensor(Y[i:i+batch_size]))
        i+=batch_size  

    return batch_W, batch_P, batch_Y

"""

Train transition-based parser to predict next action (labels)
given current configuration (featurized by word_feats and pos_feats)
Return the classifier trained using Chen and Manning (2014), "A Fast 
and Accurate Dependency Parser using Neural Networks"

"""
def train(word_feats, pos_feats, labels, embeddings, vocab, postag_vocab, label_vocab):


    # dimensionality of linear layer
    HIDDEN_DIM=100
    # dimensionality of POS embeddings
    POS_EMBEDDING_SIZE=50

    # batch size for training
    BATCH_SIZE=32

    # number of epochs to train for
    NUM_EPOCHS=10

    # learning rate for Adam optimizer
    LEARNING_RATE=0.001

    num_labels=[]
    for i, y in enumerate(labels):
        num_labels.append(label_vocab[y])

    batch_W, batch_P, batch_Y = get_batches(word_feats, pos_feats, num_labels, BATCH_SIZE)

    model = ShiftReduceParser(embeddings, HIDDEN_DIM, len(label_vocab), len(postag_vocab), POS_EMBEDDING_SIZE)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(NUM_EPOCHS):
        model.train()

        bigloss=0.
        for b in range(len(batch_W)):
            model.zero_grad()

            loss = model.forward(batch_W[b], batch_P[b], Y=batch_Y[b])
            bigloss+=loss.item()

            loss.backward()
            optimizer.step()

        print(f"loss: {bigloss}")

    return model

"""
parse sentence with trained model and return correctness measure
"""
def parse(toks, model, vocab, tag_vocab, reverse_labels):

    tokens, postags = {}, {}
    tokens[0] = "<ROOT>"
    postags[0] = "<ROOT>"

    wbuffer, stack, arcs = [], [], []
    stack.append(0)

    for position in reversed(toks):

        (idd, tok, pos, head, lab) = position
        tokens[idd] = tok
        postags[idd] = pos

        # update buffer
        wbuffer.append(idd)

    tree = {}
    while len(wbuffer) >= 0:
        if len(wbuffer) == 0 and len(stack) == 0: break
        if len(wbuffer) == 0 and len(stack) == 1 and stack[0] == 0: break

        word_feats, pos_feats = (featurize_configuration((wbuffer, stack, arcs), tokens, postags, vocab, tag_vocab))

       
        predictions=model.forward(torch.LongTensor([word_feats]), torch.LongTensor([pos_feats]))

        predictions=predictions.detach().cpu().numpy()

        # your function will be called here
        action_to_tree(tree, predictions, wbuffer, stack, arcs, reverse_labels)

    return tree

"""
parse sentence with trained model and return correctness measure
"""
def parse_and_evaluate(toks, model, vocab, tag_vocab, reverse_labels):

    heads, labels = {}, {}

    for position in reversed(toks):
        (idd, tok, pos, head, lab) = position

        # keep track of gold standards for performance evaluation
        heads[idd], labels[idd] = head, lab

    tree = parse(toks, model, vocab, tag_vocab, reverse_labels)

    # correct_unlabeled: total number of correct (head, child) dependencies
    # correct_labeled: total number of correctly *labeled* dependencies
    correct_unlabeled, correct_labeled, total = 0, 0, 0

    for child in tree:
        (head, label) = tree[child]
        if head == heads[child]:
            correct_unlabeled += 1
            if label == labels[child]: correct_labeled += 1
        total += 1

    return [correct_unlabeled, correct_labeled, total]

def get_label_vocab(labels):
    tag_vocab={}
    num_labels=[]
    for i, y in enumerate(labels):
        if y not in tag_vocab:
            tag_vocab[y]=len(tag_vocab)
        num_labels.append(tag_vocab[y])

    reverse_labels=[None]*len(tag_vocab)
    for y in tag_vocab:
        reverse_labels[tag_vocab[y]]=y

    return tag_vocab, reverse_labels


def get_pos_tag_vocab(filename):
    tag_vocab={"<none>":0, "<unk>":1}
    with open(filename) as file:
        for line in file:
            cols=line.rstrip().split("\t")
            if len(cols) < 3:
                continue
            pos=cols[4].lower()
            if pos not in tag_vocab:
                tag_vocab[pos]=len(tag_vocab)
    return tag_vocab

"""
Evaluate the performance of a parser against gold standard
"""
def test(model, vocab, tag_vocab, reverse_labels):

    model.eval()

    toks=["I", "bought", "a", "book"]
    pos=["NNP", "VBD", "DT", "NN"]

    data=[]
    # put it in format parser expects
    for i, tok in enumerate(toks):
        data.append((i+1, tok, pos[i], "_", "_"))

    tree=parse(data, model, vocab, tag_vocab, reverse_labels)

    for child in sorted(tree.keys()):
        (head, label) = tree[child]
        headStr="<ROOT>"
        # child and head indexes start at 1; 0 denotes the <ROOT>
        if head > 0: 
            headStr=toks[head-1]
        
        print(f"{child, toks[child-1]} -> ({head, headStr}) {label}")
  
"""
Evaluate the performance of a parser against gold standard
"""
def evaluate(filename, model, vocab, tag_vocab, reverse_labels):

    model.eval()

    with open(filename) as f:
        toks=[]
        totals = np.zeros(3)
        for line in f:
            cols=line.rstrip().split("\t")

            if len(cols) < 2: # end of a sentence
                if len(toks) > 0:
                    if is_projective(toks):
                        tots = np.array(parse_and_evaluate(toks, model, vocab, tag_vocab, reverse_labels))
                        totals += tots
                        
                    toks = []
                continue

            if cols[0].startswith("#"):
                continue

            idd, tok, pos, head, lab = int(cols[0]), cols[1], cols[4], int(cols[6]), cols[7]
            toks.append((idd, tok, pos, head, lab))
        
        print(f"UAS: {totals[0]/totals[2]}, LAS: {totals[1]/totals[2]}")

### Train and evaluate the neural model

* To reiterate, because you are not implementing the neural model or its training process, you will **not** be graded based on the performance of the neural model!

* You are only graded based on the correctness of each of the implemented functions from Questions 1 and 2. 

* That said, as a way to check your programming work, if all the functions you wrote are implemented correctly, you should expect a **UAS** in a range of **[0.64, 0.67]**, and a **LAS** in a range of **[0.56, 0.59]** without changing the parameters of the neural model or the training process.

In [None]:
# load the embeddings file we downloaded
embeddingsFile = "glove.6B.50d.txt"

# load in trainFile and devFile
trainFile = "train.projective.short.conll"
devFile = "dev.projective.conll"

embeddings, vocab=load_embeddings(embeddingsFile)
pos_tag_vocab=get_pos_tag_vocab(trainFile)
word_feats, pos_feats, labels = get_oracles(trainFile, vocab, pos_tag_vocab)

label_vocab, reverse_labels=get_label_vocab(labels)

In [None]:
model = train(word_feats, pos_feats, labels, embeddings, vocab, pos_tag_vocab, label_vocab)
evaluate(devFile, model, vocab, pos_tag_vocab, reverse_labels)
test(model, vocab, pos_tag_vocab, reverse_labels)