In [None]:
import torch
from absl import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.optimize import curve_fit

import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.graph_objects as go

from icl.linear.train_linear import train
from icl.linear.lr_config import get_config
from icl.linear.lr_task import *
from icl.linear.linear_utils import *
from icl.linear.task_vecs import *
from icl.linear.train_linear import get_sharded_batch_sampler
from icl.linear import DiscreteMMSE, Ridge
from icl.linear.lr_models import MixedRidge, UnbalancedMMSE
#from icl.linear.sufficient_stats import get_sufficient_statistics_fit, get_sufficient_statistics_proj_fit, get_betahat_fit
from icl.utils.linear_notebook_helpers import process_sufficient_statistics, plot_r2_curves_plotly, get_eval_task, process_beta_fit, plot_metrics
from icl.utils.linear_visualization_utils import plot_mse
from icl.utils.linear_ood_analysis import process_ood_evolve, process_ood_evolve_checkpoints
from icl.figures.attn_plots_beta import visualize_attention
from icl.figures.task_vec_viz import *
# from icl.utils.experiment_analysis import process_exp
# from icl.utils.linear_processor import process_ood_evolve_lambda_metrics

logging.set_verbosity(logging.INFO)
torch.set_printoptions(precision=3, sci_mode=False)
np.set_printoptions(precision=3, suppress=True)

%load_ext autoreload
%autoreload 2

To modify experiment configurations, using the `../src/icl/linear/lr_config.py` file.

In [2]:
config = get_config()
config.task.p_minor = 0.1
config.training.warmup_steps = 30_000
config.training.total_steps = 60_000
for k in range(10, 13):
    config.task.n_minor_tasks = 2**k
    model, log = train(config)

..\results\linear\train_4de979b8d7914639afda3b23ee7d5ba9
train_4de979b8d7914639afda3b23ee7d5ba9 already completed



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



Loaded model from ..\results\linear\train_4de979b8d7914639afda3b23ee7d5ba9\checkpoint.pt
..\results\linear\train_ccf7f6f1a7749d22a6a053f39affa348
train_ccf7f6f1a7749d22a6a053f39affa348 already completed
Loaded model from ..\results\linear\train_ccf7f6f1a7749d22a6a053f39affa348\checkpoint.pt
..\results\linear\train_6ab65809d5e5b5fe12b3488cb7cc0ede
train_6ab65809d5e5b5fe12b3488cb7cc0ede already completed
Loaded model from ..\results\linear\train_6ab65809d5e5b5fe12b3488cb7cc0ede\checkpoint.pt


In [7]:
exp_name = "train_6ab65809d5e5b5fe12b3488cb7cc0ede"

# plot sufficient statistics fit

# (X^T X)^{-1} X^T Y

fig, r2_dict = plot_r2_curves_plotly(
    process_beta_fit, exp_name=exp_name,
    layer_indices=range(0, 16), is_eval=True, K=1024
)
fig.show()

In [8]:
process_ood_evolve(exp_name, K=100, layer_index=15, include_minor=True)

Preprocessing...
Too many minority tasks (4096). Randomly sampling 64.


h_1, h_2, h_3
\bar{h} = (h_1 + h_2 + h_3) / 3
x_1 = h_1 - \bar{h}
x_2 = h_2 - \bar{h}
x = \lambda (x_1, x_2, x_3) + \eps

In [7]:
results_dict = process_ood_evolve_checkpoints(exp_name="train_6ab65809d5e5b5fe12b3488cb7cc0ede")

Preprocessing...
Already computed. Loading existing results from ..\results\linear\train_6ab65809d5e5b5fe12b3488cb7cc0ede\ood_evolve_ckpt_all_layers_h_0.0_r_2.0_on_False.pkl.


In [25]:
import plotly.graph_objects as go
import numpy as np

def plot_all_layers_plotly(results_dict, metric="summary_r2_ood"):
    """
    Plot all layers on a single interactive Plotly figure.
    
    Args:
        results_dict : dict
            Output of process_ood_evolve_checkpoints
        metric : str
            "summary_r2_ood" or "lambda_dispersion_ood"
    """
    layers = sorted(results_dict["layers"])
    metric_dict = results_dict[metric]
    
    # Collect all step values
    all_steps = sorted(set().union(*(metric_dict[L].keys() for L in layers)))

    fig = go.Figure()

    for L in layers:
        # Align values for consistent x-axis
        y_vals = [metric_dict[L].get(s, np.nan) for s in all_steps]

        fig.add_trace(go.Scatter(
            x=all_steps,
            y=y_vals,
            mode='lines+markers',
            name=f"Layer {L}",
            marker=dict(size=6),
            line=dict(width=2),
        ))

    title = "OOD R² over training (all layers)" if metric == "summary_r2_ood" \
            else "OOD lambda dispersion over training (all layers)"

    fig.update_layout(
        title=title,
        xaxis_title="Checkpoint step",
        yaxis_title="Value",
        legend_title="Layers",
        hovermode="closest",
        template="plotly_white",
        width=900,
        height=600,
    )

    fig.show()

plot_all_layers_plotly(results_dict, metric="summary_r2_ood")

In [24]:
plot_all_layers_plotly(results_dict, metric="lambda_dispersion_ood")

In [23]:
# Plot the latent loss for the current experiment
exp_name = "train_6ab65809d5e5b5fe12b3488cb7cc0ede"
log_path = f"../results/linear/{exp_name}/log.json"
fig = plot_latent_loss(log_path)

Loading log file from: ../results/linear/train_6ab65809d5e5b5fe12b3488cb7cc0ede/log.json
Found metric data with 2400 evaluation points
Metric data is nested. First element has 64 values
Aggregating by taking mean across inner dimension
Using train/step for evaluation steps. Found 2400 evaluation points
Plotting 2400 data points
