#Probing Language Models

Michael Neely & Vanessa Botha

Natural Language Processing 2, Spring 2020

University of Amsterdam

## Introduction

We investigate the extent to which popular recurrent and attention-based neural models trained withan auto-regressive language modeling objective can represent the hierarchical nature of language. Using word representations across model hidden layers, we test for linguistic and structural properties by training one set of diagnostic classifiers to predict POS-tag labels, and another to extract tree distances between words at the sentence level.

## Setup

### Mount Google Drive storage for persistence

Ensure you have two folders in the root of your Google drive:

1. `probing_lms` with the NLP2 [code and data](https://github.com/jumelet/nlp2-probing-lms). Ensure the [Gulordava LSTM model](https://drive.google.com/open?id=1w47WsZcZzPyBKDn83cMNd0Hb336e-_Sy) is in `lstm` sub-folder.
2. `probing_lms_data` where sentence representation data and diagnostic classifiers will be saved

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
!cp -r "drive/My Drive/probing_lms/" ./

###Install dependencies

In [None]:
!pip install -r probing_lms/requirements.txt
!pip install skorch

### Imports

In [None]:
# Standard library
from collections import defaultdict, Counter
from enum import Enum
import itertools
import math
import os
from pathlib import Path
import pickle
from typing import Any, Dict, List, Optional, Set, Tuple, Union

# NLP2: TA Custom Code
from probing_lms.lstm.model import RNNModel

# External imports
from conllu import parse_incr, TokenList
from ete3 import Tree as EteTree
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.stats import ttest_ind
import seaborn as sns
import skorch
import skorch.helper
from skorch.callbacks import EarlyStopping, LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import transformers
from transformers import XLNetModel, XLNetTokenizer, GPT2Model, GPT2Tokenizer

### Global Configuration

In [None]:
PACKAGE_ROOT = "drive/My Drive/probing_lms"
OUT_ROOT = "drive/My Drive/probing_lms_data"


CORPUS_DATA_ROOT = os.path.join(PACKAGE_ROOT, 'data/')
CORPUS_SAMPLE_ROOT = os.path.join(CORPUS_DATA_ROOT, 'sample/')

OUT_DATA_ROOT = os.path.join(OUT_ROOT, 'data/')
OUT_SAMPLE_DATA_ROOT = os.path.join(OUT_DATA_ROOT, 'sample/')
OUT_MODELS_ROOT = os.path.join(OUT_ROOT, 'models/')
OUT_RESULTS_ROOT = os.path.join(OUT_ROOT, 'results/')
OUT_RESULTS_IMAGES_ROOT = os.path.join(OUT_RESULTS_ROOT, 'images/')

USE_SAMPLE = True # If true use the small sample corpus instead of the full one

TRAIN_CUTOFF = None # The full corpus has 12543 training samples, but you can introduce a cutoff if you want

RUN_SANITY_CHECKS = True # If you want to jump straight into the experiments, you can skip the sanity checks

# You can toggle each diagnostic classification task on/off as needed
RUN_POS_TASK = True 
RUN_STRUCTURAL_TASK = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Filepaths

At a minimum you must have these required data files. They can be retrieved from [here](https://github.com/jumelet/nlp2-probing-lms).

In [None]:
GULORDAVA_LSTM_MODEL_STATE_PATH = os.path.join(PACKAGE_ROOT, 'lstm/gulordava.pt')
LSTM_VOCABULARY_PATH = os.path.join(PACKAGE_ROOT, 'lstm/vocab.txt')

TRAIN_DATA_PATH = os.path.join(CORPUS_DATA_ROOT, 'en_ewt-ud-train.conllu')
VAL_DATA_PATH = os.path.join(CORPUS_DATA_ROOT, 'en_ewt-ud-dev.conllu')
TEST_DATA_PATH = os.path.join(CORPUS_DATA_ROOT, 'en_ewt-ud-test.conllu')

TRAIN_SAMPLE_PATH = os.path.join(CORPUS_SAMPLE_ROOT, 'en_ewt-ud-train.conllu')
VAL_SAMPLE_PATH = os.path.join(CORPUS_SAMPLE_ROOT, 'en_ewt-ud-dev.conllu')
TEST_SAMPLE_PATH = os.path.join(CORPUS_SAMPLE_ROOT, 'en_ewt-ud-test.conllu')

for path in [
             GULORDAVA_LSTM_MODEL_STATE_PATH,
             LSTM_VOCABULARY_PATH,
             TRAIN_DATA_PATH,
             VAL_DATA_PATH,
             TEST_DATA_PATH,
             TRAIN_SAMPLE_PATH,
             VAL_SAMPLE_PATH,
             TEST_SAMPLE_PATH]:
    assert Path(path).is_file()

# Models

In the paper, we train selective diagnostic classifiers to examine whether recurrent models and attention-based transformers models encode linguistic and structural properties differently. Recurrent models have a chain-like nature and process sentences word by word while maintaining a hidden state that summarizes past inputs. In contrast, transformer architectures process all words in the sentence at once, using self-attention to learn dependencies between words.

## Transformers
Define a common interface for Transformer models

In [None]:
class TransformerModel:

    def __init__(self, id: str,  model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer):
        self.id = id
        self.model = model.to(DEVICE)
        self.tokenizer = tokenizer
        self.hidden_layers = self.model.config.n_layer + 1 # include embeddings
        if isinstance(model, GPT2Model):
            self.hidden_dim = self.model.config.n_embd
        else:
            self.hidden_dim = self.model.config.d_model
        self.model.eval()

    def clear_model(self):
        self.model = None

## LSTM

Define interface for the Gulordava et al. pre-trained LSTM from [Colorless green recurrent networks dream hierarchically.](https://arxiv.org/abs/1803.11138)

In [None]:
class GulordavaLSTM:
    def __init__(self, id: str, rnn_model: RNNModel, state_path: str, vocab_path: str):
        self.id = id
        self.model = rnn_model.to(DEVICE)
        self.model.load_state_dict(torch.load(state_path))
        self.hidden_dim = self.model.nhid
        self.vocab = self._init_vocab(vocab_path)
        self.hidden_layers = 2 # embedding + output
        self.model.eval()

    def _init_vocab(self, vocab_path: str) -> Dict[str, int]:
        with open(vocab_path) as f:
            w2i = {w.strip(): i for i, w in enumerate(f)}

        vocab = defaultdict(lambda: w2i["<unk>"])
        vocab.update(w2i)
        return vocab

    def clear_model(self):
        self.model = None

## Language Model Initialization

Finally, let's write a function to quickly load whatever language model we are interested in.

In [None]:
LanguageModel = Union[TransformerModel, GulordavaLSTM]

def initialize_model(model_id: str) -> LanguageModel:
    if model_id == 'xlnet':
        return TransformerModel(
            id=model_id,
            model=XLNetModel.from_pretrained('xlnet-base-cased', output_hidden_states=True),
            tokenizer=XLNetTokenizer.from_pretrained('xlnet-base-cased'))
    elif model_id == 'distilgpt2':
        return TransformerModel(
            id=model_id,
            model=GPT2Model.from_pretrained('distilgpt2', output_hidden_states=True),
            tokenizer=GPT2Tokenizer.from_pretrained('distilgpt2'))
    elif model_id == 'transformer_xl':
        return TransformerModel(
            id=model_id,
            model=TransfoXLModel.from_pretrained('transfo-xl-wt103', output_hidden_states=True),
            tokenizer=TransfoXLTokenizer.from_pretrained('transfo-xl-wt103'))
    elif model_id == 'lstm':
        return GulordavaLSTM(
            id=model_id,
            rnn_model=RNNModel('LSTM', 50001, 650, 650, 2),
            state_path = GULORDAVA_LSTM_MODEL_STATE_PATH,
            vocab_path=LSTM_VOCABULARY_PATH)
    else:
        return None

# Data

We use the [Universal Dependencies English Web Treebank](https://github.com/UniversalDependencies/UD_English-EWT) with the provided train/test/validation splits, which we parse with the [conllu library](https://github.com/EmilStenstrom/conllu/)

In [None]:
def parse_corpus(filename: str) -> List[TokenList]:
    data_file = open(filename, encoding="utf-8")
    ud_parses = list(parse_incr(data_file))
    return ud_parses

class DataFoldNames(Enum):
    train = 'train'
    test = 'test'
    validation = 'validation'

class Corpus:
    def __init__(self, id: str, train_path: str, validation_path: str, test_path: str):
        self.id = id
        if TRAIN_CUTOFF:
            self.train = parse_corpus(train_path)[0:TRAIN_CUTOFF]
        else:
            self.train = parse_corpus(train_path)
        self.validation = parse_corpus(validation_path)
        self.test = parse_corpus(test_path)
        self._folds = {
            'train': self.train,
            'test': self.test,
            'validation': self.validation
        }
        self._tokens_per_fold = {
            'train': self._calculate_num_tokens_in_fold(self.train),
            'test': self._calculate_num_tokens_in_fold(self.test),
            'validation': self._calculate_num_tokens_in_fold(self.validation)
        }

    def _calculate_num_tokens_in_fold(self, fold: List[TokenList]):
        return sum([len(parse) for parse in fold]) 

    def get_fold_by_name(self, fold_name: DataFoldNames):
        return self._folds[fold_name.value]

    def get_num_tokens_in_fold(self, fold_name: DataFoldNames):
        return self._tokens_per_fold[fold_name.value]

## Load Corpus

In [None]:
corpus_id = 'sample' if USE_SAMPLE else 'full'

if USE_SAMPLE:
    CORPUS = Corpus('sample', TRAIN_SAMPLE_PATH, VAL_SAMPLE_PATH, TEST_SAMPLE_PATH)
else:
    CORPUS = Corpus('full', TRAIN_DATA_PATH, VAL_DATA_PATH, TEST_DATA_PATH)

# Generating Representations

We now have our data all set, our models are running and we are good to go!

The next step is now to create the model representations for the sentences in our corpora. Once we have generated these representations we can store them, and train additional diagnostic (/probing) classifiers on top of the representations.

Transformer models make use of Byte-Pair Encodings (BPE), that chunk up a piece of next in subword pieces. For example, a word such as "largely" could be chunked up into "large" and "ly". We are interested in probing linguistic information on the __word__-level. Therefore, we will follow the suggestion of Hewitt et al. (2019a, footnote 4), and create the representation of a word by averaging over the representations of its subwords. So the representation of "largely" becomes the average of that of "large" and "ly".

In [None]:
def fetch_sen_reps(ud_parses: List[TokenList], language_model: Union[GulordavaLSTM, TransformerModel], hidden_layer_id: Optional[int] = None) -> List[torch.Tensor]:
    """ Fetches hidden word-level representations from the given model. Representations at the word-level are obtained 
    by averaging the sub-representations that correspond to each particular word, per Hewitt et al. (2019a, footnote 4)

    Args:
        - ud_parses (conllu.TokenList): UD parsed .conllu file
        - language_model (Union[GulordavaLSTM, TransformerModel]): Language Model from which hidden representations are fetched
        - hidden_layer_id (Optional[int]) Defaults to None. If specified, only return the hidden representations at the given layer

    Returns:
        - List of tensors of shape (sequence_length, representation_size, num_hidden_layers) if hidden_layer_id is not specified else (sequence_length, representation_size)
    """         
    language_model.model.eval()

    is_lstm = isinstance(language_model, GulordavaLSTM)

    if hidden_layer_id == -1:
        hidden_layer_id = language_model.hidden_layers -1

    tokens_in_corpus = sum([len(parse) for parse in ud_parses])

    # Keep track of where to split the final tensor
    sentence_splits = []

    # Pre-allocate memory for reps
    if hidden_layer_id is not None:
        reps = torch.zeros((tokens_in_corpus, language_model.hidden_dim, 1), device=DEVICE)
    else:
        reps = torch.zeros((tokens_in_corpus, language_model.hidden_dim, language_model.hidden_layers), device=DEVICE)

    # Keep track of where to insert each sentence representation into the reps tensor
    reps_index = 0

    # Build reps tensor
    for tokenlist in tqdm(ud_parses):

        # Keep track of where to average BPE sub-representations
        track_words = []
        num_tokens = len(tokenlist)
        encoded_sentence = []

        # Add a space at the start of a new word; only required for transformer models
        add_prefix_space = False

        # Encode based on the chunking of the treebank
        for token_info in tokenlist:
            
            token = ' ' + token_info['form'] if add_prefix_space else token_info['form']
            is_chunk = token_info.get('misc') and token_info.get('misc').get('SpaceAfter')
            add_prefix_space = False if is_chunk or is_lstm else True
           
            if is_lstm:
                input_id = language_model.vocab[token]
                encoded_sentence.append(input_id)
            else: 
                input_ids = language_model.tokenizer.encode(token, add_special_tokens=False)
                encoded_sentence.extend(input_ids)                          
                track_words.append(len(input_ids))
                

        encoded_sentence = torch.tensor([encoded_sentence], dtype=torch.long, device=DEVICE)

        with torch.no_grad():
            if is_lstm:
                # (batch_size * max_sequence_length * hidden_dim * 2)
                hidden_state_OI = language_model.model(encoded_sentence, hidden=language_model.model.init_hidden(1))
                # remove batch dimension -> (max_sequence_length * hidden_dim * 2)
                sentence_rep = hidden_state_OI.squeeze(0)
                
                if hidden_layer_id is not None:
                    sentence_rep = torch.index_select(sentence_rep, dim=2, index=torch.tensor([hidden_layer_id], device=DEVICE))
            else:
                # [-1] hidden representations are the last element in the tuple returned from a transformer forward pass
                hidden_states = language_model.model(encoded_sentence)[-1]

                if hidden_layer_id is not None:
                    hidden_states = [hidden_states[hidden_layer_id]]
                
                sentence_rep = torch.tensor([], device=DEVICE)
                for hidden in hidden_states:
                    #average the subword representations that belong to 1 token!
                    sentence_rep_ = torch.cat([word.mean(dim=1) for word in torch.split(hidden, track_words, dim=1)])
                    sentence_rep_ = sentence_rep_.unsqueeze(2)
                    sentence_rep = torch.cat((sentence_rep, sentence_rep_), dim=2)

        reps[reps_index: reps_index + num_tokens] = sentence_rep
        reps_index += num_tokens
        sentence_splits.append(sentence_rep.shape[0])

    if hidden_layer_id is not None:
        # remove layer dimension
        reps = reps.squeeze(-1)
    stacked_reps = reps.to('cpu')
    unstacked_reps = torch.split(stacked_reps, sentence_splits, dim=0)
    return unstacked_reps

def stack_sen_reps(unstacked_sen_reps: List[torch.Tensor]) -> torch.Tensor:
    num_tokens = sum([sen_rep.shape[0] for sen_rep in unstacked_sen_reps])
    hidden_dim = unstacked_sen_reps[0].shape[1]
    if len(unstacked_sen_reps[0].shape) == 3:
        hidden_layers = unstacked_sen_reps[0].shape[2]
        stacked_reps = torch.zeros((num_tokens, hidden_dim, hidden_layers))
    else:
        stacked_reps = torch.zeros((num_tokens, hidden_dim))
    return torch.cat(unstacked_sen_reps, dim=0, out=stacked_reps)

# I provide the following sanity check, that compares your representations against a pickled version of mine.
# Note that I use the DistilGPT-2 LM here. For the LSTM I used 0-valued initial states.
def assert_sen_reps(lstm: GulordavaLSTM, distilgpt2: TransformerModel):
    with open(os.path.join(PACKAGE_ROOT,'distilgpt2_emb1.pickle'), 'rb') as f:
        distilgpt2_emb1 = pickle.load(f)
        
    with open(os.path.join(PACKAGE_ROOT,'lstm_emb1.pickle'), 'rb') as f:
        lstm_emb1 = pickle.load(f)
    
    if not USE_SAMPLE:
        first_training_sample = parse_corpus(TRAIN_SAMPLE_PATH)[:1]
    else:
        first_training_sample = CORPUS.train[:1]

    own_distilgpt2_emb1 = stack_sen_reps(fetch_sen_reps(first_training_sample, distilgpt2, -1))
    own_lstm_emb1 = stack_sen_reps(fetch_sen_reps(first_training_sample, lstm, -1))
    
    assert distilgpt2_emb1.shape == own_distilgpt2_emb1.shape
    assert lstm_emb1.shape == own_lstm_emb1.shape
    
    assert torch.allclose(distilgpt2_emb1, own_distilgpt2_emb1, atol=1e-5), "DistilGPT-2 embs don't match!"
    assert torch.allclose(lstm_emb1, own_lstm_emb1, atol=1e-5), "LSTM embs don't match!"

## Saved Representation File Reading/Writing

These representations take up a lot of memory. We can write the sentence representations for a particular model and corpus to disk and then load in the desired stacked or unstacked representation at the particular hidden layer we want.

In [None]:
def ensure_dir(file_path: str) -> None:
  '''Create a directory at the provided path if one does not already exist'''
  directory = os.path.dirname(file_path)
  if not os.path.exists(directory):
    os.makedirs(directory)

def get_saved_rep_filename(corpus_id: str, fold: DataFoldNames, language_model_id: str, hidden_layer_id: int):
    data_root = OUT_DATA_ROOT if corpus_id != 'sample' else OUT_SAMPLE_DATA_ROOT
    return os.path.join(data_root, f'{language_model_id}/{fold.value}_{hidden_layer_id}.pt')

def reps_already_exist(corpus_id: str, language_model: LanguageModel):
    for fold in DataFoldNames:
        for hidden_layer_id in range(language_model.hidden_layers):
            if not Path(get_saved_rep_filename(corpus_id, fold, language_model.id, hidden_layer_id)).is_file():
                return False
    return True

def write_sen_reps(corpus: Corpus, language_model: LanguageModel, overwrite: Optional[bool] = False):
    data_root = OUT_DATA_ROOT if corpus.id != 'sample' else OUT_SAMPLE_DATA_ROOT
    pickle_root = os.path.join(data_root, language_model.id + '/')
    ensure_dir(pickle_root)

    print(f'Generating and saving sentence representations for {language_model.id}')
    for fold in DataFoldNames:
        print(f'Fetching sentence representations for fold {fold.value}')

        for layer_id in range(language_model.hidden_layers):

            print(f'Fetching sentence representations at layer {layer_id}')

            unstacked_fold_sen_reps = fetch_sen_reps(corpus.get_fold_by_name(fold), language_model, layer_id)

            save_path = get_saved_rep_filename(corpus.id, fold, language_model.id, layer_id)

            unstacked_sen_reps_at_layer = unstacked_fold_sen_reps

            with open(save_path, 'wb+') as handle:
                torch.save(unstacked_sen_reps_at_layer, handle)
                print(f'{language_model.id} sentence representations at hidden layer {layer_id} saved for {fold.value} fold of {corpus.id} corpus')

    print(f'All {language_model.id} sentence representations saved for corpus {corpus.id}')

def load_sen_reps(corpus_id: str, fold: DataFoldNames, language_model_id: str, hidden_layer_id: str, stack: Optional[bool] = False):
    print(f'Loading the {hidden_layer_id} hidden layer {language_model_id} sentence representations for the {fold.value} fold of {corpus_id} corpus')
    saved_sen_reps_path = get_saved_rep_filename(corpus_id, fold, language_model_id, hidden_layer_id)
    with open(saved_sen_reps_path, 'rb') as handle:
        sen_reps = torch.load(handle)
        if stack:
            sen_reps = stack_sen_reps(sen_reps)
        return sen_reps


def initialize_and_prepare_model(model_id: str, corpus: Corpus, clear_model: Optional[bool] = True):
    print(f'Loading {model_id}')
    language_model = initialize_model(model_id)
    if not reps_already_exist(corpus.id, language_model):
        print(f'Representations are not present on disk. Generating...')
        write_sen_reps(corpus, language_model)
    else:
        print(f'Representations already exist on disk.')
    if clear_model:
        language_model.clear_model()
    return language_model

## Generate the Data and Load the Models

In [None]:
if RUN_SANITY_CHECKS:
    LSTM = initialize_and_prepare_model('lstm', CORPUS, clear_model=False)
else:
    LSTM = initialize_and_prepare_model('lstm', CORPUS)

In [None]:
if RUN_SANITY_CHECKS:
    DISTILGPT2 = initialize_and_prepare_model('distilgpt2', CORPUS, clear_model=False)
else:
    DISTILGPT2 = initialize_and_prepare_model('distilgpt2', CORPUS)

In [None]:
XLNET = initialize_and_prepare_model('xlnet', CORPUS)

## Sanity Check: Sentence Representations

In [None]:
def test_transformer(text: str, transformer: TransformerModel) -> torch.Tensor:
    input_ids = [transformer.tokenizer.encode(text, add_special_tokens=True)]
    text_data = torch.tensor(input_ids, device=DEVICE)
    with torch.no_grad():
        hidden_states = transformer.model(text_data)[-1]
        return hidden_states

def test_lstm(text: str, lstm: GulordavaLSTM) -> torch.Tensor:
    input_ids = [lstm.vocab[token] for token in text.split(' ')]
    text_data = torch.tensor(input_ids, device=DEVICE).unsqueeze(0)
    hidden = LSTM.model.init_hidden(1)
    with torch.no_grad():
        hidden_states = LSTM.model(text_data, hidden)
        return hidden_states


Ensure models work as intended

In [None]:
if RUN_SANITY_CHECKS:
    test_sentence = 'I am talking to a computer'
    expected_length = len(test_sentence.split(' '))

    lstm_output = test_lstm(test_sentence, LSTM)
    assert list(lstm_output.shape) == [1, expected_length, LSTM.hidden_dim, LSTM.hidden_layers]

    distilgpt2_output = test_transformer(test_sentence, DISTILGPT2)
    assert len(distilgpt2_output) == DISTILGPT2.hidden_layers
    assert list(distilgpt2_output[0].shape) == [1, expected_length, DISTILGPT2.hidden_dim]

Compare sentence representations to Jaap's


In [None]:
if RUN_SANITY_CHECKS:
    assert_sen_reps(LSTM, DISTILGPT2)

Ensure we can load sentence representations

In [None]:
if RUN_SANITY_CHECKS:
    stacked_distilgpt2_second_hidden_layer_reps = load_sen_reps(CORPUS.id, DataFoldNames.train, DISTILGPT2.id, 0, stack=True)

    assert(len(stacked_distilgpt2_second_hidden_layer_reps.shape) == 2)
    assert stacked_distilgpt2_second_hidden_layer_reps.shape[0] == CORPUS.get_num_tokens_in_fold(DataFoldNames.train)
    assert stacked_distilgpt2_second_hidden_layer_reps.shape[1] == DISTILGPT2.hidden_dim

    del(stacked_distilgpt2_second_hidden_layer_reps)

    unstacked_distilgpt2_second_hidden_layer_reps = load_sen_reps(CORPUS.id, DataFoldNames.train, DISTILGPT2.id, 1, stack=False)
    assert len(unstacked_distilgpt2_second_hidden_layer_reps) == len(CORPUS.train)
    assert(len(unstacked_distilgpt2_second_hidden_layer_reps[0].shape) == 2)

    del(unstacked_distilgpt2_second_hidden_layer_reps)

Cleanup: remove the pytorch models from the interfaces. We don't need them anymore


In [None]:
if RUN_SANITY_CHECKS:
    LSTM.clear_model()
    DISTILGPT2.clear_model()

# Linguistic Probe

## Diagnostic Classification

DCs are simple in their complexity on purpose. To read more about why this is the case you could look at the "Designing and Interpreting Probes with Control Tasks" by Hewitt and Liang (esp. Sec. 3.2).

In [None]:
# DIAGNOSTIC CLASSIFIER

class Diagnostic_classifier(nn.Module):
    def __init__(self, rep_dim, n_classes):
        super(Diagnostic_classifier, self).__init__()

        self.linear = nn.Linear(rep_dim, n_classes) 

    def forward(self, x):
        out = self.linear(x)
        return out

##  POS-tag Prediction

In [None]:
# FETCH POS LABELS

# Should return a tensor of shape (num_tokens_in_corpus,)
# Make sure that when fetching these pos tags for your train/dev/test corpora you share the label vocabulary.

def fetch_pos_tags(ud_parses: List[TokenList], pos_vocab=None, stack=True) -> torch.Tensor:
    """ Fetches POS-tag labels for each token in the corpus
    Args:
      - ud_parses (conllu.TokenList): UD parsed corpus
      - pos_vocab (defaultdict): Encodings of the different POS-tags

    Returns: 
      - Tensor of shape (num_tokens_in_corpus, ) containing encoded POS-tags
    """

    sentence_splits = [] 
    pos_tags = []
    
    for tokenlist in ud_parses:
        sentence_splits.append(len(tokenlist))
        pos_tags.extend([token_info["upostag"] for token_info in tokenlist])

      
    # Get the pos tags of all tokens in the corpus
    # pos_tags2 = [token_info["upostag"] for tokenlist in ud_parses for token_info in tokenlist]

    
    # Create pos vocabulary if not given 
    if not pos_vocab:
        unique_pos = set(pos_tags)
        unique_pos.add("X")
        p2i = {pos: i for i, pos in enumerate(unique_pos)}
        pos_vocab = defaultdict(lambda: p2i["X"])
        pos_vocab.update(p2i)
        
    encoded_pos_tags = torch.tensor([pos_vocab[pos] for pos in pos_tags], dtype=torch.long)

    if not stack:
        encoded_pos_tags = torch.split(encoded_pos_tags, sentence_splits)     
    
    return encoded_pos_tags, pos_vocab  

## POS-tag Control Task 

In [None]:
def get_pos_empirical_distr(ud_parses: List[TokenList]): 

    pos_tags = [token_info["upostag"] for tokenlist in ud_parses for token_info in tokenlist]
    pos_count = Counter(pos_tags)

    pos_empirical_distr = {k: v / len(pos_tags) for k, v in pos_count.items()}

    return pos_empirical_distr

def get_C(vocab, pos_vocab, corpus):
    """
    Defines the control behavior for the POS-tagging control task, per Hewitt & Liang 2019
    Args: 
        - vocab (defaultdict): Word vocabulary 
        - pos_vocab (defaultdict): Encodings of the different POS-tags
    Returns: 
        - The control behavior (defaultdict) mapping of each word in the vocabulary to Y
    
    """ 
    pos_empirical_distr = get_pos_empirical_distr(corpus.get_fold_by_name(DataFoldNames.train))
    probabilities = list(pos_empirical_distr.values())
    pos_tags = list(pos_empirical_distr.keys())
    Y = [pos_vocab[pos_tag] for pos_tag in pos_tags]

    C_ = {}
    for key in vocab.keys():
        C_[key] = np.random.choice(Y, p=probabilities)   
    C_["<unk>"] = np.random.choice(Y, p=probabilities)   
    C = defaultdict(lambda: C_["<unk>"])
    C.update(C_)
    
    return C


def get_control_labels(ud_parses: List[TokenList], C, stack=True):
    """
    Maps the tokens in the corpus using the control behavior C to obtain the labels for the control task 
    Args: 
        - ud_parses (conllu.TokenList): UD parsed corpus
        - C (defaultdict): Control behavior
    Returns: 
        - Tensor of shape (num_tokens_in_corpus, ) containing the control labels 
    """

    sentence_splits = [] 
    control_labels = []
    for tokenlist in ud_parses: 
        sentence_splits.append(len(tokenlist))
        control_labels.extend([C[token_info["form"]] for token_info in tokenlist])

    control_labels = torch.tensor(control_labels, dtype=torch.long)
    if not stack:
        control_labels = torch.split(control_labels, sentence_splits)     

    return control_labels


## Complexity Control 

In [None]:
def downsample_traindata(train_x, train_y, keep_size, stack=True):

    idx = np.arange(len(train_y)) 
    sample_idx = np.random.choice(idx, keep_size, replace=False)


    ds_train_x = [sentence for i, sentence in enumerate(train_x) if i in sample_idx]
    ds_train_y = [sentence for i, sentence in enumerate(train_y) if i in sample_idx]

    if stack: 
        ds_train_x = torch.cat(ds_train_x)
        ds_train_y = torch.cat(ds_train_y)

    return ds_train_x, ds_train_y


## Experiments POS-tag Prediction

In [None]:
def append_acc_per_postag(results_frame, predicted_labels, gt_labels, test_y_counts, pos_vocab):

    wrong_idx = np.where(np.not_equal(predicted_labels, gt_labels))
    wrong_per_label = Counter(gt_labels[wrong_idx])

    for pos_tag, label in pos_vocab.items(): 
        if label not in gt_labels: 
            results_frame[pos_tag].append(None)
        elif label not in wrong_per_label.keys():
            results_frame[pos_tag].append(1)
        else:
            results_frame[pos_tag].append((test_y_counts[label] - wrong_per_label[label]) / test_y_counts[label])

    return results_frame



def run_all_results_postag(corpus, model, config, task: str = 'regular', weight_decay=[0, 0.01, 0.1, 1], seeds=[41, 42, 43]):
    """ 
    """

    results_frame = {"language_model_id": [], 
                     "hidden_layer_id": [], 
                     "seed": [], 
                     "task": [], 
                     "weight_decay": [], 
                     "sample_count": [],
                     "test_accuracy": []}

    # Set skorch callbacks given in the default configurations
    skorch_callbacks = []
    if config["early_stopping"]:
        skorch_callbacks.append(EarlyStopping())
    if config["lr_scheduler"]:
        skorch_callbacks.append(LRScheduler(policy="ReduceLROnPlateau", mode='min', factor=0.5, patience=1))

    # Add the default configurations to the lists of settings we want to run if not already in there. 
    # This is to ensure that all hyperparameter settings AND also the default settings are run 
    if config["weight_decay"] not in weight_decay:
        sample_count.append(config["weight_decay"])


    # Initialize POS labels
    train_y_unstacked, pos_vocab = fetch_pos_tags(corpus.train, stack=False)
    train_y = torch.cat(train_y_unstacked)
    val_y, _= fetch_pos_tags(corpus.validation, pos_vocab)
    test_y, _ = fetch_pos_tags(corpus.test, pos_vocab)


    # Needed to infer the accuracies per pos tag during evaluation
    test_y_counts = Counter(test_y.numpy())


    # Add a column for each POS tag to the result frame
    for pos_tag in pos_vocab.keys():
        results_frame[pos_tag] = [] 


    # Run all the settings on all hidden layer indices 
    for hidden_layer_id in range(model.hidden_layers):

        # Initialize representations 
        train_x_unstacked = load_sen_reps(corpus.id, DataFoldNames.train, model.id, hidden_layer_id, stack=False)
        train_x = torch.cat(train_x_unstacked)
        val_x = load_sen_reps(corpus.id, DataFoldNames.validation, model.id, hidden_layer_id, stack=True)
        test_x = load_sen_reps(corpus.id, DataFoldNames.test, model.id, hidden_layer_id, stack=True)


        for seed in seeds: 
            torch.manual_seed(seed)
            np.random.seed(seed)

            # Overwrite the "true" POS labels with control labels in the control task
            if task=="control": 
                model_vocab = model.vocab if isinstance(model, GulordavaLSTM) else model.tokenizer.get_vocab()

                # Get a new control behavior every seed 
                C = get_C(model_vocab, pos_vocab, corpus)

                train_y_unstacked = get_control_labels(corpus.train, C, stack=False)
                train_y = torch.cat(train_y_unstacked)
                val_y = get_control_labels(corpus.validation, C)
                test_y = get_control_labels(corpus.test, C)
                test_y_counts = Counter(test_y.numpy())


            # Initialize diagnostic classifier 
            DC_module = Diagnostic_classifier(rep_dim=model.hidden_dim, n_classes=len(pos_vocab))

            
            ###############
            # WEIGHT DECAY
            ###############
            for wd in [0]:      
                print(f'Model {model.id}')
                print(f'Task {task}')
                print(f'Hidden layer {hidden_layer_id}')
                print(f'Weight decay: {wd}') 
                print(f'Training sample count: {len(train_y_unstacked)} / {len(train_y_unstacked)}') 
                print(f'Seed: {seed}')

                DC_net = skorch.classifier.NeuralNetClassifier(module=DC_module, 
                                        batch_size=config["batch_size"],
                                        max_epochs=config["max_epochs"],
                                        callbacks=skorch_callbacks,
                                        lr=config["lr"],
                                        optimizer=config["optimizer"], 
                                        criterion=config["criterion"],
                                        optimizer__weight_decay=wd,
                                        train_split=skorch.helper.predefined_split(skorch.dataset.Dataset(val_x, val_y)),
                                        device=config["device"])

                DC_net.fit(train_x, train_y)
                predicted_labels = DC_net.predict(test_x)
                test_acc = DC_net.score(test_x, test_y)
                print("Accuracy on test set: {:.4f}\n".format(test_acc))
            

                # Update results
                results_frame["language_model_id"].append(model.id)
                results_frame["hidden_layer_id"].append(hidden_layer_id)
                results_frame["seed"].append(seed)
                results_frame["task"].append(task)
                results_frame["weight_decay"].append(wd)
                results_frame["sample_count"].append(len(train_y_unstacked))
                results_frame["test_accuracy"].append(test_acc)
                results_frame = append_acc_per_postag(results_frame, predicted_labels, test_y.numpy(), test_y_counts, pos_vocab)

            #########################
            # TRAINING SAMPLE COUNT #
            #########################
            for sc in config['training_sample_count']:
                print(f'Model {model.id}')
                print(f'Task {task}')
                print(f'Hidden layer {hidden_layer_id}')
                print(f'Weight decay: {config["weight_decay"]}') 
                print(f'Training sample count: {sc} / {len(train_y_unstacked)}') 
                print(f'Seed: {seed}')


                ds_train_x, ds_train_y = downsample_traindata(train_x_unstacked, train_y_unstacked, keep_size=sc, stack=True)

                DC_net = skorch.classifier.NeuralNetClassifier(module=DC_module, 
                                        batch_size=config["batch_size"],
                                        max_epochs=config["max_epochs"],
                                        callbacks=skorch_callbacks,
                                        lr=config["lr"],
                                        optimizer=config["optimizer"], 
                                        criterion=config["criterion"],
                                        optimizer__weight_decay=config["weight_decay"],
                                        train_split=skorch.helper.predefined_split(skorch.dataset.Dataset(val_x, val_y)),
                                        device=config["device"])

                DC_net.fit(ds_train_x, ds_train_y)
                predicted_labels = DC_net.predict(test_x)
                test_acc = DC_net.score(test_x, test_y)
                print("Accuracy on test set: {:.4f}\n".format(test_acc))

                
                # Update results
                results_frame["language_model_id"].append(model.id)
                results_frame["hidden_layer_id"].append(hidden_layer_id)
                results_frame["seed"].append(seed)
                results_frame["task"].append(task)
                results_frame["weight_decay"].append(config["weight_decay"])
                results_frame["sample_count"].append(sc)
                results_frame["test_accuracy"].append(test_acc)
                results_frame = append_acc_per_postag(results_frame, predicted_labels, test_y.numpy(), test_y_counts, pos_vocab)


    return pd.DataFrame(results_frame)

### Run All Experiments

In [None]:
if RUN_POS_TASK:

    ensure_dir(OUT_RESULTS_ROOT)

    if USE_SAMPLE:
        training_sample_count = [1, 10, 90]
    else:
        training_sample_count = [100, 1000, 10000, len(CORPUS.train)]

    config = {"batch_size": 64,
            "max_epochs": 300,
            "lr": 0.001,
            "optimizer": torch.optim.Adam,
            "weight_decay": 0,
            "criterion": nn.CrossEntropyLoss,
            "early_stopping": True,
            "lr_scheduler": True,
            "training_sample_count": training_sample_count,
            "device": DEVICE
            }

    LSTM_results_regular = run_all_results_postag(CORPUS, LSTM, config, task='regular')
    LSTM_results_regular.to_pickle(os.path.join(OUT_RESULTS_ROOT, "LSTM_results_regular.pkl"))
    LSTM_results_control = run_all_results_postag(CORPUS, LSTM, config, task='control')
    LSTM_results_control.to_pickle(os.path.join(OUT_RESULTS_ROOT,"LSTM_results_control.pkl"))

    GPT2_results_regular = run_all_results_postag(CORPUS, DISTILGPT2, config, task='regular')
    GPT2_results_regular.to_pickle(os.path.join(OUT_RESULTS_ROOT,"GPT2_results_regular.pkl"))
    GPT2_results_control = run_all_results_postag(CORPUS, DISTILGPT2, config, task='control')
    GPT2_results_control.to_pickle(os.path.join(OUT_RESULTS_ROOT,"GPT2_results_control.pkl"))

    XLNET_results_regular = run_all_results_postag(CORPUS, XLNET, config, task='regular')
    XLNET_results_regular.to_pickle(os.path.join(OUT_RESULTS_ROOT,"XLNET_results_regular.pkl"))
    XLNET_results_control = run_all_results_postag(CORPUS, XLNET, config, task='control')
    XLNET_results_control.to_pickle(os.path.join(OUT_RESULTS_ROOT,"XLNET_results_control.pkl"))

### Load and Merge Results

In [None]:
if RUN_POS_TASK:

    # Load results
    LSTM_results_regular = pd.read_pickle(os.path.join(OUT_RESULTS_ROOT, "LSTM_results_regular.pkl"))
    LSTM_results_control = pd.read_pickle(os.path.join(OUT_RESULTS_ROOT, "LSTM_results_control.pkl"))
    GPT2_results_regular = pd.read_pickle(os.path.join(OUT_RESULTS_ROOT, "GPT2_results_regular.pkl"))
    GPT2_results_control = pd.read_pickle(os.path.join(OUT_RESULTS_ROOT, "GPT2_results_control.pkl"))
    XLNET_results_regular = pd.read_pickle(os.path.join(OUT_RESULTS_ROOT, "XLNET_results_regular.pkl"))
    XLNET_results_control = pd.read_pickle(os.path.join(OUT_RESULTS_ROOT, "XLNET_results_control.pkl"))

    # Merge results 
    ALL_POS_FRAMES = [LSTM_results_regular, LSTM_results_control, GPT2_results_regular, GPT2_results_control, XLNET_results_regular, XLNET_results_control]
    ALL_POS_RESULTS = pd.concat(all_frames)

### Examine Results

Plotting

In [None]:
def get_data_sample_counts(all_results):
    res_models_acc = {'lstm': {'mean': [], 'std':[]},
                'distilgpt2': {'mean': [], 'std':[]},
                'xlnet': {'mean': [], 'std':[]}}


    res_models_selectivity = {'lstm': {'mean': [], 'std':[]},
                'distilgpt2': {'mean': [], 'std':[]},
                'xlnet': {'mean': [], 'std':[]}}


    models = ['lstm', 'distilgpt2', 'xlnet']

    hidden_index = [1, 6, 12]

    if USE_SAMPLE:
        scs = [1, 10, 90]
    else:
        scs = [100, 1000, 10000, len(CORPUS.train)]

    for i, model_id in enumerate(models): 
        for sc in scs:
            result = all_results.loc[(all_results['language_model_id'] == model_id) & (all_results['hidden_layer_id'] == hidden_index[i]) & (all_results['weight_decay'] == 0) & (all_results['sample_count'] == sc)]
            regular = np.array(result.loc[(result['task'] == 'regular')]["test_accuracy"])

            control = np.array(result.loc[(result['task'] == 'control')]["test_accuracy"])
            
            # technically this is generalized seletivity because the performance ceiling for the control task in this experiment is 1.0
            selectivity = regular - control 
        
            res_models_acc[model_id]['mean'].append(regular.mean())
            res_models_acc[model_id]['std'].append(regular.std())
            res_models_selectivity[model_id]['mean'].append(selectivity.mean())
            res_models_selectivity[model_id]['std'].append(selectivity.std())

    data = [res_models_acc, res_models_selectivity]
    return data

def get_data_hidden_indices(all_results):
    res_models_acc = {'lstm': {'mean': [], 'std':[]},
                'distilgpt2': {'mean': [], 'std':[]},
                'xlnet': {'mean': [], 'std':[]}}


    res_models_selectivity = {'lstm': {'mean': [], 'std':[]},
                'distilgpt2': {'mean': [], 'std':[]},
                'xlnet': {'mean': [], 'std':[]}}


    models = ['lstm', 'distilgpt2', 'xlnet']

    hidden_layer_indices = [np.arange(2), np.arange(7), np.arange(13)]
    for i, model_id in enumerate(models): 
        for h in hidden_layer_indices[i]:
            result = all_results.loc[(all_results['language_model_id'] == model_id) & (all_results['hidden_layer_id'] == h) & (all_results['weight_decay'] == 0) & (all_results['sample_count'] == 12543)]
            regular = np.array(result.loc[(result['task'] == 'regular')]["test_accuracy"])
            control = np.array(result.loc[(result['task'] == 'control')]["test_accuracy"])

            selectivity = regular - control 

            res_models_acc[model_id]['mean'].append(regular.mean())
            res_models_acc[model_id]['std'].append(regular.std())
            res_models_selectivity[model_id]['mean'].append(selectivity.mean())
            res_models_selectivity[model_id]['std'].append(selectivity.std())



    data = [res_models_acc, res_models_selectivity]

    return data


def plot_results(data, lims = [(0.78, 0.92), (0.15, 0.3)], step_size = [0.03, 0.03], x_label = "Sample count", x_ticks=[100, 1000, 10000, 12543], save_path="sample_count.png"):


    y_label = ["Mean accuracy", "Mean selectivity"]

    layout=[("s", "#0077b3"), ("o", "#33cc33"), ("v","#ff1a75")]
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(9,3))

    for i, d in enumerate(data): 
        for j, (model_id, res) in enumerate(d.items()):
            y = np.array(res["mean"])
            std = np.array(res["std"])

            # Make sure does not exteed 0 - 1 
            std_plus = y+std #np.where(y+std > 1, 1, y+std)
            std_min = y-std #np.where(y-std < 0, 0, y-std)

            marker = layout[j][0]
            color = layout[j][1]

            x = range(len(y))
            # Plot
            ax[i].plot(x, y, color=color, marker = marker, markersize=7, label=model_id)
            ax[i].fill_between(x, std_min, std_plus, facecolor=color,alpha=0.2, interpolate=True)

            ax[i].set_ylabel(y_label[i], fontsize=15)
            ax[i].set_xlabel(x_label, fontsize=14)
            ax[i].set_ylim(lims[i]) 
            ax[i].set_xlim((0, 1)) 

            ax[i].set_xticks(x)
            ax[i].set_xticklabels(x_ticks, fontsize=14)
            ax[i].set_yticks(np.arange(lims[i][0], lims[i][1]+step_size[i], step_size[i]))
            ax[i].tick_params(axis="y", labelsize=13)
            ax[0].legend(prop={'size': 13}, loc="lower right")
    plt.suptitle("Part-of-Speech tag prediction", fontsize=15)
    plt.tight_layout(rect=[0, 0.03, 1, 0.93])
    plt.savefig(save_path)
    plt.show()


def plot_postag_accuracies(all_results, pos_vocab, task="regular", layer="emb"):


    res_models_last = {'lstm': {'mean': [], 'std':[]},
            'distilgpt2': {'mean': [], 'std':[]},
            'xlnet': {'mean': [], 'std':[]}}

    res_models_emb = {'lstm': {'mean': [], 'std':[]},
                'distilgpt2': {'mean': [], 'std':[]},
                'xlnet': {'mean': [], 'std':[]}}


    models = ['lstm', 'distilgpt2', 'xlnet']

    hidden_index = [1, 6, 12]
    _, pos_vocab = fetch_pos_tags(CORPUS.train)
    pos_tags = [pos_tag for pos_tag in pos_vocab.keys()]
    for i, model_id in enumerate(models): 
        for pos_tag in pos_tags:

            result_last_ = all_results.loc[(all_results['language_model_id'] == model_id) & (all_results['hidden_layer_id'] == hidden_index[i]) & (all_results['weight_decay'] == 0) & (all_results['sample_count'] == 12543)]
            result_last = np.array(result_last_.loc[(result_last_['task'] == task)][pos_tag])
            result_emb_ = all_results.loc[(all_results['language_model_id'] == model_id) & (all_results['hidden_layer_id'] == 0) & (all_results['weight_decay'] == 0) & (all_results['sample_count'] == 12543)]
            result_emb = np.array(result_emb_.loc[(result_emb_['task'] == task)][pos_tag])

            res_models_last[model_id]['mean'].append(result_last.mean())
            res_models_last[model_id]['std'].append(result_last.std())
            res_models_emb[model_id]['mean'].append(result_emb.mean())
            res_models_emb[model_id]['std'].append(result_emb.std())
 
    
    if layer == "emb":
        data = res_models_emb
    else:
        data = res_models_last


    # set width of bar
    barWidth = 0.28

    # set height of bar
    bars1 = data['lstm']["mean"]
    bars2 = data['distilgpt2']["mean"]
    bars3 = data['xlnet']["mean"]
    
    # Set position of bar on X axis
    r1 = np.arange(len(bars1))
    r2 = [x + barWidth for x in r1]
    r3 = [x + barWidth for x in r2]
    
    # Make the plot
    plt.figure(figsize=(12,5))
    plt.bar(r1, bars1, color="#3392C2", width=barWidth, edgecolor='white', yerr=data['lstm']["std"], error_kw=dict(lw=0.7, capsize=4, capthick=0.7, ecolor="#404040"), label='lstm')
    plt.bar(r2, bars2, color="#5CD65C", width=barWidth, edgecolor='white', yerr=data['distilgpt2']["std"], error_kw=dict(lw=0.7, capsize=4, capthick=0.7, ecolor="#404040"),label='distilgpt2')
    plt.bar(r3, bars3, color="#FF4891", width=barWidth, edgecolor='white', yerr=data['xlnet']["std"], error_kw=dict(lw=0.7, capsize=4, capthick=0.7, ecolor="#404040"), label='xlnet')

    
    # Add xticks on the middle of the group bars
    plt.ylabel("Mean accuracy", fontsize=20)
    plt.xticks([r + barWidth for r in range(len(bars1))], pos_tags, fontsize=14, rotation=45)
    plt.yticks(np.arange(0, 1.001, 0.1), fontsize=16)
    plt.xlim([0,1])
    plt.xlim([-0.4,17])
    plt.legend(fontsize=14, framealpha=0.95, loc="lower right")


    plt.tight_layout()
    plt.savefig("pos_accs_"+task+"_"+layer+".png")
    plt.show()


def plot_empirical_distr(ud_parses: List[TokenList], save_path="pos_distr.png"): 
    empirical_distr = get_pos_empirical_distr(ud_parses)

    postags = sorted(list(empirical_distr.keys()))
    probs = [empirical_distr[k ]for k in postags]


    plt.bar(postags, probs)

    # Add xticks on the middle of the group bars
    plt.ylabel("Probability", fontsize=17)
    plt.xticks(np.arange(len(postags)), fontsize=11, rotation=45)
    plt.yticks(np.arange(0, 0.201, 0.05), fontsize=16)

    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

Plot the results

In [None]:
if RUN_POS_TASK:
    print("WITH NORMAL SELECTIVITY")
    data_counts = get_data_sample_counts(ALL_POS_RESULTS)
    data_indices = get_data_hidden_indices(ALL_POS_RESULTS)
    plot_results(data_counts, lims = [(0.78, 0.92), (0.15, 0.3)], step_size = [0.03, 0.03], x_label = "Sample count", x_ticks=[100, 1000, 10000, 12543], save_path="sample_count.png")
    plot_results(data_indices, lims = [(0.81, 0.96), (-0.1, 0.3)], step_size = [0.03, 0.08], x_label= "Hidden layer indices", x_ticks=np.arange(13), save_path="hidden_indices")

    print("EMBEDDING LAYER")
    plot_postag_accuracies(ALL_POS_RESULTS, pos_vocab, task='regular', layer="emb")
    print("LAST HIDDEN LAYER")
    plot_postag_accuracies(ALL_POS_RESULTS, pos_vocab, task='regular', layer="last")

    print("EMPIRICAL DISTRIBUTION TRAIN")
    plot_empirical_distr(CORPUS.train, save_path="pos_distr_train.png")
    print("EMPIRICAL DISTRIBUTION TEST")
    plot_empirical_distr(CORPUS.test, save_path="pos_distr_test.png")

Performance on last hidden layers

In [None]:
if RUN_POS_TASK:
    results_lstm = all_results.loc[(ALL_POS_RESULTS['language_model_id'] == 'lstm') & (ALL_POS_RESULTS['hidden_layer_id'] == 1) & (ALL_POS_RESULTS['weight_decay'] == 0) & (ALL_POS_RESULTS['sample_count'] == 12543)]
    regular_lstm = np.array(results_lstm.loc[(results_lstm['task'] == 'regular')]["test_accuracy"])
    control_lstm = np.array(results_lstm.loc[(results_lstm['task'] == 'control')]["test_accuracy"])
    selectivity_lstm = regular_lstm - control_lstm 
    print("LSTM performance regular task:\n mean {}, std {}".format(regular_lstm.mean(), regular_lstm.std()))
    print("LSTM performance control task:\n mean {}, std {}".format(control_lstm.mean(), control_lstm.std()))
    print("LSTM performance selectivity:\n mean {}, std {}".format(selectivity_lstm.mean(), selectivity_lstm.std()))
    print()

    results_distilgpt2 = ALL_POS_RESULTS.loc[(ALL_POS_RESULTS['language_model_id'] == 'distilgpt2') & (ALL_POS_RESULTS['hidden_layer_id'] == 6) & (ALL_POS_RESULTS['weight_decay'] == 0) & (ALL_POS_RESULTS['sample_count'] == 12543)]
    regular_distilgpt2 = np.array(results_distilgpt2.loc[(results_distilgpt2['task'] == 'regular')]["test_accuracy"])
    control_distilgpt2 = np.array(results_distilgpt2.loc[(results_distilgpt2['task'] == 'control')]["test_accuracy"])
    selectivity_distilgpt2 = regular_distilgpt2 - control_distilgpt2 
    print("DistilGPT2 performance regular task:\n mean {}, std {}".format(regular_distilgpt2.mean(), regular_distilgpt2.std()))
    print("DistilGPT2 performance control task:\n mean {}, std {}".format(control_distilgpt2.mean(), control_distilgpt2.std()))
    print("DistilGPT2 performance selectivity:\n mean {}, std {}".format(selectivity_distilgpt2.mean(), selectivity_distilgpt2.std()))
    print()


    results_xlnet = ALL_POS_RESULTS.loc[(ALL_POS_RESULTS['language_model_id'] == 'xlnet') & (ALL_POS_RESULTS['hidden_layer_id'] == 12) & (ALL_POS_RESULTS['weight_decay'] == 0) & (ALL_POS_RESULTS['sample_count'] == 12543)]
    regular_xlnet = np.array(results_xlnet.loc[(results_xlnet['task'] == 'regular')]["test_accuracy"])
    control_xlnet = np.array(results_xlnet.loc[(results_xlnet['task'] == 'control')]["test_accuracy"])
    selectivity_xlnet = regular_xlnet - control_xlnet 
    print("XLNET performance regular task:\n mean {}, std {}".format(regular_xlnet.mean(), regular_xlnet.std()))
    print("XLNET performance control task:\n mean {}, std {}".format(control_xlnet.mean(), control_xlnet.std()))
    print("XLNET performance selectivity:\n mean {}, std {}".format(selectivity_xlnet.mean(), selectivity_xlnet.std()))
    print()

Statistical tests

In [None]:
print("Performance on pos tag task (regular)")
print("LSTM vs. DistilGPT2: {}".format(ttest_ind(regular_lstm, regular_distilgpt2)))
print("LSTM vs. XLNET: {}".format(ttest_ind(regular_lstm, regular_xlnet)))
print("DistilGPT2 vs. XLNET: {}".format(ttest_ind(regular_distilgpt2, regular_xlnet)))
print()

print("Selectivity")
print("LSTM vs. DistilGPT2: {}".format(ttest_ind(selectivity_lstm, selectivity_distilgpt2)))
print("LSTM vs. XLNET: {}".format(ttest_ind(selectivity_lstm, selectivity_xlnet)))
print("DistilGPT2 vs. XLNET: {}".format(ttest_ind(selectivity_distilgpt2, selectivity_xlnet)))

# Stuctural Probe

## Tree Distance Task

For our gold labels, we need to recover the node distances from our parse tree. For this we will use the functionality provided by `ete3`, that allows us to compute that directly.

In [None]:
# In case you want to transform your conllu tree to an nltk.Tree, for better visualisation

def rec_tokentree_to_nltk(tokentree):
    token = tokentree.token["form"]
    tree_str = f"({token} {' '.join(rec_tokentree_to_nltk(t) for t in tokentree.children)})"

    return tree_str


def tokentree_to_nltk(tokentree):
    from nltk import Tree as NLTKTree

    tree_str = rec_tokentree_to_nltk(tokentree)

    return NLTKTree.fromstring(tree_str)

In [None]:
class FancyTree(EteTree):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, format=1, **kwargs)
        
    def __str__(self):
        return self.get_ascii(show_internal=True)
    
    def __repr__(self):
        return str(self)


def rec_tokentree_to_ete(tokentree):
    idx = str(tokentree.token["id"])
    children = tokentree.children
    if children:
        return f"({','.join(rec_tokentree_to_ete(t) for t in children)}){idx}"
    else:
        return idx
    
def tokentree_to_ete(tokentree):
    newick_str = rec_tokentree_to_ete(tokentree)

    return FancyTree(f"{newick_str};")

In [None]:
# Let's check if it works!
# We can read in a corpus using the code that was already provided, and convert it to an ete3 Tree.

if RUN_SANITY_CHECKS:
    item = CORPUS.train[0]
    tokentree = item.to_tree()
    ete3_tree = tokentree_to_ete(tokentree)
    print(ete3_tree)

As you can see we label a token by its token id (converted to a string). Based on these id's we are going to retrieve the node distances.

To create the true distances of a parse tree in our treebank, we are going to use the `.get_distance` method that is provided by `ete3`: http://etetoolkit.org/docs/latest/tutorial/tutorial_trees.html#working-with-branch-distances

We will store all these distances in a `torch.Tensor`.


In [None]:
def create_gold_distances(data_fold: List[TokenList]):
    all_distances = []

    for item in tqdm(data_fold):
        tokentree = item.to_tree()
        ete_tree = tokentree_to_ete(tokentree)

        sen_len = len(ete_tree.search_nodes())
        distances = torch.zeros((sen_len, sen_len))

        # Your code for computing all the distances comes here.
        for node_i in ete_tree.traverse():
            i = int(node_i.name) - 1
            for node_j in ete_tree.traverse():
                j = int(node_j.name) - 1
                distances[i,j] = node_i.get_distance(node_j)

        all_distances.append(distances)

    return all_distances

The next step is now to do the previous step the other way around. After all, we are mainly interested in predicting the node distances of a sentence, in order to recreate the corresponding parse tree.

Hewitt et al. reconstruct a parse tree based on a _minimum spanning tree_ (MST, https://en.wikipedia.org/wiki/Minimum_spanning_tree). Fortunately for us, we can simply import a method from `scipy` that retrieves this MST.

In [None]:
def create_mst(distances):
    distances = torch.triu(distances).detach().numpy()
    
    mst = minimum_spanning_tree(distances).toarray()
    mst[mst>0] = 1.
    return mst

Let's have a look at what this looks like, by looking at a relatively short sentence in the sample corpus.

In [None]:
if RUN_SANITY_CHECKS:
    item = CORPUS.train[5]
    tokentree = item.to_tree()
    ete3_tree = tokentree_to_ete(tokentree)
    print(ete3_tree, '\n')

    gold_distance = create_gold_distances(CORPUS.train[5:6])[0]
    print(gold_distance, '\n')

    mst = create_mst(gold_distance)
    print(mst)

Now that we are able to map edge distances back to parse trees, we can create code for our quantitative evaluation. For this we will use the Undirected Unlabeled Attachment Score (UUAS), which is expressed as:

$$\frac{\text{number of predicted edges that are an edge in the gold parse tree}}{\text{number of edges in the gold parse tree}}$$

To do this, we will need to obtain all the edges from our MST matrix. Note that, since we are using undirected trees, that an edge can be expressed in 2 ways: an edge between node $i$ and node $j$ is denoted by both `mst[i,j] = 1`, or `mst[j,i] = 1`.


In [None]:
def edges(mst: np.ndarray) -> Set[Tuple[int, int]]:
    """Extract the edges from a Minimum Spanning Tree (MST)

        Args:
            - mst (np.ndarray): MST represented as a matrix

        Returns:
            Set of node_id tuples which define edges
    """
    edges = set()
    # Your code for retrieving the edges from the MST matrix
    edge_indices = np.argwhere(mst == 1.)

    for i, j in edge_indices:
        edges.add((i,j))
    
    return edges

def calc_uuas(mst: np.ndarray, gold_distances: torch.Tensor) -> float:
    """Calculate the Undirected Unlabeled Attachment Score (UUAS) for a given MST and gold distance matrix

        Args:
            - mst (np.ndarray): MST represented as a matrix
            - gold_distances (torch.Tensor): Tensor of distances between nodes in the gold parse tree

        Returns:
            UUAS score between 0. and 1.
    """
    gold_mst = create_mst(gold_distances)
    # The easiest way to calculate the number of predicted edges in the gold parse tree
    # is to simply multiply the two MST matrices and count the cells that equal one
    match_matrix = np.argwhere(mst * gold_mst == 1.)
    gold_edges = len(edges(gold_mst))
    if gold_edges == 0:
        uuas = 0.0
    else:
        uuas = match_matrix.shape[0] / gold_edges

    return uuas

## Tree Distance Control Task

In [None]:
def create_empirical_distance_distribution(gold_distances: List[torch.Tensor]) -> Dict[int, int]:
    distribution = {}
    for gd in gold_distances:
        distances = torch.triu(gd).detach().numpy()
        for row in distances:
            for elem in row:
                distance = int(elem)
                if distance != 0:
                    if distance not in distribution:
                        distribution[distance] = 1
                    else:
                        distribution[distance] += 1
    # normalize
    N = sum(distribution.values())
    for k,v in distribution.items():
        distribution[k] = v / N

    return distribution

In [None]:
def create_structural_control_behavior(train_fold: List[TokenList], empirical_distance_distribution: Dict[int, int], seed: int) -> Dict[Tuple[str, str], int]:
    np.random.seed(seed)
    C = {}
    choices = list(empirical_distance_distribution.keys())
    probs = list(empirical_distance_distribution.values())
    for tokenlist in train_fold:
        token_form_list = [token['form'] for token in tokenlist]
        for i, token_i in enumerate(token_form_list):
            if token_i not in C:
                C[token_i] = {}
            for j, token_j in enumerate(token_form_list):
                if i != j:
                    if token_j not in C[token_i]:
                        C[token_i][token_j] = np.random.choice(choices, p=probs)
    return C

In [None]:
def create_control_distances(ud_parses: List[TokenList], C, empirical_distance_distribution, seed: int) -> Tuple[List[torch.Tensor], float]:
    np.random.seed(seed)
    all_distances = []
    oov_count = 0
    total_pairs = 0
    choices = list(empirical_distance_distribution.keys())
    probs = list(empirical_distance_distribution.values())
    for tokenlist in tqdm(ud_parses):
        sen_len = len(tokenlist)
        token_form_list = [token['form'] for token in tokenlist]
        distances = torch.zeros((sen_len, sen_len))
        for idx, token_x in enumerate(token_form_list):
            for idy, token_y in enumerate(token_form_list):
                if idx != idy:
                    total_pairs += 1
                    if token_x in C and token_y in C[token_x]:
                        distances[idx, idy] = C[token_x][token_y]
                    else:
                        oov_count +=1
                        distances[idx, idy] = np.random.choice(choices, p=probs)

        all_distances.append(distances)
    
    # Ceiling on performance is the fraction of tokens in the evaluation set whose types occur in the training set 
    # (plus biased chance accuracy on all others tokens)
    if oov_count == 0:
        performance_ceiling = 1.0
    else:
        oov_fraction = oov_count / total_pairs
        inv_fraction = (total_pairs - oov_count) / total_pairs
        biased_chance_accuracy =  oov_fraction * max(empirical_distance_distribution.values())
        performance_ceiling = inv_fraction + biased_chance_accuracy

    return all_distances, performance_ceiling

## Training the Structural Probes

We now have everything in place to start doing the actual exciting stuff: training our structural probe!
    
To make life easier, we will simply take the `torch` code for this probe from John Hewitt's repository. This allows us to focus on the training regime from now on.

In [None]:
class StructuralProbe(nn.Module):
    """ Computes squared L2 distance after projection by a matrix.
    For a batch of sentences, computes all n^2 pairs of distances
    for each sentence in the batch.
    """
    def __init__(self, model_dim, rank, device="cpu"):
        super().__init__()
        self.probe_rank = rank
        self.model_dim = model_dim
        
        self.proj = nn.Parameter(data = torch.zeros(self.model_dim, self.probe_rank))
        
        nn.init.uniform_(self.proj, -0.05, 0.05)
        self.to(device)

    def forward(self, batch):
        """ Computes all n^2 pairs of distances after projection
        for each sentence in a batch.
        Note that due to padding, some distances will be non-zero for pads.
        Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j
        Args:
          batch: a batch of word representations of the shape
            (batch_size, max_seq_len, representation_dim)
        Returns:
          A tensor of distances of shape (batch_size, max_seq_len, max_seq_len)
        """
        transformed = torch.matmul(batch, self.proj)
        
        batchlen, seqlen, rank = transformed.size()
        
        transformed = transformed.unsqueeze(2)
        transformed = transformed.expand(-1, -1, seqlen, -1)
        transposed = transformed.transpose(1,2)
        
        diffs = transformed - transposed
        
        squared_diffs = diffs.pow(2)
        squared_distances = torch.sum(squared_diffs, -1)

        return squared_distances

    
class L1DistanceLoss(nn.Module):
    """Custom L1 loss for distance matrices."""
    def __init__(self):
        super().__init__()

    def forward(self, predictions, label_batch, length_batch):
        """ Computes L1 loss on distance matrices.
        Ignores all entries where label_batch=-1
        Normalizes first within sentences (by dividing by the square of the sentence length)
        and then across the batch.
        Args:
          predictions: A pytorch batch of predicted distances
          label_batch: A pytorch batch of true distances
          length_batch: A pytorch batch of sentence lengths
        Returns:
          A tuple of:
            batch_loss: average loss in the batch
            total_sents: number of sentences in the batch
        """
        labels_1s = (label_batch != -1).float()
        predictions_masked = predictions * labels_1s
        labels_masked = label_batch * labels_1s
        total_sents = torch.sum((length_batch != 0)).float()
        squared_lengths = length_batch.pow(2).float()

        if total_sents > 0:
            loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=(1,2))
            normalized_loss_per_sent = loss_per_sent / squared_lengths
            batch_loss = torch.sum(normalized_loss_per_sent) / total_sents
        
        else:
            batch_loss = torch.tensor(0.0)
        
        return batch_loss, total_sents


### Custom PyTorch Dataset and Collate Function

In [None]:
class TreeDistanceDataset(torch.utils.data.Dataset):
    def __init__(self, data, target, performance_ceiling: float = 1.0):
        self.data = data
        self.target = target
        self.performance_ceiling = performance_ceiling

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        return x, y

    def __len__(self):
        return len(self.data)

def tree_distance_collate(batch):
    batch_len = len(batch)

    batch_data = [item[0] for item in batch] # (seq_len x rep_dim)

    batch_target = [item[1] for item in batch] # (seq_len x seq_len)

    representation_dim = batch_data[0].shape[1]

    max_seq_len = max(data.shape[0] for data in batch_data)

    padded_data_tensor = torch.zeros((batch_len, max_seq_len, representation_dim), device=DEVICE)
    padded_target_tensor = torch.zeros((batch_len, max_seq_len, max_seq_len), device=DEVICE).fill_(-1)
    length_tensor = torch.zeros(batch_len, device=DEVICE)

    for idx, (x, y) in enumerate(zip(batch_data, batch_target)):
        seq_len = x.shape[0]
        target_len = y.shape[0]
        padded_data_tensor[idx, :seq_len] = x
        padded_target_tensor[idx, :target_len, :target_len] = y
        length_tensor[idx] = seq_len

    return padded_data_tensor, padded_target_tensor, length_tensor

### Probe and Training Configuration Wrappers

In [None]:
class TrainingConfig:
    def __init__(self, corpus: Corpus, language_model_ids: List[str], ranks: List[int], seeds: List[int] = [67, 12484, 12], device = DEVICE,
                 learning_rate = 10e-4, batch_size = 64, num_epochs = 25, patience = 5, overwrite: bool = False):
        self.corpus = corpus
        self.language_model_ids = language_model_ids
        self.ranks = ranks
        self.seeds = seeds
        self.device = device
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.patience = patience
        self.overwrite = overwrite

        self.data_root = OUT_DATA_ROOT if self.corpus.id != 'sample' else OUT_SAMPLE_DATA_ROOT

        self.train_gold_distances = self._load_or_create_gold_distances(DataFoldNames.train)
        self.test_gold_distances = self._load_or_create_gold_distances(DataFoldNames.test)
        self.validation_gold_distances = self._load_or_create_gold_distances(DataFoldNames.validation)
        
        self.empirical_distance_distribution = create_empirical_distance_distribution(self.train_gold_distances)
        self.control_behavior = self._load_or_create_control_behavior()

        self.train_control_distances = self._load_or_create_control_distances(DataFoldNames.train)
        self.test_control_distances = self._load_or_create_control_distances(DataFoldNames.test)
        self.validation_control_distances = self._load_or_create_control_distances(DataFoldNames.validation)


    def _load_or_create_gold_distances(self, fold: DataFoldNames):
        gd_path = os.path.join(self.data_root, f'{fold.value}_gold_distances.pt') 
        if Path(gd_path).is_file() and not self.overwrite:
            print(f'Gold distances present on disk for {fold.value} of {self.corpus.id} corpus')
            with open(gd_path, 'rb') as handle:
                gold_distances = torch.load(handle)
        else:
            print(f'Creating gold distances for {fold.value} of {self.corpus.id} corpus and saving to disk...')
            gold_distances = create_gold_distances(self.corpus.get_fold_by_name(fold))
            with open(gd_path, 'wb+') as handle:
                torch.save(gold_distances, handle)
        return gold_distances


    def _load_or_create_control_behavior(self):
        control_behavior = {}
        for seed in self.seeds:
            cb_path = os.path.join(self.data_root, f'control_behavior_{seed}.pkl')
            if Path(cb_path).is_file() and not self.overwrite:
                print(f'Structural control behavior with seed {seed} present on disk')
                with open(cb_path, 'rb') as handle:
                    C = pickle.load(handle)
            else:
                print(f'Creating structural control behavior with seed {seed} and saving to disk...')
                C = create_structural_control_behavior(self.corpus.train, self.empirical_distance_distribution, seed)
                with open(cb_path, 'wb') as handle:
                    pickle.dump(C, handle)
            control_behavior[seed] = C
        return control_behavior

    def _load_or_create_control_distances(self, fold):
        control_distances_for_seed = {}
        for seed in self.seeds:
            cd_path = os.path.join(self.data_root, f'{fold.value}_control_distances_{seed}.pt') 
            if Path(cd_path).is_file() and not self.overwrite:
                print(f'Control distances with seed {seed} present on disk for {fold.value} fold of {self.corpus.id} corpus')
                with open(cd_path, 'rb') as handle:
                    control_distances, performance_ceiling = torch.load(handle)
                    performance_ceiling = float(performance_ceiling)
            else:
                print(f'Creating control distances with seed {seed} for {fold.value} fold of {self.corpus.id} corpus and saving to disk...')
                control_distances, performance_ceiling = create_control_distances(self.corpus.get_fold_by_name(fold), self.control_behavior[seed], self.empirical_distance_distribution, seed)
                with open(cd_path, 'wb+') as handle:
                    torch.save([control_distances, torch.tensor([performance_ceiling])], handle)
            
            control_distances_for_seed[seed] = (control_distances, performance_ceiling)
        return control_distances_for_seed

class ProbeConfig:

    def __init__(self,language_model_id: str, corpus_id: str, task: str, hidden_layer_id: int, dim: int, rank: int, seed: int):
        self.language_model_id = language_model_id
        self.corpus_id = corpus_id
        self.task = task
        self.hidden_layer_id = hidden_layer_id
        self.dim = dim
        self.rank = rank
        self.seed = seed
        self.save_path = self._get_save_path()
        self.results_path = self._get_results_path()

    @staticmethod
    def from_save_path(save_path: str):
        _, filename = os.path.split(probe_state_path)
        param_string, extension = os.path.splitext(filename)
        language_model_id, corpus_id, task, hidden_layer_id, dim, rank, seed = param_string.split('_')
        return ProbeConfig(
            language_model_id=language_model_id,
            corpus_id=corpus_id,
            task=task,
            hidden_layer_id=hidden_layer_id,
            dim=dim,
            rank=rank,
            seed=seed
        )

    def was_already_trained(self) -> bool:
        has_trained_model = Path(self.save_path).is_file()
        has_test_set_results = Path(self.results_path).is_file()
        return has_trained_model and has_test_set_results

    def print_start_message(self) -> None:
        print(
            f'Training Structural Probe on {self.task} task with rank {self.rank}'
            f' on hidden layer {self.hidden_layer_id} of {self.language_model_id}'
            f' using {self.corpus_id} corpus data and seed {self.seed}'
        )

    def save_test_set_results(self, dataframe):
        dataframe.to_pickle(self.results_path)

    def load_test_set_results(self):
        return pd.read_pickle(self.results_path)

    def _get_param_string(self):
        param_string = f'{self.language_model_id}_{self.corpus_id}_{self.task}_{self.hidden_layer_id}_{self.dim}_{self.rank}_{self.seed}'
        return param_string

    def _get_save_path(self):
        ensure_dir(OUT_MODELS_ROOT)
        param_string = self._get_param_string()
        out_path = os.path.join(OUT_MODELS_ROOT, param_string + '.pth')
        return out_path

    def _get_results_path(self):
        ensure_dir(OUT_RESULTS_ROOT)
        param_string = self._get_param_string()
        out_path = os.path.join(OUT_RESULTS_ROOT, param_string + '.pkl')
        return out_path


def all_results_for_hidden_layer_exist(language_model, corpus_id, hidden_layer_id, ranks, seeds):
    for task in ['regular', 'control']:
        for rank in ranks:
            for seed in seeds:
                config = ProbeConfig(
                    language_model_id=language_model.id,
                    corpus_id=corpus_id,
                    task=task,
                    hidden_layer_id=hidden_layer_id,
                    dim=language_model.hidden_dim,
                    rank=rank,
                    seed=seed)
                has_saved_model = Path(config.save_path).is_file()
                has_saved_results = Path(config.results_path).is_file()
                if not has_saved_model or not has_saved_results:
                    return False
    return True

### Custom Training/Validation/Testing Code

In [None]:
def train_batch(probe: StructuralProbe, optimizer: torch.optim.Optimizer, loss_function: L1DistanceLoss, data: torch.Tensor,
                target: torch.Tensor, lengths: torch.Tensor) -> float:
    optimizer.zero_grad()
    predicted_distances = probe(data)
    loss, _ = loss_function.forward(predicted_distances, target, lengths)
    loss.backward()
    optimizer.step()
    return loss.item()

def train_epoch(probe: StructuralProbe, optimizer: torch.optim.Optimizer, loss_function: L1DistanceLoss,
                train_loader: torch.utils.data.DataLoader, device: torch.device, num_batches: int):
    probe.train()
    losses = np.zeros(num_batches, dtype=float) # pre-allocate memory
    for batch_idx, (data, target, lengths) in enumerate(train_loader):
        data, target, lengths = data.to(device), target.to(device), lengths.to(device)
        loss = train_batch(probe, optimizer, loss_function, data, target, lengths)
        losses[batch_idx] = (loss)
        if batch_idx + 1 == num_batches:
            return losses

def evaluate_batch(predicted_distances: torch.Tensor, target: torch.Tensor, lengths: torch.Tensor):
    predicted_distances = torch.split(predicted_distances, [1 for _ in range(lengths.shape[0])], dim=0)
    gold_distances = torch.split(target, [1 for _ in range(lengths.shape[0])], dim=0)
    sentence_lengths = [int(item) for item in lengths.tolist()]

    msts = []
    for (pd, length) in zip(predicted_distances, sentence_lengths):
        a = pd.squeeze(0) # remove batch dimension
        mst = create_mst(a[:length, :length])
        msts.append(mst)
    
    uuas_scores = []
    for (mst, gd, length) in zip(msts, gold_distances, sentence_lengths):
        a = gd.squeeze(0) # remove batch dimension
        uuas = calc_uuas(mst, a[:length, :length])
        uuas_scores.append(uuas)
    return uuas_scores

def validate(probe: StructuralProbe, loss_function: L1DistanceLoss, split_loader: torch.utils.data.DataLoader, device: torch.device) -> Tuple[float, float]:
    probe.eval()
    batch_uuas_scores = []
    losses = []
    with torch.no_grad():  
        for batch_idx, (data, target, lengths) in enumerate(split_loader):
            data, target, lengths = data.to(device), target.to(device), lengths.to(device)
            predicted_distances = probe(data)
            loss, _ = loss_function.forward(predicted_distances, target, lengths)
            iter_loss = loss.item()
            losses.append(iter_loss)
            predicted_distances, target, lengths = predicted_distances.to('cpu'), target.to('cpu'), lengths.to('cpu')
            batch_uuas_score = evaluate_batch(predicted_distances, target, lengths)
            batch_uuas_scores.append(np.mean(batch_uuas_score))
    avg_total_uuas = np.mean(batch_uuas_scores)
    avg_total_loss = np.mean(losses)
    return avg_total_loss, avg_total_uuas

def test_probe(probe: StructuralProbe, ud_parses: List[TokenList], dataset: TreeDistanceDataset, device: torch.device) -> pd.DataFrame:
    test_set_results = {
        'sentence': [token_list.metadata['text'] for token_list in ud_parses],
        'gold_distances': [gold_distance.numpy() for gold_distance in dataset.target],
        'predicted_distances': [],
        'uuas_score': [],
        'corrected_uuas_score': [] # uuas score after dividing by the performance ceiling
    }

    test_dl = torch.utils.data.DataLoader(dataset=dataset, batch_size=64, shuffle=False, collate_fn=tree_distance_collate)
    probe.eval()
    with torch.no_grad():  
        for batch_idx, (data, target, lengths) in enumerate(test_dl):
            data, target, lengths = data.to(device), target.to(device), lengths.to(device)
            predicted_distances = probe(data)
            predicted_distances, target, lengths = predicted_distances.to('cpu'), target.to('cpu'), lengths.to('cpu')
            batch_uuas_scores = evaluate_batch(predicted_distances, target, lengths)
            corrected_batch_uuas_scores = np.divide(batch_uuas_scores, dataset.performance_ceiling)
            test_set_results['predicted_distances'].extend([predicted_distance.numpy() for predicted_distance in predicted_distances])
            test_set_results['uuas_score'].extend(batch_uuas_scores)
            test_set_results['corrected_uuas_score'].extend(corrected_batch_uuas_scores)

    return pd.DataFrame(test_set_results)

def train_probe(datasets, probe_config: ProbeConfig, training_config: TrainingConfig) -> float:
    train_dataset, test_dataset, validation_dataset = datasets

    train_dl = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=training_config.batch_size, shuffle=True, collate_fn=tree_distance_collate)
    
    validation_dl = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=training_config.batch_size, shuffle=True, collate_fn=tree_distance_collate)

    torch.manual_seed(probe_config.seed)
    np.random.seed(probe_config.seed)

    probe = StructuralProbe(probe_config.dim, probe_config.rank).to(training_config.device)
    optimizer = torch.optim.Adam(probe.parameters(), lr=training_config.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,patience=1)
    loss_function =  L1DistanceLoss()

    best_validation_uuas = 0.0
    best_epoch_id = 0
    counter = 0

    batches_per_epoch = max(1, math.ceil(len(train_dataset.data) / training_config.batch_size))

    for epoch in range(training_config.num_epochs):
        epoch_losses = train_epoch(probe, optimizer, loss_function, train_dl, training_config.device, batches_per_epoch)
        avg_epoch_loss = np.mean(epoch_losses)
        validation_loss, validation_uuas = validate(probe, loss_function, validation_dl, training_config.device)
        adjusted_validation_uuas = validation_uuas / validation_dataset.performance_ceiling
        scheduler.step(validation_loss)
        print(f'Epoch {epoch}: Loss: {avg_epoch_loss}, Validation UUAS: {validation_uuas}, Adjusted Validation UUAS: {adjusted_validation_uuas}')

        if validation_uuas > best_validation_uuas:
            best_validation_uuas = validation_uuas
            best_epoch_id = epoch
            counter = 0
            print('Best validation UUAS so far!')
            torch.save(probe.state_dict(), probe_config.save_path)
        else:
            counter += 1
            print(f'No improvement in validation UUAS. Patience remaining: {training_config.patience - counter}')

        if counter >= training_config.patience:
            print(f'Terminated. Best Model found at epoch {best_epoch_id}')
            break

    probe.load_state_dict(torch.load(probe_config.save_path))
    test_set_results = test_probe(probe, CORPUS.test, test_dataset, training_config.device)
    probe_config.save_test_set_results(test_set_results)

    return test_set_results

