__Probing Language Models__

This notebook serves as a start for your NLP2 assignment on probing Language Models. This notebook will become part of the contents that you will submit at the end, so make sure to keep your code (somewhat) clean :-)

__note__: This is only the second time anyone is doing this assignment. That's exciting! But it might well be the case that certain aspects are too unclear. Do not hesitate at all to reach to me once you get stuck, I'd be grateful to help you out.

__note 2__: This assignment is not dependent on big fancy GPUs. I run all this stuff on my own 3 year old CPU, without any Colab hassle. So it's up to you to decide how you want to run it.

# Models

For the Transformer models you are advised to make use of the `transformers` library of Huggingface: https://github.com/huggingface/transformers
Their library is well documented, and they provide great tools to easily load in pre-trained models.

In [1]:
#############################
## INITIALIZING ALL MODELS ##
#############################

# Manual:

# BertLM & tokenizer_bert use the BERT Transformer infrastructure (small / uncased)
# ElmoLM & tokenizer_elmo use the ELMo LSTM
# GulordavaLM & tokenizer_gulordava use Gulordava's LSTM
# GPT2LM & tokenizer_gpt2 use the GPT2 transformer

###########
##### BERT
###########
from transformers import BertTokenizer, BertForMaskedLM, BertModel
from torch.nn import functional as F
import torch

modelBertForFunTesting = BertForMaskedLM.from_pretrained('bert-base-uncased', output_hidden_states=True)
modelBertForFunTesting.eval()

tokenizer_bert = BertTokenizer.from_pretrained('bert-base-uncased')

# Just a fun test with a masked LM

text = "From the early 20th century onward " + tokenizer_bert.mask_token + " again became the capital of Russia."

input = tokenizer_bert.encode_plus(text, return_tensors = "pt")
mask_index = torch.where(input["input_ids"][0] == tokenizer_bert.mask_token_id)
output = modelBertForFunTesting(**input)
logits = output.logits
softmax = F.softmax(logits, dim = -1)
mask_word = softmax[0, mask_index, :]
top_10 = torch.topk(mask_word, 10, dim = 1)[1][0]

solutions = []

for token in top_10:
   word = tokenizer_bert.decode([token])
   solutions.append(word)
   #new_sentence = text.replace(tokenizer.mask_token, word)

print(solutions) # Those are the possible fill-ins for the mask! Cool!

# Now the real stuff we need for the experiments

BertLM = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
BertLM.eval()

# Let's still check if it's all working (also because it's just fun to do)
sentences = [['In', 'present', 'times', 'the', 'capital', 'of', 'Russia', 'is', 'still', 'Moscow', '.'],
             ['However', ',', 'maybe', 'at', 'some', 'point', 'in', 'time', ',', 'it', 'will', 'again', 'become', 'Kiev', '.']]

sentences_tokenized = []
segments = []

i = 0
for sentence in sentences:
    sentences_tokenized += tokenizer_bert.tokenize(" ".join(sentence))
    segments += [i] * (len(sentence) + 1)
    i += 1
    
sentences_tokenized = ['[CLS]'] + sentences_tokenized + ['[SEP]']
indexes = tokenizer_bert.convert_tokens_to_ids(sentences_tokenized)

for tup in zip(sentences_tokenized, indexes):
    print('{:<12} {:>6,}'.format(tup[0], tup[1]))

with torch.no_grad():
    outputs = BertLM(torch.tensor([indexes]), torch.tensor([segments]))
    print(len(outputs[2])) # dims: 13 (input + 12 layers) x 1 (batch) x 28 (words) x 768 features
    
###########
##### ELMo
###########

# The model itself

from allennlp.modules.elmo import Elmo, batch_to_ids
from sacremoses import MosesTokenizer

options_file = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

tokenizer_elmo = MosesTokenizer('en')
ElmoLM = Elmo(options_file, weight_file, 1, dropout=0, requires_grad=False)

# And again some fun with testing
sentences = [['In', 'present', 'times', 'the', 'capital', 'of', 'Russia', 'is', 'still', 'Moscow', '.'],
             ['However', ',', 'maybe', 'at', 'some', 'point', 'in', 'time', ',', 'it', 'will', 'again', 'become', 'Kiev', '.'],
             ['Currently', ',', 'Kiev', 'is', 'the', 'capital', 'of', 'Ukraine', '.']]
character_ids = batch_to_ids(sentences)

embeddings = ElmoLM(character_ids)
print(embeddings['elmo_representations'][0].shape) # dims: 3 (batches/sentences) x 15 (max words) x 1024 features

###########
##### Gulordava's LSTM
###########
from collections import defaultdict
from lstm.model import RNNModel
import torch

model_location = 'lstm/state_dict.pt'  # <- point this to the location of the Gulordava .pt file
GulordavaLM = RNNModel('LSTM', 50001, 650, 650, 2)
GulordavaLM.load_state_dict(torch.load(model_location))
GulordavaLM.eval()

# This LSTM does not use a Tokenizer like the Transformers, but a Vocab dictionary that maps a token to an id.
with open('lstm/vocab.txt', encoding='utf-8') as f:
    w2igd = {w.strip(): i for i, w in enumerate(f)}

tokenizer_gulordava = defaultdict(lambda: w2igd["<unk>"])
tokenizer_gulordava.update(w2igd)

###########
##### GPT2 Transformer
###########
from transformers import GPT2Model, GPT2Tokenizer

