# Installing packages and downloading pretrained model

In [None]:
#!pip install transformers
#!pip install datasets

In [None]:
#import google
from pprint import pprint

from collections import defaultdict

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer, BertTokenizer
#from datasets import list_datasets, list_metrics, load_dataset, load_metric

torch.set_grad_enabled(False)

import itertools
import functools
import pandas as pd
from scipy import stats 
from scipy import corrcoef
from scipy.spatial.distance import cosine, euclidean, pdist, squareform, is_valid_dm, cdist
from sklearn.metrics import pairwise_distances
from scipy.stats import spearmanr
from scipy.spatial import distance_matrix

#Visualization packages
import seaborn as sns
import matplotlib.pylab as plt

In [None]:
class CachingWrapperLayer(torch.nn.Module):
    """
        A pytorch layer that wraps an existing layer and
        stores its output.
    """
    def __init__(self, wrapped, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data_store = []
        self.wrapped = wrapped
    def forward(self, *inputs, **kwargs):
        result = self.wrapped.forward(*inputs, **kwargs)
        self.data_store.append(result.detach().cpu().numpy())
        return result


In [None]:
# Code copied from captum/attr/_models/base.py

def _get_deep_layer_name(obj, layer_names):
    r"""
    Traverses through the layer names that are separated by
    dot in order to access the embedding layer.
    """
    return functools.reduce(getattr, layer_names.split("."), obj)


def _set_deep_layer_value(obj, layer_names, value):
    r"""
    Traverses through the layer names that are separated by
    dot in order to access the embedding layer and update its value.
    """
    layer_names = layer_names.split(".")
    setattr(functools.reduce(getattr, layer_names[:-1], obj), layer_names[-1], value)


In [None]:
# Store the model we want to use
MODEL_NAME = "bert-base-uncased" #@param
#MODEL_NAME = "gpt2" # doesn't quite work?

# We need to create the model and tokenizer
model = AutoModel.from_pretrained(MODEL_NAME,
                                  output_hidden_states=True,
                                  output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME 
                                          )

# Accessing the hidden states (contextualized embeddings)

In [None]:
def get_head_data(wrapper_cache, layer, head):
    layer_name, wrapper, orig = wrapper_cache[layer]
    value = wrapper.data_store[-1]
    parts = np.split(value, 12, axis=2)
    result = parts[head-1]
#    print(result.shape)
    return parts[head-1]

def combine_output_for_layers(model, inputs, wrapper_cache, word_groups, head_group):
    # Stack all words in the sentence
    if MODEL_NAME in ["gpt2", "gpt2-medium", "gpt2-large"]:
        emb_layer = model.wte
    else:
        emb_layer = model.embeddings.word_embeddings
            
    sent_tokens_output = np.stack([
        # Sum the requested layers
        np.stack([
                get_head_data(wrapper_cache, layer, head)[:,token_ids_word].mean(axis=1)
                    for layer, head in head_group 
            ]).sum(axis=0).squeeze()
                for token_ids_word in word_groups
        ])
#    print("OUTPUT SHAPE", sent_tokens_output.shape)
    return sent_tokens_output

# Retrieving word representations from separate or combined layers

In [None]:
sentences = [
        "Do not compare apples and oranges and apples and apples and apples and apples.",
        "another sentence.",
        "space before .",
    ]

In [None]:

vecs = []
group_result = defaultdict(list)
sent_words = []

of_interest = [
        [(1,1),(1,2),(1,3)],
        [(4,1)],
        [(8,2)],
        [(12,3)]
    ]

wrapper_cache = {}
for layer in set(t[0] for t in itertools.chain.from_iterable(of_interest)):
    layer_name = f"encoder.layer.{layer - 1}.attention.self.value"
    orig = _get_deep_layer_name(model, layer_name)
    wrapper = CachingWrapperLayer(orig)
    _set_deep_layer_value(model, layer_name, wrapper)
    wrapper_cache[layer] = (layer_name,wrapper,orig)
    
for sent in sentences:
    encoded = tokenizer(sent, return_tensors="pt")
    inputs = encoded.input_ids
    attention_mask =  encoded['attention_mask']
    output = model(input_ids=inputs, attention_mask=attention_mask)
    states = output.hidden_states
    token_len = attention_mask.sum().item()
    decoded = tokenizer.convert_ids_to_tokens(inputs[0], skip_special_tokens=False)
    if MODEL_NAME in ["gpt2", "gpt2-medium", "gpt2-large"]:
        word_indices = np.array(list(map(lambda e: -1 if e is None else e, encoded.word_ids())))[:token_len]
        word_groups = np.split(np.arange(word_indices.shape[0]), np.unique(word_indices, return_index=True)[1])[1:]
        sw = ["".join(list(map(lambda t: t[1:] if t[:1] == "Ġ" else t, np.array(decoded)[g]))) for g in word_groups]
        sent_words.append(sw)
    else:
        word_indices = np.array(list(map(lambda e: -1 if e is None else e, encoded.word_ids())))[1:token_len - 1]
        word_groups = np.split(np.arange(word_indices.shape[0]) + 1, np.unique(word_indices, return_index=True)[1])[1:]
        sent_words.append(["".join(list(map(lambda t: t[2:] if t[:2] == "##" else t, np.array(decoded)[g]))) for g in word_groups])

    for n, head_group in enumerate(of_interest):
        sent_vec = combine_output_for_layers(model, inputs, wrapper_cache, word_groups, head_group)
        group_result[n].append(sent_vec)
#    for layer_name, wrapper, orig in wrapper_cache.values():
#        wrapper.data_store = []

vecs = [np.concatenate(r) for r in group_result.values()]

for layer_name, wrapper, orig in wrapper_cache.values():
    _set_deep_layer_value(model, layer_name, orig)

# Compute and visualize dissimilarity matrices

In [None]:
distance_matrices = [
    cdist(vec, vec, "euclidean").round(1)
        for vec in vecs
]

In [None]:
labels = list(itertools.chain.from_iterable(sent_words))

plot_data = [
    ("y", 0, "transformer layer 1, heads 1-3"),
    ("", 1, "transformer layer 4, head 1"),
    ("xy", 2, "transformer layer 8, head 2"),
    ("x", 3, "transformer layer 12, head 3"),
]


fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(18, 18))
fig.suptitle('(euclidean) distance matrices between the embeddings for each word in the sentence')

for subplot, (label_axes, matrix_index, title) in zip(itertools.chain.from_iterable(axes), plot_data):
    heatmap_args = dict(
        linewidth=1, 
        annot = np.array([[f"{v:.0f}" if (v == 0 or len(str(v)) > 4) else f"{v:.1f}" for v in r] for r in distance_matrices[matrix_index]]),
        annot_kws={"size":8}, 
        fmt="",
        cmap = 'magma_r', 
        xticklabels=labels, 
        yticklabels=labels,
    )

    heatmap = sns.heatmap(distance_matrices[matrix_index], ax=subplot, **heatmap_args)
    subplot.set_title(title)
    for axis in [x for x in "xy" if x not in label_axes]:
        getattr(subplot, f"{axis}axis").set_visible(False)
fig.tight_layout()
plt.show()


The visualization nicely shows that the 'apples' in the example sentence receive very similar embeddings in layer 1, which are very different from 'oranges' (in fact, the embeddings for 'apples' are of course identical in the input embeddings). And then the 'apples' at the end behaves differently when we move higher up in the BERT layer, and the difference between the other apples and the oranges become smaller.

Using these DMs, it should be rather straightforward to run the Mantel test, and hopefully also to run this on a larger set of input sentences.