In [1]:
# uncomment on Colab
#!pip install transformers datasets torchinfo

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn.functional as F

from datasets import load_dataset
from torchinfo import summary
from torch.nn import Identity
from transformers import AutoConfig, AutoModel, AutoTokenizer

In [3]:
def get_random_input(dataset, tokenizer):
    MIN_TEXT_LEN = 300
    while True:
        if dataset is not None:
            l = len(dataset["train"])
            text = dataset["train"][torch.randint(l, (1,)).item()]["text"]
        else:
            text = tokenizer.decode(torch.randint(tokenizer.vocab_size, (tokenizer.model_max_length,)))
        ei = tokenizer(text, return_tensors="pt", truncation=True, return_attention_mask=False)
        if ei["input_ids"].shape[1] > MIN_TEXT_LEN:
            break
    return ei

In [4]:
def compute_correlations(hidden_states, queries, keys):
    corrs = []
    for hidden_state, query, key in zip(hidden_states, queries, keys):
        X = hidden_state.squeeze(0).clone().detach().requires_grad_(False)
        Q_unit = F.normalize(query(X), dim=1)
        K_unit = F.normalize(key(X), dim=1)
        similarities = torch.matmul(Q_unit, K_unit.transpose(0, 1))
        corrs.append(similarities.flatten().detach())
    return corrs

In [5]:
def extract_queries_and_keys(model, replace_with_identity=False):
    num_layers = model.config.num_hidden_layers
    if replace_with_identity:
        return num_layers * [Identity()], num_layers * [Identity()]
    model_type = model.config.model_type
    match model_type:
        case "albert":
            Q = model.encoder.albert_layer_groups[0].albert_layers[0].attention.query
            K = model.encoder.albert_layer_groups[0].albert_layers[0].attention.key
            return num_layers * [Q], num_layers * [K]
        case "bert":
            queries = [l.attention.self.query for l in model.encoder.layer]
            keys = [l.attention.self.key for l in model.encoder.layer]
            return queries, keys
        case _:
            raise NotImplementedError("Unsupported model type", model_type)

In [6]:
def disable_dense_layers(model):
    model_type = model.config.model_type
    match model_type:
        case "albert":
            with torch.no_grad():
                model.encoder.albert_layer_groups[0].albert_layers[0].attention.dense.weight.fill_(0.0)
                model.encoder.albert_layer_groups[0].albert_layers[0].attention.dense.bias.fill_(0.0)
        case "bert":
            for i in range(len(model.encoder.layer)):
                with torch.no_grad():
                    model.encoder.layer[i].attention.output.dense.weight.fill_(0.0)
                    model.encoder.layer[i].attention.output.dense.bias.fill_(0.0)

In [18]:
def plot_histograms(dataset, model_id, use_queries_and_keys=False, no_dense_layers=False):
    tokeniser = AutoTokenizer.from_pretrained(model_id)
    model = AutoModel.from_pretrained(model_id, num_attention_heads=1)
    if no_dense_layers:
        disable_dense_layers(model)
    input = get_random_input(dataset, tokeniser)
    output = model(**input, output_hidden_states=True)
    queries, keys = extract_queries_and_keys(model, not use_queries_and_keys)
    correls = compute_correlations(output["hidden_states"], queries, keys)

    # Determine the global maximum density value
    max_density = 0
    for data in correls:
        counts, bin_edges = np.histogram(data, bins=100, density=True)
        max_density = max(max_density, max(counts))

    for i, data in enumerate(correls):
        IQR = np.percentile(data, 75) - np.percentile(data, 25)
        n = len(data)
        bin_width = 2 * IQR / n ** (1 / 3)
        bins = int((max(data) - min(data)) / bin_width)

        plt.figure()
        plt.hist(data, bins=bins, density=True, histtype="step", color="#3658bf", linewidth=1.5)
        plt.title(f"Layer {i}", fontsize=16)
        plt.xlim(-.3, 1.05)
        plt.ylim(0, max_density)  # Set a consistent y-axis limit

        dir = f"histograms/{model.config._name_or_path}"
        if dataset is None:
            dir += "_uniformtokens"
        else:
            dir += dataset["train"].info.dataset_name
        if use_queries_and_keys:
            dir += "_querieskeys"
        if no_dense_layers:
            dir += "_nodense"
        os.makedirs(dir, exist_ok=True)
        plt.savefig(f"{dir}/histogram_layer_{i}.pdf")
        plt.close()