GPT2LM = GPT2Model.from_pretrained("distilgpt2", output_hidden_states=True)
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained("distilgpt2", output_hidden_states=True, is_split_into_words=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


['moscow', 'kazan', 'it', 'kiev', 'leningrad', 'odessa', 'petersburg', 'novgorod', 'sofia', 'riga']
[CLS]           101
in            1,999
present       2,556
times         2,335
the           1,996
capital       3,007
of            1,997
russia        3,607
is            2,003
still         2,145
moscow        4,924
.             1,012
however       2,174
,             1,010
maybe         2,672
at            2,012
some          2,070
point         2,391
in            1,999
time          2,051
,             1,010
it            2,009
will          2,097
again         2,153
become        2,468
kiev         12,100
.             1,012
[SEP]           102
13
torch.Size([3, 15, 1024])


Some weights of GPT2Model were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.3.attn.masked_bias', 'transformer.h.4.attn.masked_bias', 'transformer.h.5.attn.masked_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


It is a good idea that before you move on, you try to feed some text to your LMs; and check if everything works accordingly. 

# Data

For this assignment you will train your probes on __treebank__ corpora. A treebank is a corpus that has been *parsed*, and stored in a representation that allows the parse tree to be recovered. Next to a parse tree, treebanks also often contain information about part-of-speech tags, which is exactly what we are after now.

The treebank you will use for now is part of the Universal Dependencies project. I provide a sample of this treebank as well, so you can test your setup on that before moving on to larger amounts of data.

Make sure you accustom yourself to the format that is created by the `conllu` library that parses the treebank files before moving on. For example, make sure you understand how you can access the pos tag of a token, or how to cope with the tree structure that is formed using the `to_tree()` functionality.

In [2]:
# READ DATA
from typing import List
from conllu import parse_incr, TokenList

# If stuff like `: str` and `-> ..` seems scary, fear not! 
# These are type hints that help you to understand what kind of argument and output is expected.
def parse_corpus(filename: str) -> List[TokenList]:
    data_file = open(filename, encoding="utf-8")

    ud_parses = list(parse_incr(data_file))
    
    return ud_parses

# Generating Representations

We now have our data all set, our models are running and we are good to go!

The next step is now to create the model representations for the sentences in our corpora. Once we have generated these representations we can store them, and train additional diagnostic (/probing) classifiers on top of the representations.

There are a few things you should keep in mind here. Read these carefully, as these tips will save you a lot of time in your implementation.
1. Transformer models make use of Byte-Pair Encodings (BPE), that chunk up a piece of next in subword pieces. For example, a word such as "largely" could be chunked up into "large" and "ly". We are interested in probing linguistic information on the __word__-level. Therefore, we will follow the suggestion of Hewitt et al. (2019a, footnote 4), and create the representation of a word by averaging over the representations of its subwords. So the representation of "largely" becomes the average of that of "large" and "ly".


2. Subword chunks never overlap multiple tokens. In other words, say we have a phrase like "None of the", then the tokenizer might chunk that into "No"+"ne"+" of"+" the", but __not__ into "No"+"ne o"+"f the", as those chunks overlap multiple tokens. This is great for our setup! Otherwise it would have been quite challenging to distribute the representation of a subword over the 2 tokens it belongs to.


3. **Important**: If you closely examine the provided treebank, you will notice that some tokens are split up into multiple pieces, that each have their own POS-tag. For example, in the first sentence the word "Al-Zaman" is split into "Al", "-", and "Zaman". In such cases, the conllu `TokenList` format will add the following attribute: `('misc', OrderedDict([('SpaceAfter', 'No')]))` to these tokens. Your model's tokenizer does not need to adhere to the same tokenization. E.g., "Al-Zaman" could be split into "Al-"+"Za"+"man", making it hard to match the representations with their correct pos-tag. Therefore I recommend you to not tokenize your entire sentence at once, but to do this based on the chunking of the treebank. <br /><br />
Make sure to still incoporate the spaces in a sentence though, as these are part of the BPE of the tokenizer. That is, the tokenizer uses a different token id for `"man"`, than it does for `" man"`: the former could be part of `" woman"`=`" wo`"+`"man"`, whereas the latter would be the used in case *man* occurs at the start of a word. The tokenizer for GPT-2 adds spaces at the start of a token (represented as a `Ġ` symbol). This means that you should keep track whether the previous token had the `SpaceAfter` attribute set to `'No'`: in case it did not, you should manually prepend a `" "` ahead of the token.


4. The LSTM LM does not have the issues related to subwords, but is far more restricted in its vocabulary. Make sure you keep the above points in mind though, when creating the LSTM representations. You might want to write separate functions for the LSTM, but that is up to you.


5. **N.B.**: Make sure that when you run a sentence through your model, you do so within a `with torch.no_grad():` block, and that you have run `model.eval()` beforehand as well (to disable dropout).


6. **N.B.**: Make sure to use a token's ``["form"]`` attribute, and not the ``["lemma"]``, as the latter will stem any relevant morphological information from the token. We don't want this, because we want to feed well-formed, grammatical sentences to our model.


I would like to stress that if you feel hindered in any way by the simple code structure that is presented here, you are free to modify it :-) Just make sure it is clear to an outsider what you're doing, some helpful comments never hurt.

In [3]:
# FETCH SENTENCE REPRESENTATIONS

from torch import Tensor
import pickle
import re
import numpy as np

corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')

# Should return a tensor of shape (num_tokens_in_corpus, representation_size)
# Make sure you correctly average the subword representations that belong to 1 token!

