# Exercise: Transition-based Neural Network Parser

## Resources 

* Chen and Manning (2014) [A Fast and Accurate Dependency Parser using Neural Networks](https://cs.stanford.edu/people/danqi/papers/emnlp2014.pdf)
* CS224n Assignment #2 (Winter 2018) [Question 2](http://web.stanford.edu/class/cs224n/assignment2/index.html)

In [None]:
import tensorflow as tf
import numpy as np

import pickle
import random
import os
import sys

import time
import urllib
import zipfile

from collections import Counter
from datetime import datetime
from tempfile import gettempdir

## Helper functions for downloading and processing data 

Penn Treebank-3 (https://catalog.ldc.upenn.edu/ldc99t42)

In [None]:
def maybe_download(url, filename, expected_bytes):
    "Download the file if not present, make sure it's the right size."    
    local_filename = os.path.join(gettempdir(), filename)
    if not os.path.exists(local_filename):
        local_filename, _ = urllib.request.urlretrieve(url + filename, local_filename)
        statinfo = os.stat(local_filename)
        if statinfo.st_size == expected_bytes:
            with zipfile.ZipFile(local_filename) as f:
                f.extractall()
        else:
            print(statinfo.st_size)
            raise Exception('Failed to verify ' + local_filename + 
                            '. Can you get to it with a browser?')
    return local_filename

In [None]:
def read_conll(in_file, lowercase=False, max_example=None):
    """ Returns a list of examples where each example is a dict of lists:
        'word' : list of str repr. the words in the sentence
        'pos'  : list of str repr. the XPOS tags (language specific)
        'head' : list of int repr. the position of each word's head word
        'label': list of str repr. the dependency label of each word"""
    examples = []
    with open(in_file) as f:
        word, pos, head, label = [], [], [], []
        for line in f.readlines():
            sp = line.strip().split('\t')
            if len(sp) == 10:
                if '-' not in sp[0]:
                    word.append(sp[1].lower() if lowercase else sp[1])
                    pos.append(sp[4])
                    head.append(int(sp[6]))
                    label.append(sp[7])
            elif len(word) > 0:
                examples.append({'word': word, 'pos': pos,
                                 'head': head, 'label': label})
                word, pos, head, label = [], [], [], []
                if (max_example is not None) and (len(examples) == max_example):
                    break
        if len(word) > 0:
            examples.append({'word': word, 'pos': pos,
                             'head': head, 'label': label})
    return examples


In [None]:
def build_dict(keys, n_max=None, offset=0):
    "Helper function for building the mapping dicts"
    count = Counter()
    for key in keys:
        count[key] += 1
    ls = count.most_common() if n_max is None \
        else count.most_common(n_max)
    return {w[0]: index + offset for (index, w) in enumerate(ls)}


In [None]:
def punct(language, pos):
    if language == 'english':
        return pos in ["''", ",", ".", ":", "``", "-LRB-", "-RRB-"]
    elif language == 'universal':
        return pos == 'PUNCT'
    else:
        raise ValueError('language: %s is not supported.' % language)

In [None]:
def load_and_preprocess_data(reduced=True):
    config = ParserConfig()
    print("Downloading data...")
    start = time.time()
    zipfile = maybe_download('http://web.stanford.edu/class/cs224n/assignment2/',
                             'assignment2.zip',38866004)
    print ("took {:.2f} seconds".format(time.time() - start))
    
    print ("Loading data..."),
    start = time.time()
    train_set = read_conll(os.path.join(config.data_path, config.train_file),
                           lowercase=config.lowercase)
    dev_set = read_conll(os.path.join(config.data_path, config.dev_file),
                         lowercase=config.lowercase)
    test_set = read_conll(os.path.join(config.data_path, config.test_file),
                          lowercase=config.lowercase)
    if reduced:
        train_set = train_set[:1000]
        dev_set = dev_set[:500]
        test_set = test_set[:500]
    print ("took {:.2f} seconds".format(time.time() - start))

    print ("Building parser...")
    start = time.time()
    parser = Parser(train_set)
    print ("took {:.2f} seconds".format(time.time() - start))

    print ("Loading pretrained embeddings...")
    start = time.time()
    word_vectors = {}
    for line in open(os.path.join(config.data_path, config.embedding_file)):
        sp = line.strip().split()
        word_vectors[sp[0]] = [float(x) for x in sp[1:]]
    embeddings_matrix = np.asarray(np.random.normal(0, 0.9, (parser.n_tokens, 50)), dtype='float32')

    for token in parser.tok2id:
        i = parser.tok2id[token]
        if token in word_vectors:
            embeddings_matrix[i] = word_vectors[token]
        elif token.lower() in word_vectors:
            embeddings_matrix[i] = word_vectors[token.lower()]
    print ("took {:.2f} seconds".format(time.time() - start))

    print ("Vectorizing data...")
    start = time.time()
    train_set = parser.vectorize(train_set)
    dev_set = parser.vectorize(dev_set)
    test_set = parser.vectorize(test_set)
    print ("took {:.2f} seconds".format(time.time() - start))

    print ("Preprocessing training data...")
    start = time.time()
    train_examples = parser.create_instances(train_set)
    print ("took {:.2f} seconds".format(time.time() - start))

    return parser, embeddings_matrix, train_examples, dev_set, test_set,

## Classes for Dependency Parser 

In [None]:
class ModelWrapper(object):
    def __init__(self, parser, dataset, sentence_id_to_idx):
        self.parser = parser
        self.dataset = dataset
        self.sentence_id_to_idx = sentence_id_to_idx

    def predict(self, partial_parses):
        mb_x = [self.parser.extract_features(p.stack, p.buffer, p.dependencies,
                                             self.dataset[self.sentence_id_to_idx[id(p.sentence)]])
                for p in partial_parses]
        mb_x = np.array(mb_x).astype('int32')
        mb_l = [self.parser.legal_labels(p.stack, p.buffer) for p in partial_parses]
        pred = self.parser.model.predict_on_batch(self.parser.session, mb_x)
        pred = np.argmax(pred + 10000 * np.array(mb_l).astype('float32'), 1)
        pred = ["S" if p == 2 else ("LA" if p == 0 else "RA") for p in pred]
        return pred


In [None]:
P_PREFIX = '<p>:'
L_PREFIX = '<l>:'
UNK = '<UNK>'
NULL = '<NULL>'
ROOT = '<ROOT>'

In [None]:
class Parser(object):
    "Contains everything needed for transition-based dependency parsing except for the model"

    def __init__(self, dataset):
        # Check that there is a unique label for root
        root_labels = list([l for ex in dataset
                           for (h, l) in zip(ex['head'], ex['label']) if h == 0])
        counter = Counter(root_labels)
        if len(counter) > 1:
            print('Warning: more than one root label')
            print(counter)
        self.root_label = counter.most_common()[0][0]
        
        # list of all unique dependency labels
        deprel = [self.root_label] + list(set([w for ex in dataset
                                               for w in ex['label']
                                               if w != self.root_label]))
        
        # DEP labels such as <l>:acl, ... , <l>:xcomp, <l>:<NULL>
        tok2id = {L_PREFIX + l: i for (i, l) in enumerate(deprel)}
        tok2id[L_PREFIX + NULL] = self.L_NULL = len(tok2id)

        config = ParserConfig()
        self.unlabeled = config.unlabeled
        self.with_punct = config.with_punct
        self.use_pos = config.use_pos
        self.use_dep = config.use_dep
        self.language = config.language
        
        # dictionaries for transitions (left and right) incl. S for shift
        if self.unlabeled:
            trans = ['L', 'R', 'S']
            self.n_deprel = 1
        else: 
            trans = ['L-' + l for l in deprel] + ['R-' + l for l in deprel] + ['S']
            self.n_deprel = len(deprel)

        self.n_trans = len(trans) # number of unique arc-transitions
        self.tran2id = {t: i for (i, t) in enumerate(trans)}
        self.id2tran = {i: t for (i, t) in enumerate(trans)}
        
        # POS tags such as <p>:$, ... , <p>:<UNK>, <p>:<NULL>, <p>:<ROOT>
        tok2id.update(build_dict([P_PREFIX + w for ex in dataset for w in ex['pos']],
                                  offset=len(tok2id)))
        tok2id[P_PREFIX + UNK] = self.P_UNK = len(tok2id)
        tok2id[P_PREFIX + NULL] = self.P_NULL = len(tok2id)
        tok2id[P_PREFIX + ROOT] = self.P_ROOT = len(tok2id)
        
        # Words including <UNK>, <NULL>, <ROOT>
        tok2id.update(build_dict([w for ex in dataset for w in ex['word']],
                                  offset=len(tok2id)))
        tok2id[UNK] = self.UNK = len(tok2id)
        tok2id[NULL] = self.NULL = len(tok2id)
        tok2id[ROOT] = self.ROOT = len(tok2id)

        self.tok2id = tok2id
        self.id2tok = {v: k for (k, v) in tok2id.items()}

        self.n_features = 18 + (18 if config.use_pos else 0) + (12 if config.use_dep else 0)
        self.n_tokens = len(tok2id)

        
    def vectorize(self, examples):
        " Numericalizes examples and adds ROOT to the front"
        vec_examples = []
        for ex in examples:
            word = [self.ROOT] + [self.tok2id[w] if w in self.tok2id
                                  else self.UNK for w in ex['word']]
            pos = [self.P_ROOT] + [self.tok2id[P_PREFIX + w] if P_PREFIX + w in self.tok2id
                                   else self.P_UNK for w in ex['pos']]
            head = [-1] + ex['head']
            label = [-1] + [self.tok2id[L_PREFIX + w] if L_PREFIX + w in self.tok2id
                            else -1 for w in ex['label']]
            vec_examples.append({'word': word, 'pos': pos,
                                 'head': head, 'label': label})
        return vec_examples

    
    def extract_features(self, stack, buf, arcs, ex):
        "Returns list of features described in sec 3.1 of Chen and Manning"
        if stack[0] == "ROOT":
            stack[0] = 0

        def get_lc(k):
            # sorted indices of all left children of word at position k 
            return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] < k])

        def get_rc(k):
            # sorted indices of all right children of word at position k
            return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] > k],
                          reverse=True)
        p_features = []
        l_features = []
        # top 3 words on the stack (s1, s2, s3)
        features = [self.NULL] * (3 - len(stack)) + [ex['word'][x] for x in stack[-3:]]
        # top 3 words in the buffer
        features += [ex['word'][x] for x in buf[:3]] + [self.NULL] * (3 - len(buf))
        if self.use_pos: # add POS tags for top 3 words on stack and in buffer
            p_features = [self.P_NULL] * (3 - len(stack)) + [ex['pos'][x] for x in stack[-3:]]
            p_features += [ex['pos'][x] for x in buf[:3]] + [self.P_NULL] * (3 - len(buf))

        # first and second leftmost/rightmost children of top 2 words on stack
        for i in range(2):
            if i < len(stack):   # check that there are sufficient words on stack
                k = stack[-i-1]  # index into top 2 words on stack
                lc = get_lc(k)   # all left children
                rc = get_rc(k)   # all right children
                llc = get_lc(lc[0]) if len(lc) > 0 else [] # left children of leftmost child
                rrc = get_rc(rc[0]) if len(rc) > 0 else [] # right children of rightmost child
                # first and second leftmost/rightmost children
                features.append(ex['word'][lc[0]] if len(lc) > 0 else self.NULL)
                features.append(ex['word'][rc[0]] if len(rc) > 0 else self.NULL)
                features.append(ex['word'][lc[1]] if len(lc) > 1 else self.NULL)
                features.append(ex['word'][rc[1]] if len(rc) > 1 else self.NULL)
                # leftmost child of leftmost child
                features.append(ex['word'][llc[0]] if len(llc) > 0 else self.NULL)
                # rightmost child of rightmost child
                features.append(ex['word'][rrc[0]] if len(rrc) > 0 else self.NULL)

                if self.use_pos:
                    p_features.append(ex['pos'][lc[0]] if len(lc) > 0 else self.P_NULL)
                    p_features.append(ex['pos'][rc[0]] if len(rc) > 0 else self.P_NULL)
                    p_features.append(ex['pos'][lc[1]] if len(lc) > 1 else self.P_NULL)
                    p_features.append(ex['pos'][rc[1]] if len(rc) > 1 else self.P_NULL)
                    p_features.append(ex['pos'][llc[0]] if len(llc) > 0 else self.P_NULL)
                    p_features.append(ex['pos'][rrc[0]] if len(rrc) > 0 else self.P_NULL)

                if self.use_dep:
                    l_features.append(ex['label'][lc[0]] if len(lc) > 0 else self.L_NULL)
                    l_features.append(ex['label'][rc[0]] if len(rc) > 0 else self.L_NULL)
                    l_features.append(ex['label'][lc[1]] if len(lc) > 1 else self.L_NULL)
                    l_features.append(ex['label'][rc[1]] if len(rc) > 1 else self.L_NULL)
                    l_features.append(ex['label'][llc[0]] if len(llc) > 0 else self.L_NULL)
                    l_features.append(ex['label'][rrc[0]] if len(rrc) > 0 else self.L_NULL)
            else:
                features += [self.NULL] * 6
                if self.use_pos:
                    p_features += [self.P_NULL] * 6
                if self.use_dep:
                    l_features += [self.L_NULL] * 6

        features += p_features + l_features
        assert len(features) == self.n_features
        return features
    

    def get_oracle(self, stack, buf, ex):
        """ Implements 'shortest stack oracle' in sec 3.2 of Chen and Manning
            Returns:
            (1) 2 (if unlabeled) or index for S (shift operation)
            (2) 0 (if unlabeled) or DEP label of second word on stack, s2 (left-arc)
            (3) 1 (if unlabeled) or DEP label of first word on stack, s1 (right arc)
        """
        if len(stack) < 2:            # stack contans only one word
            return self.n_trans - 1   # return shift operation

        i0 = stack[-1]                # position of top two words on stack 
        i1 = stack[-2]
        h0 = ex['head'][i0]           # position of their respective heads
        h1 = ex['head'][i1]
        l0 = ex['label'][i0]          # their respective dependency labels
        l1 = ex['label'][i1]

        if self.unlabeled:
            if (i1 > 0) and (h1 == i0):  # second on stack is not root  
                return 0                 # second on stack is dep on first
            elif (i1 >= 0) and (h0 == i1) and \
                 (not any([x for x in buf if ex['head'][x] == i0])):
                return 1
            else:
                return None if len(buf) == 0 else 2
        else:
            if (i1 > 0) and (h1 == i0):
                return l1 if (l1 >= 0) and (l1 < self.n_deprel) else None
            elif (i1 >= 0) and (h0 == i1) and \
                 (not any([x for x in buf if ex['head'][x] == i0])):
                return l0 + self.n_deprel if (l0 >= 0) and (l0 < self.n_deprel) else None
            else:
                return None if len(buf) == 0 else self.n_trans - 1

            
    def create_instances(self, examples):
        ""
        all_instances = []
        succ = 0
        for id, ex in enumerate(examples):
            n_words = len(ex['word']) - 1
            stack = [0]
            buf = [i + 1 for i in range(n_words)]
            arcs = []
            instances = []
            for i in range(n_words * 2):
                gold_t = self.get_oracle(stack, buf, ex)
                if gold_t is None:
                    break
                legal_labels = self.legal_labels(stack, buf)
                assert legal_labels[gold_t] == 1
                instances.append((self.extract_features(stack, buf, arcs, ex),
                                  legal_labels, gold_t))
                if gold_t == self.n_trans - 1:
                    stack.append(buf[0])
                    buf = buf[1:]
                elif gold_t < self.n_deprel: #
                    arcs.append((stack[-1], stack[-2], gold_t))
                    stack = stack[:-2] + [stack[-1]]
                else:
                    arcs.append((stack[-2], stack[-1], gold_t - self.n_deprel))
                    stack = stack[:-1]
            else:
                succ += 1
                all_instances += instances
        return all_instances

    
    def legal_labels(self, stack, buf):
        ""
        labels = ([1] if len(stack) > 2 else [0]) * self.n_deprel
        labels += ([1] if len(stack) >= 2 else [0]) * self.n_deprel
        labels += [1] if len(buf) > 0 else [0]
        return labels

    
    def parse(self, dataset, eval_batch_size=5000):
        ""
        sentences = []
        sentence_id_to_idx = {}
        for i, example in enumerate(dataset):
            n_words = len(example['word']) - 1
            sentence = [j + 1 for j in range(n_words)]
            sentences.append(sentence)
            sentence_id_to_idx[id(sentence)] = i

        model = ModelWrapper(self, dataset, sentence_id_to_idx)
        dependencies = minibatch_parse(sentences, model, eval_batch_size)

        UAS = all_tokens = 0.0
        for i, ex in enumerate(dataset):
            head = [-1] * len(ex['word'])
            for h, t, in dependencies[i]:
                head[t] = h
            for pred_h, gold_h, gold_l, pos in \
                    zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]):
                    assert self.id2tok[pos].startswith(P_PREFIX)
                    pos_str = self.id2tok[pos][len(P_PREFIX):]
                    if (self.with_punct) or (not punct(self.language, pos_str)):
                        UAS += 1 if pred_h == gold_h else 0
                        all_tokens += 1
        UAS /= all_tokens
        return UAS, dependencies


