In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colorbar as colorbar
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import torch.utils.benchmark as benchmark
from icl.figures.task_vec_viz import *
from tqdm.notebook import tqdm, trange
import os

import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.graph_objects as go

import copy

from icl.utils.train import train_model_with_plot
from icl.latent_markov import * # import the latent markov task and its configuration
from icl.models import Transformer
from icl.figures.attn_plots_beta import visualize_attention
from icl.utils.train_utils import get_attn_base, compute_cross_entropy, ih_score, get_attn_at_layer_base
from icl.figures.head_view import *
from icl.utils.latent_task_vec import *
from icl.utils.markov_conditional_sampler import *
from icl.utils.simple_markov_sampler import get_all_samples_base_only
import icl.utils.notebook_utils as nu
from icl.utils.ultra_latent_task_vec import compute_hiddens_onepos_all_layers_ultra
from icl.utils.kv_latent_task_vec_beta import compute_hiddens_onepos_all_layers_kvcache_beta
from icl.utils.latent_ood_analysis import process_latent_ood_evolve_checkpoints
torch.set_printoptions(precision=3, sci_mode=False)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


To modify experiment configurations, using the `../src/icl/latent_markov/latent_config.py` file. 

In [4]:
config = get_config_base()
config.training.warmup_steps = 30_000
config.training.num_epochs = 60_000
k = 11
config.task.n_minor_tasks = 2**k
model = Transformer(config)
model = model.to(config.device)
train_results = train_model_with_plot(model, config, show=False, verbose=False)

Experiment directory:  ..\results\latent\train_29fe6cb6e0713a98e06c92d04275f800
train_29fe6cb6e0713a98e06c92d04275f800 already completed


In [9]:
config = get_config_base()
config.training.warmup_steps = 30_000
config.training.num_epochs = 60_000
k = 1
config.task.n_minor_tasks = 2**k
model = Transformer(config)
model = model.to(config.device)
train_results = train_model_with_plot(model, config, show=False, verbose=False)

Experiment directory:  ..\results\latent\train_c343cc12ac23ccf7935c1c06de4d943d
train_c343cc12ac23ccf7935c1c06de4d943d already completed


In [6]:
exp_name = "train_29fe6cb6e0713a98e06c92d04275f800"

def get_all_samples(exp_name, n_tasks=61, n_ood=40, B=96):
    model, sampler, config = nu.load_everything("latent", exp_name)
    sampler_clone0 = copy.deepcopy(sampler)

    # Original minor count in the clone (capacity before expansion)
    k_ood = max(n_ood, n_tasks - sampler_clone0.n_minor_tasks)
    sampler_clone0.n_minor_tasks = n_tasks
    k_minor = n_tasks - k_ood

    orig = sampler_clone0.minor_trans_mat
    n_minor = orig.shape[0]

    # Sample new OOD matrices (match device/dtype/shape)
    ood = sampler_clone0._sample_banded_trans_mats(k_ood)
    # Make sure ood is on same device/dtype as orig if needed
    ood = ood.to(device=orig.device, dtype=orig.dtype)

    # Create expanded matrix: (n_tasks, ...)
    new_shape = (n_tasks, *orig.shape[1:])
    new_minor = orig.new_empty(new_shape)  # same dtype/device as orig

    # Fill: first k_ood are new; remaining are old shifted later
    new_minor[:k_ood].copy_(ood)
    new_minor[k_ood:].copy_(orig[:k_minor])

    # Swap in
    sampler_clone0.minor_trans_mat = new_minor
    all_samples = get_all_samples_base_only(n_tasks, sampler_clone0, B)

    return all_samples, k_minor

In [5]:
hiddens = compute_hiddens_onepos_all_layers_kvcache_beta(
        config, model, all_samples, 
        k_step = 32,
        b_step = 32,
        t_step = 4
    )

In [6]:
hiddens_voc_fast_2048 = hiddens.permute(0, 1, 3, 2, 4, 5)
hiddens_voc_fast_2048.shape

torch.Size([6, 64, 7, 128, 96, 128])

In [8]:
nu.all_kl_plot(file_path="train_29fe6cb6e0713a98e06c92d04275f800")

Computing Bayesian baseline for 128 positions...


