# What Does it Mean to be a Transformer? - Insights from a Theoretical Hessian Analysis

## Part 1: Setup

In [None]:
# Import libraries
from pathlib import Path
import pickle
import random
from typing import Callable, Sequence

from curvlinops import GGNLinearOperator, HessianLinearOperator, HutchinsonSquaredFrobeniusNormEstimator
import cv2 as cv
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning import seed_everything
import scipy
from sklearn import linear_model
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm.auto as tqdm

from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
# Define constants
V_COLOR ='#7570b3'
Q_COLOR ='#1b9e77'

SEED=1234
N_DIGITS = 5
D_VOCAB = 12
PLUS_INDEX = 10
EQUALS_INDEX = 11

In [None]:
# Configure plots
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['text.latex.preamble']=r"\usepackage{amsmath}"

mpl.rcParams['legend.handlelength'] = 1
mpl.rcParams['legend.markerscale'] = 0
mpl.rcParams['legend.fontsize'] = 10

## Part 2: Data Generator and Loss.
This section defines the loss function and the data generator. They are based on the setup from "Understanding Addition in Transformers" by Quirke and Barez.

In [None]:
# Loss functions

# Calculate the per-token probability by comparing a batch of prediction "logits" to answer "tokens"
def logits_to_tokens_loss(logits: np.array, tokens: np.array):

  # Adding 2 five-digit numbers gives a six-digit answer
  n_answer_digits = N_DIGITS+1

  # The addition answer digit token probabilities
  # The "+1" below is needed because each answer digit calculations occurs one token before the that answer digit's token is revealed.
  ans_logits = logits[:, -(n_answer_digits+1):-1]

  # Convert raw score (logits) vector into a probability distribution.
  # Emphasize the largest scores and suppress the smaller ones, to make them more distinguishable.
  ans_probs = F.log_softmax(ans_logits.to(torch.float64), dim=-1)

  max_indices = torch.argmax(ans_probs, dim=-1)

  # The addition answer digit tokens
  ans_tokens = tokens[:, -(n_answer_digits):]

  # Extract values from the ans_probs tensor, based on indices from the ans_tokens tensor
  ans_loss = torch.gather(ans_probs, -1, ans_tokens[:, :, None])[..., 0]
  # ans_loss = torch.gather(ans_logits.to(torch.float64), -1, ans_tokens[:, :, None])[..., 0]

  return ans_loss, max_indices

# Calculate loss as negative of average per-token mean probability
def loss_fn(ans_loss: np.array):
  return -ans_loss.mean(0)

In [None]:
# Define "iterator" data generator function. Invoked using next().
# Batch entries are in format XXXXX+YYYYY=ZZZZZZ e.g. 55003+80002=135005
# Note that answer has one more digit than the question
# Returns characteristics of each batch entry to aid later graphing
def data_generator(batch_size: int, n_digits: int, seed: int):
    torch.manual_seed(seed)
    while True:
        #generate a batch of addition questions (answers calculated below)
        batch = torch.zeros((batch_size, 3*n_digits+3)).to(torch.int64)
        x = torch.randint(0, 10, (batch_size, n_digits))
        y = torch.randint(0, 10, (batch_size, n_digits))


        # The UseSum9 task is compound and rare (6%) and so the hardest to learn.
        # For 20% of batches, we increase the MakeSum9 cases by 20%
        # UseSum9 also relies on MakeCarry1 (50%) from previous column.
        # So UseSum9 frequency is increased by 20% * 20% * 50% = 2%
        if random.randint(1, 5) == 1:
          # Flatten x and y to 1D tensors
          x_flat = x.view(-1)
          y_flat = y.view(-1)

          num_elements_to_modify = int(0.20 * x.numel())
          indices_to_modify = torch.randperm(x_flat.numel())[:num_elements_to_modify]
          if random.randint(1, 2) == 1:
            x_flat[indices_to_modify] = 9 - y_flat[indices_to_modify]
          else:
            y_flat[indices_to_modify] = 9 - x_flat[indices_to_modify]

          # Reshape x and y back to its original shape
          x = x_flat.view(x.shape)
          y = y_flat.view(x.shape)


        batch[:, :n_digits] = x
        batch[:, n_digits] = PLUS_INDEX
        batch[:, 1+n_digits:1+n_digits*2] = y
        batch[:, 1+n_digits*2] = EQUALS_INDEX

        # These attributes are used for testing the model training progress
        base_adds = torch.zeros((batch_size,n_digits)).to(torch.int64)
        make_carry1s = torch.zeros((batch_size,n_digits)).to(torch.int64)
        sum9s = torch.zeros((batch_size,n_digits)).to(torch.int64)
        use_carry1s = torch.zeros((batch_size,n_digits)).to(torch.int64)
        use_sum9s = torch.zeros((batch_size,n_digits)).to(torch.int64)

        # generate the addition question answers & other info for testing
        for i in range(n_digits):
            # the column in the test attributes being updated
            test_col = n_digits-1-i

            base_add = batch[:, n_digits-1-i]+batch[:, 2*n_digits-i]
            base_adds[:, test_col] = base_add

            sum9 = (base_add == 9)
            sum9s[:, test_col] = sum9

            if i>0:
              use_carry1s[:, test_col] = make_carry1s[:, test_col+1]
            use_carry = use_carry1s[:, test_col]

            use_sum9s[:, test_col] = sum9 & use_carry;

            digit_sum = base_add + use_carry1s[:, test_col]

            make_carry = (digit_sum >= 10)
            make_carry1s[:, test_col] = make_carry

            batch[:, -1-i] = (digit_sum % 10)

        # Final (possible) carry to highest digit of the sum
        batch[:, -1-n_digits] = make_carry1s[:, 0]

        yield batch.cuda(), base_adds.cuda(), make_carry1s.cuda(), sum9s.cuda(), use_carry1s.cuda(), use_sum9s.cuda()