In [8]:
wikitext = load_dataset("wikitext", "wikitext-103-v1")

In [9]:
MODEL_ID = "albert-xlarge-v2" # "bert-large-uncased" # "albert-xlarge-v2"
NUM_ATTN_HEADS = 1
tokeniser = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID, num_attention_heads=NUM_ATTN_HEADS)
print(model.config)

AlbertConfig {
  "_name_or_path": "albert-xlarge-v2",
  "architectures": [
    "AlbertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "down_scale_factor": 1,
  "embedding_size": 128,
  "eos_token_id": 3,
  "gap_size": 0,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 8192,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "albert",
  "net_structure_type": 0,
  "num_attention_heads": 1,
  "num_hidden_groups": 1,
  "num_hidden_layers": 24,
  "num_memory_blocks": 0,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.41.1",
  "type_vocab_size": 2,
  "vocab_size": 30000
}



In [10]:
# note that the same single layer is called repeatedly in ALBERT
summary(model, depth=6)

Layer (type:depth-idx)                                  Param #
AlbertModel                                             --
├─AlbertEmbeddings: 1-1                                 --
│    └─Embedding: 2-1                                   3,840,000
│    └─Embedding: 2-2                                   65,536
│    └─Embedding: 2-3                                   256
│    └─LayerNorm: 2-4                                   256
│    └─Dropout: 2-5                                     --
├─AlbertTransformer: 1-2                                --
│    └─Linear: 2-6                                      264,192
│    └─ModuleList: 2-7                                  --
│    │    └─AlbertLayerGroup: 3-1                       --
│    │    │    └─ModuleList: 4-1                        --
│    │    │    │    └─AlbertLayer: 5-1                  --
│    │    │    │    │    └─LayerNorm: 6-1               4,096
│    │    │    │    │    └─AlbertAttention: 6-2         16,789,504
│    │    │    │    │ 

In [11]:
print("wikitext input:", get_random_input(wikitext, tokeniser))
print("random input:", get_random_input(None, tokeniser))

wikitext input: {'input_ids': tensor([[    2,    76,    14,   996,    16,    14, 15588,  5149,    23,  1117,
            19,   356,  2325,    13,    15,    21,   284, 15761,    16,  4569,
          1789,    37,   375,  3864,    13,     9,  3074,  6129,   324,   703,
            32,    37,    14, 24522,    13,     9,    14,   996,  3681,    53,
            16,    14,   324,   443,   377, 14757, 14165,    13,    15,    72,
            23,    21,  3372,    13,     1,     8,     1,   159,    13,     1,
             8,     1,   315,  1377,  9541,    13,     9, 14165,   260,    36,
           545,    28,    21,   375,   850,   778,    27,    21,   633,   298,
            19,  4733,   136,    17,    27,  2614, 20948,    13,    22,    18,
          4565, 10653,    13,     9,    36,   461,   201,    23,  2739,    34,
          5408,  4927,    13,    73,    36,  1144,  9542,   206,    89,  2088,
            13,    15,    17,    39,   478,  1269,    19,  4500,  2039,    13,
            15,    47,

In [19]:
for model_id in ["albert-xlarge-v2", "bert-large-uncased"]:
    for use_queries_and_keys in [True, False]:
        for no_dense_layers in [True, False]:
            plot_histograms(None, model_id, use_queries_and_keys, no_dense_layers)