### Start by importing the packages, and setting the GPU index and seed

In [1]:
import argparse
import sys

import seaborn as sns
import torch.nn as nn
import torch.optim as optim

from utils.experiment import *
from utils.function_vector import *
from utils.learnable_task_vector import *
from torch.utils.data import TensorDataset, DataLoader, random_split
from utils.plot import *

%load_ext autoreload
%autoreload 2

sys.path.append("..")  # Adds higher directory to python modules path

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

sns.set(context='paper', style='ticks', palette='colorblind')

SEED = 17
GPU_IDX = 2


def calculate_memory_allocation(tensor):
    num_elements = tensor.numel()
    element_size = tensor.element_size()
    memory_allocation_bytes = num_elements * element_size
    memory_allocation_mb = memory_allocation_bytes / (1024 ** 2)
    return memory_allocation_mb

### Configure the experiment
##### Here, ``seq_len`` is the maximum length of the prompts that the model was trained on. ``long_seq_len`` is the prompt lengths during the inference stage. Therefore, we perform ours tests with ``long_seq_len > seq_len``. 

In [2]:
# Set the variables
task_name = "linear_regression"

batch_size = 256
seq_len = 71
long_seq_len = 96

act_fn = None

#### Configure the directories for loading/saving the models/figures

In [3]:
#  Set seeds
set_seed(SEED)
device = torch.device(f"cuda:{GPU_IDX}" if torch.cuda.is_available() else "cpu")

# Construct the running and saving paths
figures_path, preds_path = prepare_save_dirs(os.path.join(task_name, "results"))
run_path = os.path.join(f"../models", task_name, "pretrained")

experiment_dir = f"./LTV_models/{task_name}/{act_fn}/seq_len_{seq_len}"

if not os.path.exists(experiment_dir):
    raise ValueError("Invalid sequence length")

#### Initialize the trained model and obtain the necessary variables

In [4]:
# Prepare the model, get the task and data samplers, and save the parameters required later
model, conf, task_sampler, covariate_sampler, params = prepare_model(run_path, batch_size, device)
n_dims, resid_dim, n_layers, n_heads, head_dim = params

#### Sample a batch of prompts of length ``seq_len`` to compute FV and LTV on 

In [5]:
task = task_sampler()
xs_sample = covariate_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end).to(device)
ys_sample = task.evaluate(xs_sample).to(device)

#### Compute _indirect effects_ (required for FV, out-of-scope for us), and use it to compute FV

In [6]:
# Only for visualization
n_top_heads_visual = 10

# Compute the indirect
indirect_effect_mat = get_top_heads(model, xs_sample, ys_sample, n_layers, n_heads)
indirect_effect_mat_np = indirect_effect_mat.cpu().data.numpy()

#### Initialize the trained LTV layer and compute LTV

In [6]:
params_path = os.path.join(experiment_dir, f"ltv_layer_{seq_len}.pth")

ltv_layer = LearnableTaskVector(n_layers, n_heads).to(device)
ltv_layer.load_state_dict(torch.load(params_path))

model.eval()
ltv_layer.eval()

In [7]:
with torch.no_grad():
    attn_out = get_attn_outs(model, xs_sample, ys_sample, n_layers, n_heads, head_dim, resid_dim)
    LTV = ltv_layer(attn_out)

In [8]:
attn_out.shape

#### Now, sample a batch of longer prompts with length ``long_seq_len`` 

In [10]:
# Sample the data consisting of long prompts
xs_long = covariate_sampler.sample_xs(b_size=batch_size, n_points=long_seq_len).to(device)
ys_long = task.evaluate(xs_long).to(device)

### On-distribution

In [11]:
m = [0.1, 0.25, 0.5, 0.75, 0.9]
scale = 1.0
L = [6, 7, 8]

plt.rcParams['figure.figsize'] = (8, 5)

with torch.no_grad():
    pred = evaluate_model(model, model, xs_long, ys_long, L=1, FV=None, dummy=m, scale=0.0)
    loss = distance(pred, ys_long).cpu().data.numpy()

    pred_FV = evaluate_model(model, model, xs_long, ys_long, L=L, FV=FV, dummy=m, scale=scale)
    loss_FV = distance(pred_FV, ys_long).cpu().data.numpy()

    pred_LTV = evaluate_model_on_LTV(model, model, xs_long, ys_long, LTV=LTV, dummy=None, scale=1.0)
    loss_LTV = distance(pred_LTV , ys_long).cpu().data.numpy()
    
    losses = [loss, loss_FV, loss_LTV]
    legends = ["Transformer", r"Transformer + $v$", r"Transformer + $v_\theta$"]

lambda_label = r"$\lambda$"
title = r"$\mathbf{w}$, $\mathbf{x}$ $\sim$ $\mathcal{N}(0, 1)$"

plot_transformer(losses,
                 legends,
                 title,
                 x_label="# in-context examples",
                 y_label="mean-squared error",
                 baseline=None,
                 save_path=os.path.join(figures_path, f"result_1000"),
                 font_size=15,
                 y_ticks_max=20,
                 dpi=250,
                 show=True)

### Distributional shift - on prompts with the standard length ``seq_len``
#### $\mathcal{N}(0, 1) \rightarrow \mathcal{N}(-0.2, 1.25)$

In [1]:
mean = -0.75
std_dev = 1.25