In [None]:
class LossModule(nn.Module):
    def __init__(self):
        self.reduction = 'mean'
        super(LossModule, self).__init__()

    def forward(self, outputs, targets):
        per_token_train_losses_raw, _ = logits_to_tokens_loss(outputs, targets)
        per_token_train_losses = loss_fn(per_token_train_losses_raw).mean()
        return per_token_train_losses

class MSELossModule(nn.Module):
    def __init__(self):
        self.reduction = 'mean'
        super(MSELossModule, self).__init__()

    def forward(self, outputs, targets):

        n_answer_digits = N_DIGITS+1

        # The addition answer digit token probabilities
        # The "+1" below is needed because each answer digit calculations occurs one token before the that answer digit's token is revealed.
        outputs = outputs[:, -(n_answer_digits+1):-1]
        targets = targets[:, -(n_answer_digits+1):-1]
        targets = torch.nn.functional.one_hot(targets, D_VOCAB)
        return torch.linalg.norm(outputs - targets) ** 2
        

## Part 3: Heterogeneity

In this section we demonstrate the heterogeneity of the Transformer Hessian and show how softmax influences it. The presented results correspond to figures 1 and 4.

### Visualize the Hessian

In [None]:
seed_everything(SEED)

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 1,
    d_model = 16,
    d_head = 16,
    d_mlp = 4 * 16,
    act_fn = 'relu',
    normalization_type = None,
    d_vocab=D_VOCAB,
    d_vocab_out=D_VOCAB,
    n_ctx=3 * N_DIGITS + 3,
    init_weights = True,
    device="cuda",
    seed = SEED,
)
model = HookedTransformer(cfg)

NUM_BATCH = 1 
BATCH_SIZE = 64

In [None]:
params_order = ['embed.W_E', 'pos_embed.W_pos', 'blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V',
                'blocks.0.attn.W_O', 'blocks.0.ln2.w', 'blocks.0.ln2.b', 'blocks.0.mlp.W_in',  'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out',
                'blocks.0.mlp.b_out', 'ln_final.w', 'ln_final.b', 'unembed.W_U', 'unembed.b_U']
attention_params_order = ['blocks.0.attn.W_Q', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.W_O']

ds = data_generator(BATCH_SIZE, N_DIGITS, SEED)
dataset_sample = []
for _ in range(NUM_BATCH):
  tokens = next(ds)[0]
  dataset_sample.append((tokens, tokens))

param_dict = {n:p for (n, p) in model.named_parameters()}
params = [param_dict[n] for n in params_order if n in param_dict]
params_attention = [param_dict[n] for n in attention_params_order if n in param_dict]
num_params = sum(p.numel() for p in params)
num_params_attention = sum(p.numel() for p in params_attention)
num_params_layer_all = [
    p.numel() for p in params
]
num_params_layer_attention = [
    p.numel() for p in params_attention
]


Hessian_linop = HessianLinearOperator(model, LossModule(), params, dataset_sample)
Hessian_linop_attention = HessianLinearOperator(model, LossModule(), params_attention, dataset_sample)

Hessian_mat = Hessian_linop @ np.eye(num_params).astype(Hessian_linop.dtype)
Hessian_mat_attention = Hessian_linop_attention @ np.eye(num_params_attention).astype(Hessian_linop_attention.dtype)

In [None]:
matrices = [Hessian_mat, Hessian_mat_attention]
titles = ["Transformer Block", "Self-Attention"]
num_params_layer_group = [num_params_layer_all, num_params_layer_attention]

rows, columns = 1, 2
img_width = 7

plt.rcParams['font.size'] = 20

def logabs(mat: np.array, epsilon: float = 1e-6) -> np.array:
    return np.log10(np.abs(mat) + epsilon)

def plot(
    transform: Callable[[np.ndarray], np.ndarray], transform_title: str = None
):
    """Visualize transformed curvature matrices using a shared domain.

    Args:
        transform: A transformation that will be applied to the matrices. Must
            accept a matrix and return a matrix of the same shape.
        transform_title: An optional string describing the transformation.
            Default: `None` (empty).

    Returns:
        Figure and axes of the created subplot.
    """
    min_value = min(transform(mat).min() for mat in matrices)
    max_value = max(transform(mat).max() for mat in matrices)

    fig, axes = plt.subplots(
        nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width)
    )

    for idx, (ax, mat, title, num_params_layer) in enumerate(zip(axes.flat, matrices, titles, num_params_layer_group)):
        ax.set_title(title, pad=20)
        kernel = np.ones((5,5),np.float32)/25
        dst = cv.filter2D(transform(mat),-1,kernel)
        dst = cv.GaussianBlur(transform(mat),(5,5),sigmaX=100, sigmaY=100)
        img = ax.imshow(dst, vmin=min_value, vmax=max_value)
        ax.axis('off')

        # layer structure
        for pos in np.cumsum(num_params_layer):
            if pos not in [0, num_params]:
                style = {"color": "w", "lw": 0.5, "ls": "--", "alpha":0.8}
                ax.axhline(y=pos - 1, xmin=0, xmax=num_params - 1, **style)
                ax.axvline(x=pos - 1, ymin=0, ymax=num_params - 1, **style)

        # colorbar
        last = idx == len(matrices) - 1
        if last:
            cb = fig.colorbar(
                img, ax=axes.ravel().tolist(), label=transform_title, shrink=0.8
            )
            cb.set_label(transform_title, labelpad=20)

    return fig, axes

