In [None]:
import torch

In [23]:
def compute_correlations(hidden_states):
    corrs = [];
    for hs in hidden_states:
        T = hs.squeeze(0).clone().detach().requires_grad_(False);
        T = torch.nn.functional.normalize(T, dim=1);
        T2 = torch.matmul(T, T.transpose(0,1));
        corrs += [T2.flatten().cpu(),];
    return corrs;

In [28]:
from datasets import load_dataset

wikitext = load_dataset("wikitext", 'wikitext-103-v1')

In [30]:
def get_random_input(dataset, tokenizer):
    l = len(dataset['train']);
    while True:
        it = torch.randint(l,(1,)).item();
        text = dataset['train'][it]['text'];
        ei = tokenizer(text, return_tensors='pt', truncation=True);
        if (ei['input_ids'].shape[1]>300):
            break;
    return ei

In [34]:
from transformers import AlbertTokenizer, AlbertModel
al_tkz = AlbertTokenizer.from_pretrained('albert-xlarge-v2')
al_model = AlbertModel.from_pretrained("albert-xlarge-v2")
import matplotlib.pyplot as plt

In [35]:
def plot_histograms(al_tkz = al_tkz, al_model = al_model):
    
    ei = get_random_input(wikitext, al_tkz); print(al_tkz.batch_decode(ei['input_ids']));
    of = al_model(**ei, output_hidden_states=True)
    correls = compute_correlations(of['hidden_states']);
    
    if al_model.config.num_hidden_layers < 25:
        fig, axes = plt.subplots(5,5);
    
    else:
        fig, axes = plt.subplots(7,7);
    axes = axes.flatten();

    for i in range(len(correls)):
        axes[i].hist(correls[i], bins=100, density=True, histtype='step');
        axes[i].set_title(f'Layer {i}');
        axes[i].set_xlim(-.3, 1)

In [36]:
def plot_histograms_new(al_tkz = al_tkz, al_model = al_model, step=1):
   
    ei = get_random_input(wikitext, al_tkz); print(al_tkz.batch_decode(ei['input_ids']));
    of = al_model(**ei, output_hidden_states=True)
    correls = compute_correlations(of['hidden_states']);
    
    nr_plots = (al_model.config.num_hidden_layers+1) // step;
    if al_model.config.num_hidden_layers < 25:
        cols = 5;
        fig, axes = plt.subplots(5,5);
    
    else:
        fig, axes = plt.subplots(7,7);
    
    axes = axes.flatten();
    for i in range(len(correls)):
        axes[i].hist(correls[i],bins=100, density=True, histtype='step');
        axes[i].set_title(f'Layer {i}');
        axes[i].set_xlim(-.3,1)

In [80]:
import matplotlib.pyplot as plt
import os
import numpy as np

def plot_histograms_save(al_tkz = al_tkz, al_model = al_model):
    ei = get_random_input(wikitext, al_tkz); 
    print(al_tkz.batch_decode(ei['input_ids']));
    of = al_model(**ei, output_hidden_states=True)
    correls = compute_correlations(of['hidden_states'])
    
    # Create a directory to save the plots
    os.makedirs('histograms', exist_ok=True)

    # Determine the global maximum density value
    max_density = 0
    for data in correls:
        counts, bin_edges = np.histogram(data, bins=100, density=True)
        max_density = max(max_density, max(counts))

    for i, data in enumerate(correls):
        IQR = np.percentile(data, 75) - np.percentile(data, 25)
        n = len(data)
        bin_width = 2 * IQR / n ** (1/3)
        bins = int((max(data) - min(data)) / bin_width)

        plt.figure()
        plt.hist(data, bins=bins, density=True, histtype='step', color='#3658bf', linewidth=1.5)
        plt.title(f'Layer {i}', fontsize=16)
        plt.xlim(-.3, 1.05)
        plt.ylim(0, max_density)  # Set a consistent y-axis limit
        
        
        plt.savefig(f'histogram_layer_{i}.pdf')
        plt.close()

In [77]:
plot_histograms_save()

