# Evaluate ICL Model on the Omniglot Classification Dataset

In [None]:
from jaxl.constants import *
from jaxl.datasets import get_dataset
from jaxl.datasets.wrappers import (
    ContextDataset,
    StandardSupervisedDataset,
    FixedLengthContextDataset,
    RepeatedContextDataset,
)
from jaxl.models import load_config, load_model, get_model, get_activation
from jaxl.plot_utils import set_size
from jaxl.utils import parse_dict, get_device

import _pickle as pickle
import copy
import jax
import jax.random as jrandom
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import torchvision.datasets as torch_datasets

from collections import OrderedDict
from functools import partial
from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
from types import SimpleNamespace

In [None]:
# device = "cpu"
device = "gpu:0"
get_device(device)

In [None]:
doc_width_pt = 750.0

base_path = "/home/bryanpu1/projects/icl/jaxl/"
data_path = os.path.join(base_path, "data")
log_path = os.path.join(base_path, "jaxl/logs")
project_name = "icl-omniglot"
# ablation_name = "bursty_ablation"
ablation_name = "bursty_ablation-fixed_length"
# ablation_name = ""
run_name = (
    # Input tokenizer
    # "cnn-context_len_16-num_blocks_8-02-12-24_23_14_55-b0846297-59fb-4e96-9eff-b4e05902f09c"
    # "resnet-context_len_16-num_blocks_8-02-13-24_00_22_16-188aedab-34f0-4299-b830-4a62c45dcfe4"
    # "resnet-no_bn-context_len_16-num_blocks_8-02-13-24_01_00_39-26d7ced5-d8ff-4ab3-8cfb-2cc30b069a1d"

    # Output tokenizer
    # "resnet-no_bn-context_len_16-num_blocks_8-02-13-24_14_16_38-8cb7114e-b2f7-4203-9db5-d059cb1a9651"
    # "resnet-no_bn-frozen_output_tokenizer-context_len_16-num_blocks_8-02-14-24_15_57_52-f5648ff4-0443-4bd1-af75-fcc14e2146ff"

    # Noisy pretraining
    # "resnet-no_bn-frozen_output_tokenizer-noisy_pretraining-context_len_16-num_blocks_8-02-15-24_12_37_55-65f7ccfd-f504-46a7-a5e7-df355d08587b"

    # Noisy pretraining + 12 transformer blocks
    # "resnet-no_bn-frozen_output_tokenizer-noisy_pretraining-context_len_16-num_blocks_12-02-15-24_13_07_31-af0da6f1-43a8-4f6b-89f5-1f8117124861"

    # Burstiness experiments
    # Variable length
    # "bursty_0.0-02-24-24_06_55_15-910c422e-36bb-4ff2-995a-1a7e13adf571"
    # "bursty_0.5-02-24-24_06_55_11-af7e29d4-08f3-45c2-8fdd-7d052416f898"
    # "bursty_1.0-02-24-24_06_55_07-3e790c78-6da7-4be3-ba93-1359ba67ab0a"

    # Fixed length
    # "bursty_0.0-02-25-24_15_22_20-b8aa457b-7a3a-498b-848c-0e7b819952e7"
    # "bursty_0.5-02-25-24_15_22_11-483fa86e-36ee-410d-901d-ed1007f7881a"
    # "bursty_0.5-02-26-24_18_11_59-a173f14d-bc0c-4990-9dea-1eedc118e85b" # Longer training
    # "bursty_1.0-02-25-24_15_22_04-4ef72a13-fa8d-49cb-a5de-e09d36e3abaf"

    # Batch norm ablation
    # "bursty_0.0-batch_norm-02-28-24_12_45_42-5b3cd1ab-4768-4f74-9688-5be15bd4f9cd"
    "bursty_0.5-batch_norm-02-27-24_20_56_01-8347c265-2e88-40c6-9cc0-30242f2b8f41"
    # "bursty_0.5-num_blocks_12-batch_norm-02-27-24_20_29_48-7fcb7bcb-5ba1-459e-9129-3cc164b04860"
    # "bursty_0.5-num_blocks_12-batch_norm-no_aug-02-27-24_20_37_07-04e2891c-636b-4610-b9f2-8efbebe2b1db"
)

learner_path = os.path.join(
    log_path,
    project_name,
    ablation_name,
    run_name,
)

exp_name = "-".join(run_name.split("-")[:-8])
p_bursty = 0.5

# Experiment Configuration

