In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colorbar as colorbar
import matplotlib.colors as mcolors
import matplotlib.patches as patches
from itertools import product
import seaborn as sns
import torch.utils.benchmark as benchmark
from icl.figures.task_vec_viz import *
from tqdm.notebook import tqdm, trange
import pandas as pd
import os
from scipy.optimize import curve_fit

import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.graph_objects as go

import copy

from icl.utils.train import BaseTrainer, train_model_with_plot
from icl.latent_markov import * # import the latent markov task and its configuration
from icl.models import Transformer
from icl.figures.attn_plots_beta import visualize_attention
from icl.utils.train_utils import get_attn_base, compute_cross_entropy, ih_score, get_attn_at_layer_base
from icl.figures.head_view import *
#import icl.utils.task_vec as task_vec
from icl.utils.latent_task_vec import *
from icl.utils.markov_conditional_sampler import *
#from icl.utils.fast_latent_task_vec import compute_hiddens_onepos_all_layers_fast
import icl.utils.notebook_utils as nu
from icl.utils.ultra_latent_task_vec import compute_hiddens_onepos_all_layers_ultra
# from icl.utils.fast_latent_task_vec_kvcache import compute_hiddens_onepos_all_layers_kvcached
torch.set_printoptions(precision=3, sci_mode=False)

%load_ext autoreload
%autoreload 2

To modify experiment configurations, using the `../src/icl/latent_markov/latent_config.py` file. 

In [9]:
config = get_config_base()
config.training.warmup_steps = 30_000
config.training.num_epochs = 60_000
k = 11
config.task.n_minor_tasks = 2**k
model = Transformer(config)
model = model.to(config.device)
train_results = train_model_with_plot(model, config, show=False, verbose=False)

Experiment directory:  ..\results\latent\train_29fe6cb6e0713a98e06c92d04275f800
train_29fe6cb6e0713a98e06c92d04275f800 already completed


In [9]:
config = get_config_base()
config.training.warmup_steps = 30_000
config.training.num_epochs = 60_000
k = 1
config.task.n_minor_tasks = 2**k
model = Transformer(config)
model = model.to(config.device)
train_results = train_model_with_plot(model, config, show=False, verbose=False)

Experiment directory:  ..\results\latent\train_c343cc12ac23ccf7935c1c06de4d943d
train_c343cc12ac23ccf7935c1c06de4d943d already completed


In [10]:
model, sampler, config = nu.load_everything("latent", "train_29fe6cb6e0713a98e06c92d04275f800")

n_tasks = 40
n_ood = 20

sampler_clone0 = copy.deepcopy(sampler)

# Original minor count in the clone (capacity before expansion)


k_ood = max(n_ood, n_tasks - sampler_clone0.n_minor_tasks)
sampler_clone0.n_minor_tasks = n_tasks
k_minor = n_tasks - k_ood

orig = sampler_clone0.minor_trans_mat
n_minor = orig.shape[0]

# Sample new OOD matrices (match device/dtype/shape)
ood = sampler_clone0._sample_banded_trans_mats(k_ood)
# Make sure ood is on same device/dtype as orig if needed
ood = ood.to(device=orig.device, dtype=orig.dtype)

# Create expanded matrix: (n_tasks, ...)
new_shape = (n_tasks, *orig.shape[1:])
new_minor = orig.new_empty(new_shape)  # same dtype/device as orig

# Fill: first k_ood are new; remaining are old shifted later
new_minor[:k_ood].copy_(ood)
new_minor[k_ood:].copy_(orig[:k_minor])

# Swap in
sampler_clone0.minor_trans_mat = new_minor

In [11]:
%%time

def get_all_samples(n_tasks, sampler_clone0, num_samples):

    all_samples = np.empty((n_tasks+3, sampler_clone0.seq_len-1, sampler_clone0.num_states, num_samples, 2*sampler_clone0.seq_len-1))

    P_major = sampler_clone0.major_trans_mat.cpu().numpy()
    for i in range(3):
        sample = sample_all_positions_all_vocab_array(P_major[i], sampler_clone0.seq_len-1, num_samples)
        all_samples[i,:,:,:,::2] = sample[:-1]
        all_samples[i,:,:,:,1::2] = sampler_clone0.num_states

    P_minors = sampler_clone0.minor_trans_mat[:n_tasks].cpu().numpy()
    for i in range(3, n_tasks+3):
        sample = sample_all_positions_all_vocab_array(P_minors[i-3], sampler_clone0.seq_len-1, num_samples)
        all_samples[i,:,:,:,::2] = sample[:-1]
        all_samples[i,:,:,:,1::2] = sampler_clone0.num_states

    all_samples = torch.from_numpy(all_samples)
    return all_samples

