In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colorbar as colorbar
import matplotlib.colors as mcolors
import matplotlib.patches as patches
from itertools import product
import seaborn as sns
import torch.utils.benchmark as benchmark
from icl.figures.task_vec_viz import *
from tqdm.notebook import tqdm, trange
import pandas as pd
import os
from scipy.optimize import curve_fit

import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.graph_objects as go

import copy

from icl.utils.train import BaseTrainer, train_model_with_plot
from icl.latent_markov import * # import the latent markov task and its configuration
from icl.models import Transformer
from icl.utils import visualize_attention
from icl.utils.train_utils import get_attn_base, compute_cross_entropy, ih_score, get_attn_at_layer_base
from icl.figures.head_view import *
import icl.utils.task_vec as task_vec
from icl.utils.latent_task_vec import *
import icl.utils.notebook_utils as nu

torch.set_printoptions(precision=3, sci_mode=False)

%load_ext autoreload
%autoreload 2

To modify experiment configurations, using the `../src/icl/latent_markov/latent_config.py` file. 

In [5]:
config = get_config_base()
config.training.warmup_steps = 40_000
config.training.num_epochs = 80_000
for k in range(1, 13):
    config.task.n_minor_tasks = 2**k
    model = Transformer(config)
    model = model.to(config.device)
    train_results = train_model_with_plot(model, config, show=False, verbose=False)

Experiment directory:  ..\results\latent\train_fe2016b76902e71aa89f90fcf5bcbcd0
train_fe2016b76902e71aa89f90fcf5bcbcd0 already completed
Experiment directory:  ..\results\latent\train_f2ed84eae56e677a90a54bbeac6c9248
train_f2ed84eae56e677a90a54bbeac6c9248 already completed
Experiment directory:  ..\results\latent\train_b9944bd677a37e8608b766a6c103f5a7
train_b9944bd677a37e8608b766a6c103f5a7 already completed
Experiment directory:  ..\results\latent\train_a379f1ad62eb5131672e49b55b7552da
train_a379f1ad62eb5131672e49b55b7552da already completed
Experiment directory:  ..\results\latent\train_343460a645410a3ac5d325539ea38e5e
train_343460a645410a3ac5d325539ea38e5e already completed
Experiment directory:  ..\results\latent\train_80079f091a945b4ba45fd800c74555b5
train_80079f091a945b4ba45fd800c74555b5 already completed
Experiment directory:  ..\results\latent\train_ec8e5cd4e2f7f75a77aad74aaa6d3eeb
train_ec8e5cd4e2f7f75a77aad74aaa6d3eeb already completed
Experiment directory:  ..\results\latent\

In [2]:
model, sampler, config = nu.load_everything("latent", "train_5050083a5f64fb641e9ca0d7a34622d6")

n_tasks = 32
sampler_clone0 = copy.deepcopy(sampler)
sampler_clone0.n_minor_tasks = min(n_tasks, sampler.n_minor_tasks)
k_ood = max(n_tasks//2, n_tasks - sampler_clone0.n_minor_tasks)
k_minor = sampler_clone0.n_minor_tasks - k_ood
sampler_clone0.minor_trans_mat[:k_ood] = sampler_clone0._sample_banded_trans_mats(k_ood)

In [3]:
torch.set_num_threads(16) 
sampler_clone0.to("cpu")
# n_tasks = 3
task_ids = range(2)
num_samples = 4

In [8]:
%%time

samples = sample_all_positions_all_targets_multi(
    sampler_clone0, task_ids, num_samples_per_condition=num_samples, device_override="cpu"
)
print(samples.shape)
# [n_tasks, seq_len//2, num_states, num_samples, seq_len-1]

Sampling tasks:   0%|          | 0/2 [00:00<?, ?it/s]

Task 0 positions:   0%|          | 0/129 [00:00<?, ?it/s]

Task 1 positions:   0%|          | 0/129 [00:00<?, ?it/s]

torch.Size([2, 129, 9, 4, 257])
CPU times: total: 1min 6s
Wall time: 1min 7s


In [4]:
nu.all_kl_plot(file_path="train_2124d32afa2b5fd088015a249920000c")

Computing Bayesian baseline for 256 positions...


In [6]:
hiddens = compute_hiddens(config, model, sampler, layer_index=3, return_final = False)
plot_task_vector_modes(hiddens[:3])

In [30]:
# To load checkpoints, use the following function. It takes the experiment config and step.
# It will output the closest checkpoint to the given step. 
model = nu.load_checkpoint(config, step=10000)

Auto-detected checkpoint directory: ..\results\latent\train_2124d32afa2b5fd088015a249920000c\checkpoints
Loading checkpoint: model_10000.pt (step 10000)
Checkpoint info: step=10000


In [31]:
samples0 = samples0.to("cpu")
hiddens_voc0 = compute_hiddens_from_existing_samples_fast(config, model, sampler, samples0, layer_index=4)

  0%|          | 0/9 [00:00<?, ?it/s]

In [26]:
from icl.linear.linear_utils import estimate_lambda_with_r2

# We concatenate all hidden representations for different vocabularies together and then do the projection.
def plot_hidden_proj(hiddens_voc):
    K, _, T, B, _ = hiddens_voc.shape
    hiddens = hiddens_voc.permute(0, 2, 3, 4, 1).reshape(K, T, B, -1) # (K, T, B, D*V)
    global_mean = hiddens[:3].mean(dim=(0,2)) # (T, D*V)
    task_vecs_over_all_time = hiddens.mean(dim=2) - global_mean.unsqueeze(dim=0) # (K, T, D*V)
    final_task_vecs = task_vecs_over_all_time[:3, -1]
    lambdas_voc_mlp0, r2_scores, task_norms, ortho_norms = estimate_lambda_with_r2(final_task_vecs, task_vecs_over_all_time)
    fig = project_with_r2_size(task_vecs_over_all_time, final_task_vecs, r2_scores, lambdas_voc_mlp0, n_minors=k_minor)
    fig.show()

plot_hidden_proj(hiddens_voc0)

In [32]:
plot_hidden_proj(hiddens_voc0)

In [23]:
plot_hidden_proj(hiddens_voc0)

In [29]:
plot_hidden_proj(hiddens_voc0)

In [33]:
# This will take a few minutes to run
r2_scores, fig = nu.plot_sufficient_stat(
    model=model,
    sampler=sampler,
    config=config,
    num_samples=1024,
    T0=10,
    max_layers=8,
    mode="ood"
)