def fetch_sen_reps(ud_parses: List[TokenList], model, tokenizer, BERTLayerId = -1, GPT2LayerId = -1, stack_all = True, concat_all = True) -> Tensor:
    if not isinstance(model, BertModel) and not isinstance(model, Elmo) and not isinstance(model, RNNModel) and not isinstance(model, GPT2Model):
        raise ValueError('A non-compatible model has been passed (should be either BertModel or Elmo or RNNModel or GPT2Model)')
        
    representations = []
    full_sentences  = []
    
    for sentence in ud_parses:
        representations_sentence = []
        sentence_stack   = []
        tokenizer_stack  = []
        word_stack = []
        words      = []
        
        sentence_length = len(sentence)
            
        put_space_gpt = False
            
        for w, word in enumerate(sentence):
            word_stack.append(word['form'])
            
            #######################################
            ## THE CODE BELOW HANDLES TOKENIZING ##
            #######################################
            
            if isinstance(model, BertModel) or isinstance(model, GPT2Model): # BERT or GPT
                # BERT uses WordPiece encodings rather than Byte-Pair
                tokenizer_single = tokenizer.encode((" " if put_space_gpt and isinstance(model, GPT2Model) else "") + word['form'], add_special_tokens=False)
            elif isinstance(model, Elmo): # ELMo
                tokenizer_single = tokenizer.tokenize(word['form'], escape=False)
            elif isinstance(model, RNNModel): # Gulordava's LSTM
                tokenizer_single = [tokenizer[word['form']]]
                
            tokenizer_stack.append(tokenizer_single)
            sentence_stack += tokenizer_single
                
            if w == sentence_length - 1 or not (word['misc'] is not None and 'SpaceAfter' in word['misc'] and word['misc']['SpaceAfter'] == 'No'):
                if isinstance(model, BertModel) or isinstance(model, GPT2Model): # BERT or GPT
                    # encoding_full MIGHT differ (see remark #3 above); still added for analysis purposes
                    # GPT2 is byte-pair encoding, needs space at the beginning, except first word
                    tokenizer_full = tokenizer.encode((" " if (isinstance(model, GPT2Model) and len(words) > 0) else "") + "".join(word_stack), add_special_tokens=False)             
                elif isinstance(model, Elmo):
                    #sentence_stack.append("".join(word_stack))
                    tokenizer_full = tokenizer.tokenize("".join(word_stack), escape=False)
                elif isinstance(model, RNNModel): # Gulordava's LSTM
                    tokenizer_full = tokenizer["".join(word_stack)]
                    
                words.append((word_stack, tokenizer_full, tokenizer_stack))
                tokenizer_stack = []
                word_stack = []
                put_space_gpt = True
            else:
                put_space_gpt = False
        
        ######################################
        ## THE CODE BELOW HANDLES EMBEDDING ##
        ######################################
        
        if isinstance(model, BertModel): # BERT
            # padding with start/end symbols
            words = [(['[CLS]'], [101], [[101]])] + words + [(['[SEP]'], [102], [[102]])]
            sentence_stack = [101] + sentence_stack + [102]

            segments = [1] * len(sentence_stack)
            
            with torch.no_grad():
                outputs = model(torch.tensor([sentence_stack]), torch.tensor([segments]))
            embeddings = outputs[2][BERTLayerId][-1] # middle index is the layer (it's a tuple: input layer + 12 hidden)

        elif isinstance(model, Elmo): # ELMo
            sentence_tokenized = [tokenizer.tokenize(" ".join(sentence_stack), escape=False)]
            character_ids = batch_to_ids(sentence_tokenized)
            embeddings = model(character_ids)['elmo_representations'][-1][-1]
            
        elif isinstance(model, RNNModel): # Gulordava's LSTM
            encoded_input = torch.as_tensor(sentence_stack).reshape(1, -1)
            
            with torch.no_grad():
                hidden = model.init_hidden(encoded_input.shape[0])
                embeddings = model(encoded_input, hidden).squeeze()

        elif isinstance(model, GPT2Model): # GPT2
            encoding_stack = []
            
            for word in words:
                encoding_stack += word[1]
                
            with torch.no_grad():
                outputs = model(torch.tensor([sentence_stack]))
            embeddings = outputs.hidden_states[GPT2LayerId][0]
        
        #######################################################
        ## THE CODE BELOW HANDLES THE ACTUAL REPRESENTATIONS ##
        #######################################################
        
        i = 0
        for full_word in words: # a 'full word' is an arbitrary sequence before a space
            if full_word[0][0] in ['[CLS]', '[SEP]']:
                i += 1
                continue # we don't need them here (since we're comparing with an LSTM)
                
            for treebank_token in full_word[2]: # chunking according to treebank corpus
                rep = sum(embeddings[i:i+(len(treebank_token))])/len(treebank_token)
                representations.append(rep)
                representations_sentence.append(rep)
                i += len(treebank_token)
        full_sentences.append(torch.stack(representations_sentence))
    if not stack_all: # for structural probing
        # if concat_all is true, sentences won't be grouped (can still be used in structural probing if careful)
        return (representations if concat_all else full_sentences)
    
    corpus_representation = torch.stack(representations)
    
    return corpus_representation

In [4]:
# Testing all models

# ELMo
print(fetch_sen_reps(corpus, ElmoLM, tokenizer_elmo).shape)

# Bert
print(fetch_sen_reps(corpus, BertLM, tokenizer_bert).shape)

# Gulordava's LSTM
print(fetch_sen_reps(corpus, GulordavaLM, tokenizer_gulordava).shape)

# GPT2
print(fetch_sen_reps(corpus, GPT2LM, tokenizer_gpt2).shape)

torch.Size([2301, 1024])
torch.Size([2301, 768])
torch.Size([2301, 650])
torch.Size([2301, 768])


To validate your activation extraction procedure I have set up the following assertion function as a sanity check. It compares your representation against a pickled version of mine. 

For this I used `distilgpt2`.

In [5]:
def error_msg(model_name, gold_embs, embs, i2w):
    with open(f'{model_name}_tokens1.pickle', 'rb') as f:
        sen_tokens = pickle.load(f)
        
    diff = torch.abs(embs - gold_embs)
    max_diff = torch.max(diff)
    avg_diff = torch.mean(diff)
    
    print(f"{model_name} embeddings don't match!")
    print(f"Max diff.: {max_diff:.4f}\nMean diff. {avg_diff:.4f}")

    print("\nCheck if your tokenization matches with the original tokenization:")
    for idx in sen_tokens.squeeze():
        if isinstance(i2w, list):
            token = i2w[idx]
        else:
            token = i2w.convert_ids_to_tokens(idx.item())
        print(f"{idx:<6} {token}")


def assert_sen_reps(model, tokenizer, lstm, vocab):
    with open('distilgpt2_emb1.pickle', 'rb') as f:
        distilgpt2_emb1 = pickle.load(f)
        
    with open('lstm_emb1.pickle', 'rb') as f:
        lstm_emb1 = pickle.load(f)
    
    corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')[:1]
    
    own_distilgpt2_emb1 = fetch_sen_reps(corpus, model, tokenizer)
    own_lstm_emb1 = fetch_sen_reps(corpus, lstm, vocab)
    
    assert distilgpt2_emb1.shape == own_distilgpt2_emb1.shape, \
        f"Distilgpt2 shape mismatch: {distilgpt2_emb1.shape} (gold) vs. {own_distilgpt2_emb1.shape} (yours)"
    assert lstm_emb1.shape == own_lstm_emb1.shape, \
        f"LSTM shape mismatch: {lstm_emb1.shape} (gold) vs. {own_lstm_emb1.shape} (yours)"

    if not torch.allclose(distilgpt2_emb1, own_distilgpt2_emb1, rtol=1e-3, atol=1e-3):
        error_msg("distilgpt2", distilgpt2_emb1, own_distilgpt2_emb1, tokenizer)
    if not torch.allclose(lstm_emb1, own_lstm_emb1, rtol=1e-3, atol=1e-3):
        error_msg("lstm", lstm_emb1, own_lstm_emb1, list(vocab.keys()))

In [6]:
# This only works for GulordavaLM and GPT2LM

assert_sen_reps(GPT2LM, tokenizer_gpt2, GulordavaLM, tokenizer_gulordava)

Next, we should define a function that extracts the corresponding POS labels for each activation, which we do based on the **``"upostag"``** attribute of a token (so not the ``xpostag`` attribute). These labels will be transformed to a tensor containing the label index for each item.

In [7]:
# FETCH POS LABELS

# Should return a tensor of shape (num_tokens_in_corpus,)
# Make sure that when fetching these pos tags for your train/dev/test corpora you share the label vocabulary.
def fetch_pos_tags(ud_parses: List[TokenList], pos_vocab=None) -> Tensor:
    if pos_vocab is None: # to make it compatible with Jaap's function
        pos_vocab = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
        
    pos_tags = []
    
    for sentence in ud_parses:
        for word in sentence:
            pos_tags.append(pos_vocab.index(word['upostag']))
            
    return torch.tensor(pos_tags, dtype=torch.int), pos_vocab