all_samples = get_all_samples(n_tasks, sampler_clone0, 64)


CPU times: total: 3.11 s
Wall time: 3.14 s


x_1 z x_2 z ... z x_T (128 z, T=129) x_1 z x_2 z x_3  

(X_0, .., X_T) | X_t = v

In [12]:
all_samples.shape, k_minor, k_ood

(torch.Size([43, 128, 7, 64, 257]), 20, 20)

In [13]:
nu.all_kl_plot(file_path="train_5cbc25dddd202326cab47ab42f67e888")

Computing Bayesian baseline for 128 positions...


In [14]:
hiddens = compute_hiddens(config, model, sampler, layer_index=4, return_final = False)
plot_task_vector_modes(hiddens[:3])

In [30]:
# To load checkpoints, use the following function. It takes the experiment config and step.
# It will output the closest checkpoint to the given step. 
model = nu.load_checkpoint(config, step=10000)

Auto-detected checkpoint directory: ..\results\latent\train_2124d32afa2b5fd088015a249920000c\checkpoints
Loading checkpoint: model_10000.pt (step 10000)
Checkpoint info: step=10000


In [15]:
hiddens_voc_fast = compute_hiddens_onepos_all_layers_ultra(config, model, sampler, all_samples)
hiddens_voc_fast = hiddens_voc_fast.permute(0, 1, 3, 2, 4, 5)
hiddens_voc_fast.shape

torch.Size([6, 43, 7, 128, 64, 128])

In [36]:
hiddens_voc_fast_2048 = compute_hiddens_onepos_all_layers_ultra(config, model, sampler, all_samples)
hiddens_voc_fast_2048 = hiddens_voc_fast_2048.permute(0, 1, 3, 2, 4, 5)
hiddens_voc_fast_2048.shape

torch.Size([6, 43, 7, 128, 64, 128])

In [16]:
L, K, V, T, B, D = hiddens_voc_fast.shape
plot_task_vocab_vector_modes(hiddens_voc_fast[3,:3])

In [18]:
from icl.linear.linear_utils import estimate_lambda_with_r2

# We concatenate all hidden representations for different vocabularies together and then do the projection.
hiddens_voc = hiddens_voc_fast.to(torch.float32)[4]
def plot_hidden_proj(hiddens_voc):
    K, _, T, B, _ = hiddens_voc.shape
    hiddens = hiddens_voc.permute(0, 2, 3, 4, 1).reshape(K, T, B, -1) # (K, T, B, D*V)
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D*V)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D*V)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(task_vecs_over_all_time, final_task_vecs, r2_scores, lambdas_voc_mlp0, n_minors=2)
    fig.show()

plot_hidden_proj(hiddens_voc)

In [40]:
# We concatenate all hidden representations for different vocabularies together and then do the projection.
hiddens_voc = hiddens_voc_fast_2048.to(torch.float32)[4]
def plot_hidden_proj(hiddens_voc):
    K, _, T, B, _ = hiddens_voc.shape
    hiddens = hiddens_voc.permute(0, 2, 3, 4, 1).reshape(K, T, B, -1) # (K, T, B, D*V)
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D*V)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D*V)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(task_vecs_over_all_time, final_task_vecs, r2_scores, lambdas_voc_mlp0, n_minors=k_minor)
    fig.show()

plot_hidden_proj(hiddens_voc)

In [19]:
# We concatenate all hidden representations for different vocabularies together and then do the projection.
hiddens_voc = hiddens_voc_fast.to(torch.float32)[3, :, 5]
def plot_hidden_proj_voc(hiddens_voc):
    hiddens = hiddens_voc
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(task_vecs_over_all_time, final_task_vecs, r2_scores, lambdas_voc_mlp0, n_minors=k_minor)
    fig.show()

plot_hidden_proj_voc(hiddens_voc)

In [20]:
# This will take a few minutes to run
r2_scores, fig = nu.plot_sufficient_stat(
    model=model,
    sampler=sampler,
    config=config,
    num_samples=1024,
    T0=10,
    max_layers=8,
    mode="ood"
)