['[CLS] more influential than the 707 in the 698\'s design was wind <unk> -<unk> tunnel testing performed by the royal aircraft establishment at farnborough, which indicated the need for a wing redesign to avoid the onset of compressibility drag which would have restricted the maximum speed. painted gloss white, the 698 prototype vx770 flew for the first time on 30 august 1952 piloted by roly falk flying solo. the prototype 698, then fitted with only the first <unk> -<unk> pilot\'s ejection seat and a conventional control wheel, was powered by four rolls <unk> -<unk> royce<unk> avon engines of 6 <unk>,<unk> 500 lbf ( 29 kn ) thrust ; there were no wing fuel tanks, temporary tankage was carried in the bomb bay. vx770 made an appearance at the 1952 society of british aircraft constructors\'( sbac ) farnborough air show the next month when falk demonstrated an " almost vertical bank ". after its farnborough appearance, the future name of the avro 698 was a subject of speculation ; avro ha

In [40]:
al_model.config

AlbertConfig {
  "_name_or_path": "albert-xlarge-v2",
  "architectures": [
    "AlbertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "down_scale_factor": 1,
  "embedding_size": 128,
  "eos_token_id": 3,
  "gap_size": 0,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 8192,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "albert",
  "net_structure_type": 0,
  "num_attention_heads": 16,
  "num_hidden_groups": 1,
  "num_hidden_layers": 24,
  "num_memory_blocks": 0,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.35.2",
  "type_vocab_size": 2,
  "vocab_size": 30000
}

In [41]:
### Check norm of output tokens
## The norm is not exactly the same because the LayerNorm 
## that is applied at the end also has trainable diagonal matrix \gamma and  vector \beta which are used as follows 
## (on each token)
## \tilde x = (x - mean(x))/sqrt(var(x)) * \gamma + \beta   (here token is a row vector)

ei = get_random_input(wikitext, al_tkz); print(al_tkz.batch_decode(ei['input_ids']));
of = al_model(**ei, output_hidden_states=True);
of['hidden_states'][24].var(2)

['[CLS] good girl gone bad received generally favorable reviews from music critics. at metacritic, which assigns a normalized rating out of 100 to reviews from mainstream critics, the album received an average score of 72 based on 16 reviews. uncut called it a " shiny, trans <unk> -<unk> atlantic blend of europop vim, r & b grit and caribbean bounce. " andy kellman of allmusic deemed it quintessential pop music and said each of its tracks was a potential hit. quentin b. huff of popmatters praised the album, describing it as " more raw, perhaps edgier and more risque " than rihanna\'s previous material. kelefa sanneh of the new york times wrote that the album " sounds as if it were scientifically engineered to deliver hits ". peter robinson of the observer commended her collaborators for " masking her own shortcomings " and commented that, " while rihanna lacks her peers\'charisma, she\'s a great vessel for exhilarating mainstream pop. " pitchfork media\'s tom breihan found the album va

tensor([[0.0592, 0.4956, 0.4641, 0.4537, 0.4165, 0.4996, 0.4574, 0.4767, 0.4346,
         0.4034, 0.4104, 0.4587, 0.4052, 0.3656, 0.4476, 0.5080, 0.3986, 0.4170,
         0.4497, 0.4734, 0.3707, 0.4909, 0.4570, 0.4843, 0.5048, 0.4008, 0.4060,
         0.4881, 0.4060, 0.4376, 0.4798, 0.5249, 0.4535, 0.4052, 0.2066, 0.3566,
         0.4137, 0.5008, 0.4496, 0.5264, 0.4705, 0.3368, 0.5207, 0.5106, 0.4099,
         0.5347, 0.4313, 0.3955, 0.2690, 0.4626, 0.4959, 0.4806, 0.3649, 0.3898,
         0.3426, 0.5175, 0.5143, 0.4293, 0.3444, 0.4788, 0.4385, 0.4711, 0.4582,
         0.4975, 0.4641, 0.4985, 0.3992, 0.3434, 0.4310, 0.4645, 0.4713, 0.4111,
         0.3608, 0.4134, 0.4767, 0.4510, 0.4775, 0.4210, 0.5187, 0.4750, 0.3716,
         0.3880, 0.4137, 0.4524, 0.4256, 0.4855, 0.4672, 0.4891, 0.4497, 0.4760,
         0.4717, 0.3782, 0.4810, 0.5045, 0.3662, 0.4107, 0.3648, 0.4224, 0.4506,
         0.3694, 0.3726, 0.4153, 0.3911, 0.3735, 0.5712, 0.4215, 0.3516, 0.1956,
         0.4956, 0.5215, 0.3

In [78]:
from transformers import AlbertConfig, AlbertModel

alm2_config = AlbertConfig.from_pretrained("albert-xlarge-v2", num_hidden_layers=48, num_attention_heads=1);
almodel2 = AlbertModel.from_pretrained("albert-xlarge-v2", config=alm2_config);
print(almodel2.config.num_hidden_layers)

48


In [81]:
#%matplotlib inline
plot_histograms_save(al_model=almodel2)
#plot_histograms(al_model = almodel2)

["[CLS] on september 15, obama attended a philadelphia fundraising dinner for specter, an unusually public declaration of support so early in the primary season, when the president has the option of remaining neutral until the final outcome becomes clearer. governor rendell said that obama and biden felt obligated to strongly support specter because they so strongly lobbied him to switch parties. philadelphia mayor michael nutter and radio personality michael<unk> also spoke on specter's behalf. senate majority leader reid took the unusual steps of scheduling no senate votes that day so both specter and pennsylvania senator bob casey, jr. could attend the fundraiser. that move drew criticism from republicans, as well as from sestak, who felt specter was skirting his senate responsibilities, yet hypocritically criticizing sestak at the same time for missing more than 100 votes in the u.s. house. the event was expected to raise about $ 2 <unk>.<unk> 5 million, which was to be split betwe

In [50]:
# Decomposing AlbertModel
al_transfo = al_model.encoder;
al_layer = al_transfo.albert_layer_groups[0].albert_layers[0];
al_attention = al_layer.attention; (al_attention.pruned_heads, al_attention.num_attention_heads, al_attention.all_head_size, al_attention.hidden_size)

(set(), 16, 2048, 2048)

In [51]:
ei = get_random_input(wikitext, al_tkz); print(al_tkz.batch_decode(ei['input_ids']));
of = al_model(**ei, output_hidden_states=True);

['[CLS] grunge is generally characterized by a sludgy guitar sound with a " thick midrange " and rolled <unk> -<unk> off treble tone and a high level of distortion and fuzz, typically created with small<unk> pedals, with some guitarists chaining several fuzz pedals together and plugging them into a tube amplifier. the use of pedals by grunge guitarists was a move away from the expensive, studio <unk> -<unk> grade<unk> effects units used in other rock genres. grunge guitarists played loud, with kurt cobain\'s early guitar sound coming from an unusual set <unk> -<unk> up of four 800 watt pa system power amplifiers. guitar feedback effects were used. grunge guitarists were influenced by the raw, primitive sound of punk, and they favored "... energy and lack of finesse over technique and precision " ; key guitar influences included the sex pistols, the dead boys, neil young ( rust never sleeps, side two ), the replacements, husker du, black flag and the melvins. grunge guitarists often dow

In [52]:
of['hidden_states'][2].shape

torch.Size([1, 413, 2048])

In [53]:
test = torch.arange(2048).reshape(1,1,-1); print(test.shape); al_attention.transpose_for_scores(test).shape;

torch.Size([1, 1, 2048])


In [54]:
al_attention.transpose_for_scores(test)[0,1,0,:]

tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])

In [55]:
al_attention.value.weight

Parameter containing:
tensor([[ 0.0182,  0.0070, -0.0499,  ...,  0.0107, -0.0472,  0.0081],
        [-0.0077, -0.0219,  0.0191,  ...,  0.0136, -0.0058,  0.0144],
        [-0.0141, -0.0141,  0.0194,  ...,  0.0275,  0.0075, -0.0503],
        ...,
        [ 0.0342,  0.0207,  0.0073,  ...,  0.0331, -0.0137,  0.0200],
        [ 0.0521,  0.0029,  0.0261,  ...,  0.0122,  0.0592, -0.0423],
        [ 0.0113,  0.0855, -0.0034,  ..., -0.0119,  0.0079, -0.0166]],
       requires_grad=True)