In [None]:
import sys

sys.path.insert(0, "../..")

In [None]:
import os
import chex
import dill
import jax
import jax.numpy as jnp
import json
import numpy as np

from flax import nnx
from torch.utils.data import DataLoader

from src.dataset import get_iter
from src.datasets.linear_regression import ICLinearRegression
from src.utils import parse_dict

In [None]:
base_path = "/home/bryanpu1/projects/iclr_2026/icl_architecture/scaling_jax/results"
algo_name = "regression_gpt"
run_name = "adamw-06-09-25_15_24_07-07442dc2-cd9b-4467-b15d-849adec762d0"

learner_path = os.path.join(base_path, algo_name, run_name)

In [None]:
half_precision = True
eval_seed = 40
batch_size = 32

dtype = jnp.bfloat16 if half_precision else jnp.float32

In [None]:
config_dict = json.load(open(os.path.join(learner_path, "config.json"), "r"))
embed_dim = config_dict["model_config"]["model_kwargs"]["embed_dim"]

last_step = sorted(os.listdir(os.path.join(learner_path, "models")))[-1]
train_state = dill.load(
    open(os.path.join(learner_path, "models", last_step), "rb")
)
model = nnx.merge(
    train_state.graphdef,
    train_state.params,
    train_state.rest,
)

rng = jax.random.PRNGKey(eval_seed)
rng, _ = jax.random.split(rng)

In [None]:
dataset_kwargs = parse_dict(config_dict["dataset_kwargs"])
dataset = ICLinearRegression(
    dataset_kwargs.num_tasks,
    dataset_kwargs.num_dims,
    dataset_kwargs.context_len,
    eval_seed,
    dataset_kwargs.train,
    dataset_kwargs.input_noise_std,
    dataset_kwargs.label_noise_std,
)

In [None]:
data_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=0,
)
data_iter = get_iter(data_loader, None, dtype)

In [None]:
batch = next(data_iter)

In [None]:
batch["example"].shape

In [None]:
res = model(batch)

In [None]:
errors = batch["target"] - res

In [None]:
import matplotlib.pyplot as plt

x_range = range(dataset_kwargs.context_len)
error_per_context = np.mean(errors ** 2, axis=1)

plt.plot(x_range, error_per_context)
    