__Probing Language Models__

This notebook serves as a start for your NLP2 assignment on probing Language Models. This notebook will become part of the contents that you will submit at the end, so make sure to keep your code (somewhat) clean :-)

__note__: This assignment is not dependent on big fancy GPUs. I run all this stuff on my own 3 year old CPU, without any Colab hassle. So it's up to you to decide how you want to run it.

# Models

For the Transformer models you are advised to make use of the `transformers` library of Huggingface: https://github.com/huggingface/transformers
Their library is well documented, and they provide great tools to easily load in pre-trained models.

In [1]:
#
## Your code for initializing the transformer model(s)
#
# Note that transformer models use their own `tokenizer`, that should be loaded in as well.
#
from transformers import *

tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
model = GPT2LMHeadModel.from_pretrained('distilgpt2')

  from .autonotebook import tqdm as notebook_tqdm
Discovered apex.normalization.FusedRMSNorm - will use it instead of LongT5LayerNorm
Xformers is not installed correctly. If you want to use memorry_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.
Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm
Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm
loading file vocab.json from cache at /root/.cache/huggingface/hub/models--distilgpt2/snapshots/38cc92ec43315abd5136313225e95acc5986876c/vocab.json
loading file merges.txt from cache at /root/.cache/huggingface/hub/models--distilgpt2/snapshots/38cc92ec43315abd5136313225e95acc5986876c/merges.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at None
loading configuration file config.json from cache at /root/.cache/h

In [2]:
#
## Your code for initializing the rnn model(s)
#
# The Gulordava LSTM model can be found here: 
# https://drive.google.com/file/d/19Lp3AM4NEPycp_IBgoHfLc_V456pmUom/view?usp=sharing
# You can read more about this model in the original paper here: https://arxiv.org/pdf/1803.11138.pdf
#
# N.B: I have altered the RNNModel code to only output the hidden states that you are interested in.
# If you want to do more experiments with this model you could have a look at the original code here:
# https://github.com/facebookresearch/colorlessgreenRNNs/blob/master/src/language_models/model.py
#
from collections import defaultdict
from lstm.model import RNNModel
import torch


model_location = './state_dict.pt'  # <- point this to the location of the Gulordava .pt file
lstm = RNNModel('LSTM', 50001, 650, 650, 2)
lstm.load_state_dict(torch.load(model_location))


# This LSTM does not use a Tokenizer like the Transformers, but a Vocab dictionary that maps a token to an id.
with open('lstm/vocab.txt') as f:
    w2i = {w.strip(): i for i, w in enumerate(f)}

vocab = defaultdict(lambda: w2i["<unk>"])
vocab.update(w2i)

It is a good idea that before you move on, you try to feed some text to your LMs; and check if everything works accordingly. 

# Data

For this assignment you will train your probes on __treebank__ corpora. A treebank is a corpus that has been *parsed*, and stored in a representation that allows the parse tree to be recovered. Next to a parse tree, treebanks also often contain information about part-of-speech tags, which is exactly what we are after now.

The treebank you will use for now is part of the Universal Dependencies project. I provide a sample of this treebank as well, so you can test your setup on that before moving on to larger amounts of data.

Make sure you accustom yourself to the format that is created by the `conllu` library that parses the treebank files before moving on. For example, make sure you understand how you can access the pos tag of a token, or how to cope with the tree structure that is formed using the `to_tree()` functionality.

In [3]:
# READ DATA
from typing import List
from conllu import parse_incr, TokenList


# If stuff like `: str` and `-> ..` seems scary, fear not! 
# These are type hints that help you to understand what kind of argument and output is expected.
def parse_corpus(filename: str) -> List[TokenList]:
    data_file = open(filename, encoding="utf-8")

    ud_parses = list(parse_incr(data_file))
    
    return ud_parses


ud_parses = parse_corpus('data/sample/en_ewt-ud-train.conllu')


# Generating Representations

We now have our data all set, our models are running and we are good to go!

The next step is now to create the model representations for the sentences in our corpora. Once we have generated these representations we can store them, and train additional diagnostic (/probing) classifiers on top of the representations.

There are a few things you should keep in mind here. Read these carefully, as these tips will save you a lot of time in your implementation.
1. Transformer models make use of Byte-Pair Encodings (BPE), that chunk up a piece of next in subword pieces. For example, a word such as "largely" could be chunked up into "large" and "ly". We are interested in probing linguistic information on the __word__-level. Therefore, we will follow the suggestion of Hewitt et al. (2019a, footnote 4), and create the representation of a word by averaging over the representations of its subwords. So the representation of "largely" becomes the average of that of "large" and "ly".


2. Subword chunks never overlap multiple tokens. In other words, say we have a phrase like "None of the", then the tokenizer might chunk that into "No"+"ne"+" of"+" the", but __not__ into "No"+"ne o"+"f the", as those chunks overlap multiple tokens. This is great for our setup! Otherwise it would have been quite challenging to distribute the representation of a subword over the 2 tokens it belongs to.


3. **Important**: If you closely examine the provided treebank, you will notice that some tokens are split up into multiple pieces, that each have their own POS-tag. For example, in the first sentence the word "Al-Zaman" is split into "Al", "-", and "Zaman". In such cases, the conllu `TokenList` format will add the following attribute: `('misc', OrderedDict([('SpaceAfter', 'No')]))` to these tokens. Your model's tokenizer does not need to adhere to the same tokenization. E.g., "Al-Zaman" could be split into "Al-"+"Za"+"man", making it hard to match the representations with their correct pos-tag. Therefore I recommend you to not tokenize your entire sentence at once, but to do this based on the chunking of the treebank. <br /><br />
Make sure to still incoporate the spaces in a sentence though, as these are part of the BPE of the tokenizer. That is, the tokenizer uses a different token id for `"man"`, than it does for `" man"`: the former could be part of `" woman"`=`" wo`"+`"man"`, whereas the latter would be the used in case *man* occurs at the start of a word. The tokenizer for GPT-2 adds spaces at the start of a token (represented as a `Ġ` symbol). This means that you should keep track whether the previous token had the `SpaceAfter` attribute set to `'No'`: in case it did not, you should manually prepend a `" "` ahead of the token.


4. The LSTM LM does not have the issues related to subwords, but is far more restricted in its vocabulary. Make sure you keep the above points in mind though, when creating the LSTM representations. You might want to write separate functions for the LSTM, but that is up to you.


5. The huggingface transformer models don't return the hidden state by default. To achieve this you can pass `output_hidden_states=True` to a model forward pass. The hidden states are then returned for all intermediate layers as well, the latest entry in this list corresponds to the top layer.


6. **N.B.**: Make sure that when you run a sentence through your model, you do so within a `with torch.no_grad():` block, and that you have run `model.eval()` beforehand as well (to disable dropout).


7. **N.B.**: Make sure to use a token's ``["form"]`` attribute, and not the ``["lemma"]``, as the latter will stem any relevant morphological information from the token. We don't want this, because we want to feed well-formed, grammatical sentences to our model.


I would like to stress that if you feel hindered in any way by the simple code structure that is presented here, you are free to modify it :-) Just make sure it is clear to an outsider what you're doing, some helpful comments never hurt.

In [4]:
# FETCH SENTENCE REPRESENTATIONS
from torch import Tensor
import pickle


# Should return a tensor of shape (num_tokens_in_corpus, representation_size)
# Make sure you correctly average the subword representations that belong to 1 token!

def get_lstm_representations(ud_parses, model, tokenizer):
    model.eval()

    with torch.no_grad():
        hidden = model.init_hidden(1)
        arr = []
        for sent in ud_parses:
            for token in sent:
                form = token["form"]
                if form in tokenizer:
                    arr.append(tokenizer[form])
                else:
                    arr.append(tokenizer["<unk>"])
        ids = torch.tensor(arr)
        rep = model(ids.unsqueeze(0), hidden)
        rep = rep.squeeze(0)
        return rep
    