In [None]:
plot(logabs, transform_title="Logarithmic Absolute Entries")
plt.savefig('./figures/heterogeneity.pdf')
plt.show()

### Compute block histogram

In [None]:
FONTSIZE = 18
plt.rcParams['font.size'] = FONTSIZE

def plot_hists(matrix1: np.array, matrix2: np.array, key: str, matrix3: np.array = None, matrix4: np.array = None):
    plt.close()
    def h(vals, color, alpha, ax, b):
        logbins = np.logspace(max(-8, np.log10(np.min(vals))),np.log10(np.max(vals)),b)
        ax.hist(vals, edgecolor='black', color=color, bins=logbins, alpha=alpha)

    if key=='classical':
        # Create subplots: 1 row, 2 columns
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5.5, 4.5), sharex=True)  # 1 row, 2 columns
    elif key=='linear':
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9.3, 3), sharey=True)  # 1 row, 2 columns

    
    # Plot histogram for V block
    if matrix3 is not None:
        h(matrix3.flatten(), color=V_COLOR, alpha=0.4, ax=ax1, b=50)
    
    h(matrix1.flatten(), color=V_COLOR, alpha=1, ax=ax1, b=50)
    ax1.set_xscale('log')
    
    # Plot histogram for Q block
    if key == 'linear':
        bins=50
        bins_opaque=30
    elif key == 'classical':
        bins=30
        bins_opaque=-1
    if matrix4 is not None:
        h(matrix4.flatten(), color=Q_COLOR, alpha=0.4, ax=ax2, b=bins_opaque)
    h(matrix2.flatten(), color=Q_COLOR, alpha=1.0, ax=ax2, b=bins)
    ax2.set_xscale('log')
    
    if key == 'classical':
        ax2.set_xlabel('Absolute Entries')
        fig.supylabel('Frequency', x=0.07, fontsize=FONTSIZE)
    elif key == 'linear':
        ax1.set_ylabel('Frequency')
        fig.supxlabel('Absolute Entries', y=0.17, fontsize=FONTSIZE)
    
    # Adjust layout to prevent overlapping
    plt.tight_layout()

    # Save and show the plot
    plt.savefig(f'./figures/histogram_{key}.pdf', bbox_inches='tight')
    plt.show()

#### Classical self-attention

In [None]:
seed_everything(SEED)

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 1,
    d_model = 16,
    d_head = 16,
    d_mlp = 4 * 16,
    act_fn = 'relu',
    normalization_type = None,
    d_vocab=D_VOCAB,
    d_vocab_out=D_VOCAB,
    n_ctx=3 * N_DIGITS + 3,
    init_weights = True,
    device="cuda",
    seed = SEED,
)
model = HookedTransformer(cfg)

NUM_BATCH = 8
BATCH_SIZE = 8
ds = data_generator(BATCH_SIZE, N_DIGITS, SEED)
dataset_sample = []

for _ in range(NUM_BATCH):
  tokens = next(ds)[0]
  dataset_sample.append((tokens, tokens))

param_dict = {n:p for (n, p) in model.named_parameters()}

In [None]:
# Query histogram
q_hessian_linop = HessianLinearOperator(model, LossModule(), [param_dict['blocks.0.attn.W_Q']], dataset_sample)
q_hessian_mat = q_hessian_linop @ np.eye(param_dict['blocks.0.attn.W_Q'].numel()).astype(q_hessian_linop.dtype)

# Value histogram
v_hessian_linop = HessianLinearOperator(model, LossModule(), [param_dict['blocks.0.attn.W_V']], dataset_sample)
v_hessian_mat = v_hessian_linop @ np.eye(param_dict['blocks.0.attn.W_V'].numel()).astype(v_hessian_linop.dtype)

In [None]:
plot_hists(np.abs(v_hessian_mat), np.abs(q_hessian_mat), 'classical')

#### Linear

In [None]:
seed_everything(SEED)

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 1,
    d_model = 16,
    d_head = 16,
    d_mlp = 4 * 16,
    act_fn = 'relu',
    normalization_type = None, #None, #"LNPre", # None, #'LN',
    # post_embedding_ln=True,
    d_vocab=D_VOCAB,
    d_vocab_out=D_VOCAB,
    n_ctx=3 * N_DIGITS + 3,
    init_weights = True,
    device="cuda",
    seed = SEED,
    # init_mode="muP",
    linear_attention=True
)
model = HookedTransformer(cfg)

NUM_BATCH = 8
BATCH_SIZE = 8
ds = data_generator(BATCH_SIZE, N_DIGITS, SEED)
dataset_sample = []
for _ in range(NUM_BATCH):
  tokens = next(ds)[0]
  dataset_sample.append((tokens, tokens))

param_dict = {n:p for (n, p) in model.named_parameters()}

In [None]:
# Query histogram
q_hessian_linop = HessianLinearOperator(model, LossModule(), [param_dict['blocks.0.attn.W_Q']], dataset_sample)
q_hessian_mat_lin = q_hessian_linop @ np.eye(param_dict['blocks.0.attn.W_Q'].numel()).astype(q_hessian_linop.dtype)

