In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colorbar as colorbar
import matplotlib.colors as mcolors
import matplotlib.patches as patches
from itertools import product
import seaborn as sns
import torch.utils.benchmark as benchmark
from icl.figures.task_vec_viz import *
from tqdm.notebook import tqdm, trange
import pandas as pd
import os
from scipy.optimize import curve_fit

import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.graph_objects as go

import copy

from icl.utils.train import BaseTrainer, train_model_with_plot
from icl.latent_markov import * # import the latent markov task and its configuration
from icl.models import Transformer
from icl.utils import visualize_attention
from icl.utils.train_utils import get_attn_base, compute_cross_entropy, ih_score, get_attn_at_layer_base
from icl.figures.head_view import *
import icl.utils.task_vec as task_vec
from icl.utils.latent_task_vec import *
import icl.utils.notebook_utils as nu

torch.set_printoptions(precision=3, sci_mode=False)

%load_ext autoreload
%autoreload 2

To modify experiment configurations, using the `../src/icl/latent_markov/latent_config.py` file. 

In [6]:
config = get_config_base()
config.task.n_minor_tasks = 2048
config.training.warmup_steps = 30_000
config.training.num_epochs = 60_000
model = Transformer(config)
model = model.to(config.device)
train_results = train_model_with_plot(model, config, show=False, verbose=False)

Experiment directory:  ..\results\latent\train_2124d32afa2b5fd088015a249920000c
train_2124d32afa2b5fd088015a249920000c already completed


In [2]:
model, sampler, config = nu.load_everything("latent", "train_2124d32afa2b5fd088015a249920000c")

sampler_clone0 = copy.deepcopy(sampler)
sampler_clone0.n_minor_tasks = 32
k_ood = 16
k_minor = sampler_clone0.n_minor_tasks - k_ood
sampler_clone0.minor_trans_mat[:k_ood] = sampler_clone0._sample_banded_trans_mats(k_ood)

In [4]:
nu.all_kl_plot(file_path="train_2124d32afa2b5fd088015a249920000c")

Computing Bayesian baseline for 256 positions...


In [6]:
hiddens = compute_hiddens(config, model, sampler, layer_index=3, return_final = False)
plot_task_vector_modes(hiddens[:3])

In [4]:
sampler_clone0.to("cpu")
samples0 = generate_tokenwise_samples_mp(config=config, sampler=sampler_clone0, n_tasks=sampler_clone0.n_minor_tasks+3, num_workers=20)

Generating samples and saving to ..\results\latent\train_2124d32afa2b5fd088015a249920000c\samples.pt
Number of Workers:  20


generating samples (mp):   0%|          | 0/8960 [00:00<?, ?it/s]

Saved samples to ..\results\latent\train_2124d32afa2b5fd088015a249920000c\samples.pt


In [None]:
samples0 = samples0.to("cpu")
hiddens_voc0 = compute_hiddens_from_existing_samples_fast(config, model, sampler, samples0, 4)

In [6]:
from icl.linear.linear_utils import estimate_lambda_with_r2

def plot_hidden_proj(hiddens_voc):
    K, _, T, B, _ = hiddens_voc.shape
    hiddens = hiddens_voc.permute(0, 2, 3, 4, 1).reshape(K, T, B, -1) # (K, T, B, D*V)
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D*V)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D*V)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(task_vecs_over_all_time, final_task_vecs, r2_scores, lambdas_voc_mlp0, n_minors=k_minor)
    fig.show()

plot_hidden_proj(hiddens_voc0)

In [15]:
import torch
import torch.nn.functional as F

def bigram_prefix_counts(x: torch.Tensor, V: int, *, dtype=torch.int64) -> torch.Tensor:
    assert x.dim() == 2, "x must be (B, T)"
    B, T = x.shape
    device = x.device

    if T == 1:
        return torch.zeros(B, T, V, V, dtype=dtype, device=device)

    prev = x[:, :-1]               # (B, T-1)
    nxt  = x[:,  1:]               # (B, T-1)

    prev_oh = F.one_hot(prev, num_classes=V).to(torch.float32)  # (B, T-1, V)
    nxt_oh  = F.one_hot(nxt,  num_classes=V).to(torch.float32)  # (B, T-1, V)

    edges = torch.einsum('btu,btv->btuv', prev_oh, nxt_oh)      # (B, T-1, V, V)

    zero_pad = torch.zeros(B, 1, V, V, dtype=edges.dtype, device=device)
    edges_padded = torch.cat([zero_pad, edges], dim=1)          # (B, T, V, V)

    out = torch.cumsum(edges_padded, dim=1)                     
    return out.to(dtype)

In [16]:
def compute_hiddens_data(
    config,
    model: torch.nn.Module,
    x,
    layer_index: int = 1,
) -> torch.Tensor:
    model_device = next(model.parameters()).device
    x = x.to(model_device)

    seq_len = x.shape[1]
    pos = torch.arange(1, seq_len, 2).to(model_device)
    cache = {}
        
    def hook_fn(module, inp, out):
        cache["vec"] = out[:, pos, :].detach()

    handle = model.layers[layer_index].attn_block.register_forward_hook(hook_fn)
    model.eval()
    with torch.no_grad():
        _ = model(x)
    handle.remove()

    return cache["vec"]

In [17]:
import numpy as np
from sklearn.linear_model import LinearRegression

B = 1024
T0 = 10
L = 8
x, *_ = sampler.generate(num_samples=B, mode="ood")
bigram_count = bigram_prefix_counts(x[:, ::2], sampler.num_states)
bigram_count = bigram_count.reshape(B, sampler.seq_len, -1)
bigram_count = bigram_count[:, T0:-1]
Tdim = bigram_count.shape[1]

r2 = np.zeros((L, Tdim))

for layer in range(L):
    hiddens = compute_hiddens_data(
        config,
        model,
        x,
        layer_index=layer,
    )
    hiddens = hiddens[:, T0:]
    
    for t in range(Tdim):
        X = hiddens[:, t].cpu().numpy()
        y = bigram_count[:, t].cpu().numpy()
        # y = np.log(1+y)
        lr_model = LinearRegression()
        lr_model.fit(X, y)
        r2[layer, t] = lr_model.score(X, y)

In [18]:
import plotly.graph_objects as go

fig = go.Figure()

for i in range(r2.shape[0]):
    fig.add_trace(go.Scatter(
        x=T0+np.arange(Tdim),
        y=r2[i],
        mode='lines',
        name=f'Layer {i}'
    ))

fig.update_layout(
    title='Layer Values Over Time',
    xaxis_title='Step t',
    yaxis_title='Value',
    template='plotly_white'
)
fig.show()