In [None]:
def compute_ood_lambda_metrics_for_model(
    config,
    model,
    sampler,
    samples,
    layer_index: int = 4,
    n_anchor: int = 3,
):
    """
    Compute two metrics for a single checkpoint:
      1) summary_r2: mean OOD R² at the final time point
      2) lambda_dispersion: how close OOD lambdas are to each other at final time

    Args:
        config: experiment config
        model: loaded model at some checkpoint
        sampler: your sampler object (used for compute_hiddens_from_existing_samples_fast)
        samples: pre-generated samples (same samples for all checkpoints)
        layer_index: layer to probe
        n_anchor: how many of the first tasks are "anchors" (usually 3)

    Returns:
        summary_r2 (float), lambda_dispersion (float)
    """
    device = next(model.parameters()).device
    samples = samples.to(device)

    with torch.no_grad():
        # hiddens_voc: (K, ?, T, B, D)
        hiddens_voc = compute_hiddens_from_existing_samples_fast(
            config, model, sampler, samples, layer_index=layer_index
        )

        K, _, T, B, D = hiddens_voc.shape

        # (K, T, B, D)
        hiddens = hiddens_voc.permute(0, 2, 3, 4, 1).reshape(K, T, B, -1)

        # Global mean over anchor tasks and batch: (T, D*)
        global_mean = hiddens[:n_anchor].mean(dim=(0, 2))  # (T, D*)

        # Task vectors over all time: (K, T, D*)
        task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(0)

        # Final anchor vectors at last time point: (n_anchor, D*)
        final_task_vecs = task_vecs_over_all_time[:n_anchor, -1]

        # Estimate lambdas & R²
        lambdas, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(
            final_task_vecs, task_vecs_over_all_time
        )

        # Ensure tensors (estimate_lambda_with_r2 might return numpy)
        lambdas = torch.as_tensor(lambdas, device=device, dtype=torch.float32)
        r2_scores = torch.as_tensor(r2_scores, device=device, dtype=torch.float32)

        # OOD slice = everything after the anchor tasks
        ood_slice = slice(n_anchor, None)

        # Final-time R² for OOD tasks: (K_ood,)
        r2_ood_final = r2_scores[ood_slice, -1]

        # Metric 1: summary R² at last time
        summary_r2 = float(r2_ood_final.mean())

        # Final-time lambdas for OOD tasks: (K_ood, n_basis)
        lambdas_ood_final = lambdas[ood_slice, -1]  # (K_ood, n_basis)

        # Metric 2: how close lambdas are to each other
        # -> mean distance to their centroid (smaller = more clustered)
        center = lambdas_ood_final.mean(dim=0, keepdim=True)  # (1, n_basis)
        distances = (lambdas_ood_final - center).norm(dim=-1)  # (K_ood,)
        lambda_dispersion = float(distances.mean())

    return summary_r2, lambda_dispersion


In [None]:
from typing import Sequence, Dict

def process_all_checkpoints_lambda_metrics(
    exp_group: str,
    exp_name: str,
    steps: Sequence[int],
    samples0,
    layer_index: int = 4,
    device: str = "cuda",
) -> Dict[str, Dict[int, float]]:
    """
    Process all given checkpoints for a latent experiment and compute:
      - summary OOD R² at final time
      - OOD lambda dispersion at final time

    Args:
        exp_group: e.g. "latent"
        exp_name:  e.g. "train_5050..."
        steps: list of checkpoint steps to process (e.g. [1000, 2000, ...])
        samples0: pre-generated samples compatible with sampler
        layer_index: which layer to probe
        device: device string, e.g. "cuda" or "cpu"

    Returns:
        {
          "summary_r2": {step -> float},
          "lambda_dispersion": {step -> float},
        }
    """
    # Load sampler + config once
    _, sampler, config = nu.load_everything(exp_group, exp_name)

    summary_r2_dict: Dict[int, float] = {}
    lambda_dispersion_dict: Dict[int, float] = {}

    for step in steps:
        print(f"Processing step {step}...")
        model = nu.load_checkpoint(config, step=step)
        model = model.to(device)
        model.eval()

        summary_r2, lambda_dispersion = compute_ood_lambda_metrics_for_model(
            config=config,
            model=model,
            sampler=sampler,
            samples=samples0,
            layer_index=layer_index,
        )

        summary_r2_dict[step] = summary_r2
        lambda_dispersion_dict[step] = lambda_dispersion

    return {
        "summary_r2": summary_r2_dict,
        "lambda_dispersion": lambda_dispersion_dict,
    }


In [None]:
# 1) Load config & sampler and generate samples however you already do
_, sampler, config = nu.load_everything("latent", "train_5050083a5f64fb641e9ca0d7a34622d6")

# Suppose you already have samples0 built for this sampler:
# samples0 = ...
samples0 = samples0.to("cpu")  # can be moved to device later inside the function

# 2) Decide what checkpoint steps to process
# Example: every 1000 steps from 1000 to 40000
steps = list(range(1000, 40001, 500))

# 3) Run the processing
results = process_all_checkpoints_lambda_metrics(
    exp_group="latent",
    exp_name="train_5050083a5f64fb641e9ca0d7a34622d6",
    steps=steps,
    samples0=samples0,
    layer_index=4,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

summary_r2 = results["summary_r2"]
lambda_dispersion = results["lambda_dispersion"]