In [None]:
class ParserConfig(object):
    language = 'english'
    with_punct = True
    unlabeled = True
    lowercase = True
    use_pos = True
    use_dep = True
    use_dep = use_dep and (not unlabeled)
    data_path = os.path.join('assignment2', 'data')
    train_file = 'train.conll'
    dev_file = 'dev.conll'
    test_file = 'test.conll'
    embedding_file = 'en-cw.txt'


In [None]:
class PartialParse(object):
    def __init__(self, sentence):
        "Initializes this partial parse."
        self.sentence = sentence
        self.stack = ['ROOT']
        self.buffer = sentence.copy()
        self.dependencies = []
        
        
    def parse(self, transitions):
        "Applies the provided transitions"
        for transition in transitions:
            self.parse_step(transition)
        return self.dependencies
    
    
    def parse_step(self, transition):
        "Performs a single parse step with given transition"
        if transition == 'S':
            self.stack.append(self.buffer.pop(0))
        elif transition == 'LA':
            dependent = self.stack.pop(-2)
            head = self.stack[-1]
            self.dependencies.append((head, dependent))
        else:
            dependent = self.stack.pop()
            head = self.stack[-1]
            self.dependencies.append((head, dependent))
                

## Functions to generate minibatches 

In [None]:
def get_minibatches(data, minibatch_size, shuffle=True):
    """ 
    (1) data: there are two possible values:
            - a list or numpy array
            - a list where each element is either a list or numpy array
    (2) minibatch_size: the maximum number of items in a minibatch
    (3) shuffle: whether to randomize the order of returned data
    Returns:
        minibatches: the return value depends on data:
            - If data is a list/array it yields the next minibatch of data.
            - If data a list of lists/arrays it returns the next minibatch of 
              each element in the list. This can be used to iterate through 
              multiple data sources (e.g., features and labels) at the same time."""
    list_data = type(data) is list and (type(data[0]) is list or type(data[0]) is np.ndarray)
    data_size = len(data[0]) if list_data else len(data)
    indices = np.arange(data_size)
    if shuffle:
        np.random.shuffle(indices)
    for minibatch_start in np.arange(0, data_size, minibatch_size):
        minibatch_indices = indices[minibatch_start:minibatch_start + minibatch_size]
        yield [_minibatch(d, minibatch_indices) for d in data] if list_data \
            else _minibatch(data, minibatch_indices)