In [95]:
# control task

def fetch_control_pos_tags(corpus):
    pass

In [96]:
pos_tags, _ = fetch_pos_tags(corpus, None)

print(pos_tags.shape)

torch.Size([2301])


# Diagnostic Classification

We now have our models, our data, _and_ our representations all set! Hurray, well done. We can finally move onto the cool stuff, i.e. training the diagnostic classifiers (DCs).

DCs are simple in their complexity on purpose. To read more about why this is the case you could already have a look at the "Designing and Interpreting Probes with Control Tasks" by Hewitt and Liang (esp. Sec. 3.2).

A simple linear classifier will suffice for now, don't bother with adding fancy non-linearities to it.

I am personally a fan of the `skorch` library, that provides `sklearn`-like functionalities for training `torch` models, but you are free to train your dc using whatever method you prefer.

As this is an Artificial Intelligence master and you have all done ML1 + DL, I expect you to use your train/dev/test splits correctly ;-)

In [113]:
# DIAGNOSTIC CLASSIFIER
# Update@Damiaan - I have now changed this to PyTorch code

import torch.nn as nn
import torch
from torch import optim

class DCLogisticRegression(nn.Module):
    def __init__(self, embedding_dimensions, number_of_pos_tags):
        super(DCLogisticRegression, self).__init__()
        self.linear = nn.Linear(embedding_dimensions, number_of_pos_tags)

    def forward(self, x):
        out = self.linear(x)
        return out
    
def evaluate_pos_probe(probe, _data, batch_size = 24):
    loss_scores = []
    criterion = nn.CrossEntropyLoss()
    
    num_samples = _data['x'].shape[0]
    random_indices = torch.randperm(num_samples)
    
    probe.eval()
    
    correct = 0
    
    for i in range(0, num_samples, batch_size):
        if i + batch_size > len(_data['y']):
            break # drop last
            
        labels     = torch.index_select(_data['y'], 0, random_indices[i:i+batch_size]).long()
        embeddings = torch.index_select(_data['x'], 0, random_indices[i:i+batch_size])
            
        with torch.no_grad():
            output      = probe(embeddings)
            predictions = torch.argmax(output, 1)
            loss     = criterion(output, labels)
            
            correct += (predictions == labels).float().sum()
            
            loss_scores.append(loss)
    
    return torch.mean(torch.tensor(loss_scores)), correct / (i + batch_size + 1)

def fit_pos_probe(_data, pos_vocab, rt_graph = False, shift_embeddings = False):
    """
    Note: if shift_embeddings is set to true, the embeddings of the preceding word will be
          used with the label unchanged.
    """
    epochs     = 8
    lr         = 10e-3
    batch_size = 24
    embedding_dimensions = _data['train_x'].shape[1]
    number_of_pos_tags   = len(pos_vocab)
    
    probe     = DCLogisticRegression(embedding_dimensions, number_of_pos_tags)
    optimizer = torch.optim.SGD(probe.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)
    criterion = nn.CrossEntropyLoss()
    
    num_samples = _data['train_x'].shape[0]
    
    if shift_embeddings:
        num_samples -= 1
        start_index  = 1
    else:
        start_index  = 0
    
    probe.train()
    
    for epoch in range(epochs):
        random_indices = torch.randperm(num_samples)
        
        for i in range(start_index, num_samples, batch_size):
            
            if i + batch_size > num_samples:
                break # drop last
            
            labels     = torch.index_select(_data['train_y'], 0, random_indices[i:i+batch_size]).long()
            
            if shift_embeddings:
                j = i - 1
            else:
                j = i
            
            embeddings = torch.index_select(_data['train_x'], 0, random_indices[j:j+batch_size])
            
            print(labels)
            print(embeddings[1])
            print(random_indices)
            
            return 2
            
            optimizer.zero_grad()
            output = probe(embeddings)
            loss = criterion(output, labels)
            loss.backward(retain_graph=rt_graph)
            optimizer.step()
            
            dev_loss, accuracy = evaluate_pos_probe(probe, {'x': _data['dev_x'], 'y': _data['dev_y']}, batch_size = batch_size)
        
        print('After epoch %d - loss: %.4f - accuracy: %.4f' % (epoch + 1, dev_loss, accuracy))

        scheduler.step(dev_loss)

    test_loss, accuracy = evaluate_pos_probe(probe, {'x': _data['test_x'], 'y': _data['test_y']})
    print('After training - loss: %.3f - accuracy: %.3f' % (test_loss, accuracy))

In [117]:
def init_corpus_pos(path, model, tokenizer, pos_vocab = None, BERTLayerId = -1, GPT2LayerId = -1, cutoff=None, control_task=False):
    corpus = parse_corpus(path)[:cutoff]

    embeddings        = fetch_sen_reps(corpus, model, tokenizer, BERTLayerId = BERTLayerId, GPT2LayerId = GPT2LayerId)
    
    if not control_task:
        labels, pos_vocab = fetch_pos_tags(corpus, pos_vocab=pos_vocab)
    else:
        labels, pos_vocab = fetch_control_pos_tags(corpus)
        
    return embeddings, labels, pos_vocab

def load_and_train_pos(model, tokenizer, BERTLayerId = -1, GPT2LayerId = -1, cutoff = None, rt_graph = False, control_task = False, shift_embeddings = False):
    train_x, train_y, pos_vocab = init_corpus_pos('data/en_ewt-ud-train.conllu', model, tokenizer, pos_vocab = None,      BERTLayerId = BERTLayerId, GPT2LayerId = GPT2LayerId, cutoff = cutoff, control_task = control_task)
    dev_x,   dev_y,   _         = init_corpus_pos('data/en_ewt-ud-dev.conllu',   model, tokenizer, pos_vocab = pos_vocab, BERTLayerId = BERTLayerId, GPT2LayerId = GPT2LayerId, cutoff = cutoff, control_task = control_task)
    test_x,  test_y,  _         = init_corpus_pos('data/en_ewt-ud-test.conllu',  model, tokenizer, pos_vocab = pos_vocab, BERTLayerId = BERTLayerId, GPT2LayerId = GPT2LayerId, cutoff = cutoff, control_task = control_task)
    
    _data = {'train_x': train_x, 'train_y': train_y, 'dev_x': dev_x, 'dev_y': dev_y, 'test_x': test_x, 'test_y': test_y}
    
    fit_pos_probe(_data, pos_vocab, rt_graph=rt_graph, shift_embeddings = False)

In [None]:
# Possible calls:

# load_and_train_pos(GPT2LM, tokenizer_gpt2)
# load_and_train_pos(GPT2LM, tokenizer_gpt2, shift_embeddings = True)

