In [1]:
from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

run_dir = "src/output"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Directly specify your model paths
run_path_time = "output/c6b14be6-5a5a-42c9-ad1d-504622d36218"
run_path_freq = "output/c80936b9-3dd5-4e38-bb14-2501c4c34084"

# Load configs to verify
_, conf_time = get_model_from_run(run_path_time, only_conf=True)
_, conf_freq = get_model_from_run(run_path_freq, only_conf=True)

print("Time domain model:", conf_time.training.data_kwargs.domain, "n_dims:", conf_time.model.n_dims)
print("Freq domain model:", conf_freq.training.data_kwargs.domain, "n_dims:", conf_freq.model.n_dims)

Time domain model: time n_dims: 50
Freq domain model: freq n_dims: 100


In [4]:
# Choose which model to evaluate
run_path = run_path_time  # or run_path_freq

# Compute metrics (this will save to metrics.json in the run directory)
recompute_metrics = True
if recompute_metrics:
    metrics = get_run_metrics(run_path)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

# Plot metrics

In [None]:
# Load and plot metrics
import json
with open(os.path.join(run_path, "metrics.json")) as f:
    metrics = json.load(f)

_, conf = get_model_from_run(run_path, only_conf=True)
# Don't filter by model names, just plot what's in metrics
basic_plot(metrics["standard"])
plt.show()

In [None]:
# Plot any OOD metrics (if available)
for name, metric in metrics.items():
    if name == "standard": 
        continue
    
    basic_plot(metric)
    plt.title(name)
    plt.show()

# Interactive setup

We will now directly load the model and measure its in-context learning ability on a batch of random inputs. (In the paper we average over multiple such batches to obtain better estimates.)

In [6]:
from samplers import get_data_sampler
from tasks import get_task_sampler

In [None]:
model, conf = get_model_from_run(run_path)

# Move model to appropriate device
if torch.cuda.is_available():
    model = model.cuda()
elif torch.backends.mps.is_available():
    model = model.to("mps")
model.eval()

n_dims = conf.model.n_dims
batch_size = conf.training.batch_size

# Override device for samplers
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

data_kwargs = dict(conf.training.data_kwargs)
task_kwargs = dict(conf.training.task_kwargs)
data_kwargs['device'] = device
task_kwargs['device'] = device

data_sampler = get_data_sampler(conf.training.data, n_dims, **data_kwargs)
task_sampler = get_task_sampler(
    conf.training.task,
    n_dims,
    batch_size,
    **task_kwargs
)

In [8]:
task = task_sampler()
xs = data_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end)
ys = task.evaluate(xs)

In [None]:
with torch.no_grad():
    xs_device = xs.to(model.device if hasattr(model, 'device') else device)
    ys_device = ys.to(model.device if hasattr(model, 'device') else device)
    pred = model(xs_device, ys_device).cpu()

In [None]:
metric = task.get_metric()
# Compute per-point MSE
preds_np = pred.cpu().numpy()
ys_np = ys.cpu().numpy()
loss = ((preds_np - ys_np) ** 2).mean(axis=(0, 2))  # mean over batch and features, keep n_points

# Baseline for signal_conv: zero filter (output = 0)
baseline = 1.0

plt.plot(loss, lw=2, label="Transformer")
plt.axhline(baseline, ls="--", color="gray", label="zero filter")
plt.xlabel("# in-context examples")
plt.ylabel("squared error")
plt.legend()
plt.show()

# Visualize example signals

Let's plot some example input signals, their true outputs, and the model's predictions to see how well it learns the FIR filter.

In [None]:
xs2 = 2 * xs
ys2 = task.evaluate(xs2)
with torch.no_grad():
    xs2_device = xs2.to(model.device if hasattr(model, 'device') else device)
    ys2_device = ys2.to(model.device if hasattr(model, 'device') else device)
    pred2 = model(xs2_device, ys2_device).cpu()

In [None]:
preds2_np = pred2.cpu().numpy()
ys2_np = ys2.cpu().numpy()
loss2 = ((preds2_np - ys2_np) ** 2).mean(axis=(0, 2))

plt.plot(loss, lw=2, label="Transformer")
plt.plot(loss2, lw=2, label="Transformer on doubled inputs")
plt.axhline(baseline, ls="--", color="gray", label="zero filter")
plt.xlabel("# in-context examples")
plt.ylabel("squared error")
plt.legend()
plt.show()

In [ ]:
# Plot example signals: input, true output, and predicted output
# Pick one example from the batch
example_idx = 0
n_points_to_show = 5  # Number of in-context examples to visualize

fig, axes = plt.subplots(n_points_to_show, 3, figsize=(15, 3*n_points_to_show))
if n_points_to_show == 1:
    axes = axes.reshape(1, -1)

for point_idx in range(n_points_to_show):
    # Get the signals for this point
    x_signal = xs[example_idx, point_idx].cpu().numpy()
    y_true = ys[example_idx, point_idx].cpu().numpy()
    y_pred = pred[example_idx, point_idx].cpu().numpy()
    
    # Plot input signal
    axes[point_idx, 0].plot(x_signal, 'b-', alpha=0.7)
    axes[point_idx, 0].set_title(f'Input Signal (point {point_idx+1})')
    axes[point_idx, 0].set_xlabel('Time' if conf.training.data_kwargs.domain == 'time' else 'Frequency bin')
    axes[point_idx, 0].grid(True, alpha=0.3)
    
    # Plot true output
    axes[point_idx, 1].plot(y_true, 'g-', alpha=0.7, label='True')
    axes[point_idx, 1].set_title(f'True Output (point {point_idx+1})')
    axes[point_idx, 1].set_xlabel('Time' if conf.training.task_kwargs.domain == 'time' else 'Frequency bin')
    axes[point_idx, 1].grid(True, alpha=0.3)
    
    # Plot predicted output vs true
    axes[point_idx, 2].plot(y_true, 'g-', alpha=0.7, label='True')
    axes[point_idx, 2].plot(y_pred, 'r--', alpha=0.7, label='Predicted')
    axes[point_idx, 2].set_title(f'Prediction vs True (point {point_idx+1})')
    axes[point_idx, 2].set_xlabel('Time' if conf.training.task_kwargs.domain == 'time' else 'Frequency bin')
    axes[point_idx, 2].legend()
    axes[point_idx, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Show final point error
final_error = ((y_pred - y_true) ** 2).mean()
print(f"MSE for final point: {final_error:.6f}")