def _minibatch(data, minibatch_idx):
    return data[minibatch_idx] if type(data) is np.ndarray else [data[i] for i in minibatch_idx]


def minibatches(data, batch_size):
    x = np.array([d[0] for d in data])
    y = np.array([d[2] for d in data])
    one_hot = np.zeros((y.size, 3))
    one_hot[np.arange(y.size), y] = 1
    return get_minibatches([x, one_hot], batch_size)


def minibatch_parse(sentences, model, batch_size):
    "Parses a list of sentences in minibatches using a model."

    partial_parses = [PartialParse(sentence) for sentence in sentences]
    # https://docs.python.org/3.5/library/copy.html (A shallow copy constructs a new 
    # compound object and then (to the extent possible) inserts references into it to
    # the objects found in the original.) 
    unfinished_parses = partial_parses.copy()
    
    # while unfinished_parses is not empty
    while unfinished_parses:
        # Take first batch_size parses in unfinished_parses as a minibatch
        minibatch = unfinished_parses[:batch_size]
        # Use model to predict the next transition for each partial parse in minibatch
        while minibatch:
            transitions = model.predict(minibatch) # the vectorized portion to gain a speed up 
            # Perform parse step on each partial parse in minibatch with predicted transition
            for i, transition in enumerate(transitions):
                minibatch[i].parse_step(transition)
            # Remove completed (empty buffer and stack of size 1) parses from unfinished parses
            minibatch = [pp for pp in minibatch if len(pp.stack) > 1 or len(pp.buffer) > 0]
        del unfinished_parses[:batch_size]
    dependencies = [partial_parse.dependencies for partial_parse in partial_parses]
    return dependencies