def get_gpt_representations(ud_parses, model, tokenizer):
    model.eval()
    
    with torch.no_grad():
        all_sentence_reps = []
        for sentence in ud_parses:            
            no_of_original_tokens = len(sentence)
            inputs_ids_list = []
            
            for tok_idx in range(no_of_original_tokens):
                token = sentence[tok_idx]['form']
                inputs_ids_list.append(tokenizer(token, return_tensors="pt")['input_ids'])
                
            inputs_tensor = torch.cat(inputs_ids_list, -1)
            
            outputs = model(input_ids=inputs_tensor, output_hidden_states=True)
            final_reps = outputs.hidden_states[-1][0]
            
            combined_reps = []
            idx = 0
            
            for input_ids in inputs_ids_list:
                i_len = input_ids.size(-1)
                combined_reps.append(final_reps[idx:idx+i_len].mean(0))
                idx += i_len
        
            sentence_rep = torch.stack(combined_reps)
            
            all_sentence_reps.append(sentence_rep)
    return torch.cat(all_sentence_reps)


def fetch_sen_reps(ud_parses: List[TokenList], model, tokenizer) -> Tensor:    
    if isinstance(model, GPT2LMHeadModel):
        rep = get_gpt_representations(ud_parses, model, tokenizer)
    elif isinstance(model, RNNModel):
        rep = get_lstm_representations(ud_parses, model, tokenizer)
    else:
        print("NOT A SUPPORTED MODEL!!")
        return None
    
    return rep

To validate your activation extraction procedure I have set up the following assertion function as a sanity check. It compares your representation of the first sentence in the corpus against a pickled version of mine. 

For this I used `distilgpt2`.

In [5]:
def error_msg(model_name, gold_embs, embs, i2w):
    with open(f'{model_name}_tokens1.pickle', 'rb') as f:
        sen_tokens = pickle.load(f)
        
    diff = torch.abs(embs - gold_embs)
    max_diff = torch.max(diff)
    avg_diff = torch.mean(diff)
    
    print(f"{model_name} embeddings don't match!")
    print(f"Max diff.: {max_diff:.4f}\nMean diff. {avg_diff:.4f}")

    print("\nCheck if your tokenization matches with the original tokenization:")
    for idx in sen_tokens.squeeze():
        if isinstance(i2w, list):
            token = i2w[idx]
        else:
            token = i2w.convert_ids_to_tokens(idx.item())
        print(f"{idx:<6} {token}")


def assert_sen_reps(model, tokenizer, lstm, vocab):
    with open('distilgpt2_emb1.pickle', 'rb') as f:
        distilgpt2_emb1 = pickle.load(f)
        
    with open('lstm_emb1.pickle', 'rb') as f:
        lstm_emb1 = pickle.load(f)
       
    
    corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')[:1]
    
    own_distilgpt2_emb1 = fetch_sen_reps(corpus, model, tokenizer)
    own_lstm_emb1 = fetch_sen_reps(corpus, lstm, vocab)
    
    assert distilgpt2_emb1.shape == own_distilgpt2_emb1.shape, \
        f"Distilgpt2 shape mismatch: {distilgpt2_emb1.shape} (gold) vs. {own_distilgpt2_emb1.shape} (yours)"
    assert lstm_emb1.shape == own_lstm_emb1.shape, \
        f"LSTM shape mismatch: {lstm_emb1.shape} (gold) vs. {own_lstm_emb1.shape} (yours)"

    if not torch.allclose(distilgpt2_emb1, own_distilgpt2_emb1, rtol=1e-3, atol=1e-3):
        error_msg("distilgpt2", distilgpt2_emb1, own_distilgpt2_emb1, tokenizer)
    if not torch.allclose(lstm_emb1, own_lstm_emb1, rtol=1e-3, atol=1e-3):
        error_msg("lstm", lstm_emb1, own_lstm_emb1, list(vocab.keys()))


assert_sen_reps(model, tokenizer, lstm, vocab)

distilgpt2 embeddings don't match!
Max diff.: 57.6748
Mean diff. 0.3720

Check if your tokenization matches with the original tokenization:
2348   Al
12     -
57     Z
10546  aman
1058   Ġ:
1605   ĠAmerican
3386   Ġforces
2923   Ġkilled
19413  ĠSha
13848  ikh
26804  ĠAbdullah
435    Ġal
12     -
2025   An
72     i
11     ,
262    Ġthe
39797  Ġpreacher
379    Ġat
262    Ġthe
18575  Ġmosque
287    Ġin
262    Ġthe
3240   Ġtown
286    Ġof
1195   ĠQ
1385   aim
11     ,
1474   Ġnear
262    Ġthe
6318   ĠSyrian
4865   Ġborder
13     .


Next, we should define a function that extracts the corresponding POS labels for each activation, which we do based on the **``"upostag"``** attribute of a token (so not the ``xpostag`` attribute). These labels will be transformed to a tensor containing the label index for each item.

In [6]:
# FETCH POS LABELS

from typing import Optional, Dict, Tuple
# Should return a tensor of shape (num_tokens_in_corpus,)
# Make sure that when fetching these pos tags for your train/dev/test corpora you share the label vocabulary.
def fetch_pos_tags(ud_parses: List[TokenList], pos_vocab: Optional[Dict[str, int]] = None) -> Tuple[Tensor, Dict[str, int]]:
    pos_tags = []
    for sent in ud_parses:
        for token in sent:
            pos = token["upostag"]
            if pos_vocab is None:
                pos_vocab = {}
            if pos not in pos_vocab:
                new_id =  len(pos_vocab)
                pos_vocab[pos] = new_id
                pos_tags.append(new_id)
            elif pos in pos_vocab:
                old_id = pos_vocab[pos]
                pos_tags.append(old_id)
            else: 
                pass
    return torch.tensor(pos_tags), pos_vocab
                
                

Finally, we can combine all these methods to retrieve the representations (`fetch_sen_reps`) and the corresponding labels (`fetch_pos_tags`). If you are still debugging and testing your setup you can set the `use_sample` variable to `True`, and once everything works you can extract the full corpus by setting it to `False`.

The reason we pass the `train_vocab` to the data creation of the `dev` and `test` data is that we want to use the same label vocabulary across the different train/dev/test splits.

In [7]:
import os

# Function that combines the previous functions, and creates 2 tensors for a .conllu file: 
# 1 containing the token representations, and 1 containing the (tokenized) pos_tags.

def create_data(ud_parses, filename: str, lm, w2i, pos_vocab=None):
    ud_parses = parse_corpus(filename)
    
    sen_reps = fetch_sen_reps(ud_parses, lm, w2i)
    pos_tags, pos_vocab = fetch_pos_tags(ud_parses, pos_vocab=pos_vocab)
    
    return sen_reps, pos_tags, pos_vocab


lm = model  # or `lstm`
w2i = tokenizer  # or `vocab`
# lm = lstm
# w2i = vocab
use_sample = True

train_x, train_y, train_vocab = create_data(
    ud_parses,
    os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-train.conllu'),
    lm, 
    w2i
)

dev_x, dev_y, _ = create_data(
    ud_parses,
    os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-dev.conllu'),
    lm, 
    w2i,
    pos_vocab=train_vocab
)

test_x, test_y, _ = create_data(
    ud_parses,
    os.path.join('data', 'sample' if use_sample else '', 'en_ewt-ud-test.conllu'),
    lm,
    w2i,
    pos_vocab=train_vocab
)

# Diagnostic Classification

We now have our models, our data, _and_ our representations all set! Hurray, well done. We can finally move onto the cool stuff, i.e. training the diagnostic classifiers (DCs).

