In [1]:
import torch
from absl import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.optimize import curve_fit

import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.graph_objects as go

from icl.linear.train_linear import train
from icl.linear.lr_config import get_config
from icl.linear.lr_task import *
from icl.linear.linear_utils import *
from icl.linear.task_vecs import *
from icl.linear.train_linear import get_sharded_batch_sampler
from icl.linear import DiscreteMMSE, Ridge
from icl.linear.lr_models import MixedRidge, UnbalancedMMSE
#from icl.linear.sufficient_stats import get_sufficient_statistics_fit, get_sufficient_statistics_proj_fit, get_betahat_fit
from icl.linear.linear_notebook_helpers import process_sufficient_statistics, plot_r2_curves_plotly, get_eval_task, process_beta_fit, plot_metrics
from icl.utils.linear_visualization_utils import plot_mse
from icl.utils.linear_ood_analysis import process_ood_evolve, process_ood_evolve_checkpoints, process_ood_evolve_task_diversity
from icl.figures.attn_plots_beta import visualize_attention
from icl.figures.task_vec_viz import *
# from icl.utils.experiment_analysis import process_exp
# from icl.utils.linear_processor import process_ood_evolve_lambda_metrics

logging.set_verbosity(logging.INFO)
torch.set_printoptions(precision=3, sci_mode=False)
np.set_printoptions(precision=3, suppress=True)

%load_ext autoreload
%autoreload 2

To modify experiment configurations, using the `../src/icl/linear/lr_config.py` file.

In [2]:
config = get_config()
config.task.p_minor = 0.1
config.training.warmup_steps = 30_000
config.training.total_steps = 60_000
exp_names = []
for k in range(1, 13):
    config.task.n_minor_tasks = 2**k
    model, log = train(config)
    if isinstance(log, tuple):
        log, exp_path = log
        exp_names.append((k, exp_path.split("\\")[-2]))

..\results\linear\train_244beee6c29be8ab4f96771fdb79342f
train_244beee6c29be8ab4f96771fdb79342f already completed
Loaded model from ..\results\linear\train_244beee6c29be8ab4f96771fdb79342f\checkpoint.pt
..\results\linear\train_596c3de4dd0c1f079982dec494417f92
train_596c3de4dd0c1f079982dec494417f92 already completed
Loaded model from ..\results\linear\train_596c3de4dd0c1f079982dec494417f92\checkpoint.pt
..\results\linear\train_eb3b586c12d500ad1d3a3262c313a6a9
train_eb3b586c12d500ad1d3a3262c313a6a9 already completed
Loaded model from ..\results\linear\train_eb3b586c12d500ad1d3a3262c313a6a9\checkpoint.pt
..\results\linear\train_cbcd1c0ddd2348f736f6c6098cb69faa
train_cbcd1c0ddd2348f736f6c6098cb69faa already completed
Loaded model from ..\results\linear\train_cbcd1c0ddd2348f736f6c6098cb69faa\checkpoint.pt
..\results\linear\train_17121002a85d5cddfec36dcbeae81c10
train_17121002a85d5cddfec36dcbeae81c10 already completed
Loaded model from ..\results\linear\train_17121002a85d5cddfec36dcbeae81c10

In [3]:
steps, exp_names_name = zip(*exp_names)

In [21]:
result_dict_task = process_ood_evolve_task_diversity(exp_names_name, steps, forced=False)

Processing 12 experiments for task diversity analysis...
Already computed. Loading existing results from ..\results\linear\task_diversity_analysis\task_diversity_h_0.0_r_2.0_on_False_n_12_hash_7752bee4.pkl.


In [22]:
result_dict_task.keys()

dict_keys(['exp_names', 'steps', 'layers', 'summary_r2_ood', 'lambda_dispersion_ood', 'include_minor', 'radius', 'K'])

In [4]:
exp_name = "train_6ab65809d5e5b5fe12b3488cb7cc0ede"

# plot sufficient statistics fit

# (X^T X)^{-1} X^T Y

fig, r2_dict = plot_r2_curves_plotly(
    process_beta_fit, exp_name=exp_name,
    layer_indices=range(0, 16), is_eval=True, K=1024
)
fig.show()

In [5]:
process_ood_evolve(exp_name, K=100, layer_index=15, include_minor=True)

Preprocessing...
Too many minority tasks (4096). Randomly sampling 64.


h_1, h_2, h_3
\bar{h} = (h_1 + h_2 + h_3) / 3
x_1 = h_1 - \bar{h}
x_2 = h_2 - \bar{h}
x = \lambda (x_1, x_2, x_3) + \eps