# Value histogram
v_hessian_linop = HessianLinearOperator(model, LossModule(), [param_dict['blocks.0.attn.W_V']], dataset_sample)
v_hessian_mat_lin = v_hessian_linop @ np.eye(param_dict['blocks.0.attn.W_V'].numel()).astype(v_hessian_linop.dtype)

In [None]:
plot_hists(np.abs(v_hessian_mat_lin), np.abs(q_hessian_mat_lin), 'linear', np.abs(v_hessian_mat), np.abs(q_hessian_mat))

## Part 4: Growth Rates in Self-Attention Hessian
In this section we show the growth rates associated with the blocks of the self-attention Hessian. The results correspond to the figures 3, 5, 6, and 7.

### Generate and Save the Results

In [None]:
def generate_and_save_growth_rates(
    scale: str,
    norm: str,
    batch_size: int,
    num_batch: int,
    seed: int,
    n_digits: int,
    n_ctx: int,
    d_model: int,
    d_mlp: int,
    repeats: int,
    hvp: int,
    residual_scaling: float,
    num_layers_iter: Sequence,
    linear_attention: bool,
    loss_name: str,
):
    if scale == 'log':
        sigma = 10 ** np.linspace(-2, 1.0, 20)
    elif scale == 'log_short':
        sigma = 10 ** np.linspace(-1, 0.0, 20)
    elif scale == 'linear':
        sigma = np.linspace(0.1, 10.0, 20)
    elif scale == 'linear_short':
        sigma = np.linspace(0.1, 1.0, 20)
    else:
        raise ValueError('Scale not known')
    if loss_name == 'mse':
        loss_module =  MSELossModule()
    elif loss_name == 'ce':
        loss_module =  LossModule()
    else:
        raise ValueError('Unknown loss name')
    
    ds = data_generator(batch_size, n_digits, seed)
    dataset_sample = []
    for _ in range(num_batch):
       tokens = next(ds)[0]
       dataset_sample.append((tokens, tokens))
    
    
    for n_layers in num_layers_iter:
    
        frob_q = {l:{i:[] for i in range(repeats)} for l in range(n_layers)}
        frob_v = {l:{i:[] for i in range(repeats)} for l in range(n_layers)}
        frob_q_outer = {l:{i:[] for i in range(repeats)} for l in range(n_layers)}
        frob_v_outer = {l:{i:[] for i in range(repeats)} for l in range(n_layers)}
        frob_q_func = {l:{i:[] for i in range(repeats)} for l in range(n_layers)}
        frob_v_func = {l:{i:[] for i in range(repeats)} for l in range(n_layers)}
    

        for r in tqdm.tqdm(range(repeats)):
            cfg = HookedTransformerConfig(
                    n_layers = n_layers,
                    n_heads = 1,
                    d_model = d_model,
                    d_head = d_model,
                    d_mlp = d_mlp,
                    act_fn = 'relu',
                    attn_only=False,
                    normalization_type = "LN" if norm == 'pre' else None,
                    post_embedding_ln=False,
                    d_vocab=D_VOCAB,
                    d_vocab_out=D_VOCAB,
                    n_ctx=n_ctx,
                    init_weights = True,
                    device="cuda",
                    seed = seed+r,
                    residual_scaling = residual_scaling,
                    linear_attention=linear_attention,
                )
            model = HookedTransformer(cfg)
            for s in sigma:
                nn.init.normal_(model.embed.W_E, mean=0, std=s)
                nn.init.normal_(model.pos_embed.W_pos, mean=0, std=s)
        
                
                param_dict = {n:p for (n, p) in model.named_parameters()}
        
                for l in range(n_layers):
                    Hessian_linop_k = HessianLinearOperator(
                        model,
                        loss_module,
                        [param_dict[f'blocks.{l}.attn.W_Q']],
                        dataset_sample,
                        check_deterministic=False,
                    )
                    GGN_linop_k = GGNLinearOperator(
                        model,
                        loss_module,
                        [param_dict[f'blocks.{l}.attn.W_Q']],
                        dataset_sample,
                        check_deterministic=False
                    )
                    est = HutchinsonSquaredFrobeniusNormEstimator(Hessian_linop_k)
                    frob_q[l][r].append(np.sqrt(np.mean([est.sample() for _ in range(hvp)])))
                    est = HutchinsonSquaredFrobeniusNormEstimator(GGN_linop_k)
                    frob_q_outer[l][r].append(np.sqrt(np.mean([est.sample() for _ in range(hvp)])))
                    est = HutchinsonSquaredFrobeniusNormEstimator(GGN_linop_k - Hessian_linop_k)
                    frob_q_func[l][r].append(np.sqrt(np.mean([est.sample() for _ in range(hvp)])))
                    
                    Hessian_linop_v = HessianLinearOperator(
                        model,
                        loss_module,
                        [param_dict[f'blocks.{l}.attn.W_V']],
                        dataset_sample,
                        check_deterministic=False
                    )
                    GGN_linop_v = GGNLinearOperator(
                        model,
                        loss_module,
                        [param_dict[f'blocks.{l}.attn.W_V']],
                        dataset_sample,
                        check_deterministic=False
                    )
                    est = HutchinsonSquaredFrobeniusNormEstimator(Hessian_linop_v)
                    frob_v[l][r].append(np.sqrt(np.mean([est.sample() for _ in range(hvp)])))
                    est = HutchinsonSquaredFrobeniusNormEstimator(GGN_linop_v)
                    frob_v_outer[l][r].append(np.sqrt(np.mean([est.sample() for _ in range(hvp)])))
                    est = HutchinsonSquaredFrobeniusNormEstimator(GGN_linop_v - Hessian_linop_v)
                    frob_v_func[l][r].append(np.sqrt(np.mean([est.sample() for _ in range(hvp)])))
        
        
        for d, name in zip(
            [frob_v_outer, frob_v_func, frob_v, frob_q_outer, frob_q_func, frob_q],
            ['frob_v_outer', 'frob_v_func', 'frob_v', 'frob_q_outer', 'frob_q_func', 'frob_q']
        ):
          f_name = f'numerical_results/norm={norm}_num_layers={n_layers}_scale={scale}_residual_scaling={residual_scaling}_linear_attention={linear_attention}_loss={loss_name}_{name}.pickle'
          with open(f_name, 'wb') as handle:
            pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