w_b = torch.randn(batch_size, n_dims, 1)
w_b = w_b.to(device)
w_b_shifted = w_b * std_dev + mean

xs = covariate_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end).to(device)
ys = (xs @ w_b)[:, :, 0]
ys_shifted = (xs @ w_b_shifted)[:, :, 0]

In [13]:
m = [0.1, 0.25, 0.5, 0.75, 0.9]
scale = 1.0
L = [6, 7, 8]

plt.rcParams['figure.figsize'] = (8, 5)

with torch.no_grad():
    pred = evaluate_model(model, model, xs, ys_shifted, L=1, FV=None, dummy=m, scale=0.0)
    loss = distance(pred, ys).cpu().data.numpy()

    pred_FV = evaluate_model(model, model, xs, ys_shifted, L=L, FV=FV, dummy=m, scale=scale)
    loss_FV = distance(pred_FV, ys).cpu().data.numpy()

    pred_LTV = evaluate_model_on_LTV(model, model, xs, ys_shifted, LTV=LTV, dummy=None, scale=1.0)
    loss_LTV = distance(pred_LTV , ys).cpu().data.numpy()
    
    losses = [loss, loss_FV, loss_LTV]
    legends = ["Transformer", r"Transformer + $v$", r"Transformer + $v_\theta$"]

lambda_label = r"$\lambda$"
title = r"$\bf{dist}$ $\bf{shift:}$ $\mathcal{N}(0, 1)$ $\rightarrow$ $\mathcal{N}(-0.75, 1.25)$"
# title = r"$\bf{out-of-dist:}$ $\mathcal{N}(0, 1)$ $\rightarrow$ $U(-2, 2)$"

plot_transformer(losses,
                 legends,
                 title,
                 x_label="# in-context examples",
                 y_label="mean-squared error",
                 baseline=None,
                 save_path=os.path.join(figures_path, f"result_1000"),
                 font_size=15,
                 y_ticks_max=20,
                 dpi=250,
                 show=True)

### Out-of-distribution - on prompts with the standard length ``seq_len``
#### $\mathcal{N}(0, 1) \rightarrow U(-2, 2)$

In [28]:
w_b = torch.randn(batch_size, n_dims, 1)
w_b = w_b.to(device)
w_b_out = torch.rand(size=w_b.size()).to(device) * 4 - 2

xs = covariate_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end).to(device)
ys = (xs @ w_b)[:, :, 0]
# ys_out = (xs @ w_b_out)[:, :, 0]

In [25]:
n_FV_heads = 35

with torch.no_grad():
    top_heads = top_heads_locations(indirect_effect_mat_np, n_FV_heads)
    universal_mean_activations = get_universal_mean_act(model, xs, ys, n_layers, n_heads, head_dim)
    FV = compute_function_vector(model, universal_mean_activations, top_heads, resid_dim, head_dim)

In [26]:
with torch.no_grad():
    attn_out = get_attn_outs(model, xs, ys, n_layers, n_heads, head_dim, resid_dim)
    LTV = ltv_layer(attn_out)

In [29]:
m = [0.1, 0.25, 0.5, 0.75, 0.9]
scale = 1.0
L = [6, 7, 8]

plt.rcParams['figure.figsize'] = (8, 5)

with torch.no_grad():
    pred = evaluate_model(model, model, xs, ys_out, L=1, FV=None, dummy=m, scale=0.0)
    loss = distance(pred, ys).cpu().data.numpy()

    pred_FV = evaluate_model(model, model, xs, ys_out, L=L, FV=FV, dummy=m, scale=scale)
    loss_FV = distance(pred_FV, ys).cpu().data.numpy()

    pred_LTV = evaluate_model_on_LTV(model, model, xs, ys_out, LTV=LTV, dummy=None, scale=1.0)
    loss_LTV = distance(pred_LTV , ys).cpu().data.numpy()
    
    losses = [loss, loss_FV, loss_LTV]
    legends = ["Transformer", r"Transformer + $v$", r"Transformer + $v_\theta$"]

lambda_label = r"$\lambda$"
title = r"$\bf{out-of-dist:}$ $\mathcal{N}(0, 1)$ $\rightarrow$ $U(-2, 2)$"

plot_transformer(losses,
                 legends,
                 title,
                 x_label="# in-context examples",
                 y_label="mean-squared error",
                 baseline=None,
                 save_path=os.path.join(figures_path, f"result_1000"),
                 font_size=15,
                 y_ticks_max=20,
                 dpi=250,
                 show=True)

### [Optional] Not a direct comparison but compare the emphasize given to attention heads by FV and LTV

In [18]:
ltv_layer_weights_np = ltv_layer.weights.cpu().data.numpy().reshape(indirect_effect_mat_np.shape)

# Visualize the heatmap of the attention head significance found by FV and LTV on a (n_layers x n_heads) grid
create_heatmap(indirect_effect_mat_np, row_labels=[str(i) for i in range(n_heads)],
               col_labels=[str(i) for i in range(n_layers)], color_label='AIE (CIE)',
               num_top_heads=n_top_heads_visual, row_label_overall='Head Index', col_label_overall='Layer', show=True)

create_heatmap(ltv_layer_weights_np, row_labels=[str(i) for i in range(n_heads)],
               col_labels=[str(i) for i in range(n_layers)], color_label='LTV Layer weights',
               num_top_heads=n_top_heads_visual, row_label_overall='Head Index', col_label_overall='Layer', show=True)