In [8]:
for exp_name in exp_names_name:
    results_dict = process_ood_evolve_checkpoints(K=100, exp_name=exp_name, layer_indices=[12])

Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_244beee6c29be8ab4f96771fdb79342f\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_596c3de4dd0c1f079982dec494417f92\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_eb3b586c12d500ad1d3a3262c313a6a9\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_cbcd1c0ddd2348f736f6c6098cb69faa\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_17121002a85d5cddfec36dcbeae81c10\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_1f47a24f3c5b1a0a4a20aee54df9750c\ood

In [10]:
results_dict.keys()

dict_keys(['steps', 'layers', 'summary_r2_ood', 'lambda_dispersion_ood', 'include_minor', 'n_minor_sampled', 'radius'])

In [None]:
import plotly.graph_objects as go
import numpy as np
import json

def plot_all_layers_plotly(results_dict, metric="summary_r2_ood", steps=None, 
                           exp_name=None, eval_key=None, include_transformer_true=False,
                           exp_names_with_k=None):
    """
    Plot all layers on a single interactive Plotly figure.
    
    Args:
        results_dict : dict
            Output of process_ood_evolve_checkpoints
        metric : str
            "summary_r2_ood" or "lambda_dispersion_ood"
        steps : list, optional
            List of step values to use for x-axis (or k values for task diversity)
        exp_name : str, optional
            Experiment name to load log.json from (e.g., "train_1cb25eea155301b4045c86ff82db64a7")
        eval_key : str or list of str, optional
            One of "Latent_false", "Latent_true", "Pretrain_false", "Pretrain_true"
            or a list of multiple keys to overlay
        include_transformer_true : bool
            If True and exp_name/eval_key provided, overlay Transformer|True from log.json
        exp_names_with_k : list of tuples, optional
            List of (k, exp_name) tuples for task diversity plots. If provided, plots final
            Transformer|True values vs n_minor_tasks=2^k instead of training curves.
    """
    layers = sorted(results_dict["layers"])
    metric_dict = results_dict[metric]
    
    # Collect all step values
    if not steps:
        all_steps = sorted(set().union(*(metric_dict[L].keys() for L in layers)))
    else:
        all_steps = steps

    fig = go.Figure()

    for L in layers:
        # Align values for consistent x-axis
        y_vals = [metric_dict[L].get(s, np.nan) for s in all_steps]

        fig.add_trace(go.Scatter(
            x=all_steps,
            y=y_vals,
            mode='lines+markers',
            name=f"Layer {L}",
            marker=dict(size=6),
            line=dict(width=2),
        ))

    # Optionally add Transformer|True from log.json
    if include_transformer_true and eval_key:
        # Case 1: Task diversity plot (exp_names_with_k provided)
        if exp_names_with_k:
            eval_keys = [eval_key] if isinstance(eval_key, str) else eval_key
            
            for ek in eval_keys:
                final_values = []
                k_values = []
                
                for k, exp_name in exp_names_with_k:
                    log_path = f"../results/linear/{exp_name}/log.json"
                    try:
                        with open(log_path) as f:
                            data = json.load(f)
                        
                        eval_data = data[f'eval/{ek}']
                        transformer_true = eval_data['Transformer | True']
                        mean_values = [np.mean(values) for values in transformer_true]
                        final_values.append(mean_values[-1])  # Take final step
                        k_values.append(k)
                    except Exception as e:
                        print(f"Warning: Could not load data for k={k}, exp={exp_name}: {e}")
                
                if final_values:
                    fig.add_trace(go.Scatter(
                        x=k_values,
                        y=final_values,
                        mode='lines+markers',
                        name=f'Transformer | True ({ek})',
                        marker=dict(size=10),
                        line=dict(width=3, dash='dash'),
                    ))
        
        # Case 2: Single experiment training curve (exp_name provided)
        elif exp_name:
            log_path = f"../results/linear/{exp_name}/log.json"
            try:
                with open(log_path) as f:
                    data = json.load(f)
                
                log_steps = data['train/step']
                
                # Handle single eval_key or list of eval_keys
                eval_keys = [eval_key] if isinstance(eval_key, str) else eval_key
                
                for ek in eval_keys:
                    eval_data = data[f'eval/{ek}']
                    transformer_true = eval_data['Transformer | True']
                    mean_values = [np.mean(values) for values in transformer_true]
                    
                    fig.add_trace(go.Scatter(
                        x=log_steps,
                        y=mean_values,
                        mode='lines',
                        name=f'Transformer | True ({ek})',
                        line=dict(width=3, dash='dash'),
                    ))
            except Exception as e:
                print(f"Warning: Could not load Transformer|True data: {e}")

    title = "OOD R² over training (all layers)" if metric == "summary_r2_ood" \
            else "OOD lambda dispersion over training (all layers)"
    
    # Update x-axis label based on context
    if exp_names_with_k:
        xaxis_title = "k (where n_minor_tasks = 2^k)"
    else:
        xaxis_title = "Checkpoint step"

    fig.update_layout(
        title=title,
        xaxis_title=xaxis_title,
        yaxis_title="Value",
        legend_title="Layers",
        hovermode="closest",
        template="plotly_white",
        width=900,
        height=600,
    )

    fig.show()