BATCH_SIZE = 64
NUM_BATCH = 1
N_CTX = 3 * N_DIGITS + 3
D_MODEL = 128 
D_MLP = 4 * D_MODEL 
REPEATS = 20 
HVP = 20 

In [None]:
# Without layer norm
generate_and_save_growth_rates(
    scale = 'log',
    norm = 'none',
    batch_size = BATCH_SIZE,
    num_batch = NUM_BATCH,
    seed = SEED,
    n_digits = N_DIGITS,
    n_ctx = N_CTX,
    d_model = D_MODEL,
    d_mlp = D_MLP,
    repeats = REPEATS,
    hvp = HVP,
    residual_scaling = 1.0,
    num_layers_iter = [1, 2, 3, 4, 5],
    linear_attention = False,
    loss_name = 'ce',
    )

In [None]:
# Without layer norm but with linear scale
generate_and_save_growth_rates(
    scale = 'linear',
    norm = 'none',
    batch_size = BATCH_SIZE,
    num_batch = NUM_BATCH,
    seed = SEED,
    n_digits = N_DIGITS,
    n_ctx = N_CTX,
    d_model = D_MODEL,
    d_mlp = D_MLP,
    repeats = REPEATS,
    hvp = HVP,
    residual_scaling = 1.0,
    num_layers_iter = [1],
    linear_attention = False,
    loss_name = 'ce',
)

generate_and_save_growth_rates(
    scale = 'linear_short',
    norm = 'none',
    batch_size = BATCH_SIZE,
    num_batch = NUM_BATCH,
    seed = SEED,
    n_digits = N_DIGITS,
    n_ctx = N_CTX,
    d_model = D_MODEL,
    d_mlp = D_MLP,
    repeats = REPEATS,
    hvp = HVP,
    residual_scaling = 1.0,
    num_layers_iter = [1],
    linear_attention = False,
    loss_name = 'ce',
)

In [None]:
# With pre-layer norm
generate_and_save_growth_rates(
    scale = 'log',
    norm = 'pre',
    batch_size = BATCH_SIZE,
    num_batch = NUM_BATCH,
    seed = SEED,
    n_digits = N_DIGITS,
    n_ctx = N_CTX,
    d_model = D_MODEL,
    d_mlp = D_MLP,
    repeats = REPEATS,
    hvp = HVP,
    residual_scalings = 1.0,
    num_layers_iter = [1],
    linear_attention = False,
    loss_name = 'ce',
)

In [None]:
# Without layer norm and softmax
generate_and_save_growth_rates(
    scale = 'log_short',
    norm = 'none',
    batch_size = BATCH_SIZE,
    num_batch = NUM_BATCH,
    seed = SEED,
    n_digits = N_DIGITS,
    n_ctx = N_CTX,
    d_model = D_MODEL,
    d_mlp = D_MLP,
    repeats = REPEATS,
    hvp = HVP,
    residual_scaling = 0.0,
    num_layers_iter = [3, 2, 1],
    linear_attention = True,
    loss_name = 'mse',
)

### Plot

In [None]:
plt.rcParams['font.size'] = 16

def get_exponent(title: str, param: str, norm: str|None = None, v: Sequence = [], linear_attention: bool = False, n_layers: int = 1):
    if linear_attention:
        if norm=='pre':
            regr = linear_model.LinearRegression()
            regr.fit(
                np.expand_dims(np.log(np.array(sigma)), axis=1),
                np.expand_dims(np.log(np.mean(np.array([v[r] for r in range(repeats)]), axis=0)), axis=1),
            )
            return np.round(regr.coef_, 1)[0][0]

        return 2 * 3 ** n_layers
    
    if norm=='pre':
        regr = linear_model.LinearRegression()
        regr.fit(
            np.expand_dims(np.log(np.array(sigma)), axis=1),
            np.expand_dims(np.log(np.mean(np.array([v[r] for r in range(repeats)]), axis=0)), axis=1),
        )
        return np.round(regr.coef_, 1)[0][0]
        
    if param == 'q':
        if '{o}' in title:
            return 6
        else:
            return 5

    elif param == 'v':
        if '{f}' in title:
            return 0
        else:
            return 2

def read_numerical_results(path_base: str, scale: str):
    if scale == 'log':
        sigma = 10 ** np.linspace(-2, 1.0, 20)
    elif scale == 'log_short':
        sigma = 10 ** np.linspace(-1, 0.0, 20)
    elif scale == 'linear':
        sigma = np.linspace(0.1, 10.0, 20)
    elif scale == 'linear_short':
        sigma = np.linspace(0.1, 1.0, 20)
        
    else:
        raise ValueError('Scale not known')
    
    with open(f'{path_base}frob_v_outer.pickle', 'rb') as handle:
        frob_v_outer = pickle.load(handle)
    with open(f'{path_base}frob_v_func.pickle', 'rb') as handle:
        frob_v_func = pickle.load(handle)
    with open(f'{path_base}frob_v.pickle', 'rb') as handle:
        frob_v = pickle.load(handle)
    with open(f'{path_base}frob_q_outer.pickle', 'rb') as handle:
        frob_q_outer = pickle.load(handle)
    with open(f'{path_base}frob_q_func.pickle', 'rb') as handle:
        frob_q_func = pickle.load(handle)
    with open(f'{path_base}frob_q.pickle', 'rb') as handle:
        frob_q = pickle.load(handle)

    return sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func