In [None]:
def train_all_structural_probes(training_config: TrainingConfig) -> pd.DataFrame:

    print('Loading or creating gold and control distances')

    final_results_frame = {
        'language_model_id': [],
        'hidden_layer_id': [],
        'rank': [],
        'seed': [],
        'task': [],
        'uuas_score': [],
        'corrected_uuas_score': [],
        'selectivity': [],
        'generalized_selectivity': []
    }

    final_results_path = os.path.join(OUT_RESULTS_ROOT, f'final_results_{training_config.corpus.id}.pkl')

    if Path(final_results_path).is_file() and not training_config.overwrite:
        print('Final results already exist!')
        final_results = pd.read_pickle(final_results_path)
        return final_results

    else:
        language_models = {}
        for language_model_id in training_config.language_model_ids:
            print(f'Initializing and preparing {language_model_id} language model')
            language_models[language_model_id] = initialize_and_prepare_model(language_model_id, training_config.corpus)

        for seed in training_config.seeds:

            test_performance_ceiling = training_config.test_control_distances[seed][1]
            validation_performance_ceiling = training_config.validation_control_distances[seed][1]

            for language_model_id in training_config.language_model_ids:


                language_model = language_models[language_model_id]

                for hidden_layer_id in range(language_model.hidden_layers):

                    if not all_results_for_hidden_layer_exist(language_model, CORPUS.id, hidden_layer_id, training_config.ranks, training_config.seeds) or training_config.overwrite:

                        layer_train_sen_reps = load_sen_reps(CORPUS.id, DataFoldNames.train, language_model.id, hidden_layer_id, False)
                        layer_test_sen_reps = load_sen_reps(CORPUS.id, DataFoldNames.test, language_model.id, hidden_layer_id, False)
                        layer_validation_sen_reps = load_sen_reps(CORPUS.id, DataFoldNames.validation, language_model.id, hidden_layer_id, False)

                        regular_train_dataset = TreeDistanceDataset(layer_train_sen_reps, training_config.train_gold_distances)
                        regular_test_dataset = TreeDistanceDataset(layer_test_sen_reps, training_config.test_gold_distances, test_performance_ceiling)
                        regular_validation_dataset = TreeDistanceDataset(layer_validation_sen_reps, training_config.validation_gold_distances, validation_performance_ceiling)

                        control_train_dataset = TreeDistanceDataset(layer_train_sen_reps, training_config.train_control_distances[seed][0], training_config.train_control_distances[seed][1])
                        control_test_dataset = TreeDistanceDataset(layer_test_sen_reps, training_config.test_control_distances[seed][0], test_performance_ceiling)
                        control_validation_dataset = TreeDistanceDataset(layer_validation_sen_reps, training_config.validation_control_distances[seed][0], validation_performance_ceiling)

                    else:
                        print(f'All trained models and results have been saved for hidden layer {hidden_layer_id} or {language_model.id}. Not loading dataset...')

                    for rank in training_config.ranks:

                        print('\n\n')
                        print('#'*140)
                        print('\n\n')

                        # REGULAR TASK
                        regular_probe_config = ProbeConfig(
                            language_model_id=language_model.id,
                            corpus_id=CORPUS.id,
                            task='regular',
                            hidden_layer_id=hidden_layer_id,
                            dim = language_model.hidden_dim,
                            rank=rank,
                            seed=seed)
                        
                        regular_probe_config.print_start_message()
                        
                        if regular_probe_config.was_already_trained() and not training_config.overwrite:
                            print('A probe with the given configuration has already been trained and saved to disk.')
                            regular_results = regular_probe_config.load_test_set_results()
                        else:
                            regular_results = train_probe([regular_train_dataset, regular_test_dataset, regular_validation_dataset], regular_probe_config, training_config)
                        
                        regular_uuas = np.mean(regular_results['uuas_score'])
                        final_results_frame['language_model_id'].append(language_model.id)
                        final_results_frame['hidden_layer_id'].append(hidden_layer_id)
                        final_results_frame['rank'].append(rank)
                        final_results_frame['seed'].append(seed)
                        final_results_frame['task'].append('regular')
                        final_results_frame['uuas_score'].append(regular_uuas)
                        final_results_frame['corrected_uuas_score'].append(regular_uuas) # not needed but prevents empty cells in the frame

                        # CONTROL TASK
                        control_probe_config = ProbeConfig(
                            language_model_id=language_model.id,
                            corpus_id=CORPUS.id,
                            task='control',
                            hidden_layer_id=hidden_layer_id,
                            dim = language_model.hidden_dim,
                            rank=rank,
                            seed=seed)

                        control_probe_config.print_start_message()

                        if control_probe_config.was_already_trained() and not training_config.overwrite:
                            print('A Probe with the given configuration has already been trained and saved to disk.')
                            control_results = control_probe_config.load_test_set_results()
                        else:
                            control_results  = train_probe([control_train_dataset, control_test_dataset, control_validation_dataset], control_probe_config, training_config)
                        
                        control_uuas = np.mean(control_results['uuas_score'])
                        corrected_control_uuas = np.mean(control_results['corrected_uuas_score'])

                        final_results_frame['language_model_id'].append(language_model.id)
                        final_results_frame['hidden_layer_id'].append(hidden_layer_id)
                        final_results_frame['rank'].append(rank)
                        final_results_frame['seed'].append(seed)
                        final_results_frame['task'].append('control')
                        final_results_frame['uuas_score'].append(control_uuas)
                        final_results_frame['corrected_uuas_score'].append(corrected_control_uuas)

                        selectivity = regular_uuas - control_uuas
                        generalized_selectivity = regular_uuas - corrected_control_uuas

                        final_results_frame['selectivity'].extend([selectivity, selectivity])
                        final_results_frame['generalized_selectivity'].extend([generalized_selectivity, generalized_selectivity])
                
                        print(f'With the given configuration, the final selectivity of the Probe is: {selectivity}, Generalized: {generalized_selectivity}')
        final_results = pd.DataFrame(final_results_frame)
        final_results.to_pickle(final_results_path)
        return final_results