plot_all_layers_plotly(
    results_dict, metric="summary_r2_ood",exp_name="train_6ab65809d5e5b5fe12b3488cb7cc0ede",
    eval_key=["Latent_false", "Pretrain_false"],
    include_transformer_true=True)

In [26]:
plot_all_layers_plotly(results_dict, metric="lambda_dispersion_ood",
                       exp_name="train_6ab65809d5e5b5fe12b3488cb7cc0ede",
    eval_key=["Latent_false", "Pretrain_false"],
    include_transformer_true=True)

In [24]:
plot_all_layers_plotly(
    result_dict_task, 
    metric="summary_r2_ood", 
    steps=steps,
    exp_names_with_k=exp_names,
    eval_key=["Latent_false", "Pretrain_false"],
    include_transformer_true=True
)

In [27]:
plot_all_layers_plotly(
    result_dict_task, 
    metric="lambda_dispersion_ood", 
    steps=steps,
    exp_names_with_k=exp_names,
    eval_key=["Latent_false", "Pretrain_false"],
    include_transformer_true=True
)

## Plot eval/Latent_false Transformer|True vs Steps

In [6]:
import json
import plotly.graph_objects as go
import numpy as np

def plot_ood_metrics_with_transformer_true(exp_names_with_k, metric="summary_r2_ood", 
                                            layer_index=12, eval_key="Latent_false",
                                            K=100, checkpoint_step=None):
    """
    Plot OOD metrics from process_ood_evolve_checkpoints alongside Transformer|True MSE.
    
    Args:
        exp_names_with_k: List of tuples (k, exp_name) where n_minor_tasks = 2^k
        metric: "summary_r2_ood" or "lambda_dispersion_ood"
        layer_index: Which layer to extract from results_dict
        eval_key: One of "Latent_false", "Latent_true", "Pretrain_false", "Pretrain_true"
        K: Number of batches for process_ood_evolve_checkpoints
        checkpoint_step: Specific checkpoint step to plot. If None, uses final checkpoint.
    """
    k_values = []
    ood_metric_values = []
    transformer_mse_values = []
    n_minor_tasks_list = []
    
    for k, exp_name in exp_names_with_k:
        n_minor_tasks = 2**k
        
        try:
            # Get OOD metric from process_ood_evolve_checkpoints
            results_dict = process_ood_evolve_checkpoints(
                K=K, 
                exp_name=exp_name, 
                layer_indices=[layer_index]
            )
            
            # Extract the metric for the specified layer
            layer_metric = results_dict[metric][layer_index]
            
            # Get the value at the checkpoint step (or final if not specified)
            if checkpoint_step is not None:
                ood_value = layer_metric.get(checkpoint_step, np.nan)
            else:
                # Get the last checkpoint
                steps = sorted(layer_metric.keys())
                ood_value = layer_metric[steps[-1]] if steps else np.nan
            
            # Get Transformer|True MSE from log.json
            log_path = f"../results/linear/{exp_name}/log.json"
            with open(log_path) as f:
                data = json.load(f)
            
            eval_data = data[f'eval/{eval_key}']
            transformer_true = eval_data['Transformer | True']
            mean_values = [np.mean(values) for values in transformer_true]
            mse_value = mean_values[-1]  # Final step
            
            k_values.append(k)
            n_minor_tasks_list.append(n_minor_tasks)
            ood_metric_values.append(ood_value)
            transformer_mse_values.append(mse_value)
            
        except Exception as e:
            print(f"Warning: Could not process k={k}, exp={exp_name}: {e}")
    
    # Create figure with secondary y-axis
    fig = go.Figure()
    
    # Add OOD metric trace
    fig.add_trace(go.Scatter(
        x=k_values,
        y=ood_metric_values,
        mode='lines+markers',
        name=f'{metric} (Layer {layer_index})',
        marker=dict(size=10),
        line=dict(width=2),
        yaxis='y1'
    ))
    
    # Add Transformer|True MSE trace
    fig.add_trace(go.Scatter(
        x=k_values,
        y=transformer_mse_values,
        mode='lines+markers',
        name=f'Transformer | True MSE ({eval_key})',
        marker=dict(size=10, symbol='square'),
        line=dict(width=2, dash='dash'),
        yaxis='y2'
    ))
    
    metric_title = "OOD R²" if metric == "summary_r2_ood" else "OOD λ dispersion"
    
    fig.update_layout(
        title=f'{metric_title} and Transformer MSE vs Task Diversity<br>Layer {layer_index}, Eval: {eval_key}',
        xaxis=dict(
            title='k (where n_minor_tasks = 2^k)',
            side='bottom'
        ),
        yaxis=dict(
            title=metric_title,
            side='left'
        ),
        yaxis2=dict(
            title='Transformer | True MSE',
            side='right',
            overlaying='y',
            showgrid=False
        ),
        hovermode='x unified',
        template='plotly_white',
        width=1000,
        height=600,
        legend=dict(x=0.02, y=0.98)
    )
    
    return fig