# START HERE >>

## Exercise: Build and Train the Model 

1. Fill in the hyperparameter values:

In [None]:
class TrainConfig(object):
    "Holds model hyperparams and data information"
    n_features = 36
    n_classes = 3
    dropout = 
    embed_size = 50
    hidden_size = 
    batch_size = 
    n_epochs = 10
    lr = 
    

2. Complete the code:

In [None]:
class ParserModel:
    """
    Implements a feedforward neural network with an embedding layer and single hidden layer.
    This network will predict which transition should be applied to a given partial parse
    configuration.
    """

    def add_placeholders(self):
        """Generates placeholder variables to represent the input tensors

        These placeholders are used as inputs by the rest of the model building and will be fed
        data during training.  Note that when "None" is in a placeholder's shape, it's flexible
        (so we can use different batch sizes without rebuilding the model).

        Adds following nodes to the computational graph

        input_placeholder: Input placeholder tensor of  shape (None, n_features), type tf.int32
        labels_placeholder: Labels placeholder tensor of shape (None, n_classes), type tf.float32
        dropout_placeholder: Dropout value placeholder (scalar), type tf.float32

        Add these placeholders to self as the instance variables
            self.input_placeholder
            self.labels_placeholder
            self.dropout_placeholder

        (Don't change the variable names)
        """
        ### YOUR CODE HERE
        ### END YOUR CODE
        
        
    def create_feed_dict(self, inputs_batch, labels_batch=None, dropout=0):
       """Creates the feed_dict for the dependency parser.

        A feed_dict takes the form of:

        feed_dict = {
                <placeholder>: <tensor of values to be passed for placeholder>,
                ....
        }


        Hint: The keys for the feed_dict should be a subset of the placeholder
                    tensors created in add_placeholders.
        Hint: When an argument is None, don't add it to the feed_dict.

        Args:
            inputs_batch: A batch of input data.
            labels_batch: A batch of label data.
            dropout: The dropout rate.
        Returns:
            feed_dict: The feed dictionary mapping from placeholders to values.
        """
        ### YOUR CODE HERE
        ### END YOUR CODE
        return feed_dict
    
    
    def add_embedding(self):
        """Adds an embedding layer that maps from input tokens (integers) to vectors and then
        concatenates those vectors:
            - Creates a tf.Variable and initializes it with self.pretrained_embeddings.
            - Uses the input_placeholder to index into the embeddings tensor, resulting in a
              tensor of shape (None, n_features, embedding_size).
            - Concatenates the embeddings by reshaping the embeddings tensor to shape
              (None, n_features * embedding_size).

        Hint: You might find tf.nn.embedding_lookup useful.
        Hint: You can use tf.reshape to concatenate the vectors. See following link to understand
            what -1 in a shape means.
            https://www.tensorflow.org/api_docs/python/tf/reshape

        Returns:
            embeddings: tf.Tensor of shape (None, n_features*embed_size)
        """
        ### YOUR CODE HERE
        ### END YOUR CODE
        return embeddings
    

    def add_prediction_op(self):
        """Adds the 1-hidden-layer NN:
            h = Relu(xW + b1)
            h_drop = Dropout(h, dropout_rate)
            pred = h_dropU + b2

        Note that we are not applying a softmax to pred. The softmax will instead be done in
        the add_loss_op function, which improves efficiency because we can use
        tf.nn.softmax_cross_entropy_with_logits

        Use the initializer from q2_initialization.py to initialize W and U (you can initialize b1
        and b2 with zeros)

        Hint: Note that tf.nn.dropout takes the keep probability (1 - p_drop) as an argument.
              Therefore the keep probability should be set to the value of
              (1 - self.dropout_placeholder)

        Returns:
            pred: tf.Tensor of shape (batch_size, n_classes)
        """

        x = self.add_embedding()
        ### YOUR CODE HERE
        ### END YOUR CODE
        return pred
    
    def add_loss_op(self, pred):
        """Adds Ops for the loss function to the computational graph.
        In this case we are using cross entropy loss.
        The loss should be averaged over all examples in the current minibatch.

        Hint: You can use tf.nn.softmax_cross_entropy_with_logits_v2 and tf.stop_gradient to 
                simplify your implementation. You might find tf.reduce_mean useful.
        Args:
            pred: A tensor of shape (batch_size, n_classes) containing the output of the neural
                  network before the softmax layer.
        Returns:
            loss: A 0-d tensor (scalar)
        """
        ### YOUR CODE HERE
        ### END YOUR CODE
        return loss

    
    def add_training_op(self, loss):
        """Sets up the training Ops.

        Creates an optimizer and applies the gradients to all trainable variables.
        The Op returned by this function is what must be passed to the
        `sess.run()` call to cause the model to train. See

        https://www.tensorflow.org/api_docs/python/tf/train/Optimizer

        for more information.

        Use tf.train.AdamOptimizer for this model.
        Use the learning rate from self.config.
        Calling optimizer.minimize() will return a train_op object.

        Args:
            loss: Loss tensor, from cross_entropy_loss.
        Returns:
            train_op: The Op for training.
        """
        ### YOUR CODE HERE
        ### END YOUR CODE
        return train_op
    

    def train_on_batch(self, sess, inputs_batch, labels_batch):
        feed = self.create_feed_dict(inputs_batch, labels_batch=labels_batch,
                                     dropout=self.config.dropout)
        _, loss = sess.run([self.train_op, self.loss], feed_dict=feed)
        return loss
    

    def predict_on_batch(self, sess, inputs_batch):
        "Make predictions for the provided batch of data"
        feed = self.create_feed_dict(inputs_batch)
        predictions = sess.run(self.pred, feed_dict=feed)
        return predictions


    def run_epoch(self, sess, parser, train_examples, dev_set):
        "Training loop"
        n_minibatches = 1 + len(train_examples) / self.config.batch_size
        prog = tf.keras.utils.Progbar(target=n_minibatches)
        for i, (train_x, train_y) in enumerate(minibatches(train_examples, self.config.batch_size)):
            loss = self.train_on_batch(sess, train_x, train_y)
            prog.update(i + 1, [("train loss", loss)])
        print ("\nEvaluating on dev set")
        dev_UAS, _ = parser.parse(dev_set)
        print ("- dev UAS: {:.2f}".format(dev_UAS * 100.0))
        return dev_UAS
    

    def fit(self, sess, saver, parser, train_examples, dev_set):
        best_dev_UAS = 0
        for epoch in range(self.config.n_epochs):
            print ("Epoch {:} out of {:}".format(epoch + 1, self.config.n_epochs))
            dev_UAS = self.run_epoch(sess, parser, train_examples, dev_set)
            if dev_UAS > best_dev_UAS:
                best_dev_UAS = dev_UAS
                if saver:
                    print ("New best dev UAS! Saving model in ./assignment2/data/weights/parser.weights\n")
                    saver.save(sess, './assignment2/data/weights/parser.weights')

            
    def build(self):
        self.add_placeholders()
        self.pred = self.add_prediction_op()
        self.loss = self.add_loss_op(self.pred)
        self.train_op = self.add_training_op(self.loss)

        
    def __init__(self, config, pretrained_embeddings):
        self.pretrained_embeddings = pretrained_embeddings
        self.config = config
        self.build()
        