## Run the Structural Probe Experiments

In [None]:
if RUN_STRUCTURAL_TASK:
    STRUCTURAL_PROBE_TRAINING_CONFIG = TrainingConfig(
        corpus=CORPUS,
        language_model_ids=['lstm', 'distilgpt2', 'xlnet'],
        ranks=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
        seeds=[67, 12484, 12],
        overwrite=True)

In [None]:
if RUN_STRUCTURAL_TASK:
    STRUCTURAL_PROBE_RESULTS = train_all_structural_probes(STRUCTURAL_PROBE_TRAINING_CONFIG)

## Examine Results

Let's visualize the results

In [None]:
def graph_y_by_rank(results_frame, y_id, y_label, training_config, agg_function):
    ensure_dir(OUT_RESULTS_IMAGES_ROOT)
    if agg_function == 'mean':
        grouped = results_frame.groupby(['language_model_id', 'rank', 'seed'])[y_id].mean().reset_index()
    elif agg_function == 'max':
        grouped = results_frame.groupby(['language_model_id', 'rank', 'seed'])[y_id].max().reset_index()
    plt.clf()
    ax = sns.lineplot(x='rank', y=y_id, data=grouped, hue='language_model_id', style='language_model_id', markers=['s', 'v', 'o'], palette=[ "#0077B3", "#FF1A75","#33CC33"], dashes=False, ci='sd')
    ax.set_xscale('log')
    ax.set_xlabel('Probe Rank')
    ax.set_ylabel(y_label)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[1:], labels=labels[1:])
    ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(training_config.ranks))
    ax.xaxis.set_major_formatter(matplotlib.ticker.FixedFormatter(training_config.ranks))
    out_path = os.path.join(OUT_RESULTS_IMAGES_ROOT, f'max_{y_label}_by_rank.png')
    plt.savefig(out_path)