# Example usage:
fig = plot_ood_metrics_with_transformer_true(
     exp_names, 
     metric="summary_r2_ood", 
     layer_index=12, 
     eval_key="Latent_false",
     K=100
)
fig.show()

Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_244beee6c29be8ab4f96771fdb79342f\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_596c3de4dd0c1f079982dec494417f92\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_eb3b586c12d500ad1d3a3262c313a6a9\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_cbcd1c0ddd2348f736f6c6098cb69faa\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_17121002a85d5cddfec36dcbeae81c10\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_1f47a24f3c5b1a0a4a20aee54df9750c\ood

In [16]:
import json
import plotly.graph_objects as go
import numpy as np

def plot_training_curves_all_experiments(exp_names_with_k, 
                                         include_ood_metric=True,
                                         metric="summary_r2_ood",
                                         layer_index=12,
                                         K=100,
                                         include_transformer_mse=True,
                                         eval_key="Latent_false",
                                         combine_plots=False):
    """
    Unified function to plot OOD metrics and/or Transformer|True MSE across all experiments.
    Each experiment (different k/n_minor_tasks) gets its own line.
    
    Args:
        exp_names_with_k: List of tuples (k, exp_name) where n_minor_tasks = 2^k
        include_ood_metric: If True, plot OOD metric (R² or λ dispersion)
        metric: "summary_r2_ood" or "lambda_dispersion_ood"
        layer_index: Which layer to extract from results_dict
        K: Number of batches for process_ood_evolve_checkpoints
        include_transformer_mse: If True, plot Transformer|True MSE
        eval_key: One of "Latent_false", "Latent_true", "Pretrain_false", "Pretrain_true"
        combine_plots: If True and both metrics included, plot on same figure with dual y-axes
    """
    results = {}
    
    # Collect data for all experiments
    for k, exp_name in exp_names_with_k:
        n_minor_tasks = 2**k
        
        try:
            # Get OOD metric if requested
            if include_ood_metric:
                results_dict = process_ood_evolve_checkpoints(
                    K=K, 
                    exp_name=exp_name, 
                    layer_indices=[layer_index]
                )
                layer_metric = results_dict[metric][layer_index]
                ood_steps = sorted(layer_metric.keys())
                ood_values = [layer_metric[step] for step in ood_steps]
            else:
                ood_steps, ood_values = None, None
            
            # Get Transformer MSE if requested
            if include_transformer_mse:
                log_path = f"../results/linear/{exp_name}/log.json"
                with open(log_path) as f:
                    data = json.load(f)
                
                train_steps = data['train/step']
                eval_data = data[f'eval/{eval_key}']
                transformer_true = eval_data['Transformer | True']
                mse_values = [np.mean(values) for values in transformer_true]
            else:
                train_steps, mse_values = None, None
            
            results[k] = {
                'n_minor': n_minor_tasks,
                'ood_steps': ood_steps,
                'ood_values': ood_values,
                'train_steps': train_steps,
                'mse_values': mse_values
            }
            
        except Exception as e:
            print(f"Warning: Could not process k={k}, exp={exp_name}: {e}")
    
    # Create plots
    if combine_plots and include_ood_metric and include_transformer_mse:
        # Combined plot with dual y-axes
        fig = go.Figure()
        
        for k, data in results.items():
            n_minor = data['n_minor']
            
            # Add OOD metric
            fig.add_trace(go.Scatter(
                x=data['ood_steps'],
                y=data['ood_values'],
                mode='lines+markers',
                name=f'OOD (k={k}, n={n_minor})',
                marker=dict(size=6),
                line=dict(width=2),
                yaxis='y1'
            ))
            
            # Add MSE
            fig.add_trace(go.Scatter(
                x=data['train_steps'],
                y=data['mse_values'],
                mode='lines',
                name=f'MSE (k={k}, n={n_minor})',
                line=dict(width=2, dash='dash'),
                yaxis='y2'
            ))
        
        metric_title = "OOD R²" if metric == "summary_r2_ood" else "OOD λ dispersion"
        
        fig.update_layout(
            title=f'{metric_title} and Transformer MSE Across Training<br>Layer {layer_index}, Eval: {eval_key}',
            xaxis_title='Training Step',
            yaxis=dict(
                title=metric_title,
                side='left'
            ),
            yaxis2=dict(
                title='Transformer | True MSE',
                side='right',
                overlaying='y',
                showgrid=False
            ),
            hovermode='x unified',
            template='plotly_white',
            width=1400,
            height=700,
            legend=dict(x=1.05, y=1)
        )
        
        return fig
    
    else:
        # Separate plots
        figs = []
        
        if include_ood_metric:
            fig1 = go.Figure()
            
            for k, data in results.items():
                if data['ood_values'] is not None:
                    fig1.add_trace(go.Scatter(
                        x=data['ood_steps'],
                        y=data['ood_values'],
                        mode='lines+markers',
                        name=f'k={k} (n_minor={data["n_minor"]})',
                        marker=dict(size=6),
                        line=dict(width=2)
                    ))
            
            metric_title = "OOD R²" if metric == "summary_r2_ood" else "OOD λ dispersion"
            
            fig1.update_layout(
                title=f'{metric_title} Across Training Steps<br>Layer {layer_index}, All Experiments',
                xaxis_title='Training Step',
                yaxis_title=metric_title,
                hovermode='x unified',
                template='plotly_white',
                width=1200,
                height=600,
                legend_title='Experiment'
            )
            figs.append(fig1)
        
        if include_transformer_mse:
            fig2 = go.Figure()
            
            for k, data in results.items():
                if data['mse_values'] is not None:
                    fig2.add_trace(go.Scatter(
                        x=data['train_steps'],
                        y=data['mse_values'],
                        mode='lines',
                        name=f'k={k} (n_minor={data["n_minor"]})',
                        line=dict(width=2)
                    ))
            
            fig2.update_layout(
                title=f'Transformer | True MSE Across Training Steps<br>Eval: {eval_key}, All Experiments',
                xaxis_title='Training Step',
                yaxis_title='MSE',
                hovermode='x unified',
                template='plotly_white',
                width=1200,
                height=600,
                legend_title='Experiment'
            )
            figs.append(fig2)
        
        return figs if len(figs) > 1 else figs[0]

