In [1]:
import sys

sys.path.insert(0, "..")
sys.path.insert(0, "../src")

In [2]:
import dill
import jax.numpy as jnp
import os

from tqdm.notebook import tqdm

from experiments.evaluation import get_eval_datasets
from experiments.utils import *

from src.constants import *
from src.dataset import get_data_loader
from src.models import SimpleICLModel
from src.utils import parse_dict, load_config, iterate_models, set_seed

run_seed = 0
set_seed(run_seed)

2024-08-29 23:38:18.903941: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-29 23:38:18.903978: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-29 23:38:18.904019: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# learner_path = "/Users/chanb/research/ualberta/simple_icl/experiments/results/simple_icl_model/seed_46-high_prob_0.5-08-28-24_19_06_10-f70cfe1c-05d3-4ac9-adb2-45174536cbf6"
learner_path = "/home/chanb/scratch/simple_icl/results/simple_icl-g-high_prob_0.5/ground_truth_prob_0.0-seed_0-08-28-24_23_12_31-2f92ffd0-3338-486c-b2cb-370a0e0143ba/"
num_eval_samples = 1000
batch_size = 100
test_data_seed = 1000

prefetched_data_path = "eval_prefetched_data.dill"

# Fetch Datasets

In [5]:
if os.path.isfile(prefetched_data_path):
    prefetched_data = dill.load(open(prefetched_data_path, "rb"))