def plot_growth_rate(
    sigma: Sequence,
    frob_v: Sequence,
    frob_q: Sequence,
    frob_v_outer: Sequence,
    frob_q_outer: Sequence,
    frob_v_func: Sequence,
    frob_q_func: Sequence,
    file_name_base: str,
    repeats: int,
    norm: str,
    scale: str,
    linear_attention: bool = False,
    n_layers: int = 1
):
    
    fig, axss = plt.subplots(2, 3, figsize=(12, 4), sharex=True)
    
    frobs_v = [frob_v_outer, frob_v_func, frob_v]
    frobs_q = [frob_q_outer, frob_q_func, frob_q]
    titles = [r'$\mathbf{H}_{\text{o}}$', r'$\mathbf{H}_{\text{f}}$', '$\mathbf{H}$']
    
    
    for k, (axs, p, fs, c) in enumerate(zip(axss, ['v', 'q'], [frobs_v, frobs_q], [V_COLOR, Q_COLOR])):
        for j, (ax, v) in enumerate(zip(axs, fs)):
    
            if not linear_attention and p == 'v' and j == 1 and norm == 'none' and n_layers == 1:
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['left'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
    
                ax.xaxis.set_ticks([])  # Remove x-axis ticks
                ax.yaxis.set_ticks([])  # Remove y-axis ticks  
                ax.set_xticks([], minor=True)
                ax.xaxis.set_ticks_position('none') 
                continue

            mi = np.infty
            ma = 0
            title = titles[j]
            
            for l in range(n_layers):
                mean_to_plot = np.mean(np.array([v[l][r] for r in range(repeats)]), axis=0)
                mi = min(mi, np.nanmin(mean_to_plot))
                ma = max(ma, np.nanmax(mean_to_plot))

                ax.errorbar(
                    sigma,
                    mean_to_plot,
                    scipy.stats.sem(
                        np.array([v[l][r] for r in range(repeats)]),
                        axis=0),
                    color=c,
                    alpha=(l+2) / (n_layers+1),
                    label = l+1
                )
            if (p != 'v' or j != 1 or norm == 'pre') and not (linear_attention and j == 1 and n_layers == 1):
                exp = get_exponent(title, p, norm, v[l], linear_attention, n_layers)
                if linear_attention:
                        iter_range = range(-30, 30)
                        min_i = -30
                        max_i = 30
            
                        for i in iter_range:
                            ax.plot(sigma, [1000 ** i * s**exp for s in sigma], color='gray', linestyle='dashed', alpha=0.2)

                        ax.text(sigma[-1] - 0.2, max(10 ** -16, mi), f'$\sigma^{{{exp}}}$', fontsize=12, alpha=0.8)

                elif scale == 'log':
                    iter_range = range(-30, 30)
                    min_i = -30
                    max_i = 30
    
                    for i in iter_range:
                        ax.plot(
                            sigma,
                            [10 ** i * s**exp for s in sigma],
                            color='gray',
                            linestyle='dashed',
                            alpha=0.2)
                    if norm == 'pre':
                        ax.text(sigma[0] + 0.003, 0.5 * mi, f'$\sigma^{{{exp:.1f}}}$', fontsize=12, alpha=0.8)
                    else:
                        ax.text(0.5 * sigma[-1], mi, f'$\sigma^{{{exp}}}$', fontsize=12, alpha=0.8)
                elif scale in ['linear', 'linear_short']:
                    iter_range = range(-30, 5)
                    min_i = -20
                    max_i = 30
    
                    for i in iter_range:
                        ax.plot(sigma, [2 ** i * s**exp for s in sigma], color='gray', linestyle='dashed', alpha=0.2)
                    if norm == 'pre':
                        ax.text(sigma[0] + 0.001, 0.5 * mi, f'$\sigma^{{{exp}}}$', fontsize=12, alpha=0.8)
                    else:
                        ax.text(sigma[-1]- 1, mi + 0.1 * (ma - mi), f'$\sigma^{{{exp}}}$', fontsize=12, alpha=0.8)
                    
            if scale in ['log', 'log_short']:
                ax.set_yscale('log')
                ax.set_xscale('log')
            mi = 0.1 * mi
            if scale in ['linear', 'linear_short']:
                ma = 1.1 * ma
            elif scale == 'log':
                ma = 10 * ma

            if linear_attention:
                ax.set_ylim((max(mi, 10**-17), min(ma, 10**19)))
            else:
                ax.set_ylim((max(mi, 10**-15), ma))

            
            if j == 0:
                handles, labels = ax.get_legend_handles_labels()
                handles = [h[0] for h in handles]
                if n_layers > 1:
                    ax.legend(handles, labels, loc='upper left', title='Layer:', title_fontsize=10)
            if k == 1:
                ax.set_xlabel(r'$\sigma$', labelpad=-5)
                if scale == 'log':
                    ax.set_xticks([0.01, 0.1, 1, 10])
                elif scale == 'linear_short':
                    ax.set_xticks([0.01, 0.5, 1.0])
    
            if k == 0 or (p == 'q' and j == 1 and norm == 'none' and n_layers == 1):
                ax.set_title(title)
            ax.axvline(x=1, color='black', linestyle='dotted', linewidth=1, alpha = 1.0)
            ax.set_xlim((sigma[0], sigma[-1]))

            if linear_attention:
                ax.set_xticks([0.1, 0.2, 0.5, 1.0])
                ax.set_xticklabels([0.1, 0.2, 0.5, 1.0])
                ax.xaxis.set_ticks([0.1, 0.2, 0.5, 1.0])
                ax.xaxis.set_ticks_position('none') 
                ax.tick_params(
                    axis='x',          # changes apply to the x-axis
                    which='minor',      # both major and minor ticks are affected
                    bottom=False,      # ticks along the bottom edge are off
                    top=False,         # ticks along the top edge are off
                    labelbottom=False)
            elif scale == 'linear_short':
                ax.set_xticks([0.01, 0.5, 1.0])
                ax.xaxis.set_ticks([0.01, 0.5, 1.0])
            elif scale == 'linear':
                ax.set_xticks([0, 1, 4, 7, 10])
                ax.xaxis.set_ticks([0, 1, 4, 7, 10])
        
    
    fig.set_tight_layout(False)
    fig.supylabel(r'Block $\|\cdot\|_\text{F}$', x=0.04)
    plt.subplots_adjust(wspace=0.25)
    plt.savefig(f"figures/{file_name_base}_all.pdf", bbox_inches='tight')
    plt.show()
    
    plt.close()

In [None]:
# multilayer
NORM = 'none'
SCALE = 'log'
RESIDUAL_SCALING = 1.0
LINEAR_ATTENTION = False
LOSS_NAME = 'ce'
REPEATS = 20

for n_layers in [1, 2, 3, 4, 5]:

    file_name_base = f'norm={NORM}_num_layers={n_layers}_scale={SCALE}_residual_scaling={RESIDUAL_SCALING}_linear_attention={LINEAR_ATTENTION}_loss={LOSS_NAME}_'
    path_base = f'numerical_results/{file_name_base}'
    
    sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func = read_numerical_results(path_base, SCALE)
    plot_growth_rate(sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func, file_name_base, REPEATS, NORM, SCALE, LINEAR_ATTENTION, n_layers)

In [None]:
# linear scale
NORM = 'none'
RESIDUAL_SCALING = 1.0
LINEAR_ATTENTION = False
LOSS_NAME = 'ce'
REPEATS = 20
N_LAYERS = 1

for scale in ['linear', 'linear_short']:

    file_name_base = f'norm={NORM}_num_layers={N_LAYERS}_scale={scale}_residual_scaling={RESIDUAL_SCALING}_linear_attention={LINEAR_ATTENTION}_loss={LOSS_NAME}_'
    path_base = f'numerical_results/{file_name_base}'
    
    
    sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func = read_numerical_results(path_base, scale)
    plot_growth_rate(sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func, file_name_base, REPEATS, NORM, scale, LINEAR_ATTENTION, N_LAYERS)

In [None]:
# single layer pre-norm
REPEATS = 20
NORM = 'pre'
SCALE = 'log'
RESIDUAL_SCALING = 1.0
N_LAYERS = 1
LINEAR_ATTENTION = False
LOSS_NAME = 'ce'

file_name_base = f'norm={NORM}_num_layers={N_LAYERS}_scale={SCALE}_residual_scaling={RESIDUAL_SCALING}_linear_attention={LINEAR_ATTENTION}_loss={LOSS_NAME}_'
path_base = f'numerical_results/{file_name_base}'
    
sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func = read_numerical_results(path_base, SCALE)
plot_growth_rate(sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func, file_name_base, REPEATS, NORM, SCALE, LINEAR_ATTENTION, N_LAYERS)

In [None]:
# multilayer linear
REPEATS = 20
NORM = 'none'
SCALE = 'log_short'
RESIDUAL_SCALING = 0.0
LINEAR_ATTENTION = True
LOSS_NAME = 'mse'

for n_layers in [1, 2, 3]:

    file_name_base = f'norm={NORM}_num_layers={n_layers}_scale={SCALE}_residual_scaling={RESIDUAL_SCALING}_linear_attention={LINEAR_ATTENTION}_loss={LOSS_NAME}_'
    path_base = f'numerical_results/{file_name_base}'
    
    sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func = read_numerical_results(path_base, SCALE)
    plot_growth_rate(sigma, frob_v, frob_q, frob_v_outer, frob_q_outer, frob_v_func, frob_q_func, file_name_base, REPEATS, NORM, SCALE, LINEAR_ATTENTION, n_layers)

## Part 5: Growth Rates in MLP Hessian
In this section we show the growth rates associated with the layers of the MLP Hessian. The results correspond to the figure 8.

In [None]:
class MLPNoActivations(nn.Module):
    def __init__(self, hidden_size, vocab_size, num_layers, device='cuda'):
        super(MLPNoActivations, self).__init__()
        self.vocab_size = vocab_size
        self.emb = nn.Embedding(vocab_size, hidden_size, device=device)
        self.layers = nn.Sequential(*[
            nn.Linear(
                hidden_size,
                hidden_size,
                device=device,
                bias=False
            ) for _ in range(num_layers-1)
        ] + [
            nn.Linear(
                hidden_size,
                vocab_size,
                device=device,
                bias=False
            )
        ])


    def forward(self, x):
        x = self.emb(x)
        BATCH_SIZE, seq_len, emb_dim = x.shape
        x = x.reshape((BATCH_SIZE * seq_len, emb_dim))
        x = self.layers(x)
        x = x.reshape((BATCH_SIZE, seq_len, self.vocab_size))
        return x


In [None]:
BATCH_SIZE=64
NUM_BATCH = 1
D_MODEL = 128 
REPEATS=20 
HVP=20 
num_layers_iter = [1, 2, 3, 4, 5]

In [None]:
loss_module =  MSELossModule()
sigma = 10 ** np.linspace(-2, 1.0, 20)



ds = data_generator(BATCH_SIZE, N_DIGITS, SEED)
dataset_sample = []
for _ in range(NUM_BATCH):
   tokens = next(ds)[0]
   dataset_sample.append((tokens, tokens))


for n_layers in num_layers_iter:

    frob_mlp = {l: {i:[] for i in range(REPEATS)} for l in range(n_layers)}
    frob_mlp_outer = {l:{i:[] for i in range(REPEATS)} for l in range(n_layers)}
    frob_mlp_func = {l:{i:[] for i in range(REPEATS)} for l in range(n_layers)}


    for r in tqdm.tqdm(range(REPEATS)):
        model = MLPNoActivations(
            num_layers = n_layers,
            hidden_size = D_MODEL,
            vocab_size=D_VOCAB,
        )
        for s in sigma:
            nn.init.normal_(model.emb.weight, mean=0, std=s)        
            
            param_dict = {n:p for (n, p) in model.named_parameters()}
    
            for l in range(n_layers):
                Hessian_linop_mlp = HessianLinearOperator(
                    model,
                    loss_module,
                    [param_dict[f'layers.{l}.weight']],
                    dataset_sample,
                    check_deterministic=False,
                )
                est = HutchinsonSquaredFrobeniusNormEstimator(Hessian_linop_mlp)
                frob_mlp[l][r].append(np.sqrt(np.mean([est.sample() for _ in range(HVP)])))

    
    for d, name in zip([frob_mlp], ['frob_mlp']):
      f_name = f'numerical_results/mlp_num_layers={n_layers}_{name}.pickle'
      with open(f_name, 'wb') as handle:
        pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
for n_layers in num_layers_iter:

    base_name = f'mlp_num_layers={n_layers}_{name}'
    f_name = f'numerical_results/{base_name}.pickle'

    sigma = 10 ** np.linspace(-2, 1.0, 20)
    
    with open(f_name, 'rb') as handle:
        frob_mlp = pickle.load(handle)    
    
    
    fig, axs = plt.subplots(1, 1, figsize=(4, 2), sharex=True)  # 1 row, 3 columns
    
    frobs_mlp = [frob_mlp]
    titles = ['$\mathbf{H}$']
    
    if not isinstance(axs, list):
        axs = [axs]
    
    c = 'blue'
    
    for k, (ax, v, t) in enumerate(zip(axs, frobs_mlp, titles)):
    
            mi = np.infty
            ma = 0
    
            
            for l in range(n_layers):
                mean_to_plot = np.mean(np.array([v[l][r] for r in range(repeats)]), axis=0)
                mi = min(mi, np.nanmin(mean_to_plot))
                ma = max(ma, np.nanmax(mean_to_plot))
                
                ax.errorbar(
                    sigma,
                    mean_to_plot,
                    scipy.stats.sem(
                        np.array([v[l][r] for r in range(repeats)]),
                        axis=0),
                    color=c,
                    alpha=(l+2) / (n_layers+1),
                    label = l+1
                )
            exp = 2
            iter_range = range(-30, 30)

            for i in iter_range:
                ax.plot(sigma, [10 ** i * s**exp for s in sigma], color='gray', linestyle='dashed', alpha=0.2)

            ax.text(sigma[-1] - 5, max(10 ** -16, mi), f'$\sigma^{{{exp}}}$', fontsize=12, alpha=0.8)
                    
            ax.set_yscale('log')
            ax.set_xscale('log')
            mi = 0.1 * mi
            ma = 10 * ma
            ax.set_ylim((max(mi, 10**-17), min(ma, 10**19)))
            
    
            handles, labels = ax.get_legend_handles_labels()
            handles = [h[0] for h in handles]
            ax.legend(handles, labels, loc='upper left', title='Layer:', title_fontsize=10)
            
            ax.set_xlabel(r'$\sigma$', labelpad=-5)
    
    
            ax.set_title(t)
            
            ax.axvline(x=1, color='black', linestyle='dotted', linewidth=1, alpha = 1.0)
            ax.set_xlim((sigma[0], sigma[-1]))
    
            ax.set_xticks([0.01, 0.1, 1.0, 10.0])
            ax.set_xticklabels([0.01, 0.1, 1.0, 10.0])
            ax.xaxis.set_ticks([0.01, 0.1, 1.0, 10.0])
            ax.xaxis.set_ticks_position('none') 
            ax.tick_params(
                axis='x',          # changes apply to the x-axis
                which='minor',      # both major and minor ticks are affected
                bottom=False,      # ticks along the bottom edge are off
                top=False,         # ticks along the top edge are off
                labelbottom=False)
        
    
    fig.set_tight_layout(False)
    fig.supylabel(r'Block $\|\cdot\|_\text{F}$', x=-0.05)
    plt.subplots_adjust(wspace=0.25)
    plt.savefig(f"figures/{base_name}.pdf", bbox_inches='tight')
    plt.show()
    
    plt.close()