# Assignment 3: Graph-based dependency parsing

In this assignment, you will implement a simplified version of the dependency parser used by [Glavaš and Vulić (2021)](http://dx.doi.org/10.18653/v1/2021.eacl-main.270) (Figure&nbsp;1). This parser consists of a transformer encoder followed by a bi-affine layer that computes arc scores for all pairs of words. These scores are then used as logits in a classifier that predicts the syntactic head of each word. In contrast to the parser described in the paper, your parser will only support unlabelled parsing, i.e., you will implement an *arc classifier* but no *relation classifier*. As the encoder, you will use the [uncased BERT base model](https://huggingface.co/bert-base-uncased) from the [Transformers](https://huggingface.co/docs/transformers/main/en/index) library.

We start by importing PyTorch and setting the device we will use for training and evaluating.

In [None]:
import torch

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Dataset

The data for this lab comes from the English Web Treebank from the [Universal Dependencies Project](http://universaldependencies.org); we distribute it here in the form of two JSON files. The code in the next cell below defines a PyTorch Dataset wrapper for the data.

In [None]:
import json

from torch.utils.data import Dataset

class ParserDataset(Dataset):

    def __init__(self, filename):
        super().__init__()
        with open(filename, 'rt', encoding='utf-8') as fp:
            self.items = [[tuple(x) for x in json.loads(l)] for l in fp]

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        return self.items[idx]

We can now load the training data:

In [None]:
TRAIN_DATA = ParserDataset('en_ewt-ud-train.jsonl')

A data set consists of **parsed sentences**. A parsed sentence is represented as a list of pairs. The first component of each pair (a string) represents a word. The second component (an integer) specifies the position of the word’s syntactic head, i.e., its parent in the dependency tree. Note that word positions are numbered starting at&nbsp;1. The special head position&nbsp;0 marks the root of the tree.

Run the following code cell to see an example sentence:

In [None]:
EXAMPLE_SENTENCE = TRAIN_DATA[531]

EXAMPLE_SENTENCE

In this example the head of the pronoun *I* is the word at position&nbsp;2 – the verb *like*. The dependents of *like* are *I* (position&nbsp;1) and the noun *blog* (position&nbsp;4), as well as the final punctuation mark. Note that the pronoun *your* (position&nbsp;3) is misspelled as *yuor*.

## Problem 1: Tokenisation

To feed parsed sentences to BERT, we need to tokenise them and encode the resulting tokens as integers in the model vocabulary. We start by loading the BERT tokeniser using the Auto classes:

In [None]:
from transformers import AutoTokenizer

TOKENIZER = AutoTokenizer.from_pretrained('bert-base-uncased')

We can call the tokeniser on the example sentence as follows:

In [None]:
TOKENIZER([w for w, _ in EXAMPLE_SENTENCE], is_split_into_words=True)

Note that we use the *is_split_into_words* keyword argument to indicate that the input is already pre-tokenised (split on whitespace).

The BERT tokeniser segments each word (pre-token) into one or several subword tokens, and we want to keep track of which of these were introduced by which word. To this end, for each actual word in a parsed sentence we compute the corresponding span in the tokens list.

To illustrate this, consider again the tokenisation of the example sentence. Note that in order to match the default behaviour when calling the tokeniser, we explicitly add the special tokens at the beginning and at the end of the sentence.

In [None]:
TOKENIZER.tokenize([w for w, _ in EXAMPLE_SENTENCE], add_special_tokens=True, is_split_into_words=True)

For this sentence, we would like to compute the following token spans:

In [None]:
[(1, 2), (2, 3), (3, 5), (5, 6), (6, 7)]

Each of these spans covers a single token, except for the span `(3, 5)` which covers the tokens `yu` and `##or`. Note that token indices start at&nbsp;1, as position&nbsp;0 is occupied by the special `[CLS]` token.

The next cell contains skeleton code for a function `encode` that takes a tokeniser and a batch of sentences and returns the tokeniser&rsquo;s encoded input as well as the corresponding token spans.

In [None]:
def encode(tokenizer, sentences):
    # TODO: Replace the next line with your own code
    raise NotImplementedError

Implement this function to match the following specification:

**encode** (*tokenizer*, *sentences*):

> Uses the specified *tokenizer* to encode a batch of parsed sentences (*sentences*). This returns a pair consisting of a [`BatchEncoding`](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.BatchEncoding) and a matching batch of token spans (as explained above). The `BatchEncoding` is the standard batch encoding, including the token ids to be fed to a model. Its inner content have been converted to PyTorch tensors. Token indexes start at&nbsp;1.

### 🤞 Test your code

To test you code, call `encode` on the example sentence and check that output matches your expectations.

### Problem 2: Merging tokens

BERT gives us a representation for each *token* in the input sequence. To compute scores between pairs of *words*, we need to combine the token representations that correspond to each word. A standard strategy for this is to take their element-wise mean.

The next cell contains skeleton code for a function `merge_tokens` that implements this strategy.

In [None]:
def merge_tokens(tokens, token_spans):
    # TODO: Replace the next line with your own code
    raise NotImplementedError

Implement this function to match the following specification:

**merge_tokens** (*tokens*, *token_spans*)

> Takes a batch of token vectors (*tokens*) and a batch of token spans (*token_spans*) and returns a new batch of word-level representations, computed as explained above. The token vectors are a tensor of shape (*batch_size*, *num_tokens*, *hidden_dim*). The token spans are a nested list of integer pairs as computed in Problem&nbsp;1. The result is a tensor of shape (*batch_size*, *max_num_words*, *hidden_dim*), where *max_num_words* denotes the maximum number of words in any sentence in the batch. Entries corresponding to padding are represented by the zero vector of size *hidden_dim*.

### 🤞 Test your code

To test you code, create a sample input to `merge_tokens` and check that the output matches your expectations

## Problem 3: Biaffine layer

Your next task is to implement the bi-affine layer. Given matrices $X \in \mathbb{R}^{m \times h}$ and $X' \in \mathbb{R}^{n \times h}$, this layer computes a matrix $Y \in \mathbb{R}^{m \times n}$ as

$$
Y = X W X'{}^\top + b
$$

where $W \in \mathbb{R}^{h \times h}$ and $b \in \mathbb{R}$ are learnable weight and bias parameters. In the context of the dependency parser, the matrices $X$ and $X'$ hold the encodings of all words in the input sentence, and the entries of the matrix $Y$ are interpreted as scores of dependency arcs between words. More specifically, the entry $Y_{ij}$ represents the score of an arc from a head word at position&nbsp;$j$ to a dependent at position&nbsp;$i$.


The following cell contains skeleton code for the implementation of the bi-affine layer. Implement this layer according to the specification above.

In [None]:
import torch.nn as nn

class Biaffine(nn.Module):

    def __init__(self, in_features):
        super().__init__()
        # TODO: Replace the next line with your own code
        raise NotImplementedError

    def forward(self, x1, x2):
        # TODO: Replace the next line with your own code
        raise NotImplementedError

**⚠️ Note that your implementation should be able to handle *batches* of input sentences.**

### 🤞 Test your code

To test you code, create a sample input to the bi-affine layer as well as suitable weights and biases and check that the output of the `forward` method matches your expectations

## Problem 4: Parser

We are now ready to put the two main components of the parser together: the encoder (BERT) and the bi-affine layer that computes the arc scores. We also add a dropout layer between the two components. The following code cell contains skeleton code for the parsing model with the `init` method already complete. Your task is to implement the `forward` method. Have another look at the paper to understand how things need to be wired up.

In [None]:
import torch.nn as nn
import torch

from transformers import BertConfig, BertModel, BertPreTrainedModel

class BertForParsing(BertPreTrainedModel):

    config_class = BertConfig

    def __init__(self, config, dropout=0.1):
        super().__init__(config)
        self.encoder = BertModel(config)
        self.dropout = nn.Dropout(dropout)
        self.biaffine = Biaffine(config.hidden_size)

    def forward(self, encoded_input, token_spans):
        # TODO: Replace the next line with your own code
        raise NotImplementedError

Implement the `forward` method to match the following specification:

**forward** (*encoded_input*, *token_spans*)

> Takes a tokeniser-encoded batch of sentences (of type `BatchEncoding`) and a corresponding batch of token spans and returns a tensor with scores between all words in the input. More specifically, the output tensor $Y$ has shape (*batch_size*, *num_words*, *num_words+1*), where the entry $Y_{bij}$ represents the score of an arc from a head word at position&nbsp;$j$ to a dependent at position&nbsp;$i$ in the $b$th sentence of the batch. Note that the number of possible heads is one greater than the number of possible dependents because the heads include the special token `[CLS]` (at position&nbsp;0), which the paper uses to represent the root vertex.

### 🤞 Test your code

To test you code, instantiate the parsing model and feed it the tokenised example sentence.

## Batching the data

We are now almost ready to train the parser. The missing piece is a data collator that prepares a batch of parsed sentences:

* tokenises the sentences and extracts token spans using `encode` (Problem&nbsp;1)
* constructs the ground-truth head tensor needed to compute the loss (Problem&nbsp;2)

The code in the next cell implements these two steps. For pseudo-words introduced through padding, we assign a head index of −100. This value is ignored by PyTorch’s cross-entropy loss function.

In [None]:
import torch

class ParserBatcher(object):

    def __init__(self, tokenizer, device=None):
        self.tokenizer = tokenizer
        self.device = device

    def __call__(self, parser_inputs):
        encoded_input, start_indices = encode(self.tokenizer, parser_inputs)

        # Get the maximal number of words, for padding
        max_num_words = max(len(s) for s in parser_inputs)

        # Construct tensor containing the ground-truth heads
        all_heads = []
        for parser_input in parser_inputs:
            words, heads = zip(*parser_input)
            heads = list(heads)
            heads.extend([-100] * (max_num_words - len(heads)))  # -100 will be ignored
            all_heads.append(heads)
        all_heads = torch.LongTensor(all_heads)

        # Send all data to the specified device
        if self.device:
            encoded_input = encoded_input.to(self.device)
            all_heads = all_heads.to(self.device)

        return encoded_input, start_indices, all_heads

## Training loop

Finally, here is the training loop of the parser. Most of it is quite standard. The training loss of the parser is the cross-entropy between the head scores and the ground truth head positions. In other words, the parser is trained as a classifier that predicts the position of each word&rsquo;s head.

In [None]:
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

def train(dataset, n_epochs=1, lr=1e-5, batch_size=8):
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    model = BertForParsing.from_pretrained('bert-base-uncased').to(DEVICE)
    data_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=ParserBatcher(tokenizer, device=DEVICE))
    optimizer = Adam(model.parameters(), lr=lr)
    for epoch in range(n_epochs):
        model.train()
        running_loss = 0
        n_batches = 0
        with tqdm(total=len(dataset)) as pbar:
            pbar.set_description(f'Epoch {epoch+1}')
            for encoded_input, token_spans, gold_heads in data_loader:
                optimizer.zero_grad()
                head_scores = model.forward(encoded_input, token_spans)
                loss = F.cross_entropy(head_scores.flatten(0, -2), gold_heads.view(-1))
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                n_batches += 1
                pbar.set_postfix(loss=running_loss/n_batches)
                pbar.update(len(token_spans))
    return model

We are now ready to train the parser. With a GPU, you should expect training times of approximately 3&nbsp;minutes per epoch.

In [None]:
PARSING_MODEL = train(TRAIN_DATA, n_epochs=1)

## Evaluation

The parser is evaluated using unlabelled attachment score (UAS), which is the percentage of words that have been assigned their correct heads. Note that pseudo-words corresponding to padding (which we marked with the special head index −100 above) must be excluded from this calculation.

In [None]:
DEV_DATA = ParserDataset('en_ewt-ud-dev.jsonl')

In [None]:
import torch

def evaluate(model, dataset, batch_size=8):
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    data_loader = DataLoader(DEV_DATA, batch_size=batch_size, collate_fn=ParserBatcher(tokenizer, device=DEVICE))
    n_correct = 0
    n_total = 0
    model.eval()
    with tqdm(total=len(dataset)) as pbar:
        for encoded_input, token_spans, gold_heads in data_loader:
            with torch.no_grad():
                head_scores = model.forward(encoded_input, token_spans)
                pred_heads = torch.argmax(head_scores, dim=-1)
            mask = gold_heads.ne(-100)
            n_correct += torch.sum(pred_heads[mask] == gold_heads[mask])
            n_total += torch.sum(mask)
            pbar.update(len(token_spans))
    return n_correct / n_total

In [None]:
evaluate(PARSING_MODEL, DEV_DATA)

**Your notebook must contain output demonstrating at least 88% UAS on the development data.**

That’s it! Congratulations on finishing the last assignment of this course! 🥳