In [2]:
import os
import torch
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

from datasets import load_dataset # Hugging Face Datasets
from transformers import AlbertTokenizer, AlbertModel, AlbertConfig

In [3]:
# Load the WikiText-103 dataset (version 1)
wikitext = load_dataset("wikitext", "wikitext-103-v1")

### Normal Model (24 layers)

In [4]:
al_tkz = AlbertTokenizer.from_pretrained('albert-xlarge-v2')
al_model = AlbertModel.from_pretrained("albert-xlarge-v2")

In [5]:
def compute_correlations(hidden_states):
    corrs = []
    for hs in hidden_states:
        T = hs.squeeze(0).clone().detach().requires_grad_(False)
        T = torch.nn.functional.normalize(T, dim=1)
        T2 = torch.matmul(T, T.transpose(0, 1))
        corrs += [
            T2.flatten().cpu(),
        ]
    return corrs

In [6]:
def get_random_input(dataset, tokenizer):
    l = len(dataset["train"])
    while True:
        it = torch.randint(l, (1,)).item()
        text = dataset["train"][it]["text"]
        ei = tokenizer(text, return_tensors="pt", truncation=True)
        if ei["input_ids"].shape[1] > 300:
            break
    return ei

In [7]:
def plot_histograms(al_tkz=al_tkz, al_model=al_model):
    """
    Plots histograms for all layers in one 5x5 or 7x6 figure
    """
    ei = get_random_input(wikitext, al_tkz)
    print(al_tkz.batch_decode(ei["input_ids"]))
    of = al_model(**ei, output_hidden_states=True)
    correls = compute_correlations(of["hidden_states"])
    if al_model.config.num_hidden_layers < 25:
        fig, axes = plt.subplots(5, 5)
    else:
        fig, axes = plt.subplots(7, 7)
    axes = axes.flatten()
    for i in range(len(correls)):
        axes[i].hist(correls[i], bins=100, density=True, histtype="step")
        axes[i].set_title(f"Layer {i}")
        axes[i].set_xlim(-0.3, 1)

In [9]:
def plot_histograms_save(al_tkz=al_tkz, al_model=al_model):
    """
    Creates individual figures for each layer
    Saves each histogram as a separate PDF file in a histograms folder
    """
    ei = get_random_input(wikitext, al_tkz)

    print('- '*20)
    print("Random Input:")
    print(al_tkz.batch_decode(ei["input_ids"]))
    print('- '*20)
    
    of = al_model(**ei, output_hidden_states=True)
    correls = compute_correlations(of["hidden_states"])

    # Create a directory to save the plots
    os.makedirs("histograms", exist_ok=True)

    # Determine the global maximum density value
    max_density = 0
    for data in correls:
        counts, bin_edges = np.histogram(data, bins=100, density=True)
        max_density = max(max_density, max(counts))

    for i, data in enumerate(correls):
        IQR = np.percentile(data, 75) - np.percentile(data, 25)
        n = len(data)
        bin_width = 2 * IQR / n ** (1 / 3)
        bins = int((max(data) - min(data)) / bin_width)

        plt.figure()
        plt.hist(
            data,
            bins=bins,
            density=True,
            histtype="step",
            color="#3658bf",
            linewidth=1.5,
        )
        plt.title(f"Layer {i}", fontsize=16)
        plt.xlim(-0.3, 1.05)
        plt.ylim(0, max_density)  # Set a consistent y-axis limit

        plt.savefig(f"./histograms/histogram_layer_{i}.pdf")
        plt.close()


In [10]:
plot_histograms_save()