#for k in range(1, 13):
#    print('Now doing layer %d' % k)
#    load_and_train_pos(BertLM, tokenizer_bert, BERTLayerId = k)

# GPT-2 also has 6 layers (can be trained with: for k in range(1, 7)

# for control task, just add: control_task = True as parameter

# Trees

For our gold labels, we need to recover the node distances from our parse tree. For this we will use the functionality provided by `ete3`, that allows us to compute that directly. I have provided code that transforms a `TokenTree` to a `Tree` in `ete3` format.

In [11]:
# In case you want to transform your conllu tree to an nltk.Tree, for better visualisation

def rec_tokentree_to_nltk(tokentree):
    token = tokentree.token["form"]
    tree_str = f"({token} {' '.join(rec_tokentree_to_nltk(t) for t in tokentree.children)})"

    return tree_str


def tokentree_to_nltk(tokentree):
    from nltk import Tree as NLTKTree

    tree_str = rec_tokentree_to_nltk(tokentree)

    return NLTKTree.fromstring(tree_str)

In [12]:
# !pip install ete3
from ete3 import Tree as EteTree


class FancyTree(EteTree):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, format=1, **kwargs)
        
    def __str__(self):
        return self.get_ascii(show_internal=True)
    
    def __repr__(self):
        return str(self)


def rec_tokentree_to_ete(tokentree):
    idx = str(tokentree.token["id"])
    children = tokentree.children
    if children:
        return f"({','.join(rec_tokentree_to_ete(t) for t in children)}){idx}"
    else:
        return idx
    
def tokentree_to_ete(tokentree):
    newick_str = rec_tokentree_to_ete(tokentree)

    return FancyTree(f"{newick_str};")

In [13]:
# Let's check if it works!
# We can read in a corpus using the code that was already provided, and convert it to an ete3 Tree.

def parse_corpus(filename):
    from conllu import parse_incr

    data_file = open(filename, encoding="utf-8")

    ud_parses = list(parse_incr(data_file))
    
    return ud_parses

corpus = parse_corpus('data/sample/en_ewt-ud-train.conllu')

item = corpus[0]
tokentree = item.to_tree()
ete3_tree = tokentree_to_ete(tokentree)
print(ete3_tree)


   /-2
  |
  |--3
  |
  |--4
  |
  |   /6 /-5
  |  |
  |  |   /-9
  |  |  |
  |  |  |--10
  |  |  |
  |  |  |--11
  |  |-8|
  |  |  |--12
  |-7|  |
  |  |  |--13
  |  |  |
  |  |   \15/-14
-1|  |
  |  |   /-16
  |  |  |
  |  |  |--17
  |  |  |
  |   \18   /-19
  |     |  |
  |     |  |--20
  |     |  |
  |     |  |-23/-22
  |      \21
  |        |--24
  |        |
  |        |   /-25
  |        |  |
  |         \28--26
  |           |
  |            \-27
  |
   \-29


As you can see we label a token by its token id (converted to a string). Based on these id's we are going to retrieve the node distances.

To create the true distances of a parse tree in our treebank, we are going to use the `.get_distance` method that is provided by `ete3`: http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#working-with-branch-distances

We will store all these distances in a `torch.Tensor`.

Please fill in the gap in the following method. I recommend you to have a good look at Hewitt's blog post  about these node distances.

In [None]:
# This function is used for the structural probe control task

import random

def create_control_trees(corpus, depth=2):
    trees = [] * len(corpus)
    for i in range(len((corpus))):
            item = corpus[i]
            tokentree = item.to_tree()
            ete_tree = tokentree_to_ete(tokentree)
            levels = [[] for i in range(depth)]
            root = ete_tree.get_tree_root().detach()
            for node in ete_tree.iter_descendants('preorder'):
                level = random.choice(levels)
                level.append(node.detach())
             
            levels = [level for level in levels if len(level) != 0]

            for i in range(len(levels)):
                if i == 0:
                    for node in levels[0]:
                        root.add_child(node)
                else:
                    for node in levels[i]:
                        parent = random.choice(levels[i-1])
                        parent.add_child(node)
           
            trees.append(ete_tree)
    return trees

In [14]:
from itertools import combinations
import math

def create_gold_distances(corpus, control_task = False):
    all_distances = []
    
    if control_task:
        corpus = create_control_trees(corpus)

    for item in (corpus):
        if not control_task:
            tokentree = item.to_tree()
            ete_tree = tokentree_to_ete(tokentree)
        else:
            ete_tree = item
            
        sen_len = len(ete_tree.search_nodes())
        distances = torch.zeros((sen_len, sen_len))
        
        # Proper way to get all nodes should be a traversal: http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#getting-leaves-descendants-and-node-s-relatives
        nodes = ete_tree.traverse("postorder")
        
        # we're dealing with a symmetric matrix, so I came up with this stuff to make it a bit more efficient
        m = o = sen_len - 1 # number of new distances to be put in current column/row (decreases)
        n   = 0 # number of distances already added; if equal to m, then reset n and decrease m
        
        for pair in combinations(nodes, 2):
            # could have also used the tree root as node
            distance = pair[0].get_distance(pair[1], topology_only=False) # http://etetoolkit.org/docs/latest/reference/reference_tree.html#ete3.TreeNode.get_distance
            
            if n == m:
                n = 0
                m  -= 1
            
            j = o - m
            i = n + j + 1
            
            distances[i, j] = distances[j, i] = distance
            distances[j, j] = 0
            
            n += 1
            
        all_distances.append(distances)

    return all_distances

The next step is now to do the previous step the other way around. After all, we are mainly interested in predicting the node distances of a sentence, in order to recreate the corresponding parse tree.