In [None]:
config_dict, config = load_config(learner_path)
config_dict["learner_config"]["dataset_config"]["dataset_kwargs"]["task_config"]["p_bursty"] = p_bursty
fixed_length = config.learner_config.dataset_config.dataset_wrapper.type in ["FixedLengthContextDataset"]

In [None]:
config

# Load Dataset and Model

In [None]:
train_dataset = get_dataset(
    config.learner_config.dataset_config,
    config.learner_config.seeds.data_seed,
)

In [None]:
params, model = load_model(
    train_dataset.input_dim, train_dataset.output_dim, learner_path, -1
)

In [None]:
context_len = config.model_config.num_contexts
num_samples_per_task = train_dataset._dataset.sequence_length - 1
sequence_length = train_dataset._dataset.sequence_length
num_tasks = 100
num_workers = 4

print(num_samples_per_task, num_tasks, sequence_length, context_len)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=num_samples_per_task,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

# Helper Functions

In [None]:
def get_preds_labels(data_loader, num_tasks, max_label=None):
    all_preds = []
    all_labels = []
    all_outputs = []

    for batch_i, samples in enumerate(data_loader):
        if batch_i >= num_tasks:
            break

        (context_inputs, context_outputs, queries, one_hot_labels) = samples

        train_outputs, _, train_updates = model.forward(
            params[CONST_MODEL_DICT][CONST_MODEL],
            queries.numpy(),
            {
                CONST_CONTEXT_INPUT: context_inputs.numpy(),
                CONST_CONTEXT_OUTPUT: context_outputs.numpy(),
            },
            eval=False,
        )

        outputs, _, updates = model.forward(
            params[CONST_MODEL_DICT][CONST_MODEL],
            queries.numpy(),
            {
                CONST_CONTEXT_INPUT: context_inputs.numpy(),
                CONST_CONTEXT_OUTPUT: context_outputs.numpy(),
            },
            eval=True,
        )
        return train_outputs, train_updates, outputs, updates
        if max_label is None:
            preds = np.argmax(outputs, axis=-1)
        else:
            preds = np.argmax(outputs[..., :max_label], axis=-1)
        labels = np.argmax(one_hot_labels, axis=-1)
        all_preds.append(preds)
        all_labels.append(labels)
        all_outputs.append(outputs)

    all_outputs = np.concatenate(all_outputs)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    return all_preds, all_labels, all_outputs


def print_performance(
    all_preds,
    all_labels,
    sequence_length,
    context_len,
    output_dim,
    fixed_length = False,
):
    conf_mat = confusion_matrix(all_labels, all_preds, labels=np.arange(output_dim))
    acc = np.trace(conf_mat) / np.sum(conf_mat) * 100
    print("Pretraining Accuracy: {}".format(acc))

    if not fixed_length:
        reshaped_preds = all_preds.reshape((-1, sequence_length - 1))
        reshaped_labels = all_labels.reshape((-1, sequence_length - 1))
        for curr_context_len in range(context_len):
            if curr_context_len < context_len - 1:
                curr_preds = reshaped_preds[:, curr_context_len]
                curr_labels = reshaped_labels[:, curr_context_len]
            else:
                curr_preds = reshaped_preds[:, curr_context_len:]
                curr_labels = reshaped_labels[:, curr_context_len:]

            curr_preds = curr_preds.reshape(-1)
            curr_labels = curr_labels.reshape(-1)

            curr_conf_mat = confusion_matrix(
                curr_labels, curr_preds, labels=np.arange(output_dim)
            )
            curr_acc = np.trace(curr_conf_mat) / np.sum(curr_conf_mat) * 100
            print(
                "Pretraining Accuracy with Context Length {} (Num Samples: {}): {}".format(
                    curr_context_len + 1, np.sum(curr_conf_mat), curr_acc
                )
            )

In [None]:
train_outputs, train_updates, outputs, updates = get_preds_labels(train_loader, num_tasks)

# Check Dataset

