In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colorbar as colorbar
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import torch.utils.benchmark as benchmark
from icl.figures.task_vec_viz import *
from tqdm.notebook import tqdm, trange
import os

import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.graph_objects as go

import copy

from icl.utils.train import train_model_with_plot
from icl.coin.coin import * 
from icl.coin.coin_config import get_config_coin
from icl.models import Transformer
from icl.figures.attn_plots_beta import visualize_attention
from icl.utils.train_utils import get_attn_base, compute_cross_entropy, ih_score, get_attn_at_layer_base
from icl.figures.head_view import *
#import icl.utils.task_vec as task_vec
from icl.utils.latent_task_vec import *
#from icl.utils.fast_latent_task_vec import compute_hiddens_onepos_all_layers_fast
import icl.utils.notebook_utils as nu
from icl.utils.ultra_latent_task_vec import compute_hiddens_onepos_all_layers_ultra
from icl.utils.kv_latent_task_vec_beta import compute_hiddens_onepos_all_layers_kvcache_beta
# from icl.utils.fast_latent_task_vec_kvcache import compute_hiddens_onepos_all_layers_kvcached
torch.set_printoptions(precision=3, sci_mode=False)

%load_ext autoreload
%autoreload 2

In [5]:
config = get_config_coin()
config.training.warmup_steps = 10_000
config.training.num_epochs = 20_000
config.vocab_size = 4
for k in range(1, 10, 2):
    config.task.n_minor_tasks = 2**k
    model = Transformer(config)
    model = model.to(config.device)
    train_results = train_model_with_plot(model, config, show=False, verbose=False)

Experiment directory:  ..\results\coin\train_3f17a09db4719418877265b9a8582ab3
train_3f17a09db4719418877265b9a8582ab3 already completed
Experiment directory:  ..\results\coin\train_7599f2078faa51921b8aff140db931c1
train_7599f2078faa51921b8aff140db931c1 already completed
Experiment directory:  ..\results\coin\train_cd0cdd990ac9b3b8e1a0c83b1f486bdb
train_cd0cdd990ac9b3b8e1a0c83b1f486bdb already completed
Experiment directory:  ..\results\coin\train_de294a0cd8f599e46f1886de966bcbe7
train_de294a0cd8f599e46f1886de966bcbe7 already completed
Experiment directory:  ..\results\coin\train_02221ecdf0b6a4c1ddd41a5ae61ef3f9
train_02221ecdf0b6a4c1ddd41a5ae61ef3f9 already completed


In [10]:
exp_name = "train_02221ecdf0b6a4c1ddd41a5ae61ef3f9"
model, sampler, config = nu.load_everything("coin", exp_name)

def get_new_sampler(exp_name, n_tasks=61, n_ood=40, B=96):
    model, sampler, config = nu.load_everything("coin", exp_name)
    sampler_clone0 = copy.deepcopy(sampler)

    orig = sampler_clone0.minor_p
    if orig is None:
        raise ValueError("sampler.minor_p is None; cannot expand minor pool.")

    # How many OOD rows we must insert up front
    # (ensures at least n_ood, and also enough room so we can reach n_tasks total)
    k_ood = max(n_ood, n_tasks - sampler_clone0.n_minor_tasks)

    sampler_clone0.n_minor_tasks = n_tasks
    k_minor = n_tasks - k_ood  # how many "old" rows we want after the OOD block

    n_orig = orig.shape[0]
    trailing_shape = orig.shape[1:]  # () for Bernoulli-style, (K,) for categorical, etc.

    # --- build OOD rows with matching shape/dtype/device ---
    # raw random
    ood = torch.rand((k_ood, *trailing_shape), device=orig.device, dtype=orig.dtype)

    # If categorical/multinomial probs, normalize over the last dim
    # (assumes last dim represents states)
    if orig.ndim >= 2:
        ood = ood / ood.sum(dim=-1, keepdim=True).clamp_min(1e-12)

    # --- allocate new minor_p ---
    new_minor = orig.new_empty((n_tasks, *trailing_shape))  # same dtype/device

    # Fill first k_ood with new OOD
    new_minor[:k_ood].copy_(ood)

    # Fill remaining with old minors (as many as we actually have)
    take = min(k_minor, n_orig)
    if take > 0:
        new_minor[k_ood:k_ood + take].copy_(orig[:take])

    # If we still need more rows (old pool too small), generate additional random rows
    remain = k_minor - take
    if remain > 0:
        extra = torch.rand((remain, *trailing_shape), device=orig.device, dtype=orig.dtype)
        if orig.ndim >= 2:
            extra = extra / extra.sum(dim=-1, keepdim=True).clamp_min(1e-12)
        new_minor[k_ood + take:].copy_(extra)

    sampler_clone0.minor_p = new_minor
    return sampler_clone0, k_minor