Hewitt et al. reconstruct a parse tree based on a _minimum spanning tree_ (MST, https://en.wikipedia.org/wiki/Minimum_spanning_tree). Fortunately for us, we can simply import a method from `scipy` that retrieves this MST.

In [15]:
from scipy.sparse.csgraph import minimum_spanning_tree

def create_mst(distances):
    distances = torch.triu(distances).detach().numpy()
    
    mst = minimum_spanning_tree(distances).toarray()
    mst[mst>0] = 1.
    
    return mst

Let's have a look at what this looks like, by looking at a relatively short sentence in the sample corpus.

If your addition to the `create_gold_distances` method has been correct, you should be able to run the following snippet. This then shows you the original parse tree, the distances between the nodes, and the MST that is retrieved from these distances. Can you spot the edges in the MST matrix that correspond to the edges in the parse tree?

In [16]:
item = corpus[5]
tokentree = item.to_tree()
ete3_tree = tokentree_to_ete(tokentree)
print(ete3_tree, '\n')


gold_distance = create_gold_distances(corpus[5:6])[0]
print(gold_distance, '\n')

mst = create_mst(gold_distance)
print(mst)


   /2 /-1
  |
  |--3
  |
  |--4
  |
  |   /-6
  |  |
-5|  |--7
  |-8|
  |  |   /-9
  |  |  |
  |   \12--10
  |     |
  |      \-11
  |
   \-13 

tensor([[0., 1., 3., 3., 4., 4., 5., 5., 5., 4., 3., 3., 2.],
        [1., 0., 2., 2., 3., 3., 4., 4., 4., 3., 2., 2., 1.],
        [3., 2., 0., 2., 3., 3., 4., 4., 4., 3., 2., 2., 1.],
        [3., 2., 2., 0., 3., 3., 4., 4., 4., 3., 2., 2., 1.],
        [4., 3., 3., 3., 0., 2., 3., 3., 3., 2., 1., 3., 2.],
        [4., 3., 3., 3., 2., 0., 3., 3., 3., 2., 1., 3., 2.],
        [5., 4., 4., 4., 3., 3., 0., 2., 2., 1., 2., 4., 3.],
        [5., 4., 4., 4., 3., 3., 2., 0., 2., 1., 2., 4., 3.],
        [5., 4., 4., 4., 3., 3., 2., 2., 0., 1., 2., 4., 3.],
        [4., 3., 3., 3., 2., 2., 1., 1., 1., 0., 1., 3., 2.],
        [3., 2., 2., 2., 1., 1., 2., 2., 2., 1., 0., 2., 1.],
        [3., 2., 2., 2., 3., 3., 4., 4., 4., 3., 2., 0., 1.],
        [2., 1., 1., 1., 2., 2., 3., 3., 3., 2., 1., 1., 0.]]) 

[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0

Now that we are able to map edge distances back to parse trees, we can create code for our quantitative evaluation. For this we will use the Undirected Unlabeled Attachment Score (UUAS), which is expressed as:

$$\frac{\text{number of predicted edges that are an edge in the gold parse tree}}{\text{number of edges in the gold parse tree}}$$

To do this, we will need to obtain all the edges from our MST matrix. Note that, since we are using undirected trees, that an edge can be expressed in 2 ways: an edge between node $i$ and node $j$ is denoted by both `mst[i,j] = 1`, or `mst[j,i] = 1`.

You will write code that computes the UUAS score for a matrix of predicted distances, and the corresponding gold distances. I recommend you to split this up into 2 methods: 1 that retrieves the edges that are present in an MST matrix, and one general method that computes the UUAS score.

In [17]:
np.set_printoptions(threshold=np.inf)
def edges(mst, show=False):
    d1, d2 = np.nonzero(mst)
    return tuple(zip(d1, d2)), len(d1)

def calc_uuas(pred_distances, gold_distances):
    uuas = None
    
    # Filtering batch padding
    labels_1s = (gold_distances != -1).float()
    pred_distances = pred_distances * labels_1s
    gold_distances = gold_distances * labels_1s
    
    batch_size = pred_distances.shape[0]
    
    for j in range(batch_size):
        pred_distances_mst = create_mst(pred_distances[j])
        gold_distances_mst = create_mst(gold_distances[j])
        
        
        pred_edges, _     = edges(pred_distances_mst)
        gold_edges, total = edges(gold_distances_mst)
        
        correct = 0
        
        for c in pred_edges:
            if c in gold_edges or (c[1], c[0]) in gold_edges:
                correct += 1
        
        
        if uuas is None:
            uuas = 0
            
        if total != 0:
            uuas += correct/total
        
        
    if uuas is not None:
        uuas /= batch_size
    
    return uuas

# Structural Probes

We now have everything in place to start doing the actual exciting stuff: training our structural probe!
    
To make life easier for you, we will simply take the `torch` code for this probe from John Hewitt's repository. This allows you to focus on the training regime from now on.

In [18]:
class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, device="cpu"):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents


I have provided a rough outline for the training regime that you can use. Note that the hyper parameters that I provide here only serve as an indication, but should be (briefly) explored by yourself.

As can be seen in Hewitt's code above, there exists functionality in the probe to deal with batched input. It is up to you to use that: a (less efficient) method can still incorporate batches by doing multiple forward passes for a batch and computing the backward pass only once for the summed losses of all these forward passes. (_I know, this is not the way to go, but in the interest of time that is allowed ;-), the purpose of the assignment is writing a good paper after all_).

In [19]:
'''
Similar to the `create_data` method of the previous notebook, I recommend you to use a method 
that initialises all the data of a corpus. Note that for your embeddings you can use the 
`fetch_sen_reps` method again. However, for the POS probe you concatenated all these representations into 
1 big tensor of shape (num_tokens_in_corpus, model_dim). 

The StructuralProbe expects its input to contain all the representations of 1 sentence, so I recommend you
to update your `fetch_sen_reps` method in a way that it is easy to retrieve all the representations that 
correspond to a single sentence.
''' 

def init_corpus(path, model, tokenizer, BERTLayerId = -1, GPT2LayerId = -1, concat=False, cutoff=None, control_task=False):
    """ Initialises the data of a corpus.
    
    Parameters
    ----------
    path : str
        Path to corpus location
    concat : bool, optional
        Optional toggle to concatenate all the tensors
        returned by `fetch_sen_reps`.
    cutoff : int, optional
        Optional integer to "cutoff" the data in the corpus.
        This allows only a subset to be used, alleviating 
        memory usage.
    """
    corpus = parse_corpus(path)[:cutoff]

    embs = fetch_sen_reps(corpus, model, tokenizer, BERTLayerId = BERTLayerId, GPT2LayerId = GPT2LayerId, stack_all=concat)
    
    gold_distances = create_gold_distances(corpus, control_task=control_task)
    
    return embs, gold_distances # x & y; we map from embeddings to distances :)


# I recommend you to write a method that can evaluate the UUAS & loss score for the dev (& test) corpus.
# Feel free to alter the signature of this method.
def evaluate_probe(probe, _data, emb_dim, batch_size = 24):
    
    uuas_scores = []
    loss_scores = []
    x_pointer = 0
    
    loss_function =  L1DistanceLoss()
    
    probe.eval()
    
    for i in range(0, len(_data['y']), batch_size):
        if i + batch_size > len(_data['y']):
            break # drop last
            
        labels = _data['y'][i:i+batch_size]
        sequence_lengths    = torch.empty(batch_size, dtype=int)
        max_sequence_length = 0

        for j in range(len(labels)):
            sequence_length     = labels[j].shape[0]
            sequence_lengths[j] = sequence_length
            if sequence_length > max_sequence_length:
                max_sequence_length = sequence_length


        data_batch  =  torch.zeros(batch_size, max_sequence_length, emb_dim)
        label_batch = -torch.ones(batch_size, max_sequence_length, max_sequence_length)
            
        for j in range(len(labels)):
            data_batch[j, 0:sequence_lengths[j], :] = torch.stack(_data['x'][x_pointer:x_pointer + sequence_lengths[j]])
            x_pointer += sequence_lengths[j]
            label_batch[j, 0:labels[j].shape[0], 0:labels[j].shape[0]] = torch.Tensor(labels[j])
            
        with torch.no_grad():
            pred_distances = probe(data_batch)
            loss_score, _  = loss_function(pred_distances, label_batch, sequence_lengths)
            uuas_score     = calc_uuas(pred_distances, label_batch)
            
            loss_scores.append(loss_score)
            uuas_scores.append(uuas_score)
    
    
    return torch.mean(torch.tensor(loss_scores)), torch.mean(torch.tensor(uuas_scores))


# Feel free to alter the signature of this method.
def train(_data, emb_dim, rt_graph = False):
    rank = 64
    lr = 10e-3
    batch_size = 24
    epochs = 40

    probe = StructuralProbe(emb_dim, rank)
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,patience=1)
    loss_function =  L1DistanceLoss()

    probe.train()
    
    sequence_lengths = torch.empty(len(_data['train_y']), dtype=int)
    for k in range(len(_data['train_y'])):
        sequence_lengths[k] = _data['train_y'][k].shape[0]
    
    starting_pointers_x = torch.zeros(len(_data['train_y']) + 1, dtype=int)
    starting_pointers_x[1:] = torch.cumsum(sequence_lengths, dim=0)
    
    for epoch in range(epochs):
        # shuffle training data before
        sequence_ids_epoch = np.random.permutation(len(_data['train_y']))
        
        for i in range(0, len(_data['train_y']), batch_size):
            if i + batch_size > len(_data['train_y']):
                break # drop last
                
            optimizer.zero_grad()
            
            labels = [_data['train_y'][y] for y in sequence_ids_epoch[i:i+batch_size]]
            
            sequence_lengths_batch = torch.empty(batch_size, dtype=int)
            max_sequence_length    = 0
            
            for j in range(len(labels)):
                sequence_length           = labels[j].shape[0]
                sequence_lengths_batch[j] = sequence_length
                if sequence_length > max_sequence_length:
                    max_sequence_length = sequence_length
                    
            data_batch  =  torch.zeros(batch_size, max_sequence_length, emb_dim)
            label_batch = -torch.ones(batch_size, max_sequence_length, max_sequence_length)
            
            for j in range(len(labels)):
                starting_index_x = starting_pointers_x[sequence_ids_epoch[i+j]]
                data_batch[j, 0:sequence_lengths_batch[j], :] = torch.stack(_data['train_x'][starting_index_x:starting_index_x + sequence_lengths_batch[j]])
                label_batch[j, 0:labels[j].shape[0], 0:labels[j].shape[0]] = torch.Tensor(labels[j])
            
            
            pred_distances = probe(data_batch)
            
            batch_loss, _ = loss_function(pred_distances, label_batch, sequence_lengths_batch)
            batch_loss.backward(retain_graph=rt_graph)
            optimizer.step()

        dev_loss, dev_uuas = evaluate_probe(probe, {'x': _data['dev_x'], 'y': _data['dev_y']}, emb_dim)
        
        print('After epoch %d - loss: %.4f, uuas: %.3f' % (epoch + 1, dev_loss, dev_uuas))

        # Using a scheduler is up to you, and might require some hyper param fine-tuning
        scheduler.step(dev_loss)

    test_loss, test_uuas = evaluate_probe(probe, {'x': _data['test_x'], 'y': _data['test_y']}, emb_dim)
    print('After training - loss: %.2f, uuas: %.3f' % (test_loss, test_uuas))

