In [1]:
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer, BertTokenizer

torch.set_grad_enabled(False)

import itertools
import numpy as np
import pandas as pd
from scipy import stats 
from scipy import corrcoef
from scipy.spatial.distance import cosine, euclidean, pdist, squareform, is_valid_dm, cdist
from sklearn.metrics import pairwise_distances
from scipy.stats import spearmanr
from scipy.spatial import distance_matrix
import torch

#Visualization packages
import seaborn as sns
import matplotlib.pylab as plt

# Implementing [All Bark and No Bite: Rogue Dimensions in Transformer Language Models Obscure Representational Quality](https://arxiv.org/abs/2109.04404)

Numpy implementation of the formulas in the paper. The quoted text is copied out of the paper for context

> [ethayarajh (2019)](https://aclanthology.org/D19-1006/) defines the anisotropy in layer $\ell$ of model $f$ as the expected cosine similarity of any pair of words in a corpus.
This can be approximated as $\textit{$\hat{A}$} (f_{\ell})$ from a sample $S$ of $n$ random token pairs from a corpus $\mathcal{O}$.

> $$S = \{\{x_1,y_1\},...,\{x_n,y_n\}\} \sim \mathcal{O}$$

In [2]:
n = 300000 # number of samples
d = 100 # vector dimension
S = np.random.rand(n,2, d)

> $$\hat{A}(f_{\ell}) = \frac{1}{n} \sum_{\{x_\alpha,y_\alpha\} \in S} \cos( f_{\ell}(x_\alpha), f_{\ell}(y_\alpha) )$$

In [3]:
anisotropy = lambda S: (1/n) * np.sum([cos_uv(fx, fy) for fx, fy in S])

> $$ \cos(u, v) = \frac{u \cdot v} {{\lVert u \rVert}{\lVert v \rVert}} = \sum_{i=1}^d \frac{ u_i v_i}{{\lVert u \rVert}{\lVert v \rVert}} $$

In [4]:
cos_uv = lambda u, v: u.dot(v) / (np.linalg.norm(u) * np.linalg.norm(v))

> $$ CC_i(u, v) = \frac{ u_i v_i}{{\lVert u \rVert}{\lVert v \rVert}} $$

In [5]:
cci_uv = lambda u, v: u * v / (np.linalg.norm(u) * np.linalg.norm(v))

> $$ {CC}(f^i_{\ell}) = \frac{1}{n} \ \cdot\!\!\!\!\!\!\sum_{\{x_\alpha,y_\alpha\} \in S} CC_i( f_{\ell}(x_\alpha), f_{\ell}(y_\alpha) ) $$

In [6]:
cc_fl = lambda S: (1/n) * np.sum(np.squeeze([cci_uv(fx, fy) for fx, fy in S]), axis=0)

> Note that $\sum^d_i{CC}(f^i_{\ell}) = \hat{A}(f_{\ell})$.

In [7]:
np.isclose(anisotropy(S), cc_fl(S).sum())

True

In [8]:
%time anisotropy(S)

CPU times: user 2.9 s, sys: 23.6 ms, total: 2.92 s
Wall time: 2.91 s


0.7506107466202372

In [9]:
%%time
cc_fl(S)

CPU times: user 3.5 s, sys: 129 ms, total: 3.63 s
Wall time: 3.51 s


array([0.00749749, 0.00750954, 0.00748971, 0.00751329, 0.00748488,
       0.00750173, 0.00751275, 0.00750584, 0.00750578, 0.00749085,
       0.00750159, 0.0074942 , 0.00750036, 0.00750421, 0.00749298,
       0.00750623, 0.0075318 , 0.00751165, 0.00749315, 0.00751508,
       0.00750017, 0.00751986, 0.00748115, 0.00750584, 0.00749662,
       0.00750519, 0.00751023, 0.00750669, 0.00750762, 0.00750954,
       0.00751803, 0.00749991, 0.00747264, 0.00750444, 0.00753296,
       0.00748529, 0.00750528, 0.00751205, 0.00750904, 0.00750049,
       0.00750696, 0.00748823, 0.00753414, 0.00750516, 0.00751942,
       0.0075053 , 0.00751044, 0.00751752, 0.00750244, 0.00750581,
       0.00751689, 0.00750857, 0.00750777, 0.00750515, 0.0075133 ,
       0.00749719, 0.00750302, 0.00748344, 0.00749711, 0.00749886,
       0.0074961 , 0.00750853, 0.00751339, 0.00749931, 0.00750472,
       0.0075191 , 0.00751134, 0.00751343, 0.00752263, 0.00751289,
       0.0074892 , 0.0074945 , 0.0075044 , 0.00750736, 0.00750

## vectorized version for speed

Instead of looping over the samples, do the calculations with matrices.

In [10]:
norm = lambda S: np.linalg.norm(S, axis=2, keepdims=True).prod(axis=1)
ccfl = lambda S: (1/S.shape[0]) * ((np.prod(S, axis=1) / norm(S))).sum(axis=0)

In [11]:
%time np.sort(ccfl(S))

CPU times: user 235 ms, sys: 44.2 ms, total: 279 ms
Wall time: 278 ms


array([0.00747264, 0.00748115, 0.00748344, 0.00748488, 0.00748529,
       0.00748823, 0.0074892 , 0.00748971, 0.00749085, 0.00749273,
       0.00749298, 0.00749315, 0.0074942 , 0.0074945 , 0.00749575,
       0.0074961 , 0.00749662, 0.00749711, 0.00749719, 0.00749749,
       0.00749819, 0.00749886, 0.00749931, 0.00749991, 0.00750017,
       0.00750036, 0.00750049, 0.00750128, 0.00750159, 0.00750173,
       0.00750239, 0.00750244, 0.00750283, 0.00750302, 0.00750305,
       0.00750386, 0.00750421, 0.0075044 , 0.00750444, 0.00750472,
       0.00750487, 0.00750515, 0.00750516, 0.00750519, 0.00750528,
       0.0075053 , 0.00750559, 0.00750578, 0.00750581, 0.00750584,
       0.00750584, 0.00750623, 0.00750627, 0.00750639, 0.00750669,
       0.00750696, 0.00750736, 0.00750762, 0.00750777, 0.00750853,
       0.00750857, 0.00750898, 0.00750899, 0.00750904, 0.00750954,
       0.00750954, 0.00750994, 0.00751019, 0.00751023, 0.00751044,
       0.00751103, 0.00751134, 0.00751165, 0.00751205, 0.00751

In [12]:
print(np.isclose(anisotropy(S), ccfl(S).sum()))
print(np.isclose(cc_fl(S),ccfl(S)))

True
[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]


## Testing it on some data

In [13]:
# Store the model we want to use
MODEL_NAME = "bert-base-uncased" #@param
#MODEL_NAME = "gpt2" # doesn't quite work?

# We need to create the model and tokenizer
model = AutoModel.from_pretrained(MODEL_NAME,
                                  output_hidden_states=True,
                                  output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME 
                                          )

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
def combine_output_for_layers(model, inputs, states, word_groups, layers):
    # Stack all words in the sentence
    if MODEL_NAME in ["gpt2", "gpt2-medium", "gpt2-large"]:
        emb_layer = model.wte
    else:
        emb_layer = model.embeddings.word_embeddings
        
    sent_tokens_output = torch.stack([
        # Sum the requested layers
        torch.stack([
                states[i].detach()[:,token_ids_word].mean(axis=1)
                    if i > 0 else
                emb_layer(inputs)[:,token_ids_word].mean(axis=1)
                        for i in layers
            ]).sum(axis=0).squeeze()
                for token_ids_word in word_groups
        ])
#    print("OUTPUT SHAPE", sent_tokens_output.shape)
    return sent_tokens_output

In [15]:
sentences = pd.read_csv("lines.csv")
sentences['length'] = sentences.line.str.split().apply(len)
display(sentences[sentences['length'] <100].length.describe())

count    3913.000000
mean       20.652952
std         9.943947
min         1.000000
25%        13.000000
50%        20.000000
75%        27.000000
max        80.000000
Name: length, dtype: float64

In [16]:
layers = (np.arange(12) + 1).reshape(-1,1)
from collections import defaultdict
vecs = []
layer_result = defaultdict(list)
sent_words = []

for sent in sentences['line']:
    encoded = tokenizer(sent, return_tensors="pt")
    inputs = encoded.input_ids
    attention_mask =  encoded['attention_mask']
    output = model(input_ids=inputs, attention_mask=attention_mask)
    states = output.hidden_states
    token_len = attention_mask.sum().item()
    decoded = tokenizer.convert_ids_to_tokens(inputs[0], skip_special_tokens=False)
    if MODEL_NAME in ["gpt2", "gpt2-medium", "gpt2-large"]:
        word_indices = np.array(list(map(lambda e: -1 if e is None else e, encoded.word_ids())))[:token_len]
        word_groups = np.split(np.arange(word_indices.shape[0]), np.unique(word_indices, return_index=True)[1])[1:]
        sw = ["".join(list(map(lambda t: t[1:] if t[:1] == "Ġ" else t, np.array(decoded)[g]))) for g in word_groups]
        sent_words.append(sw)
    else:
        word_indices = np.array(list(map(lambda e: -1 if e is None else e, encoded.word_ids())))[1:token_len - 1]
        word_groups = np.split(np.arange(word_indices.shape[0]) + 1, np.unique(word_indices, return_index=True)[1])[1:]
        sent_words.append(["".join(list(map(lambda t: t[2:] if t[:2] == "##" else t, np.array(decoded)[g]))) for g in word_groups])

    for n, layer_group in enumerate(layers):
        sent_vec = combine_output_for_layers(model, inputs, states, word_groups, layer_group)
        layer_result[n].append(sent_vec)

vecs = [np.concatenate(r) for r in layer_result.values()]


In [17]:
word_count = vecs[0].shape[0]
sample_size = word_count // 5
sample_size -= sample_size % 2
sample = np.random.choice(word_count, sample_size, replace=False)
result = []
for layer_data in vecs:
    data = layer_data[sample].reshape(sample_size // 2, 2, -1)
#    print("shape", data.shape)
#    print("sum", data[0][0].shape, data[0][0].sum())
    cc = ccfl(data)
#    print(cc.shape)
    ani = cc.sum()
#    print("ANI", ani)
    rel_cc = cc / ani
    result.append(np.sort(rel_cc).round(3)[-3:][::-1].tolist() + [ani])
    
display(pd.DataFrame(result, columns = ["1", "2", "3", "𝐴̂(𝑓ℓ)"], index=layers.flatten()))

Unnamed: 0,1,2,3,𝐴̂(𝑓ℓ)
1,0.589,0.077,0.026,0.137615
2,0.743,0.013,0.008,0.187472
3,0.716,0.01,0.008,0.186794
4,0.735,0.009,0.008,0.217777
5,0.754,0.014,0.009,0.254179
6,0.719,0.031,0.02,0.243871
7,0.593,0.113,0.025,0.237626
8,0.535,0.184,0.027,0.244794
9,0.487,0.187,0.027,0.229058
10,0.511,0.114,0.022,0.237548