def graph_y_by_hidden_layer_id(results_frame, y_id, y_label, agg_function: str):
    ensure_dir(OUT_RESULTS_IMAGES_ROOT)
    if agg_function == 'mean':
        grouped = results_frame.groupby(['language_model_id', 'hidden_layer_id', 'seed'])[y_id].mean().reset_index()
    elif agg_function == 'max':
        grouped = results_frame.groupby(['language_model_id', 'hidden_layer_id', 'seed'])[y_id].max().reset_index()
    elif agg_function == 'no_agg':
        grouped = results_frame
    plt.clf()
    ax = sns.lineplot(x='hidden_layer_id', y=y_id, data=grouped, hue='language_model_id', style='language_model_id', markers=['s', 'v', 'o'], palette=[ "#0077B3", "#FF1A75","#33CC33"], dashes=False, ci='sd')
    ax.set_xlabel('Hidden Layer Index')
    ax.set_ylabel(y_label)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[1:], labels=labels[1:])
    out_path = os.path.join(OUT_RESULTS_IMAGES_ROOT, f'average_{y_label}_by_hidden_layer_id.png')
    plt.savefig(out_path)

def plot_uuas_by_sentence_length(results_frame):
    ensure_dir(OUT_RESULTS_IMAGES_ROOT)
    plot_frame = results_frame.sort_values(by='language_model_id')
    plt.clf()
    ax = sns.lineplot(x='sentence_length', y='uuas_score', data=plot_frame, hue='language_model_id', style='language_model_id', markers=['s', 'v', 'o'], palette=[ "#0077B3", "#FF1A75","#33CC33"], dashes=False, ci='sd')
    ax.set_xlabel('Sentence Length')
    ax.set_ylabel('UUAS')
    ax.set_ylim(0.2, 1.0)
    ax.set_xlim(2, 82)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[1:], labels=labels[1:])
    out_path = os.path.join(OUT_RESULTS_IMAGES_ROOT, f'uuas_by_sentence_length_most_selective_probes.png')
    plt.savefig(out_path)