DCs are simple in their complexity on purpose. To read more about why this is the case you could already have a look at the "Designing and Interpreting Probes with Control Tasks" by Hewitt and Liang (esp. Sec. 3.2).

A simple linear classifier will suffice for now, don't bother with adding fancy non-linearities to it.

I am personally a fan of the `skorch` library, that provides `sklearn`-like functionalities for training `torch` models, but you are free to train your dc using whatever method you prefer.

As this is an Artificial Intelligence master and you have all done ML1 + DL, I expect you to use your train/dev/test splits correctly ;-)

In [8]:
import torch
from torch import nn
import torch.nn.functional as F

# ! pip install skorch
from skorch import NeuralNetClassifier

In [9]:
# DIAGNOSTIC CLASSIFIER
class ClassifierModule(nn.Module):
    def __init__(
            self,
            num_units=15,
    ):
        super(ClassifierModule, self).__init__()
        self.num_units = num_units
        self.output = nn.Linear(768, num_units)


    def forward(self, X, **kwargs):
        X = self.output(X)
        X = F.softmax(X, dim=-1)
        return X

net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=2500,
    lr=0.001)

In [10]:
net.fit(train_x, train_y)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m3.3967[0m       [32m0.1258[0m        [35m2.8583[0m  0.0560
      2        [36m2.8895[0m       0.1085        3.0174  0.0344
      3        [36m2.8069[0m       0.1193        [35m2.6833[0m  0.0358
      4        [36m2.6844[0m       0.1193        [35m2.6423[0m  0.0358
      5        2.7162       0.1106        2.6760  0.0353
      6        [36m2.5536[0m       0.1106        3.1018  0.0349
      7        2.6291       0.1193        2.8488  0.0348
      8        2.5865       0.1258        2.7091  0.0351
      9        [36m2.4941[0m       [32m0.1302[0m        2.7429  0.0353
     10        2.5496       [32m0.1367[0m        [35m2.5549[0m  0.0350
     11        [36m2.4072[0m       0.1345        2.7042  0.0346
     12        2.4917       [32m0.1562[0m        [35m2.5056[0m  0.0354
     13        [36m2.3709[0m       [32m0.1735[0m      

    107        [36m1.4362[0m       0.4382        [35m1.7520[0m  0.0361
    108        [36m1.4313[0m       [32m0.4403[0m        [35m1.7460[0m  0.0354
    109        [36m1.4265[0m       [32m0.4469[0m        [35m1.7401[0m  0.0355
    110        [36m1.4217[0m       0.4469        [35m1.7343[0m  0.0362
    111        [36m1.4169[0m       [32m0.4490[0m        [35m1.7285[0m  0.0347
    112        [36m1.4122[0m       [32m0.4512[0m        [35m1.7227[0m  0.0351
    113        [36m1.4075[0m       [32m0.4555[0m        [35m1.7170[0m  0.0346
    114        [36m1.4028[0m       [32m0.4599[0m        [35m1.7113[0m  0.0346
    115        [36m1.3982[0m       0.4599        [35m1.7057[0m  0.0349
    116        [36m1.3936[0m       0.4599        [35m1.7001[0m  0.0346
    117        [36m1.3891[0m       0.4599        [35m1.6946[0m  0.0344
    118        [36m1.3846[0m       [32m0.4642[0m        [35m1.6891[0m  0.0348
    119        [36m1.3801[0m       

    211        [36m1.0823[0m       0.5727        [35m1.3186[0m  0.0361
    212        [36m1.0800[0m       0.5727        [35m1.3158[0m  0.0354
    213        [36m1.0777[0m       0.5727        [35m1.3130[0m  0.0355
    214        [36m1.0754[0m       0.5727        [35m1.3103[0m  0.0359
    215        [36m1.0731[0m       [32m0.5748[0m        [35m1.3076[0m  0.0362
    216        [36m1.0709[0m       0.5748        [35m1.3049[0m  0.0355
    217        [36m1.0686[0m       [32m0.5770[0m        [35m1.3022[0m  0.0359
    218        [36m1.0664[0m       0.5770        [35m1.2996[0m  0.0362
    219        [36m1.0642[0m       [32m0.5792[0m        [35m1.2969[0m  0.0363
    220        [36m1.0620[0m       0.5792        [35m1.2943[0m  0.0363
    221        [36m1.0598[0m       0.5792        [35m1.2917[0m  0.0358
    222        [36m1.0576[0m       [32m0.5813[0m        [35m1.2892[0m  0.0356
    223        [36m1.0555[0m       0.5813        [35m1.2866[

    317        [36m0.8970[0m       0.6529        [35m1.1083[0m  0.0347
    318        [36m0.8957[0m       0.6529        [35m1.1068[0m  0.0344
    319        [36m0.8944[0m       [32m0.6551[0m        [35m1.1054[0m  0.0357
    320        [36m0.8930[0m       0.6551        [35m1.1040[0m  0.0369
    321        [36m0.8917[0m       [32m0.6573[0m        [35m1.1026[0m  0.0368
    322        [36m0.8904[0m       0.6573        [35m1.1012[0m  0.0369
    323        [36m0.8891[0m       0.6573        [35m1.0999[0m  0.0363
    324        [36m0.8878[0m       0.6573        [35m1.0985[0m  0.0365
    325        [36m0.8865[0m       [32m0.6594[0m        [35m1.0971[0m  0.0358
    326        [36m0.8852[0m       0.6594        [35m1.0957[0m  0.0356
    327        [36m0.8840[0m       [32m0.6616[0m        [35m1.0944[0m  0.0349
    328        [36m0.8827[0m       [32m0.6638[0m        [35m1.0930[0m  0.0352
    329        [36m0.8814[0m       [32m0.6681[0m   

    424        [36m0.7808[0m       [32m0.7050[0m        [35m0.9877[0m  0.0357
    425        [36m0.7799[0m       0.7050        [35m0.9868[0m  0.0374
    426        [36m0.7790[0m       0.7050        [35m0.9859[0m  0.0350
    427        [36m0.7782[0m       0.7050        [35m0.9850[0m  0.0352
    428        [36m0.7773[0m       0.7050        [35m0.9842[0m  0.0345
    429        [36m0.7764[0m       0.7050        [35m0.9833[0m  0.0347
    430        [36m0.7755[0m       0.7050        [35m0.9824[0m  0.0350
    431        [36m0.7747[0m       0.7050        [35m0.9815[0m  0.0347
    432        [36m0.7738[0m       0.7050        [35m0.9807[0m  0.0343
    433        [36m0.7729[0m       0.7050        [35m0.9798[0m  0.0353
    434        [36m0.7721[0m       [32m0.7072[0m        [35m0.9789[0m  0.0362
    435        [36m0.7712[0m       0.7072        [35m0.9781[0m  0.0365
    436        [36m0.7704[0m       0.7072        [35m0.9772[0m  0.0362
    437

    532        [36m0.6993[0m       [32m0.7354[0m        [35m0.9070[0m  0.0359
    533        [36m0.6987[0m       0.7354        [35m0.9064[0m  0.0360
    534        [36m0.6980[0m       0.7354        [35m0.9058[0m  0.0363
    535        [36m0.6974[0m       0.7354        [35m0.9052[0m  0.0357
    536        [36m0.6968[0m       0.7354        [35m0.9045[0m  0.0355
    537        [36m0.6961[0m       0.7354        [35m0.9039[0m  0.0373
    538        [36m0.6955[0m       0.7354        [35m0.9033[0m  0.0386
    539        [36m0.6949[0m       0.7354        [35m0.9027[0m  0.0378
    540        [36m0.6942[0m       0.7354        [35m0.9021[0m  0.0382
    541        [36m0.6936[0m       0.7354        [35m0.9015[0m  0.0379
    542        [36m0.6930[0m       [32m0.7375[0m        [35m0.9009[0m  0.0354
    543        [36m0.6923[0m       0.7375        [35m0.9003[0m  0.0353
    544        [36m0.6917[0m       0.7375        [35m0.8997[0m  0.0356
    545

    641        [36m0.6383[0m       0.7549        [35m0.8484[0m  0.0367
    642        [36m0.6378[0m       0.7549        [35m0.8480[0m  0.0364
    643        [36m0.6373[0m       0.7549        [35m0.8475[0m  0.0365
    644        [36m0.6368[0m       0.7549        [35m0.8470[0m  0.0361
    645        [36m0.6363[0m       0.7549        [35m0.8466[0m  0.0353
    646        [36m0.6358[0m       [32m0.7570[0m        [35m0.8461[0m  0.0352
    647        [36m0.6353[0m       0.7570        [35m0.8457[0m  0.0350
    648        [36m0.6349[0m       0.7570        [35m0.8452[0m  0.0348
    649        [36m0.6344[0m       0.7570        [35m0.8448[0m  0.0351
    650        [36m0.6339[0m       0.7570        [35m0.8443[0m  0.0361
    651        [36m0.6334[0m       0.7570        [35m0.8438[0m  0.0366
    652        [36m0.6329[0m       0.7570        [35m0.8434[0m  0.0382
    653        [36m0.6325[0m       0.7570        [35m0.8429[0m  0.0371
    654        

    750        [36m0.5908[0m       0.7636        [35m0.8038[0m  0.0365
    751        [36m0.5904[0m       0.7636        [35m0.8034[0m  0.0368
    752        [36m0.5900[0m       0.7636        [35m0.8031[0m  0.0366
    753        [36m0.5896[0m       0.7636        [35m0.8027[0m  0.0379
    754        [36m0.5892[0m       0.7636        [35m0.8023[0m  0.0381
    755        [36m0.5888[0m       0.7636        [35m0.8020[0m  0.0372
    756        [36m0.5884[0m       0.7636        [35m0.8016[0m  0.0370
    757        [36m0.5881[0m       0.7636        [35m0.8013[0m  0.0358
    758        [36m0.5877[0m       0.7614        [35m0.8009[0m  0.0377
    759        [36m0.5873[0m       0.7614        [35m0.8005[0m  0.0366
    760        [36m0.5869[0m       0.7614        [35m0.8002[0m  0.0360
    761        [36m0.5865[0m       0.7614        [35m0.7998[0m  0.0359
    762        [36m0.5861[0m       0.7614        [35m0.7995[0m  0.0362
    763        [36m0.585

    859        [36m0.5524[0m       0.7701        [35m0.7682[0m  0.0367
    860        [36m0.5521[0m       0.7701        [35m0.7679[0m  0.0371
    861        [36m0.5518[0m       0.7701        [35m0.7676[0m  0.0363
    862        [36m0.5515[0m       0.7701        [35m0.7673[0m  0.0367
    863        [36m0.5512[0m       0.7701        [35m0.7671[0m  0.0371
    864        [36m0.5509[0m       0.7701        [35m0.7668[0m  0.0386
    865        [36m0.5505[0m       0.7701        [35m0.7665[0m  0.0375
    866        [36m0.5502[0m       0.7701        [35m0.7662[0m  0.0377
    867        [36m0.5499[0m       0.7701        [35m0.7659[0m  0.0374
    868        [36m0.5496[0m       [32m0.7722[0m        [35m0.7656[0m  0.0373
    869        [36m0.5493[0m       0.7722        [35m0.7653[0m  0.0385
    870        [36m0.5490[0m       0.7722        [35m0.7650[0m  0.0381
    871        [36m0.5487[0m       0.7722        [35m0.7647[0m  0.0383
    872        

    968        [36m0.5207[0m       0.7831        [35m0.7390[0m  0.0367
    969        [36m0.5204[0m       0.7831        [35m0.7387[0m  0.0369
    970        [36m0.5201[0m       0.7831        [35m0.7385[0m  0.0366
    971        [36m0.5199[0m       0.7831        [35m0.7382[0m  0.0378
    972        [36m0.5196[0m       0.7831        [35m0.7380[0m  0.0376
    973        [36m0.5193[0m       0.7831        [35m0.7378[0m  0.0363
    974        [36m0.5191[0m       0.7831        [35m0.7375[0m  0.0369
    975        [36m0.5188[0m       0.7831        [35m0.7373[0m  0.0375
    976        [36m0.5185[0m       0.7831        [35m0.7370[0m  0.0368
    977        [36m0.5183[0m       0.7831        [35m0.7368[0m  0.0364
    978        [36m0.5180[0m       0.7831        [35m0.7365[0m  0.0362
    979        [36m0.5177[0m       0.7831        [35m0.7363[0m  0.0363
    980        [36m0.5175[0m       0.7831        [35m0.7361[0m  0.0377
    981        [36m0.517

   1078        [36m0.4935[0m       0.7809        [35m0.7141[0m  0.0367
   1079        [36m0.4932[0m       0.7809        [35m0.7139[0m  0.0367
   1080        [36m0.4930[0m       0.7809        [35m0.7137[0m  0.0361
   1081        [36m0.4928[0m       0.7809        [35m0.7135[0m  0.0367
   1082        [36m0.4926[0m       0.7809        [35m0.7133[0m  0.0367
   1083        [36m0.4923[0m       0.7809        [35m0.7131[0m  0.0366
   1084        [36m0.4921[0m       0.7809        [35m0.7129[0m  0.0372
   1085        [36m0.4919[0m       0.7809        [35m0.7127[0m  0.0380
   1086        [36m0.4917[0m       0.7809        [35m0.7125[0m  0.0387
   1087        [36m0.4914[0m       0.7809        [35m0.7123[0m  0.0386
   1088        [36m0.4912[0m       0.7809        [35m0.7121[0m  0.0388
   1089        [36m0.4910[0m       0.7809        [35m0.7119[0m  0.0374
   1090        [36m0.4907[0m       0.7809        [35m0.7116[0m  0.0380
   1091        [36m0.490

   1187        [36m0.4702[0m       0.7874        [35m0.6930[0m  0.0464
   1188        [36m0.4700[0m       0.7874        [35m0.6928[0m  0.0462
   1189        [36m0.4698[0m       0.7874        [35m0.6926[0m  0.0426
   1190        [36m0.4696[0m       0.7874        [35m0.6925[0m  0.0389
   1191        [36m0.4694[0m       0.7874        [35m0.6923[0m  0.0372
   1192        [36m0.4692[0m       0.7874        [35m0.6921[0m  0.0372
   1193        [36m0.4690[0m       0.7874        [35m0.6919[0m  0.0367
   1194        [36m0.4688[0m       0.7874        [35m0.6917[0m  0.0361
   1195        [36m0.4687[0m       0.7874        [35m0.6916[0m  0.0367
   1196        [36m0.4685[0m       0.7874        [35m0.6914[0m  0.0366
   1197        [36m0.4683[0m       0.7874        [35m0.6912[0m  0.0368
   1198        [36m0.4681[0m       0.7874        [35m0.6910[0m  0.0369
   1199        [36m0.4679[0m       0.7874        [35m0.6908[0m  0.0374
   1200        [36m0.467

   1296        [36m0.4499[0m       0.7939        [35m0.6746[0m  0.0369
   1297        [36m0.4497[0m       0.7939        [35m0.6744[0m  0.0373
   1298        [36m0.4495[0m       0.7939        [35m0.6743[0m  0.0378
   1299        [36m0.4494[0m       0.7939        [35m0.6741[0m  0.0375
   1300        [36m0.4492[0m       0.7939        [35m0.6739[0m  0.0385
   1301        [36m0.4490[0m       0.7939        [35m0.6738[0m  0.0377
   1302        [36m0.4488[0m       0.7939        [35m0.6736[0m  0.0386
   1303        [36m0.4487[0m       0.7939        [35m0.6735[0m  0.0364
   1304        [36m0.4485[0m       0.7939        [35m0.6733[0m  0.0364
   1305        [36m0.4483[0m       0.7939        [35m0.6732[0m  0.0360
   1306        [36m0.4481[0m       0.7939        [35m0.6730[0m  0.0365
   1307        [36m0.4480[0m       0.7939        [35m0.6728[0m  0.0384
   1308        [36m0.4478[0m       0.7939        [35m0.6727[0m  0.0383
   1309        [36m0.447

   1406        [36m0.4317[0m       0.7939        [35m0.6582[0m  0.0372
   1407        [36m0.4315[0m       0.7939        [35m0.6580[0m  0.0371
   1408        [36m0.4314[0m       0.7939        [35m0.6579[0m  0.0371
   1409        [36m0.4312[0m       0.7939        [35m0.6578[0m  0.0372
   1410        [36m0.4311[0m       0.7939        [35m0.6576[0m  0.0372
   1411        [36m0.4309[0m       0.7939        [35m0.6575[0m  0.0367
   1412        [36m0.4308[0m       0.7939        [35m0.6574[0m  0.0372
   1413        [36m0.4306[0m       0.7939        [35m0.6572[0m  0.0366
   1414        [36m0.4305[0m       0.7939        [35m0.6571[0m  0.0370
   1415        [36m0.4303[0m       0.7939        [35m0.6569[0m  0.0371
   1416        [36m0.4302[0m       0.7939        [35m0.6568[0m  0.0374
   1417        [36m0.4300[0m       0.7939        [35m0.6567[0m  0.0386
   1418        [36m0.4298[0m       0.7939        [35m0.6565[0m  0.0396
   1419        [36m0.429

   1516        [36m0.4154[0m       0.7983        [35m0.6436[0m  0.0377
   1517        [36m0.4153[0m       0.7983        [35m0.6435[0m  0.0371
   1518        [36m0.4152[0m       0.7983        [35m0.6434[0m  0.0380
   1519        [36m0.4150[0m       0.7983        [35m0.6432[0m  0.0385
   1520        [36m0.4149[0m       0.7983        [35m0.6431[0m  0.0385
   1521        [36m0.4148[0m       0.7983        [35m0.6430[0m  0.0376
   1522        [36m0.4146[0m       0.7983        [35m0.6429[0m  0.0366
   1523        [36m0.4145[0m       0.7983        [35m0.6427[0m  0.0376
   1524        [36m0.4143[0m       0.7983        [35m0.6426[0m  0.0366
   1525        [36m0.4142[0m       0.7983        [35m0.6425[0m  0.0365
   1526        [36m0.4141[0m       0.7983        [35m0.6424[0m  0.0366
   1527        [36m0.4139[0m       0.7983        [35m0.6422[0m  0.0370
   1528        [36m0.4138[0m       0.7983        [35m0.6421[0m  0.0368
   1529        [36m0.413

   1626        [36m0.4008[0m       0.8004        [35m0.6305[0m  0.0366
   1627        [36m0.4007[0m       0.8004        [35m0.6304[0m  0.0360
   1628        [36m0.4005[0m       0.8004        [35m0.6303[0m  0.0356
   1629        [36m0.4004[0m       0.8004        [35m0.6302[0m  0.0369
   1630        [36m0.4003[0m       0.8004        [35m0.6301[0m  0.0364
   1631        [36m0.4002[0m       0.8004        [35m0.6300[0m  0.0369
   1632        [36m0.4000[0m       0.8004        [35m0.6299[0m  0.0377
   1633        [36m0.3999[0m       0.8004        [35m0.6298[0m  0.0380
   1634        [36m0.3998[0m       0.8004        [35m0.6296[0m  0.0377
   1635        [36m0.3997[0m       0.8004        [35m0.6295[0m  0.0369
   1636        [36m0.3995[0m       0.8004        [35m0.6294[0m  0.0370
   1637        [36m0.3994[0m       0.8004        [35m0.6293[0m  0.0374
   1638        [36m0.3993[0m       0.8004        [35m0.6292[0m  0.0376
   1639        [36m0.399

   1735        [36m0.3876[0m       0.8069        [35m0.6188[0m  0.0379
   1736        [36m0.3875[0m       0.8069        [35m0.6187[0m  0.0373
   1737        [36m0.3874[0m       0.8069        [35m0.6186[0m  0.0384
   1738        [36m0.3873[0m       0.8069        [35m0.6185[0m  0.0380
   1739        [36m0.3871[0m       0.8069        [35m0.6184[0m  0.0384
   1740        [36m0.3870[0m       0.8069        [35m0.6183[0m  0.0379
   1741        [36m0.3869[0m       0.8069        [35m0.6182[0m  0.0388
   1742        [36m0.3868[0m       0.8069        [35m0.6181[0m  0.0394
   1743        [36m0.3867[0m       0.8069        [35m0.6180[0m  0.0387
   1744        [36m0.3866[0m       0.8069        [35m0.6179[0m  0.0383
   1745        [36m0.3865[0m       0.8069        [35m0.6178[0m  0.0385
   1746        [36m0.3863[0m       0.8069        [35m0.6177[0m  0.0464
   1747        [36m0.3862[0m       0.8069        [35m0.6176[0m  0.0390
   1748        [36m0.386

   1845        [36m0.3754[0m       0.8091        [35m0.6081[0m  0.0396
   1846        [36m0.3753[0m       0.8091        [35m0.6080[0m  0.0396
   1847        [36m0.3752[0m       0.8091        [35m0.6079[0m  0.0394
   1848        [36m0.3751[0m       0.8091        [35m0.6078[0m  0.0390
   1849        [36m0.3750[0m       0.8091        [35m0.6077[0m  0.0378
   1850        [36m0.3749[0m       0.8091        [35m0.6076[0m  0.0367
   1851        [36m0.3748[0m       0.8091        [35m0.6075[0m  0.0369
   1852        [36m0.3747[0m       0.8091        [35m0.6075[0m  0.0375
   1853        [36m0.3746[0m       0.8091        [35m0.6074[0m  0.0370
   1854        [36m0.3745[0m       0.8091        [35m0.6073[0m  0.0368
   1855        [36m0.3744[0m       0.8091        [35m0.6072[0m  0.0365
   1856        [36m0.3743[0m       0.8091        [35m0.6071[0m  0.0368
   1857        [36m0.3742[0m       0.8091        [35m0.6070[0m  0.0375
   1858        [36m0.374

   1955        [36m0.3643[0m       0.8113        [35m0.5983[0m  0.0369
   1956        [36m0.3642[0m       0.8113        [35m0.5982[0m  0.0372
   1957        [36m0.3641[0m       0.8113        [35m0.5981[0m  0.0372
   1958        [36m0.3640[0m       0.8113        [35m0.5980[0m  0.0368
   1959        [36m0.3639[0m       0.8113        [35m0.5980[0m  0.0377
   1960        [36m0.3638[0m       0.8113        [35m0.5979[0m  0.0376
   1961        [36m0.3637[0m       0.8113        [35m0.5978[0m  0.0374
   1962        [36m0.3636[0m       0.8113        [35m0.5977[0m  0.0376
   1963        [36m0.3635[0m       0.8113        [35m0.5976[0m  0.0375
   1964        [36m0.3634[0m       0.8113        [35m0.5975[0m  0.0374
   1965        [36m0.3633[0m       0.8113        [35m0.5974[0m  0.0369
   1966        [36m0.3632[0m       0.8113        [35m0.5974[0m  0.0376
   1967        [36m0.3631[0m       0.8113        [35m0.5973[0m  0.0378
   1968        [36m0.363

   2064        [36m0.3540[0m       0.8178        [35m0.5894[0m  0.0372
   2065        [36m0.3540[0m       0.8178        [35m0.5893[0m  0.0383
   2066        [36m0.3539[0m       0.8178        [35m0.5892[0m  0.0382
   2067        [36m0.3538[0m       0.8178        [35m0.5891[0m  0.0384
   2068        [36m0.3537[0m       0.8178        [35m0.5890[0m  0.0400
   2069        [36m0.3536[0m       0.8178        [35m0.5890[0m  0.0382
   2070        [36m0.3535[0m       0.8178        [35m0.5889[0m  0.0380
   2071        [36m0.3534[0m       0.8178        [35m0.5888[0m  0.0383
   2072        [36m0.3533[0m       0.8178        [35m0.5887[0m  0.0385
   2073        [36m0.3532[0m       0.8178        [35m0.5887[0m  0.0380
   2074        [36m0.3531[0m       0.8178        [35m0.5886[0m  0.0382
   2075        [36m0.3531[0m       0.8178        [35m0.5885[0m  0.0381
   2076        [36m0.3530[0m       0.8178        [35m0.5884[0m  0.0382
   2077        [36m0.352

   2174        [36m0.3445[0m       0.8178        [35m0.5811[0m  0.0373
   2175        [36m0.3444[0m       0.8178        [35m0.5810[0m  0.0374
   2176        [36m0.3443[0m       0.8178        [35m0.5809[0m  0.0382
   2177        [36m0.3442[0m       0.8178        [35m0.5808[0m  0.0379
   2178        [36m0.3441[0m       0.8178        [35m0.5808[0m  0.0381
   2179        [36m0.3441[0m       0.8178        [35m0.5807[0m  0.0380
   2180        [36m0.3440[0m       0.8178        [35m0.5806[0m  0.0369
   2181        [36m0.3439[0m       0.8178        [35m0.5805[0m  0.0368
   2182        [36m0.3438[0m       0.8178        [35m0.5805[0m  0.0386
   2183        [36m0.3437[0m       0.8178        [35m0.5804[0m  0.0381
   2184        [36m0.3436[0m       0.8178        [35m0.5803[0m  0.0380
   2185        [36m0.3436[0m       0.8178        [35m0.5803[0m  0.0381
   2186        [36m0.3435[0m       0.8178        [35m0.5802[0m  0.0370
   2187        [36m0.343

   2284        [36m0.3356[0m       0.8200        [35m0.5734[0m  0.0376
   2285        [36m0.3355[0m       0.8200        [35m0.5733[0m  0.0392
   2286        [36m0.3354[0m       0.8200        [35m0.5732[0m  0.0385
   2287        [36m0.3353[0m       0.8200        [35m0.5732[0m  0.0378
   2288        [36m0.3353[0m       0.8200        [35m0.5731[0m  0.0378
   2289        [36m0.3352[0m       0.8200        [35m0.5730[0m  0.0380
   2290        [36m0.3351[0m       0.8200        [35m0.5730[0m  0.0379
   2291        [36m0.3350[0m       0.8200        [35m0.5729[0m  0.0387
   2292        [36m0.3349[0m       0.8200        [35m0.5728[0m  0.0386
   2293        [36m0.3349[0m       0.8200        [35m0.5728[0m  0.0384
   2294        [36m0.3348[0m       0.8200        [35m0.5727[0m  0.0387
   2295        [36m0.3347[0m       0.8200        [35m0.5726[0m  0.0370
   2296        [36m0.3346[0m       0.8200        [35m0.5726[0m  0.0372
   2297        [36m0.334

   2394        [36m0.3273[0m       0.8221        [35m0.5662[0m  0.0380
   2395        [36m0.3272[0m       0.8221        [35m0.5662[0m  0.0371
   2396        [36m0.3271[0m       0.8221        [35m0.5661[0m  0.0372
   2397        [36m0.3270[0m       0.8221        [35m0.5660[0m  0.0375
   2398        [36m0.3270[0m       0.8221        [35m0.5660[0m  0.0368
   2399        [36m0.3269[0m       0.8221        [35m0.5659[0m  0.0372
   2400        [36m0.3268[0m       0.8221        [35m0.5659[0m  0.0377
   2401        [36m0.3267[0m       0.8221        [35m0.5658[0m  0.0378
   2402        [36m0.3267[0m       0.8221        [35m0.5657[0m  0.0383
   2403        [36m0.3266[0m       0.8221        [35m0.5657[0m  0.0376
   2404        [36m0.3265[0m       0.8221        [35m0.5656[0m  0.0371
   2405        [36m0.3265[0m       0.8221        [35m0.5655[0m  0.0383
   2406        [36m0.3264[0m       0.8221        [35m0.5655[0m  0.0383
   2407        [36m0.326

<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=ClassifierModule(
    (output): Linear(in_features=768, out_features=15, bias=True)
  ),
)

In [11]:
net.score(dev_x,dev_y)

0.807032590051458

In [12]:
net.score(test_x, test_y)

0.8169991326973114

# Trees

For our gold labels, we need to recover the node distances from our parse tree. For this we will use the functionality provided by `ete3`, that allows us to compute that directly. I have provided code that transforms a `TokenTree` to a `Tree` in `ete3` format.

In [13]:
# In case you want to transform your conllu tree to an nltk.Tree, for better visualisation

def rec_tokentree_to_nltk(tokentree):
    token = tokentree.token["form"]
    tree_str = f"({token} {' '.join(rec_tokentree_to_nltk(t) for t in tokentree.children)})"

    return tree_str


def tokentree_to_nltk(tokentree):
    from nltk import Tree as NLTKTree

    tree_str = rec_tokentree_to_nltk(tokentree)

    return NLTKTree.fromstring(tree_str)

In [14]:
# !pip install ete3
from ete3 import Tree as EteTree


class FancyTree(EteTree):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, format=1, **kwargs)
        
    def __str__(self):
        return self.get_ascii(show_internal=True)
    
    def __repr__(self):
        return str(self)


def rec_tokentree_to_ete(tokentree):
    idx = str(tokentree.token["id"])
    children = tokentree.children
    if children:
        return f"({','.join(rec_tokentree_to_ete(t) for t in children)}){idx}"
    else:
        return idx
    
def tokentree_to_ete(tokentree):
    newick_str = rec_tokentree_to_ete(tokentree)

    return FancyTree(f"{newick_str};")

In [15]:
# Let's check if it works!
# We can read in a corpus using the code that was already provided, and convert it to an ete3 Tree.

def parse_corpus(filename):
    from conllu import parse_incr

    data_file = open(filename, encoding="utf-8")

    ud_parses = list(parse_incr(data_file))
    
    return ud_parses

corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')
item = corpus[0]
tokentree = item.to_tree()
ete3_tree = tokentree_to_ete(tokentree)
print(ete3_tree)


   /-2
  |
  |--3
  |
  |--4
  |
  |   /6 /-5
  |  |
  |  |   /-9
  |  |  |
  |  |  |--10
  |  |  |
  |  |  |--11
  |  |-8|
  |  |  |--12
  |-7|  |
  |  |  |--13
  |  |  |
  |  |   \15/-14
-1|  |
  |  |   /-16
  |  |  |
  |  |  |--17
  |  |  |
  |   \18   /-19
  |     |  |
  |     |  |--20
  |     |  |
  |     |  |-23/-22
  |      \21
  |        |--24
  |        |
  |        |   /-25
  |        |  |
  |         \28--26
  |           |
  |            \-27
  |
   \-29


As you can see we label a token by its token id (converted to a string). Based on these id's we are going to retrieve the node distances.

To create the true distances of a parse tree in our treebank, we are going to use the `.get_distance` method that is provided by `ete3`: http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#working-with-branch-distances

We will store all these distances in a `torch.Tensor`.

Please fill in the gap in the following method. I recommend you to have a good look at Hewitt's blog post  about these node distances.

In [16]:
def create_gold_distances(corpus):
    all_distances = []

    for item in (corpus):
        tokentree = item.to_tree()
        ete_tree = tokentree_to_ete(tokentree)
 
        sen_len = len(ete_tree.search_nodes())
        distances = torch.zeros((sen_len, sen_len))
        for src_node in range(1, sen_len+1):
            for target_node in range(1, sen_len+1):
                if distances[src_node-1][target_node-1] == 0:
                     s = ete_tree&str(src_node)
                     t  = ete_tree&str(target_node)
                     dist = s.get_distance(t)
                     distances[src_node-1][target_node-1] = dist
                     distances[target_node-1][src_node-1] = dist

        # Your code for computing all the distances comes here.

        all_distances.append(distances)

    return all_distances

The next step is now to do the previous step the other way around. After all, we are mainly interested in predicting the node distances of a sentence, in order to recreate the corresponding parse tree.

Hewitt et al. reconstruct a parse tree based on a _minimum spanning tree_ (MST, https://en.wikipedia.org/wiki/Minimum_spanning_tree). Fortunately for us, we can simply import a method from `scipy` that retrieves this MST.

In [17]:
from scipy.sparse.csgraph import minimum_spanning_tree
import torch


def create_mst(distances):
    distances = torch.triu(distances).detach().numpy()
    mst = minimum_spanning_tree(distances).toarray()
    mst[mst>0] = 1.
    return mst

Let's have a look at what this looks like, by looking at a relatively short sentence in the sample corpus.

If your addition to the `create_gold_distances` method has been correct, you should be able to run the following snippet. This then shows you the original parse tree, the distances between the nodes, and the MST that is retrieved from these distances. Can you spot the edges in the MST matrix that correspond to the edges in the parse tree?

In [18]:
item = corpus[5]
tokentree = item.to_tree()
ete3_tree = tokentree_to_ete(tokentree)
print(ete3_tree, '\n')

gold_distance = create_gold_distances(corpus[5:6])[0]

mst = create_mst(gold_distance)


   /2 /-1
  |
  |--3
  |
  |--4
  |
  |   /-6
  |  |
-5|  |--7
  |-8|
  |  |   /-9
  |  |  |
  |   \12--10
  |     |
  |      \-11
  |
   \-13 



Now that we are able to map edge distances back to parse trees, we can create code for our quantitative evaluation. For this we will use the Undirected Unlabeled Attachment Score (UUAS), which is expressed as:

$$\frac{\text{number of predicted edges that are an edge in the gold parse tree}}{\text{number of edges in the gold parse tree}}$$

To do this, we will need to obtain all the edges from our MST matrix. Note that, since we are using undirected trees, that an edge can be expressed in 2 ways: an edge between node $i$ and node $j$ is denoted by both `mst[i,j] = 1`, or `mst[j,i] = 1`.

You will write code that computes the UUAS score for a matrix of predicted distances, and the corresponding gold distances. I recommend you to split this up into 2 methods: 1 that retrieves the edges that are present in an MST matrix, and one general method that computes the UUAS score.

In [19]:
def edges(mst):
    edges = set()

    edges_list = []
    for i, row in enumerate(mst):
        for j, val in enumerate(mst[i]):
            if int(val) == 1 and i<j: # i < j ensures that (1,2) is added and not (2,1) undirectional case
                edges_list.append((i,j))
                
    edges = set(edges_list)
    return edges


def calc_uuas(pred_distances, gold_distances):
    num, denom = 0, 0
    uuas = 0
    
    for pred_matrix, gold_matrix in zip(pred_distances, gold_distances):
        pred_mst = create_mst(pred_matrix)
        gold_mst = create_mst(gold_matrix)
        pred_edges = edges(pred_mst)
        gold_edges = edges(gold_mst)
        edges_pg = pred_edges & gold_edges
        num += len(edges_pg)
        denom += len(gold_edges)

    uuas = num/denom
    return uuas


# Structural Probes

We now have everything in place to start doing the actual exciting stuff: training our structural probe!
    
To make life easier for you, we will simply take the `torch` code for this probe from John Hewitt's repository. This allows you to focus on the training regime from now on.

In [20]:
import torch.nn as nn
import torch


class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, device="cpu"):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents


I have provided a rough outline for the training regime that you can use. Note that the hyper parameters that I provide here only serve as an indication, but should be (briefly) explored by yourself.

As can be seen in Hewitt's code above, there exists functionality in the probe to deal with batched input. It is up to you to use that: a (less efficient) method can still incorporate batches by doing multiple forward passes for a batch and computing the backward pass only once for the summed losses of all these forward passes. (_I know, this is not the way to go, but in the interest of time that is allowed ;-), the purpose of the assignment is writing a good paper after all_).

In [21]:
from torch import optim
from tqdm import tqdm

'''
Similar to the `create_data` method of the previous notebook, I recommend you to use a method 
that initialises all the data of a corpus. Note that for your embeddings you can use the 
`fetch_sen_reps` method again. However, for the POS probe you concatenated all these representations into 
1 big tensor of shape (num_tokens_in_corpus, model_dim). 

The StructuralProbe expects its input to contain all the representations of 1 sentence, so I recommend you
to update your `fetch_sen_reps` method in a way that it is easy to retrieve all the representations that 
correspond to a single sentence.
''' 


def fetch_sen_reps(ud_parses: List[TokenList], model, tokenizer, concat) -> Tensor:    
    rep = []
    if isinstance(model, GPT2LMHeadModel):
        for ud_parse in ud_parses:
            rep.append(get_gpt_representations([ud_parse], model, tokenizer))
        if concat:
            rep = nn.utils.rnn.pad_sequence(rep, batch_first=True)
    elif isinstance(model, RNNModel):
        for ud_parse in ud_parses:
            rep.append(get_lstm_representations([ud_parse], model, tokenizer))
        if concat:
            rep = nn.utils.rnn.pad_sequence(rep, batch_first=True)
    else:
        print("NOT A SUPPORTED MODEL!!")
        return None
    
    return rep
    

def init_corpus(path, model, tokenizer, concat=False, cutoff=None):
    """ Initialises the data of a corpus.
    
    Parameters
    ----------
    path : str
        Path to corpus location
    concat : bool, optional
        Optional toggle to concatenate all the tensors
        returned by `fetch_sen_reps`.
    cutoff : int, optional
        Optional integer to "cutoff" the data in the corpus.
        This allows only a subset to be used, alleviating 
        memory usage.
    """
    corpus = parse_corpus(path)[:cutoff]

    embs = fetch_sen_reps(corpus, model, tokenizer, concat=concat)    
    gold_distances = create_gold_distances(corpus)

    lengths = [sent.size(0) for sent in gold_distances]
    maxlen = int(max(lengths))
    label_maxshape = [maxlen for _ in gold_distances[0].shape]
    labels = [-torch.ones(*label_maxshape) for _ in range(len(lengths))]

    for idx, gold_dist in enumerate(gold_distances):
        length = lengths[idx]
        labels[idx][:length,:length] = gold_dist
    
    labels = torch.stack(labels)

    return labels, embs, torch.Tensor(lengths)


# I recommend you to write a method that can evaluate the UUAS & loss score for the dev (& test) corpus.
# Feel free to alter the signature of this method.
def evaluate_probe(probe, loss_function, _data):
    # YOUR CODE HERE

    probe.eval()
    y, x, sent_lens = _data
    preds = probe(x)
    loss_score, _ = loss_function(preds, y, sent_lens)
    
    preds_new, y_new = [], []
    for i, length in enumerate(sent_lens):
        length = int(length)
        preds_resized, y_resized = preds[i, :length, :length], y[i, :length, :length]
        preds_new.append(preds_resized)
        y_new.append(y_resized)
 
    uuas_score = calc_uuas(preds_new, y_new)
    return loss_score, uuas_score


# Feel free to alter the signature of this method.
def train(_data, _dev_data, _test_data, epochs):
    emb_dim = 768
    rank = 64
    lr = 1e-1
    batch_size = 15

    probe = StructuralProbe(emb_dim, rank)
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,patience=1)
    loss_function =  L1DistanceLoss()

    train_y, train_x, train_sent_lens = _data

    for epoch in range(epochs):

        for i in range(0, len(corpus), batch_size):
            probe.train()
            optimizer.zero_grad()

            # YOUR CODE FOR DOING A PROBE FORWARD PASS

            _train_batch = train_x[i:i+batch_size]
            _train_labels = train_y[i:i+batch_size]
            _train_lengths = train_sent_lens[i:i+batch_size]
            _preds = probe(_train_batch)
            batch_loss, total_sents = loss_function(_preds, _train_labels, _train_lengths)
            batch_loss.backward()
            optimizer.step()

        dev_loss, dev_uuas = evaluate_probe(probe, loss_function, _dev_data)
        print("Dev Loss: {}, Dev uuas: {}".format(dev_loss, dev_uuas))
        # Using a scheduler is up to you, and might require some hyper param fine-tuning
        scheduler.step(dev_loss)

    test_loss, test_uuas = evaluate_probe(probe,loss_function,_test_data)
    print("Test Loss: {}, Test uuas: {}".format(test_loss, test_uuas))
    return probe, loss_function

In [22]:
train_data_path = 'data/sample/en_ewt-ud-train.conllu'
dev_data_path = 'data/sample/en_ewt-ud-dev.conllu'
test_data_path = 'data/sample/en_ewt-ud-test.conllu'

lm = model  # `model` or `lstm`
w2i = tokenizer  # `tokenizer` or `vocab`

train_data = init_corpus(train_data_path, lm, w2i, concat=True)
dev_data = init_corpus(dev_data_path, lm, w2i, concat=True)
test_data = init_corpus(test_data_path, lm, w2i, concat=True)

In [23]:
probe, loss_fn = train(train_data, dev_data, test_data, 200)

Dev Loss: 589.7552490234375, Dev uuas: 0.17204301075268819
Dev Loss: 407.6564025878906, Dev uuas: 0.18458781362007168
Dev Loss: 318.6963195800781, Dev uuas: 0.1639784946236559
Dev Loss: 281.4398193359375, Dev uuas: 0.1603942652329749
Dev Loss: 189.85659790039062, Dev uuas: 0.15412186379928317
Dev Loss: 129.0977020263672, Dev uuas: 0.15681003584229392
Dev Loss: 83.31909942626953, Dev uuas: 0.1774193548387097
Dev Loss: 52.66570281982422, Dev uuas: 0.17652329749103943
Dev Loss: 35.878395080566406, Dev uuas: 0.17652329749103943
Dev Loss: 24.669536590576172, Dev uuas: 0.1774193548387097
Dev Loss: 20.601089477539062, Dev uuas: 0.15412186379928317
Dev Loss: 20.056547164916992, Dev uuas: 0.16577060931899643
Dev Loss: 14.486572265625, Dev uuas: 0.17025089605734767
Dev Loss: 21.6811466217041, Dev uuas: 0.13978494623655913
Dev Loss: 22.907289505004883, Dev uuas: 0.13799283154121864
Dev Loss: 14.077306747436523, Dev uuas: 0.16129032258064516
Dev Loss: 6.057729721069336, Dev uuas: 0.145161290322580

Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.5992473363876343, Dev uuas: 0.271505376344086
Dev Loss: 1.59

## LaTeX trees

For your report you might want to add some of those fancy dependency tree plots like those of Figure 2 in the Structural Probing paper. For that you can use the following code, that outputs the corresponding LaTeX markup.

**N.B.**: for the latex tikz tree the first token in a sentence has index 1 (instead of 0), so take that into account with the predicted and gold edges that you pass to the method.

In [25]:
def print_tikz(predicted_edges, gold_edges, words):
    """ Turns edge sets on word (nodes) into tikz dependency LaTeX.
    Parameters
    ----------
    predicted_edges : Set[Tuple[int, int]]
        Set (or list) of edge tuples, as predicted by your probe.
    gold_edges : Set[Tuple[int, int]]
        Set (or list) of gold edge tuples, as obtained from the treebank.
    words : List[str]
        List of strings representing the tokens in the sentence.
    """

    string = """\\begin{dependency}[hide label, edge unit distance=.5ex]
    \\begin{deptext}[column sep=0.05cm]
    """

    string += (
        "\\& ".join([x.replace("$", "\$").replace("&", "+") for x in words])
        + " \\\\\n"
    )
    string += "\\end{deptext}" + "\n"
    for i_index, j_index in gold_edges:
        string += "\\depedge[-]{{{}}}{{{}}}{{{}}}\n".format(i_index, j_index, ".")
    for i_index, j_index in predicted_edges:
        string += f"\\depedge[-,edge style={{red!60!}}, edge below]{{{i_index}}}{{{j_index}}}{{.}}\n"
    string += "\\end{dependency}\n"
    print(string)

In [29]:
probe.eval()
y, x, sent_lens = test_data
preds = probe(x)

corpus = parse_corpus(test_data_path)[0]

preds_new, y_new = [], []
for i, length in enumerate(sent_lens):
    length = int(length)
    preds_resized, y_resized = preds[i, :length, :length], y[i, :length, :length]
    preds_new.append(preds_resized)
    y_new.append(y_resized)
    break

for pred_matrix, gold_matrix in zip(preds_new, y_new):
    pred_mst = create_mst(pred_matrix)
    gold_mst = create_mst(gold_matrix)
    pred_edges = edges(pred_mst)
    gold_edges = edges(gold_mst)

words = []
for sent in ud_parses:
    for token in sent:
        words.append(token["form"])

print_tikz(pred_edges, gold_edges, words)

\begin{dependency}[hide label, edge unit distance=.5ex]
    \begin{deptext}[column sep=0.05cm]
    Al\& -\& Zaman\& :\& American\& forces\& killed\& Shaikh\& Abdullah\& al\& -\& Ani\& ,\& the\& preacher\& at\& the\& mosque\& in\& the\& town\& of\& Qaim\& ,\& near\& the\& Syrian\& border\& .\& [\& This\& killing\& of\& a\& respected\& cleric\& will\& be\& causing\& us\& trouble\& for\& years\& to\& come\& .\& ]\& DPA\& :\& Iraqi\& authorities\& announced\& that\& they\& had\& busted\& up\& 3\& terrorist\& cells\& operating\& in\& Baghdad\& .\& Two\& of\& them\& were\& being\& run\& by\& 2\& officials\& of\& the\& Ministry\& of\& the\& Interior\& !\& The\& MoI\& in\& Iraq\& is\& equivalent\& to\& the\& US\& FBI\& ,\& so\& this\& would\& be\& like\& having\& J.\& Edgar\& Hoover\& unwittingly\& employ\& at\& a\& high\& level\& members\& of\& the\& Weathermen\& bombers\& back\& in\& the\& 1960s\& .\& The\& third\& was\& being\& run\& by\& the\& head\& of\& an\& investment\& firm\& .\& You\&