In [20]:
# I introduced the layerId parameters so that we could easily loop over all hidden layers overnight
# (Although I already did that for BERT manually)

def load_and_train(model, tokenizer, dim, BERTLayerId = -1, GPT2LayerId = -1, cutoff = None, rt_graph = False, control_task = False):
    train_x, train_y = init_corpus('data/en_ewt-ud-train.conllu', model, tokenizer, BERTLayerId = BERTLayerId, GPT2LayerId = GPT2LayerId, cutoff = cutoff, control_task = control_task)
    dev_x, dev_y = init_corpus('data/en_ewt-ud-dev.conllu', model, tokenizer, BERTLayerId = BERTLayerId, GPT2LayerId = GPT2LayerId, cutoff = cutoff, control_task = control_task)
    test_x, test_y   = init_corpus('data/en_ewt-ud-test.conllu', model, tokenizer, BERTLayerId = BERTLayerId, GPT2LayerId = GPT2LayerId, cutoff = cutoff, control_task = control_task)

    _data = {'train_x': train_x, 'train_y': train_y, 'dev_x': dev_x, 'dev_y': dev_y, 'test_x': test_x, 'test_y': test_y}
    
    train(_data, emb_dim=dim, rt_graph=rt_graph)

In [None]:
for k in range(1, 7):
    print('Now doing layer %d' % k)
    load_and_train(GPT2LM, tokenizer_gpt2, 768, GPT2LayerId=k, control_task=True)

Now training for layer 1


In [33]:
load_and_train(GulordavaLM, tokenizer_gulordava, 650, control_task=True)

After epoch 1 - loss: 0.6952, uuas: 0.363
After epoch 2 - loss: 0.6864, uuas: 0.382
After epoch 3 - loss: 0.6964, uuas: 0.355
After epoch 4 - loss: 0.6925, uuas: 0.385
After epoch 5 - loss: 0.6237, uuas: 0.424
After epoch 6 - loss: 0.6310, uuas: 0.407
After epoch 7 - loss: 0.6328, uuas: 0.420
After epoch 8 - loss: 0.5932, uuas: 0.442
After epoch 9 - loss: 0.5973, uuas: 0.420
After epoch 10 - loss: 0.6005, uuas: 0.429
After epoch 11 - loss: 0.5770, uuas: 0.453
After epoch 12 - loss: 0.5802, uuas: 0.438
After epoch 13 - loss: 0.5817, uuas: 0.450
After epoch 14 - loss: 0.5759, uuas: 0.458
After epoch 15 - loss: 0.5756, uuas: 0.460
After epoch 16 - loss: 0.5749, uuas: 0.458
After epoch 17 - loss: 0.5751, uuas: 0.449
After epoch 18 - loss: 0.5755, uuas: 0.452
After epoch 19 - loss: 0.5712, uuas: 0.462
After epoch 20 - loss: 0.5713, uuas: 0.460
After epoch 21 - loss: 0.5705, uuas: 0.458
After epoch 22 - loss: 0.5719, uuas: 0.462
After epoch 23 - loss: 0.5726, uuas: 0.457
After epoch 24 - los