- - - - - - - - - - - - - - - - - - - - 
Random Input:
['[CLS] in march 1940 , hancock \'s directorate of works and buildings was transferred from the office of the chief of the air staff to the newly formed organisation and equipment branch under air marshal richard williams . considered a key part of the air force \'s expansion during the early part of world war ii , " works and bricks " quickly absorbed all staff with civil engineering and building experience in the raaf active reserve . as director , hancock was responsible for surveying and developing a military aerodrome at evans head , near the queensland and new south wales border , which became home to no. 1 bombing and gunnery school ( no. 1<unk> ) . promoted to wing commander , he held command of no. 1<unk> , operating fairey battle single <unk> -<unk> engined bombers , from august 1940 until november 1941 . he was promoted to acting group captain in april 1941 . appointed an officer of the order of the british empire ( obe 

In [24]:
# al_model.config

In [12]:
## Check norm of output tokens
# The norm is not exactly the same because the LayerNorm
# that is applied at the end also has trainable diagonal matrix \gamma and vector \beta which are used as follows
# (on each token)
# \tilde x = (x - mean(x))/sqrt(var(x)) * \gamma + \beta (here token is a row vector)

# ei = get_random_input(wikitext, al_tkz)
# print(al_tkz.batch_decode(ei["input_ids"]))
# of = al_model(**ei, output_hidden_states=True)
# of["hidden_states"][24].var(2)

### Larger Model (48 layers)

In [13]:
alm2_config = AlbertConfig.from_pretrained(
    "albert-xlarge-v2", num_hidden_layers=48, num_attention_heads=1
)
almodel2 = AlbertModel.from_pretrained("albert-xlarge-v2", config=alm2_config)
print(almodel2.config.num_hidden_layers)

48


In [14]:
# %matplotlib inline
plot_histograms_save(al_model=almodel2)
# plot_histograms(al_model = almodel2)

- - - - - - - - - - - - - - - - - - - - 
Random Input:
['[CLS] " i \'m her slave " is a heroin anecdote with lyrics narrated by a subjugate lover . music critic greg kot cites " turn on the water " as an example of when " the twisted narrator is the victim " and " cast adrift " in dulli \'s lyrics . inspired by a paranoid breakup , " conjure me " is told from the perspective of an aggressive predator and obscure object of desire . on " kiss the floor " , the narrator recounts stealing a girl \'s virginity and avoiding her brothers . " this is my confession " has a theme of absolution . the lyrics depicts it as an empty sexual experience : " shove my head against the door , crawl inside and kiss the floor / waiting for the sun again , drink it , smoke it , stick it in . " " the temple " is a cover of the song of the same name from the 1970 rock opera jesus christ superstar . dulli became a fan of the rock opera as a child when his babysitter played it . " let me lie to you " has lyrics 

### Decomposing Internal Structure

This section appears to be a debugging/exploration section where the author is:

- Understanding the internal structure of the ALBERT model
- Testing how the attention mechanism processes inputs
- Examining the shapes and transformations of tensors through the model
- Looking at the actual weight matrices used in the attention mechanism

This kind of exploration is common when trying to understand or debug transformer models, as it helps to verify that the internal components are working as expected and to understand how the model processes information at a detailed level.

In [15]:
# Decomposing AlbertModel
al_transfo = al_model.encoder
al_layer = al_transfo.albert_layer_groups[0].albert_layers[0]
al_attention = al_layer.attention
(
    al_attention.pruned_heads,
    al_attention.num_attention_heads,
    al_attention.all_head_size,
    al_attention.hidden_size,
)

(set(), 16, 2048, 2048)

In [19]:
ei = get_random_input(wikitext, al_tkz)
print(al_tkz.batch_decode(ei["input_ids"]))
of = al_model(**ei, output_hidden_states=True);

['[CLS] the subject matter of the carvings of the central brackets as misericords is very varied , but with many common themes recurring in different churches . typically , the themes are less unified and less directly related to the bible and christian theology than are the themes of small sculptures seen elsewhere within churches , such as those on bosses . this is much the case at wells , where none of the misericord carvings is directly based on a bible story . the subjects , chosen either by the woodcarver , or perhaps by the individual paying for the stall , have no over <unk>-<unk> riding theme . the sole unifying element is the roundels on each side of the pictorial subject , which are all elaborately carved foliage , in most cases formal and stylised in the later decorated manner , but with several examples of naturalistic foliage , including roses and<unk> . many of the subjects carry traditional interpretations . the image of the " pelican in her piety " ( believed to feed h

In [20]:
of["hidden_states"][2].shape

torch.Size([1, 321, 2048])

In [21]:
test = torch.arange(2048).reshape(1, 1, -1)
print(test.shape)
al_attention.transpose_for_scores(test).shape;

torch.Size([1, 1, 2048])


In [22]:
al_attention.transpose_for_scores(test)[0, 1, 0, :]

tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])

In [23]:
al_attention.value.weight

Parameter containing:
tensor([[ 0.0182,  0.0070, -0.0499,  ...,  0.0107, -0.0472,  0.0081],
        [-0.0077, -0.0219,  0.0191,  ...,  0.0136, -0.0058,  0.0144],
        [-0.0141, -0.0141,  0.0194,  ...,  0.0275,  0.0075, -0.0503],
        ...,
        [ 0.0342,  0.0207,  0.0073,  ...,  0.0331, -0.0137,  0.0200],
        [ 0.0521,  0.0029,  0.0261,  ...,  0.0122,  0.0592, -0.0423],
        [ 0.0113,  0.0855, -0.0034,  ..., -0.0119,  0.0079, -0.0166]],
       requires_grad=True)