In [None]:
STRUCTURAL_PROBE_REGULAR_TASK_RESULTS = STRUCTURAL_PROBE_RESULTS[STRUCTURAL_PROBE_RESULTS['task'] == 'regular']

Scores of the most selective probe at the final hidden layer

In [None]:
FINAL_LAYER_SCORES = STRUCTURAL_PROBE_REGULAR_TASK_RESULTS.groupby(['language_model_id', 'hidden_layer_id', 'rank'], as_index=False)
FINAL_LAYER_SCORES = FINAL_LAYER_SCORES.agg({'uuas_score': ['mean', 'std'],'generalized_selectivity': ['mean', 'std']})
FINAL_LAYER_SCORES = FINAL_LAYER_SCORES.sort_values(by=['hidden_layer_id', ('generalized_selectivity', 'mean')], ascending=False)
FINAL_LAYER_SCORES.columns = ['language_model_id', 'hidden_layer_id', 'rank', 'mean_uuas', 'mean_std', 'mean_gs', 'std_gs']
FINAL_LAYER_SCORES = FINAL_LAYER_SCORES.drop_duplicates(subset=['language_model_id'])
FINAL_LAYER_SCORES

Now let's graph!

In [None]:
graph_y_by_hidden_layer_id(STRUCTURAL_PROBE_REGULAR_TASK_RESULTS, 'generalized_selectivity', 'Generalized Selectivity', 'max')

