# Evaluate ICL Model on the Omniglot Classification Dataset

## Imports

In [None]:
from jaxl.constants import *
from jaxl.datasets import get_dataset
from jaxl.datasets.wrappers import (
    ContextDataset,
    StandardSupervisedDataset,
    FixedLengthContextDataset,
    RepeatedContextDataset,
)
from jaxl.models import load_config, load_model, get_model, get_activation
from jaxl.plot_utils import set_size
from jaxl.utils import parse_dict, get_device

import _pickle as pickle
import copy
import jax
import jax.random as jrandom
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import torchvision.datasets as torch_datasets

from collections import OrderedDict
from functools import partial
from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
from types import SimpleNamespace

# Setup

In [None]:
# device = "cpu"
device = "gpu:0"
get_device(device)

In [None]:
doc_width_pt = 750.0

base_path = "/home/bryanpu1/projects/icl/jaxl/"
data_path = os.path.join(base_path, "data")
log_path = os.path.join(base_path, "jaxl/logs")
project_name = "icl-omniglot"
# ablation_name = "bursty_ablation"
# ablation_name = "bursty_ablation-fixed_length"
ablation_name = "pixel_noise_ablation/all_omniglot-pixel_noise_0.1"
# ablation_name = ""
run_name = (
    # "bursty_0.0-context_len_8-03-04-24_11_06_22-7a6e6aca-b77d-401d-842e-9bdddc5eeaa1"
    # "bursty_0.5-context_len_8-03-04-24_11_06_15-b2563eae-7e83-4b45-aa00-eb6d770c5fc5"
    # "bursty_1.0-context_len_8-03-04-24_11_06_28-33b976d7-422f-4d66-ae9f-3bd86f1bd451"

    # "bursty_0.0-context_len_8-batch_norm-03-05-24_07_56_37-1c4d12c3-4363-4cb6-bcda-6b25932cc9ea"
    # "bursty_0.5-context_len_8-batch_norm-03-05-24_07_55_58-a4d55433-45e2-406f-869a-ccd082c2b983"
    # "bursty_1.0-context_len_8-batch_norm-03-05-24_07_57_05-f318fcb7-b332-4719-8917-028876194a02"

    # "bursty_0.0-context_len_8-batch_norm-no_aug-03-06-24_09_04_24-b772c27c-1592-4be2-a24e-b1b1c56e4761"
    # "bursty_0.5-context_len_8-batch_norm-no_aug-03-06-24_09_04_50-25fc1859-2f1c-4758-8d6a-a893d59425d7"
    # "bursty_1.0-context_len_8-batch_norm-no_aug-03-06-24_09_04_16-feec38cd-486e-4885-a108-961c93061a7b"

    # "bursty_1.0-context_len_8-batch_norm-num_blocks_12-03-08-24_11_59_04-52544705-d03f-4781-b2ce-98a98f8d487d"
    # "bursty_1.0-context_len_8-batch_norm-num_blocks_12-larger_dataset-03-08-24_12_59_55-f4498c2d-df47-45d0-a817-41db45c76d7a"
    # "bursty_1.0-context_len_8-batch_norm-num_blocks_12-no_aug-larger_dataset-03-08-24_13_25_46-1c0e931a-8989-4a4a-8db5-ceba093418e4"

    # "bursty_0.0-context_len_8-batch_norm-larger_dataset-03-09-24_22_11_33-6da36b91-6f7e-4950-98d3-f81c0ad471e5"
    # "bursty_0.5-context_len_8-batch_norm-larger_dataset-03-09-24_22_11_22-f752093b-3f9f-413c-956a-8965896f52a4"
    # "bursty_1.0-context_len_8-batch_norm-larger_dataset-03-09-24_22_11_11-5ac8c8f6-ad00-46c9-92eb-469b22d7979d"

    # Pixel Noise
    # "bursty_0.0-pixel_noise_0.1-03-14-24_12_01_47-93f2a221-43a3-4cf5-870c-627608e759b6"
    # "bursty_0.5-pixel_noise_0.1-03-14-24_11_53_05-d84ee8c5-24ca-486e-ae76-9553b15a35a6"
    # "bursty_1.0-pixel_noise_0.1-03-14-24_11_53_08-91577252-aa44-4133-8ec7-a5fcda220d3a"

    # all_omniglot-pixel_noise_0.1
    "bursty_1.0-diff_pos_enc-03-19-24_11_07_54-6f37e06f-569d-427c-a14c-9c9d54587c5b"
)