In [9]:
hiddens = compute_hiddens(config, model, sampler, layer_index=4, return_final = False)
plot_task_vector_modes(hiddens[:3])

In [30]:
# To load checkpoints, use the following function. It takes the experiment config and step.
# It will output the closest checkpoint to the given step. 
model = nu.load_checkpoint(config, step=10000)

Auto-detected checkpoint directory: ..\results\latent\train_2124d32afa2b5fd088015a249920000c\checkpoints
Loading checkpoint: model_10000.pt (step 10000)
Checkpoint info: step=10000


In [10]:
hiddens_voc_fast_2048 = compute_hiddens_onepos_all_layers_ultra(config, model, sampler, all_samples)
hiddens_voc_fast_2048 = hiddens_voc_fast_2048.permute(0, 1, 3, 2, 4, 5)
hiddens_voc_fast_2048.shape

torch.Size([6, 43, 7, 128, 64, 128])

In [7]:
L, K, V, T, B, D = hiddens_voc_fast_2048.shape
plot_task_vocab_vector_modes(hiddens_voc_fast_2048[3,:3])

In [9]:
from icl.linear.linear_utils import estimate_lambda_with_r2

# We concatenate all hidden representations for different vocabularies together and then do the projection.
hiddens_voc = hiddens_voc_fast_2048.to(torch.float32)[5]
def plot_hidden_proj(hiddens_voc):
    K, _, T, B, _ = hiddens_voc.shape
    hiddens = hiddens_voc.permute(0, 2, 3, 4, 1).reshape(K, T, B, -1) # (K, T, B, D*V)
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D*V)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D*V)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(task_vecs_over_all_time, final_task_vecs, r2_scores, lambdas_voc_mlp0, n_minors=k_minor)
    fig.show()

plot_hidden_proj(hiddens_voc)

In [8]:
# We concatenate all hidden representations for different vocabularies together and then do the projection.
hiddens_voc = hiddens_voc_fast_2048.to(torch.float32)[2]
def plot_hidden_proj(hiddens_voc):
    K, _, T, B, _ = hiddens_voc.shape
    hiddens = hiddens_voc.permute(0, 2, 3, 4, 1).reshape(K, T, B, -1) # (K, T, B, D*V)
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D*V)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D*V)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(task_vecs_over_all_time, final_task_vecs, r2_scores, lambdas_voc_mlp0, n_minors=k_minor)
    fig.show()

plot_hidden_proj(hiddens_voc)

In [20]:
# We concatenate all hidden representations for different vocabularies together and then do the projection.
hiddens_voc = hiddens_voc_fast_2048.to(torch.float32)[3, :, 5]
def plot_hidden_proj_voc(hiddens_voc):
    hiddens = hiddens_voc
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(task_vecs_over_all_time, final_task_vecs, r2_scores, lambdas_voc_mlp0, n_minors=k_minor)
    fig.show()

plot_hidden_proj_voc(hiddens_voc)

In [20]:
# This will take a few minutes to run
r2_scores, fig = nu.plot_sufficient_stat(
    model=model,
    sampler=sampler,
    config=config,
    num_samples=1024,
    T0=10,
    max_layers=8,
    mode="ood"
)

In [7]:
from icl.utils.latent_ood_analysis import compute_latent_ood_metrics, process_latent_ood_evolve_checkpoints

# Example usage: Process checkpoints for OOD evolution analysis
# Make sure you have all_samples and k_minor defined before running this

steps_to_process = list(range(0, 30001, 400))
exp_name = "train_29fe6cb6e0713a98e06c92d04275f800"

all_samples, k_minor = get_all_samples(exp_name)

In [8]:


# Run the analysis (set forced=True to recompute even if cache exists)
results = process_latent_ood_evolve_checkpoints(
    exp_group="latent",
    exp_name=exp_name,
    steps=steps_to_process,
    all_samples=all_samples,
    k_minor=k_minor,
    layer_indices=list(range(6)),
    device="cuda" if torch.cuda.is_available() else "cpu",
    k_step=32,
    b_step=32,
    t_step=4,
    forced=False  # Set to True to bypass cache and recompute
)