In [None]:
graph_y_by_hidden_layer_id(STRUCTURAL_PROBE_REGULAR_TASK_RESULTS, 'uuas_score', 'UUAS', 'max')

In [None]:
graph_y_by_rank(STRUCTURAL_PROBE_REGULAR_TASK_RESULTS, 'generalized_selectivity', 'Generalized Selectivity', STRUCTURAL_PROBE_TRAINING_CONFIG, 'mean')

In [None]:
graph_y_by_rank(STRUCTURAL_PROBE_REGULAR_TASK_RESULTS, 'uuas_score', 'UUAS', STRUCTURAL_PROBE_TRAINING_CONFIG, 'max')

Display highest scores and corresponding hidden layers

In [None]:
AGG_SCORES_PER_LANGUAGE_MODEL = STRUCTURAL_PROBE_REGULAR_TASK_RESULTS.groupby(['language_model_id', 'hidden_layer_id', 'rank'], as_index=False)
AGG_SCORES_PER_LANGUAGE_MODEL = AGG_SCORES_PER_LANGUAGE_MODEL.agg({'uuas_score': ['mean', 'std'], 'selectivity': ['mean', 'std'], 'generalized_selectivity': ['mean', 'std']})

In [None]:
BEST_UUUAS_PER_LANGUAGE_MODEL = AGG_SCORES_PER_LANGUAGE_MODEL.sort_values(by=('uuas_score', 'mean'), ascending=False)
BEST_UUUAS_PER_LANGUAGE_MODEL.columns = ['language_model_id', 'hidden_layer_id', 'rank', 'mean_uuas', 'mean_std', 'mean_s', 'std_s', 'mean_gs', 'std_gs']
BEST_UUUAS_PER_LANGUAGE_MODEL = BEST_UUUAS_PER_LANGUAGE_MODEL.drop_duplicates(subset=['language_model_id'])
BEST_UUUAS_PER_LANGUAGE_MODEL