learner_path = os.path.join(
    log_path,
    project_name,
    ablation_name,
    run_name,
)

exp_name = "-".join(run_name.split("-")[:-8])

## Experiment Configuration

In [None]:
config_dict, config = load_config(learner_path)
fixed_length = config.learner_config.dataset_config.dataset_wrapper.type in ["FixedLengthContextDataset"]

In [None]:
config

## Load Train Dataset and Model

In [None]:
train_dataset = get_dataset(
    config.learner_config.dataset_config,
    config.learner_config.seeds.data_seed,
)

In [None]:
params, model = load_model(
    train_dataset.input_dim, train_dataset.output_dim, learner_path, -1
)

In [None]:
num_test_tasks = 30
test_data_seed = 1000

context_len = config.model_config.num_contexts
num_samples_per_task = train_dataset._dataset.sequence_length - 1
sequence_length = train_dataset._dataset.sequence_length
num_tasks = 100
num_workers = 4

print(num_samples_per_task, num_tasks, sequence_length, context_len)

## Helper Functions

In [None]:
# Plot dataset example
def plot_examples(dataset, num_examples = 2):
    num_samples_per_task = dataset._dataset.sequence_length - 1
    for task_i in range(num_examples):
        ci, co, q, l = dataset[
            task_i * num_samples_per_task + num_samples_per_task - 1
        ]

        nrows = 2
        ncols = 8
        fig, axes = plt.subplots(
            nrows,
            ncols + 1,
            figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
            layout="constrained",
        )

        for idx, (img, label) in enumerate(zip(ci, co)):
            axes[idx // ncols, idx % ncols].imshow(img)
            axes[idx // ncols, idx % ncols].set_title(np.argmax(label))
            axes[idx // ncols, idx % ncols].axis('off')
        axes[0, -1].axis('off')
        axes[1, -1].axis('off')
        axes[1, -1].imshow(q[0])
        axes[1, -1].set_title(np.argmax(l, axis=-1))
        plt.show()
        plt.close()

# Get model predictions
def get_preds_labels(data_loader, num_tasks, max_label=None):
    all_preds = []
    all_labels = []
    all_outputs = []

    for batch_i, samples in enumerate(data_loader):
        if batch_i >= num_tasks:
            break

        (context_inputs, context_outputs, queries, one_hot_labels) = samples

        outputs, _, _ = model.forward(
            params[CONST_MODEL_DICT][CONST_MODEL],
            queries.numpy(),
            {
                CONST_CONTEXT_INPUT: context_inputs.numpy(),
                CONST_CONTEXT_OUTPUT: context_outputs.numpy(),
            },
            eval=True,
        )
        # return train_outputs, train_updates, outputs, updates
        if max_label is None:
            preds = np.argmax(outputs, axis=-1)
        elif max_label == CONST_AUTO:
            preds = np.argmax(outputs[..., :data_loader.dataset._data["num_classes"]], axis=-1)
        else:
            preds = np.argmax(outputs[..., :max_label], axis=-1)
        labels = np.argmax(one_hot_labels, axis=-1)
        all_preds.append(preds)
        all_labels.append(labels)
        all_outputs.append(outputs)

    all_outputs = np.concatenate(all_outputs)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    return all_preds, all_labels, all_outputs

# Check model accuracy
def print_performance(
    all_preds,
    all_labels,
    output_dim,
):
    result_str = ""
    conf_mat = confusion_matrix(all_labels, all_preds, labels=np.arange(output_dim))
    acc = np.trace(conf_mat) / np.sum(conf_mat) * 100
    result_str += "Accuracy: {}%\n".format(acc)

    return result_str

# Complete evaluation
def evaluate(exp_name, eval_name, dataset_config, seed, num_tasks, max_label, batch_size, num_workers, visualize=False, save=False):
    dataset = get_dataset(
        dataset_config,
        seed,
    )

    if visualize:
        plot_examples(dataset)

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
    )

    preds, labels, outputs = get_preds_labels(data_loader, num_tasks, max_label)
    acc_str = print_performance(
        preds,
        labels,
        dataset.output_dim[0],
    )

    if save:
        save_dir = "./evaluation-{}".format(exp_name)
        os.makedirs(save_dir, exist_ok=True)
        pickle.dump(
            {
                "config": dataset_config,
                "seed": seed,
                "max_label": max_label,
                "num_tasks": num_tasks,
                "results": {
                    "preds": preds,
                    "labels": labels,
                    "outputs": outputs,
                    "acc_str": acc_str,
                }
            },
            open(os.path.join(save_dir, "{}.pkl".format(eval_name)), "wb"),
        )
    return acc_str

# Pretraining
Pretraining is the exact same dataset for training the ICL model---we expect the performance to be near perfect.

In [None]:
pretrain_acc = evaluate(
    exp_name=exp_name,
    eval_name="pretraining",
    dataset_config=config.learner_config.dataset_config,
    seed=config.learner_config.seeds.data_seed,
    num_tasks=num_tasks,
    max_label=None,
    batch_size=num_samples_per_task,
    num_workers=num_workers,
    save=True,
    visualize=True,
)
print(pretrain_acc)

# In-distribution Evaluation
This uses the pretraining image classes.  
We consider three types of evaluations:
1. **Same pretraining data distribution accuracy**: This uses the same dataset as the pretraining but with a different seed.  
This may change, for example, the augmentation or the sequences.

1. **In-weight accuracy**: This uses totally random contexts and the model should predict purely using the query.  
This follows from Chan et al. (2022).  
Note: When $P(\text{bursty}) = 0$ for pretraining, then this is the same as same pretraining data distribution

1. **Pretraining N-shot 2-way accuracy**: This uses the same dataset as the pretraining but with a different seed  
Furthermore, we relabel each task to be a binary classification task such that half of the contexts are filled with class 1, and half with class 0.  
We constrain the model output to only be the valid classes.

## Same Pretraining Data Distribution

In [None]:
same_pretraining_config_dict = copy.deepcopy(
    config_dict["learner_config"]["dataset_config"]
)
same_pretraining_config_dict["dataset_kwargs"]["num_sequences"] = num_test_tasks
same_pretraining_config = parse_dict(same_pretraining_config_dict)

same_pretraining_acc = evaluate(
    exp_name=exp_name,
    eval_name="same_pretraining",
    dataset_config=same_pretraining_config,
    seed=test_data_seed,
    num_tasks=num_test_tasks,
    max_label=None,
    batch_size=num_samples_per_task,
    num_workers=num_workers,
    save=True,
    visualize=True,
)
print(same_pretraining_acc)

## In-weight

In [None]:
in_weight_config_dict = copy.deepcopy(
    config_dict["learner_config"]["dataset_config"]
)
in_weight_config_dict["dataset_kwargs"]["num_sequences"] = num_test_tasks
in_weight_config_dict["dataset_kwargs"]["task_config"]["p_bursty"] = 0.0
in_weight_config_dict["dataset_kwargs"]["task_config"]["unique_classes"] = True
in_weight_config = parse_dict(in_weight_config_dict)

in_weight_acc = evaluate(
    exp_name=exp_name,
    eval_name="in_weight",
    dataset_config=in_weight_config,
    seed=test_data_seed,
    num_tasks=num_test_tasks,
    max_label=None,
    batch_size=num_samples_per_task,
    num_workers=num_workers,
    save=True,
    visualize=True,
)
print(in_weight_acc)

## Pretraining N-shot 2-way

In [None]:
pretrain_n_shot_2_way_config_dict = copy.deepcopy(config_dict["learner_config"]["dataset_config"])
pretrain_n_shot_2_way_config_dict["dataset_kwargs"]["task_name"] = CONST_MULTITASK_OMNIGLOT_N_SHOT_K_WAY
pretrain_n_shot_2_way_config_dict["dataset_kwargs"]["task_config"]["p_bursty"] = 1.0
pretrain_n_shot_2_way_config_dict["dataset_kwargs"]["task_config"]["k_way"] = 2
pretrain_n_shot_2_way_config_dict["dataset_kwargs"]["num_sequences"] = num_test_tasks
pretrain_n_shot_2_way_config = parse_dict(pretrain_n_shot_2_way_config_dict)

pretrain_n_shot_2_way_acc = evaluate(
    exp_name=exp_name,
    eval_name="pretrain_n_shot_2_way",
    dataset_config=pretrain_n_shot_2_way_config,
    seed=test_data_seed,
    num_tasks=num_test_tasks,
    max_label=2,
    batch_size=num_samples_per_task,
    num_workers=num_workers,
    save=True,
    visualize=True,
)
print(pretrain_n_shot_2_way_acc)

# In-context Evaluation
This uses heldout image classes.  
We consider two types of evaluations:
1. **Complete out-of-distribution accuracy**: We use simply the heldout classes.  
We constrain the model output to only be the valid classes.

1. **N-shot 2-way accuracy**: We treat each task to be a binary classification task such that half of the contexts are filled with class 1, and half with class 0.  
We constrain the model output to only be the valid classes.

## Complete Out-of-distribution

In [None]:
ood_config_dict = copy.deepcopy(
    config_dict["learner_config"]["dataset_config"]
)
ood_config_dict["dataset_kwargs"]["train"] = False
ood_config_dict["dataset_kwargs"]["num_sequences"] = num_test_tasks
ood_config = parse_dict(ood_config_dict)

ood_acc = evaluate(
    exp_name=exp_name,
    eval_name="ood",
    dataset_config=ood_config,
    seed=test_data_seed,
    num_tasks=num_test_tasks,
    max_label=CONST_AUTO,
    batch_size=num_samples_per_task,
    num_workers=num_workers,
    save=True,
)
print(ood_acc)

## N-shot 2-way

In [None]:
test_n_shot_2_way_config_dict = copy.deepcopy(config_dict["learner_config"]["dataset_config"])
test_n_shot_2_way_config_dict["dataset_kwargs"]["train"] = False
test_n_shot_2_way_config_dict["dataset_kwargs"]["task_name"] = CONST_MULTITASK_OMNIGLOT_N_SHOT_K_WAY
test_n_shot_2_way_config_dict["dataset_kwargs"]["task_config"]["p_bursty"] = 1.0
test_n_shot_2_way_config_dict["dataset_kwargs"]["task_config"]["k_way"] = 2
test_n_shot_2_way_config_dict["dataset_kwargs"]["num_sequences"] = num_test_tasks
test_n_shot_2_way_config = parse_dict(test_n_shot_2_way_config_dict)

test_n_shot_2_way_acc = evaluate(
    exp_name=exp_name,
    eval_name="test_n_shot_2_way",
    dataset_config=test_n_shot_2_way_config,
    seed=test_data_seed,
    num_tasks=num_test_tasks,
    max_label=2,
    batch_size=num_samples_per_task,
    num_workers=num_workers,
    save=True,
)
print(test_n_shot_2_way_acc)

# Final Results

In [None]:
all_accs = OrderedDict(
    pretraining=pretrain_acc,
    same_pretraining=same_pretraining_acc,
    in_weight=in_weight_acc,
    pretrain_n_shot_2_way=pretrain_n_shot_2_way_acc,
    ood=ood_acc,
    test_n_shot_2_way=test_n_shot_2_way_acc,
)

all_accs