In [11]:
sampler_clone0, k_minor = get_new_sampler(exp_name, n_tasks=61, n_ood=40, B=96)
hiddens = compute_hiddens(config, model, sampler_clone0, layer_index=5, return_final = False)
plot_task_vector_modes(hiddens[:3])

In [8]:
from icl.linear.linear_utils import estimate_lambda_with_r2

# We concatenate all hidden representations for different vocabularies together and then do the projection.
def plot_hidden_proj(hiddens, k_minor):
    # (n_tasks, seq_len-1, B, n_embd)
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(
        task_vecs_over_all_time, 
        final_task_vecs, 
        r2_scores, 
        lambdas_voc_mlp0, 
        n_minors=k_minor,
        hover_data=torch.concat([sampler_clone0.major_p, sampler_clone0.minor_p]), 
        hover_name="p",
        )
    fig.show()

plot_hidden_proj(hiddens, k_minor)

In [12]:
plot_hidden_proj(hiddens, k_minor)

In [13]:
import json
import plotly.graph_objects as go

def plot_eval_losses(exp_name, exp_group="coin", step_range=None):
    """
    Plot eval/IDLoss and eval/OODLoss against training steps.
    
    Args:
        exp_name: Experiment name (e.g., "train_6b9b1e239ad66222ef2361ad413c390b")
        exp_group: Experiment group directory (default: "coin")
        step_range: Tuple of (min_step, max_step) to filter steps. If None, plots all steps.
    """
    log_path = f"../results/{exp_group}/{exp_name}/log.json"
    
    try:
        with open(log_path, 'r') as f:
            log_data = json.load(f)
        
        eval_steps = log_data["eval/step"]
        id_loss = log_data["eval/IDLoss"]
        ood_loss = log_data["eval/OODLoss"]
        
        # Filter steps if range is provided
        if step_range is not None:
            min_step, max_step = step_range
            filtered_indices = [i for i, step in enumerate(eval_steps) if min_step < step <= max_step]
            eval_steps = [eval_steps[i] for i in filtered_indices]
            id_loss = [id_loss[i] for i in filtered_indices]
            ood_loss = [ood_loss[i] for i in filtered_indices]
        
        # Create figure
        fig = go.Figure()
        
        # Add ID Loss
        fig.add_trace(go.Scatter(
            x=eval_steps,
            y=id_loss,
            mode='lines',
            name='eval/IDLoss',
            line=dict(width=2)
        ))
        
        # Add OOD Loss
        fig.add_trace(go.Scatter(
            x=eval_steps,
            y=ood_loss,
            mode='lines',
            name='eval/OODLoss',
            line=dict(width=2)
        ))
        
        # Update layout
        fig.update_layout(
            title=f'Evaluation Losses over Training Steps<br>Experiment: {exp_name}',
            xaxis_title='Training Step',
            yaxis_title='Loss',
            hovermode='x unified',
            template='plotly_white',
            width=1000,
            height=600,
            legend=dict(x=0.02, y=0.98)
        )
        
        return fig
        
    except FileNotFoundError:
        print(f"Could not find log file at {log_path}")
        return None
    except KeyError as e:
        print(f"Missing key in log data: {e}")
        return None
    except Exception as e:
        print(f"Error loading eval losses: {e}")
        return None

# Example usage:
fig = plot_eval_losses(exp_name, exp_group="coin")
if fig:
    fig.show()

# With step range filter:
# fig = plot_eval_losses(exp_name, exp_group="coin", step_range=(300, 30000))
# if fig:
#     fig.show()