In [None]:
BEST_SELECTIVITY_PER_LANGUAGE_MODEL = AGG_SCORES_PER_LANGUAGE_MODEL.sort_values(by=('selectivity', 'mean'), ascending=False)
BEST_SELECTIVITY_PER_LANGUAGE_MODEL.columns = ['language_model_id', 'hidden_layer_id', 'rank', 'mean_uuas', 'mean_std', 'mean_s', 'std_s', 'mean_gs', 'std_gs']
BEST_SELECTIVITY_PER_LANGUAGE_MODEL = BEST_SELECTIVITY_PER_LANGUAGE_MODEL.drop_duplicates(subset=['language_model_id'])
BEST_SELECTIVITY_PER_LANGUAGE_MODEL

In [None]:
BEST_GSELECTIVITY_PER_LANGUAGE_MODEL = AGG_SCORES_PER_LANGUAGE_MODEL.sort_values(by=('generalized_selectivity', 'mean'), ascending=False)
BEST_GSELECTIVITY_PER_LANGUAGE_MODEL.columns = ['language_model_id', 'hidden_layer_id', 'rank', 'mean_uuas', 'mean_std', 'mean_s', 'std_s', 'mean_gs', 'std_gs']
BEST_GSELECTIVITY_PER_LANGUAGE_MODEL = BEST_GSELECTIVITY_PER_LANGUAGE_MODEL.drop_duplicates(subset=['language_model_id'])
BEST_GSELECTIVITY_PER_LANGUAGE_MODEL