else:
    config_dict, config = load_config(learner_path)
    config_dict["batch_size"] = batch_size
    config = parse_dict(config_dict)

    context_len = config.dataset_kwargs.num_examples
    fixed_length = True

    datasets, dataset_configs = get_eval_datasets(
        config_dict,
        test_data_seed,
        context_len,
    )

    train_data_loader, train_dataset = get_data_loader(
        config,
    )
    datasets["pretraining"] = (
        train_data_loader,
        train_dataset,
    )
    dataset_configs["pretraining"] = config.dataset_kwargs

    prefetched_data = {}
    # for eval_name in tqdm(datasets, postfix="Prefetching data"):
    for eval_name in datasets:
        data_loader, dataset = datasets[eval_name]
        data_iter = iter(data_loader)
        prefetched_data[eval_name] = dict(
            samples=[next(data_iter) for _ in range(num_eval_samples // batch_size)],
            dataset_output_dim=dataset.output_space.n,
        )

    dill.dump(prefetched_data, open(prefetched_data_path, "wb"))

# Evaluation

In [22]:
model_iter = iterate_models(learner_path)
for _ in range(20):
    params, model, checkpoint_step = next(model_iter)
eval_names = list(prefetched_data.keys())
print(eval_names)

['pretrain-sample_high_prob_class_only-start_pos_1-flip_label', 'pretrain-sample_low_prob_class_only-start_pos_1-flip_label', 'test-sample_high_prob_class_only-start_pos_1-flip_label', 'test-sample_low_prob_class_only-start_pos_1-flip_label', 'pretraining']


In [23]:
checkpoint_step

1900

In [8]:
eval_name = eval_names[0]

In [24]:
preds, labels, outputs, model_auxes = get_preds_labels(
    SimpleICLModel(1.0, "l2"), params, prefetched_data[eval_name], None
)

In [25]:
preds.shape

(1000,)

In [26]:
model_auxes["h"][0], model_auxes["p_iwl"][0]

(array([[-3.962206  ],
        [-3.9725022 ],
        [-3.9662254 ],
        [-3.9660006 ],
        [-3.96836   ],
        [-3.97082   ],
        [-3.964296  ],
        [-0.03334247]], dtype=float32),
 array([0.36500788], dtype=float32))

In [28]:
jnp.sum(
    jax.nn.softmax(model_auxes["h"][:100] / 0.1, axis=1) * prefetched_data[eval_name]["samples"][0]["target"][:, :-1],
    axis=1,
)[0]

Array([5.765116e-17, 1.000000e+00], dtype=float32)

In [31]:
context_targets = prefetched_data[eval_name]["samples"][0]["target"][:, :-1]
flip_labels = prefetched_data[eval_name]["samples"][0]["flip_label"][:, None]

In [32]:
context_targets[0]

array([[1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [0, 1]], dtype=int32)

In [29]:
np.concatenate((model_auxes["iw_pred"], model_auxes["ic_pred"]), axis=-1)[0]

array([1.000000e+00, 0.000000e+00, 1.000000e+00, 5.765116e-17],
      dtype=float32)

In [None]:
ground_truth_prob = 1.0
jnp.clip(
    jnp.full_like(
        np.eye(2)[labels], fill_value=((1 - ground_truth_prob) / (np.eye(2)[labels].shape[-1] - 1))
    )
    + np.eye(2)[labels],
    a_min=0.0,
    a_max=ground_truth_prob,
)

In [None]:
model_auxes["iw_pred"] - np.eye(2)[labels]

In [None]:
model_auxes.keys()

In [None]:
model_auxes["ic_pred"]

In [None]:
model_auxes["h"][..., 0].max(), model_auxes["h"][..., 0].min()

In [None]:
prefetched_data[eval_name]["samples"][0]["target"][43]

In [None]:
np.argmax(prefetched_data[eval_name]["samples"][2]["target"][:, -2], axis=-1)

In [None]:
model_auxes["ic_pred"]

In [None]:
model_auxes["h"][0, :, 0]

In [None]:
prefetched_data[eval_name]["samples"][0]["example"][0].shape

In [None]:
context_inputs = prefetched_data[eval_name]["samples"][0]["example"][0, :-1]
queries = prefetched_data[eval_name]["samples"][0]["example"][0, [-1]]

In [None]:
jnp.exp(
    -jnp.sum((context_inputs - queries) ** 2, axis=-1, keepdims=True) / 1e-5
)

In [None]:
jnp.sum((context_inputs - queries) ** 2, axis=-1, keepdims=True)

In [None]:
outputs

In [None]:
(1 - model_auxes["p_iwl"]) * model_auxes["ic_pred"] + model_auxes["p_iwl"] * model_auxes["iw_pred"]

In [None]:
model_auxes["p_iwl"]

In [None]:
model_auxes["ic_pred"]

In [None]:
model_auxes["iw_pred"]

In [None]:
last_context_labels = []
for samples in prefetched_data[eval_name]["samples"]:
    last_context_labels.append(
        np.argmax(samples["target"][:, -2], axis=-1)
    )

In [None]:
np.mean(np.concatenate(last_context_labels))

In [None]:
np.mean(preds)

In [None]:
mismatch_inds = np.where(preds != np.concatenate(last_context_labels))

In [None]:
np.concatenate(last_context_labels)[mismatch_inds]

In [None]:
model_outs = (1 - model_auxes["p_iwl"]) * model_auxes["ic_pred"] + model_auxes["p_iwl"] * model_auxes["iw_pred"]

In [None]:
preds[mismatch_inds].shape

In [None]:
mismatch_inds = np.where(preds != 0)

In [None]:
len(model_auxes["p_iwl"])

In [None]:
np.mean(model_auxes["p_iwl"] >= 0.5)

In [None]:
np.concatenate((
    model_auxes["p_iwl"][mismatch_inds],
    np.argmax(model_auxes["iw_pred"][mismatch_inds], axis=-1, keepdims=True),
    np.argmax(model_auxes["ic_pred"][mismatch_inds], axis=-1, keepdims=True),
    preds[mismatch_inds][:, None]
    # np.argmax(model_outs[mismatch_inds], axis=-1, keepdims=True),
    # model_outs[mismatch_inds],
), axis=1)

In [None]:
np.argmax(model_auxes["iw_pred"][mismatch_inds], axis=-1).shape