In [26]:
for k in range(1, 13):
    print('Now doing layer %d' % k)
    load_and_train(BertLM, tokenizer_bert, 768, BERTLayerId = k, control_task=True)

Now doing layer 6
After epoch 1 - loss: 1.8241, uuas: 0.286
After epoch 2 - loss: 1.5207, uuas: 0.311
After epoch 3 - loss: 2.3648, uuas: 0.297
After epoch 4 - loss: 2.1517, uuas: 0.289
After epoch 5 - loss: 0.7670, uuas: 0.370
After epoch 6 - loss: 0.7871, uuas: 0.343
After epoch 7 - loss: 1.1114, uuas: 0.305
After epoch 8 - loss: 0.6450, uuas: 0.414
After epoch 9 - loss: 0.7100, uuas: 0.397
After epoch 10 - loss: 0.6872, uuas: 0.368
After epoch 11 - loss: 0.5805, uuas: 0.456
After epoch 12 - loss: 0.6031, uuas: 0.432
After epoch 13 - loss: 0.5972, uuas: 0.438
After epoch 14 - loss: 0.5558, uuas: 0.468
After epoch 15 - loss: 0.5505, uuas: 0.468
After epoch 16 - loss: 0.5485, uuas: 0.478
After epoch 17 - loss: 0.5500, uuas: 0.459
After epoch 18 - loss: 0.5438, uuas: 0.465
After epoch 19 - loss: 0.5498, uuas: 0.463
After epoch 20 - loss: 0.5496, uuas: 0.473
After epoch 21 - loss: 0.5324, uuas: 0.475
After epoch 22 - loss: 0.5339, uuas: 0.481
After epoch 23 - loss: 0.5339, uuas: 0.481
Af

After epoch 27 - loss: 0.5224, uuas: 0.484
After epoch 28 - loss: 0.5157, uuas: 0.491
After epoch 29 - loss: 0.5165, uuas: 0.489
After epoch 30 - loss: 0.5163, uuas: 0.486
After epoch 31 - loss: 0.5151, uuas: 0.492
After epoch 32 - loss: 0.5136, uuas: 0.491
After epoch 33 - loss: 0.5140, uuas: 0.491
After epoch 34 - loss: 0.5152, uuas: 0.491
After epoch 35 - loss: 0.5143, uuas: 0.492
After epoch 36 - loss: 0.5140, uuas: 0.491
After epoch 37 - loss: 0.5141, uuas: 0.492
After epoch 38 - loss: 0.5141, uuas: 0.492
After epoch 39 - loss: 0.5140, uuas: 0.492
After epoch 40 - loss: 0.5140, uuas: 0.493
After training - loss: 0.50, uuas: 0.483
Now doing layer 4
After epoch 1 - loss: 1.4558, uuas: 0.330
After epoch 2 - loss: 1.7064, uuas: 0.282
After epoch 3 - loss: 2.3154, uuas: 0.291
After epoch 4 - loss: 0.7413, uuas: 0.390
After epoch 5 - loss: 0.9018, uuas: 0.339
After epoch 6 - loss: 0.9923, uuas: 0.356
After epoch 7 - loss: 0.6347, uuas: 0.428
After epoch 8 - loss: 0.6615, uuas: 0.420
Aft

After epoch 12 - loss: 0.8730, uuas: 0.318
After epoch 13 - loss: 0.6651, uuas: 0.359
After epoch 14 - loss: 0.6699, uuas: 0.373
After epoch 15 - loss: 0.6526, uuas: 0.380
After epoch 16 - loss: 0.6913, uuas: 0.366
After epoch 17 - loss: 0.6944, uuas: 0.363
After epoch 18 - loss: 0.6247, uuas: 0.400
After epoch 19 - loss: 0.6533, uuas: 0.377
After epoch 20 - loss: 0.6162, uuas: 0.386
After epoch 21 - loss: 0.6297, uuas: 0.394
After epoch 22 - loss: 0.6199, uuas: 0.387
After epoch 23 - loss: 0.5902, uuas: 0.416
After epoch 24 - loss: 0.5826, uuas: 0.422
After epoch 25 - loss: 0.5842, uuas: 0.413
After epoch 26 - loss: 0.5870, uuas: 0.407
After epoch 27 - loss: 0.5750, uuas: 0.428
After epoch 28 - loss: 0.5710, uuas: 0.436
After epoch 29 - loss: 0.5764, uuas: 0.426
After epoch 30 - loss: 0.5717, uuas: 0.434
After epoch 31 - loss: 0.5686, uuas: 0.437
After epoch 32 - loss: 0.5689, uuas: 0.440
After epoch 33 - loss: 0.5681, uuas: 0.445
After epoch 34 - loss: 0.5681, uuas: 0.442
After epoch

In [27]:
load_and_train(ElmoLM, tokenizer_elmo, 1024, rt_graph=True, control_task=True)

After epoch 1 - loss: 1.5056, uuas: 0.327
After epoch 2 - loss: 2.0521, uuas: 0.296
After epoch 3 - loss: 1.5555, uuas: 0.308
After epoch 4 - loss: 0.8121, uuas: 0.344
After epoch 5 - loss: 0.8114, uuas: 0.373
After epoch 6 - loss: 0.9090, uuas: 0.375
After epoch 7 - loss: 0.8923, uuas: 0.342
After epoch 8 - loss: 0.6466, uuas: 0.411
After epoch 9 - loss: 0.6874, uuas: 0.380
After epoch 10 - loss: 0.7025, uuas: 0.365
After epoch 11 - loss: 0.6145, uuas: 0.413
After epoch 12 - loss: 0.6105, uuas: 0.411
After epoch 13 - loss: 0.6054, uuas: 0.431
After epoch 14 - loss: 0.6150, uuas: 0.415
After epoch 15 - loss: 0.6141, uuas: 0.423
After epoch 16 - loss: 0.5818, uuas: 0.441
After epoch 17 - loss: 0.5841, uuas: 0.443
After epoch 18 - loss: 0.5928, uuas: 0.452
After epoch 19 - loss: 0.5744, uuas: 0.453
After epoch 20 - loss: 0.5716, uuas: 0.456
After epoch 21 - loss: 0.5767, uuas: 0.458
After epoch 22 - loss: 0.5739, uuas: 0.456
After epoch 23 - loss: 0.5699, uuas: 0.465
After epoch 24 - los