Measure probe generalization by plotting average UUAS by sentence length

In [None]:
to_frame = {
    'language_model_id': [],
    'seed': [],
    'sentence_length': [],
    'uuas_score': []
}

LANGUAGE_MODEL_TO_DIM_MAP ={
    'lstm': 650,
    'xlnet': 1024,
    'distilgpt2': 768
}


for tup in BEST_GSELECTIVITY_PER_LANGUAGE_MODEL.itertuples():
    for seed in STRUCTURAL_PROBE_TRAINING_CONFIG.seeds:
        probe_config = ProbeConfig(
            language_model_id=tup.language_model_id,
            corpus_id=CORPUS.id,
            task='regular',
            hidden_layer_id=tup.hidden_layer_id,
            dim=LANGUAGE_MODEL_TO_DIM_MAP[tup.language_model_id],
            rank=tup.rank,
            seed=seed)
        best_results = probe_config.load_test_set_results()
        for tup2 in best_results.itertuples():
            to_frame['language_model_id'].append(tup.language_model_id)
            to_frame['sentence_length'].append(tup2.gold_distances.shape[0])
            to_frame['seed'].append(seed)
            to_frame['uuas_score'].append(tup2.uuas_score)

HIGHEST_GS_PROBES_UUAS_BY_SENTENCE_LENGTH = pd.DataFrame(to_frame)

In [None]:
plot_uuas_by_sentence_length(HIGHEST_GS_PROBES_UUAS_BY_SENTENCE_LENGTH)

Finally, let's test for significance

In [None]:
lstm_scores = HIGHEST_GS_PROBES_UUAS_BY_SENTENCE_LENGTH[HIGHEST_GS_PROBES_UUAS_BY_SENTENCE_LENGTH['language_model_id'] == 'lstm'].groupby(['sentence_length']).agg({'uuas_score': 'mean'}).values
distilgpt2_scores = HIGHEST_GS_PROBES_UUAS_BY_SENTENCE_LENGTH[HIGHEST_GS_PROBES_UUAS_BY_SENTENCE_LENGTH['language_model_id'] == 'distilgpt2'].groupby(['sentence_length']).agg({'uuas_score': 'mean'}).values
xlnet_scores = HIGHEST_GS_PROBES_UUAS_BY_SENTENCE_LENGTH[HIGHEST_GS_PROBES_UUAS_BY_SENTENCE_LENGTH['language_model_id'] == 'xlnet'].groupby(['sentence_length']).agg({'uuas_score': 'mean'}).values

In [None]:
ttest_ind(lstm_scores, distilgpt2_scores)

In [None]:
ttest_ind(lstm_scores, xlnet_scores)

In [None]:
ttest_ind(distilgpt2_scores, xlnet_scores)

Let's also test significance between final layer results

In [None]:
FINAL_LAYER_SCORES_BY_SEED = STRUCTURAL_PROBE_REGULAR_TASK_RESULTS.groupby(['language_model_id', 'hidden_layer_id','seed', 'rank'], as_index=False)
FINAL_LAYER_SCORES_BY_SEED = FINAL_LAYER_SCORES_BY_SEED.agg({'uuas_score': ['mean'],'generalized_selectivity': ['mean']})
FINAL_LAYER_SCORES_BY_SEED_XLNET = FINAL_LAYER_SCORES_BY_SEED[(FINAL_LAYER_SCORES_BY_SEED.language_model_id == 'xlnet') & (FINAL_LAYER_SCORES_BY_SEED.hidden_layer_id == 12)]
FINAL_LAYER_SCORES_BY_SEED_LSTM = FINAL_LAYER_SCORES_BY_SEED[(FINAL_LAYER_SCORES_BY_SEED.language_model_id == 'lstm') & (FINAL_LAYER_SCORES_BY_SEED.hidden_layer_id == 1)]
FINAL_LAYER_SCORES_BY_SEED_DISTILGPT2 = FINAL_LAYER_SCORES_BY_SEED[(FINAL_LAYER_SCORES_BY_SEED.language_model_id == 'distilgpt2') & (FINAL_LAYER_SCORES_BY_SEED.hidden_layer_id == 6)]
FINAL_LAYER_SCORES_BY_SEED = pd.concat([FINAL_LAYER_SCORES_BY_SEED_XLNET, FINAL_LAYER_SCORES_BY_SEED_LSTM, FINAL_LAYER_SCORES_BY_SEED_DISTILGPT2])
FINAL_LAYER_SCORES_BY_SEED = FINAL_LAYER_SCORES_BY_SEED.sort_values(by=('generalized_selectivity', 'mean'), ascending=False)
FINAL_LAYER_SCORES_BY_SEED.columns = ['language_model_id', 'hidden_layer_id', 'seed', 'rank', 'mean_uuas', 'mean_gs']
FINAL_LAYER_SCORES_BY_SEED = FINAL_LAYER_SCORES_BY_SEED.drop_duplicates(subset=['language_model_id', 'seed'])

In [None]:
to_frame = {
    'language_model_id': [],
    'seed': [],
    'uuas_score': []
}

for tup in FINAL_LAYER_SCORES_BY_SEED.itertuples():
    probe_config = ProbeConfig(
        language_model_id=tup.language_model_id,
        corpus_id=CORPUS.id,
        task='regular',
        hidden_layer_id=tup.hidden_layer_id,
        dim=LANGUAGE_MODEL_TO_DIM_MAP[tup.language_model_id],
        rank=tup.rank,
        seed=seed)
    best_results = probe_config.load_test_set_results()
    for tup2 in best_results.itertuples():
        to_frame['language_model_id'].append(tup.language_model_id)
        to_frame['seed'].append(seed)
        to_frame['uuas_score'].append(tup2.uuas_score)

FINAL_LAYER_SCORES_BY_SEED_BY_SENTENCE = pd.DataFrame(to_frame)

In [None]:
lstm_uuas = ('lstm', FINAL_LAYER_SCORES_BY_SEED_BY_SENTENCE[(FINAL_LAYER_SCORES_BY_SEED_BY_SENTENCE['language_model_id'] == 'lstm')]['uuas_score'].values)
distilgpt2_uuas = ('distilgpt2', FINAL_LAYER_SCORES_BY_SEED_BY_SENTENCE[(FINAL_LAYER_SCORES_BY_SEED_BY_SENTENCE['language_model_id'] == 'distilgpt2')]['uuas_score'].values)
xlnet_uuas = ('xlnet', FINAL_LAYER_SCORES_BY_SEED_BY_SENTENCE[(FINAL_LAYER_SCORES_BY_SEED_BY_SENTENCE['language_model_id'] == 'xlnet')]['uuas_score'].values)

In [None]:
for combo in itertools.combinations([lstm_uuas, distilgpt2_uuas, xlnet_uuas], 2):
    (model_name_one, results_one), (model_name_two, results_two) = combo
    sig_val = ttest_ind(results_one, results_two).pvalue
    print(f'The difference between {model_name_one} and {model_name_two} is {"not" if sig_val > 0.05 else "indeed" } significant: {sig_val}')

# Generate Final Report Figures

In [None]:
def plot_pos_tag_results_for_paper(data, lims = [(0.78, 0.92), (0.15, 0.3)], step_size = [0.03, 0.03], x_label = "Sample count", x_ticks=[100, 1000, 10000, 12543]):

    y_label = ["Mean accuracy", "Mean selectivity"]

    layout=[("s", "#0077b3"), ("o", "#33cc33"), ("v","#ff1a75")]
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(9,3))

    for i, d in enumerate(data): 
        for j, (model_id, res) in enumerate(d.items()):
            y = np.array(res["mean"])
            std = np.array(res["std"])

            # Make sure does not exteed 0 - 1 
            std_plus = y+std #np.where(y+std > 1, 1, y+std)
            std_min = y-std #np.where(y-std < 0, 0, y-std)

            marker = layout[j][0]
            color = layout[j][1]

            x = range(len(y))
            # Plot
            ax[i].plot(x, y, color=color, marker = marker, markersize=7, label=model_id)
            ax[i].fill_between(x, std_min, std_plus, facecolor=color,alpha=0.2, interpolate=True)

            ax[i].set_ylabel(y_label[i], fontsize=15)
            ax[i].set_xlabel(x_label, fontsize=14)
            ax[i].set_ylim(lims[i]) 
            ax[i].set_xlim((0, 1)) 

            ax[i].set_xticks(x)
            ax[i].set_xticklabels(x_ticks, fontsize=14)
            ax[i].set_yticks(np.arange(lims[i][0], lims[i][1]+step_size[i], step_size[i]))
            ax[i].tick_params(axis="y", labelsize=13)
            ax[0].legend(prop={'size': 13}, loc="lower right")
    plt.suptitle("Part-of-Speech tag prediction", fontsize=15)
    plt.tight_layout(rect=[0, 0.03, 1, 0.93])
    plt.savefig(os.path.join(OUT_RESULTS_ROOT, "plot_pos_tag_hidden_indices.png"))
    plt.show()

def plot_tree_distance_task_results_for_paper(tree_distance_results_frame):
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,3))

    grouped = tree_distance_results_frame.groupby(['language_model_id', 'hidden_layer_id', 'seed'])['uuas_score'].max().reset_index()

    sns.lineplot(ax=ax[0], x='hidden_layer_id', y='uuas_score', data=grouped, hue='language_model_id', style='language_model_id',markersize=8, markers=['s', 'v', 'o'], palette=[ "#0077B3", "#FF1A75","#33CC33"], dashes=False, ci='sd')
    handles, labels = ax[0].get_legend_handles_labels()
    ax[0].legend(handles=handles[1:], labels=labels[1:], prop={'size': 11}, loc="lower right")
    ax[0].set_xlabel('Hidden Layer Index', fontsize=14)
    ax[0].set_ylabel('UUAS', fontsize=15)

    # ax[0].set_ylim((0.35, 0.75)) 

    ax[0].set_xticks(np.arange(13))
    ax[0].set_xticklabels(np.arange(13), fontsize=13)
    # ax[0].set_yticks(np.arange(0.35, 0.75+0.1, 0.1))
    ax[0].tick_params(axis="y", labelsize=13)


    grouped = tree_distance_results_frame.groupby(['language_model_id', 'hidden_layer_id', 'seed'])['generalized_selectivity'].max().reset_index()

    sns.lineplot(ax=ax[1], x='hidden_layer_id', y='generalized_selectivity', data=grouped, legend=None, hue='language_model_id', style='language_model_id',markersize=8, markers=['s', 'v', 'o'], palette=[ "#0077B3", "#FF1A75","#33CC33"], dashes=False, ci='sd')
    ax[1].set_xlabel('Hidden Layer Index', fontsize=14)
    ax[1].set_ylabel('Gen. Selectivity', fontsize=15)

    # ax[1].set_ylim((-0.1, 0.3)) 

    ax[1].set_xticks(np.arange(13))
    ax[1].set_xticklabels(np.arange(13), fontsize=12)
    # ax[1].set_yticks(np.arange(-0.1, 0.3 +0.1, 0.1))
    ax[1].tick_params(axis="y", labelsize=13)

    plt.suptitle("Tree Distance prediction", fontsize=14)
    # plt.tight_layout(rect=[0, 0.03, 1, 0.94])
    plt.savefig(os.path.join(OUT_RESULTS_ROOT, "plot_tree_distance_hidden_indices.png"))
    plt.show()   

In [None]:
if RUN_POS_TASK:
    plot_pos_tag_results_for_paper(ALL_POS_RESULTS)

In [None]:
if RUN_STRUCTURAL_TASK:
    plot_tree_distance_task_results_for_paper(STRUCTURAL_PROBE_REGULAR_TASK_RESULTS)