In [None]:
check = True
if check:
    for task_i in range(2):
        ci, co, q, l = train_dataset[
            task_i * num_samples_per_task + num_samples_per_task - 1
        ]

        nrows = 2
        ncols = 8
        fig, axes = plt.subplots(
            nrows,
            ncols + 1,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )

        for idx, (img, label) in enumerate(zip(ci, co)):
            axes[idx // ncols, idx % ncols].imshow(img[0])
            axes[idx // ncols, idx % ncols].set_title(np.argmax(label))
            axes[idx // ncols, idx % ncols].axis('off')
        axes[0, -1].axis('off')
        axes[1, -1].axis('off')
        axes[1, -1].imshow(q[0, 0])
        axes[1, -1].set_title(np.argmax(l, axis=-1))
        plt.show()
        plt.close()

# Get Training Performance

In [None]:
train_preds, train_labels, train_outputs = get_preds_labels(train_loader, num_tasks)
pickle.dump(
    [train_preds, train_labels, train_outputs],
    open("train_prediction_result.pkl", "wb"),
)

In [None]:
print_performance(
    train_preds,
    train_labels,
    sequence_length,
    context_len,
    train_dataset.output_dim[0],
    fixed_length=fixed_length,
)

# In-distribution Test Data

In [None]:
num_in_dist_test_tasks = 30
in_dist_test_data_seed = 1000

In [None]:
config.learner_config.dataset_config

In [None]:
in_dist_test_config_dict = copy.deepcopy(
    config_dict["learner_config"]["dataset_config"]
)
in_dist_test_config_dict["dataset_kwargs"]["num_sequences"] = num_in_dist_test_tasks
in_dist_test_config_dict["dataset_kwargs"]["task_config"]["augmentation"] = True
in_dist_test_config_dict["dataset_kwargs"]["task_config"]["noise_scale"] = 0.1
in_dist_test_config = parse_dict(in_dist_test_config_dict)

In [None]:
in_dist_test_dataset = get_dataset(
    in_dist_test_config,
    in_dist_test_data_seed,
)

## Check Dataset

In [None]:
check = True
if check:
    for task_i in range(2):
        ci, co, q, l = in_dist_test_dataset[
            task_i * num_samples_per_task + num_samples_per_task - 1
        ]

        nrows = 2
        ncols = 8
        fig, axes = plt.subplots(
            nrows,
            ncols + 1,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )

        for idx, (img, label) in enumerate(zip(ci, co)):
            axes[idx // ncols, idx % ncols].imshow(img[0])
            axes[idx // ncols, idx % ncols].set_title(np.argmax(label))
            axes[idx // ncols, idx % ncols].axis('off')
        axes[0, -1].axis('off')
        axes[1, -1].axis('off')
        axes[1, -1].imshow(q[0, 0])
        axes[1, -1].set_title(np.argmax(l, axis=-1))
        plt.show()
        plt.close()

In [None]:
in_dist_test_loader = DataLoader(
    in_dist_test_dataset,
    batch_size=num_samples_per_task,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

In [None]:
in_dist_test_preds, in_dist_test_labels, in_dist_test_outputs = get_preds_labels(
    in_dist_test_loader, num_in_dist_test_tasks
)
pickle.dump(
    [in_dist_test_preds, in_dist_test_labels, in_dist_test_outputs],
    open("in_dist_test_prediction_result.pkl", "wb"),
)

In [None]:
print_performance(
    in_dist_test_preds,
    in_dist_test_labels,
    sequence_length,
    context_len,
    in_dist_test_dataset.output_dim[0],
    fixed_length=fixed_length,
)

# Relabel Train Data

In [None]:
num_remap_train_tasks = 30
remap_train_data_seed = 1000

In [None]:
remap_train_config_dict = copy.deepcopy(config_dict["learner_config"]["dataset_config"])
remap_train_config_dict["dataset_kwargs"]["train"] = True
remap_train_config_dict["dataset_kwargs"]["remap"] = True
remap_train_config_dict["dataset_kwargs"]["num_sequences"] = num_remap_train_tasks
remap_train_config = parse_dict(remap_train_config_dict)

In [None]:
remap_train_dataset = get_dataset(
    remap_train_config,
    remap_train_data_seed,
)

## Check Dataset

In [None]:
check = True
if check:
    for task_i in range(2):
        ci, co, q, l = remap_train_dataset[
            task_i * num_samples_per_task + num_samples_per_task - 1
        ]

        nrows = 2
        ncols = 8
        fig, axes = plt.subplots(
            nrows,
            ncols + 1,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )

        for idx, (img, label) in enumerate(zip(ci, co)):
            axes[idx // ncols, idx % ncols].imshow(img[0])
            axes[idx // ncols, idx % ncols].set_title(np.argmax(label))
            axes[idx // ncols, idx % ncols].axis('off')
        axes[0, -1].axis('off')
        axes[1, -1].axis('off')
        axes[1, -1].imshow(q[0, 0])
        axes[1, -1].set_title(np.argmax(l, axis=-1))
        plt.show()
        plt.close()

In [None]:
remap_train_loader = DataLoader(
    remap_train_dataset,
    batch_size=num_samples_per_task,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

In [None]:
remap_train_preds, remap_train_labels, remap_train_outputs = get_preds_labels(
    remap_train_loader, num_remap_train_tasks, max_label=2,
)
pickle.dump(
    [remap_train_preds, remap_train_labels, remap_train_outputs],
    open("remap_train_prediction_result.pkl", "wb"),
)

In [None]:
print_performance(
    remap_train_preds,
    remap_train_labels,
    sequence_length,
    context_len,
    remap_train_dataset.output_dim[0],
    fixed_length=fixed_length,
)

# Out-of-class Test Data
This checks out-of-class generalization (i.e. heldout classes)

In [None]:
num_ooc_test_tasks = 30
ooc_test_data_seed = 1000

In [None]:
ooc_test_config_dict = copy.deepcopy(config_dict["learner_config"]["dataset_config"])
ooc_test_config_dict["dataset_kwargs"]["train"] = False
ooc_test_config_dict["dataset_kwargs"]["remap"] = False
ooc_test_config_dict["dataset_kwargs"]["num_sequences"] = num_ooc_test_tasks
ooc_test_config = parse_dict(ooc_test_config_dict)

In [None]:
ooc_test_dataset = get_dataset(
    ooc_test_config,
    ooc_test_data_seed,
)

## Check Dataset

In [None]:
check = True
if check:
    for task_i in range(2):
        ci, co, q, l = ooc_test_dataset[
            task_i * num_samples_per_task + num_samples_per_task - 1
        ]

        nrows = 2
        ncols = 8
        fig, axes = plt.subplots(
            nrows,
            ncols + 1,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )

        for idx, (img, label) in enumerate(zip(ci, co)):
            axes[idx // ncols, idx % ncols].imshow(img[0])
            axes[idx // ncols, idx % ncols].set_title(np.argmax(label))
            axes[idx // ncols, idx % ncols].axis('off')
        axes[0, -1].axis('off')
        axes[1, -1].axis('off')
        axes[1, -1].imshow(q[0, 0])
        axes[1, -1].set_title(np.argmax(l, axis=-1))
        plt.show()
        plt.close()

In [None]:
ooc_test_loader = DataLoader(
    ooc_test_dataset,
    batch_size=num_samples_per_task,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

In [None]:
ooc_test_preds, ooc_test_labels, ooc_test_outputs = get_preds_labels(
    ooc_test_loader, num_ooc_test_tasks
)
pickle.dump(
    [ooc_test_preds, ooc_test_labels, ooc_test_outputs],
    open("ooc_test_prediction_result.pkl", "wb"),
)

In [None]:
print_performance(
    ooc_test_preds,
    ooc_test_labels,
    sequence_length,
    context_len,
    ooc_test_dataset.output_dim[0],
    fixed_length=fixed_length,
)

# Relabel Test Data
Maps the labels to a constrained subset

In [None]:
num_remap_test_tasks = 30
remap_test_data_seed = 1000

In [None]:
remap_test_config_dict = copy.deepcopy(config_dict["learner_config"]["dataset_config"])
remap_test_config_dict["dataset_kwargs"]["train"] = False
remap_test_config_dict["dataset_kwargs"]["remap"] = True
remap_test_config_dict["dataset_kwargs"]["num_sequences"] = num_remap_test_tasks
remap_test_config = parse_dict(remap_test_config_dict)

In [None]:
remap_test_dataset = get_dataset(
    remap_test_config,
    remap_test_data_seed,
)

In [None]:
remap_test_loader = DataLoader(
    remap_test_dataset,
    batch_size=num_samples_per_task,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

In [None]:
remap_test_preds, remap_test_labels, remap_test_outputs = get_preds_labels(
    remap_test_loader, num_remap_test_tasks, max_label=2
)
pickle.dump(
    [remap_test_preds, remap_test_labels, remap_test_outputs],
    open("remap_test_prediction_result.pkl", "wb"),
)

In [None]:
print_performance(
    remap_test_preds,
    remap_test_labels,
    sequence_length,
    context_len,
    remap_test_dataset.output_dim[0],
    fixed_length=fixed_length,
)

In [None]:
remap_test_preds, remap_test_labels, remap_test_outputs = get_preds_labels(
    remap_test_loader, num_remap_test_tasks
)
pickle.dump(
    [remap_test_preds, remap_test_labels, remap_test_outputs],
    open("remap_test_prediction_result-unconstrained.pkl", "wb"),
)

In [None]:

print_performance(
    remap_test_preds,
    remap_test_labels,
    sequence_length,
    context_len,
    remap_test_dataset.output_dim[0],
    fixed_length=fixed_length,
)