# Explore residuals from encoder layers' residual connections

__Objective:__ take a trained `TransformerClassifier` model and, for an arbitrary transformer encoder layer in its encoder, get the residuals that pass through the residual connections and the output of the attention block, and compare the two.

In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('../modules/')

from utilities import read_data
from pytorch_utilities import load_checkpoint
from models import TransformerClassifier, get_encoder_layer_residuals

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sns.set_theme()

%load_ext autoreload
%autoreload 2

In [None]:
MODEL_PATH = '../models/inverse_pretraining/classification_model_epoch_200.pt'
DATA_PATH = '../data/inverse_pretraining/labeled_data_fixed_4_4_1.0_0.00000.npy'
SEED = 0
N_VAL_SAMPLES = 1000

Load data.

In [None]:
q, k, sigma, epsilon, roots, leaves, rho = read_data(DATA_PATH, SEED)

val_leaves = torch.from_numpy(leaves[-N_VAL_SAMPLES:]).to(device=device).to(dtype=torch.int64)

Load model.

In [None]:
model = TransformerClassifier(
    seq_len=leaves.shape[-1],
    embedding_size=128,
    n_tranformer_layers=4,
    n_heads=1,
    vocab_size=q,
    encoder_dim_feedforward=2048,
    positional_encoding=True,
    n_special_tokens=0,
    embedding_agg='flatten',
    decoder_hidden_sizes=[64],
    decoder_activation='relu',
    decoder_output_activation='identity',
)

torch_checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(torch_checkpoint['model_state_dict'])

model = model.to(device=device)

model.eval();

In [None]:
residuals, attention_output = get_encoder_layer_residuals(model=model, layer_number=0, leaves=val_leaves)

In [None]:
import pandas as pd
import numpy as np

In [None]:
fig = plt.figure(figsize=(14, 6))

norm_ratios_data = []

for i in range(len(model.transformer_encoder.layers)):
    residuals, attention_output = get_encoder_layer_residuals(model=model, layer_number=i, leaves=val_leaves)

    norm_ratios = (attention_output.norm(dim=-1) / residuals.norm(dim=-1)).cpu().numpy().ravel()

    norm_ratios_statistics = {
        'layer_number': i,
        'mean': norm_ratios.mean(),
        'std': norm_ratios.std(),
        'median': np.median(norm_ratios),
        'percentile_5': np.percentile(norm_ratios, 5.),
        'percentile_95': np.percentile(norm_ratios, 95.)
    }

    norm_ratios_data.append(norm_ratios_statistics)
    
    sns.histplot(
        norm_ratios,
        stat='density',
        alpha=0.6,
        label=f'Encoder layer {i}',
        bins=10
    )

plt.legend()
plt.title('Distribution of norm(attention output)/norm(residual) ratios', fontsize=14)

norm_ratios_data = pd.DataFrame(norm_ratios_data)

norm_ratios_data

In [None]:
# Old way of computing the relative importance of residuals vs
# attention outputs - the above one is probably better!
# component_ratios = residuals / attention_output

# fig = plt.figure(figsize=(14, 6))

# component_ratios_data = []

# for i in range(len(model.transformer_encoder.layers)):
#     residuals, attention_output = get_encoder_layer_residuals(model=model, layer_number=i, leaves=val_leaves)

#     component_ratios = (residuals / attention_output).mean(dim=-1).cpu().numpy().ravel()

#     component_ratios_statistics = {
#         'layer_number': i,
#         'mean': component_ratios.mean(),
#         'std': component_ratios.std(),
#         'median': np.median(component_ratios),
#         'percentile_5': np.percentile(component_ratios, 5.),
#         'percentile_95': np.percentile(component_ratios, 95.)
#     }

#     component_ratios_data.append(component_ratios_statistics)

#     component_ratios_filtered = component_ratios[
#         (component_ratios > component_ratios_statistics['percentile_5'])
#         & (component_ratios < component_ratios_statistics['percentile_95'])
#     ]
    
#     sns.histplot(
#         component_ratios_filtered[:1000],
#         stat='density',
#         alpha=0.6,
#         label=f'Encoder layer {i}'
#     )

#     plt.legend()
#     plt.title('Distribution of average [residual]/[attention output] ratios\n(averaged over components in the hidden dimension)', fontsize=14)

# component_ratios_data = pd.DataFrame(component_ratios_data)

# component_ratios_data