# Check what's in results
print("Results keys:", results.keys())
if results:
    # Access results
    summary_r2_ood = results["summary_r2_ood"]
    lambda_dispersion_ood = results["lambda_dispersion_ood"]
    print(f"Processed {len(results['steps'])} checkpoints")
    print(f"Layers analyzed: {results['layers']}")
else:
    print("No results returned - check for errors above")

Loading cached results from ..\results\latent\train_29fe6cb6e0713a98e06c92d04275f800\latent_ood_evolve_ckpt_kminor_21_layers_0_1_2_3_4_5.pkl
Results keys: dict_keys(['steps', 'layers', 'summary_r2_ood', 'lambda_dispersion_ood', 'k_minor'])
Processed 76 checkpoints
Layers analyzed: [0, 1, 2, 3, 4, 5]


In [11]:
import plotly.graph_objects as go
import numpy as np
import json

def plot_all_layers_plotly(results_dict, metric="summary_r2_ood", exp_name=None, exp_group="latent", include_eval_losses=True):
    """
    Plot all layers on a single interactive Plotly figure.
    
    Args:
        results_dict : dict
            Output of process_ood_evolve_checkpoints
        metric : str
            "summary_r2_ood" or "lambda_dispersion_ood"
        exp_name : str, optional
            Experiment name to load eval losses from log.json
        exp_group : str
            Experiment group (default: "latent")
        include_eval_losses : bool
            Whether to include eval/IDLoss and eval/OODLoss (default: True)
    """
    layers = sorted(results_dict["layers"])
    metric_dict = results_dict[metric]
    
    # Collect all step values
    all_steps = sorted(set().union(*(metric_dict[L].keys() for L in layers)))

    fig = go.Figure()

    for L in layers:
        # Align values for consistent x-axis
        y_vals = [metric_dict[L].get(s, np.nan) for s in all_steps]

        fig.add_trace(go.Scatter(
            x=all_steps,
            y=y_vals,
            mode='lines+markers',
            name=f"Layer {L}",
            marker=dict(size=6),
            line=dict(width=2),
        ))

    # Add eval losses if requested and exp_name is provided
    if include_eval_losses and exp_name:
        try:
            log_path = f"../results/{exp_group}/{exp_name}/log.json"
            with open(log_path, 'r') as f:
                log_data = json.load(f)
            
            eval_steps = log_data["eval/step"]
            id_loss = log_data["eval/IDLoss"]
            ood_loss = log_data["eval/OODLoss"]
            
            # Filter to match checkpoint range: steps > 300 and <= 30000
            filtered_indices = [i for i, step in enumerate(eval_steps) if 300 < step <= 30000]
            filtered_steps = [eval_steps[i] for i in filtered_indices]
            filtered_id_loss = [id_loss[i] for i in filtered_indices]
            filtered_ood_loss = [ood_loss[i] for i in filtered_indices]
            
            # Add eval losses on secondary y-axis
            fig.add_trace(go.Scatter(
                x=filtered_steps,
                y=filtered_id_loss,
                mode='lines',
                name='eval/IDLoss',
                line=dict(width=2, dash='dash'),
                yaxis='y2'
            ))
            
            fig.add_trace(go.Scatter(
                x=filtered_steps,
                y=filtered_ood_loss,
                mode='lines',
                name='eval/OODLoss',
                line=dict(width=2, dash='dash'),
                yaxis='y2'
            ))
        except Exception as e:
            print(f"Could not load eval losses: {e}")

    title = "OOD R² over training (all layers)" if metric == "summary_r2_ood" \
            else "OOD lambda dispersion over training (all layers)"

    layout_update = {
        "title": title,
        "xaxis_title": "Checkpoint step",
        "yaxis_title": "Value",
        "legend_title": "Metrics",
        "hovermode": "closest",
        "template": "plotly_white",
        "width": 900,
        "height": 600,
    }
    
    # Add secondary y-axis if eval losses were added
    if include_eval_losses and exp_name:
        layout_update["yaxis2"] = dict(
            title="Loss",
            overlaying='y',
            side='right'
        )
    
    fig.update_layout(**layout_update)
    fig.show()

plot_all_layers_plotly(results, metric="summary_r2_ood", exp_name=exp_name)

In [12]:
plot_all_layers_plotly(results, metric="lambda_dispersion_ood", exp_name=exp_name)