# Example 1: Both metrics on separate plots
figs = plot_training_curves_all_experiments(
    exp_names, 
    include_ood_metric=True,
    metric="summary_r2_ood", 
    layer_index=12, 
    K=100,
    include_transformer_mse=True,
    eval_key="Latent_false",
    combine_plots=False
)
figs[0].show()  # OOD metric
figs[1].show()  # Transformer MSE

# Example 2: Both metrics on same plot with dual y-axes
#fig = plot_training_curves_all_experiments(
#     exp_names, 
#     include_ood_metric=True,
#     metric="summary_r2_ood", 
#     layer_index=12, 
#     K=100,
#     include_transformer_mse=True,
#     eval_key="Latent_false",
#     combine_plots=True
#)
#fig.show()

# Example 3: Just Transformer MSE
# fig = plot_training_curves_all_experiments(
#     exp_names, 
#     include_ood_metric=False,
#     include_transformer_mse=True,
#     eval_key="Latent_false"
# )
# fig.show()

Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_244beee6c29be8ab4f96771fdb79342f\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_596c3de4dd0c1f079982dec494417f92\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_eb3b586c12d500ad1d3a3262c313a6a9\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_cbcd1c0ddd2348f736f6c6098cb69faa\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_17121002a85d5cddfec36dcbeae81c10\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.
Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_1f47a24f3c5b1a0a4a20aee54df9750c\ood