3. Test your model and iterate if necessary. The goal is to obtain training loss < 0.2 and dev UAS >= 65.

In [None]:
debug = True   
print ("Initializing...")
config = TrainConfig()
parser, embeddings, train_examples, dev_set, test_set = load_and_preprocess_data(debug)
if not os.path.exists('./assignment2/data/weights/'):
    os.makedirs('./assignment2/data/weights/')

with tf.Graph().as_default() as graph:
    print ("Building model...")
    start = time.time()
    model = ParserModel(config, embeddings)
    parser.model = model
    init_op = tf.global_variables_initializer()
    saver = None if debug else tf.train.Saver()
    print ("took {:.2f} seconds\n".format(time.time() - start))
graph.finalize()

with tf.Session(graph=graph) as session:
    parser.session = session
    session.run(init_op)
    print ("Training...")
    model.fit(session, saver, parser, train_examples, dev_set)
    

4. Run your tested model on full training set.

In [None]:
# Run this cell after results above show train loss < 0.2 and dev UAS is >= 65

debug = False   
print ("Initializing...")
config = TrainConfig()
parser, embeddings, train_examples, dev_set, test_set = load_and_preprocess_data(debug)
if not os.path.exists('./assignment2/data/weights/'):
    os.makedirs('./assignment2/data/weights/')

with tf.Graph().as_default() as graph:
    print ("Building model...")
    start = time.time()
    model = ParserModel(config, embeddings)
    parser.model = model
    init_op = tf.global_variables_initializer()
    saver = None if debug else tf.train.Saver()
    print ("took {:.2f} seconds\n".format(time.time() - start))
graph.finalize()

with tf.Session(graph=graph) as session:
    parser.session = session
    session.run(init_op)
    print ("Training...")
    model.fit(session, saver, parser, train_examples, dev_set)
    if not debug:
        print ("Testing...")
        print ("Restoring the best model weights found on the dev set")
        saver.restore(session, './assignment2/data/weights/parser.weights')
        print ("Final evaluation on test set")
        UAS, dependencies = parser.parse(test_set)
        print ("- test UAS: {:.2f}".format(UAS * 100.0))
        print ("Writing predictions")
        with open('q2_test.predicted.pkl', 'wb') as f:
            pickle.dump(dependencies, f, -1)
        print ("Done!")
