In [1]:
import sys

sys.path.insert(0, "..")
sys.path.insert(0, "../src")

from experiments.utils import *

from src.constants import *
from src.dataset import get_data_loader
from src.utils import parse_dict, load_config, iterate_models, set_seed

from tqdm.notebook import tqdm

import _pickle as pickle
import json
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import numpy as np
import os
import seaborn as sns
import pandas as pd
import timeit

In [2]:
repo_path = "/Users/chanb/research/ualberta/icl/simple_icl"
results_dir = "/Users/chanb/research/ualberta/icl/cc_results/paper_experiments/evaluation_results"
templates_dir = os.path.join(repo_path, "cc_utils", "templates")

In [69]:
variant_name = "omniglot-input_noise"
# variant_name = "omniglot-p_relevant"

seed = 0
dataset_size = 10000
p_relevant_context = 0.9
input_noise_std = 0.1

In [70]:
stats_file = os.path.join(repo_path, "plot_utils/plots/agg_stats", "{}.feather".format(variant_name))
stats = pd.read_feather(stats_file)

In [None]:
stats

In [72]:
set_seed(seed)

In [73]:
template_path = os.path.join(templates_dir, "{}.json".format(variant_name))

In [None]:
template_path

In [75]:
config_dict = json.load(open(template_path))

In [None]:
config_dict

In [77]:
config_dict["seeds"]["data_seed"] = seed
config_dict["dataset_kwargs"]["dataset_size"] = dataset_size
config_dict["dataset_kwargs"]["p_relevant_context"] = p_relevant_context
config_dict["dataset_kwargs"]["input_noise_std"] = input_noise_std
config_dict["dataset_kwargs"]["exemplar"] = "heldout"


In [78]:
config = parse_dict(config_dict)

In [None]:
loader, dataset = get_data_loader(config)

In [None]:
np.mean(dataset.targets[:, -1] >= 20)

In [None]:
np.unique(dataset.targets[:, -1], return_counts=True)[1][20:]

In [None]:
plt.bar(np.arange(0, 1603), np.unique(dataset.targets[:, -1], return_counts=True)[1][20:])
plt.title("{}".format(dataset_size))

In [None]:
np.min(np.unique(dataset.targets[:, -1], return_counts=True)[1][20:]), np.max(np.unique(dataset.targets[:, -1], return_counts=True)[1][20:])

In [None]:
plt.bar(np.arange(0, 1603), np.unique(dataset.targets[:, -1], return_counts=True)[1][20:])
plt.title("{}".format(dataset_size))

In [None]:
np.min(np.unique(dataset.targets[:, -1], return_counts=True)[1][20:]), np.max(np.unique(dataset.targets[:, -1], return_counts=True)[1][20:])

In [None]:
np.mean(dataset.targets[:, -1] >= 20)

In [14]:
batch_size = config.batch_size
# checkpoint_interval = config.logging_config.checkpoint_interval
# num_epochs = config.num_epochs
num_epochs = 10000
checkpoint_interval = 1
num_high_freq_class = config.dataset_kwargs.num_high_prob_classes

In [15]:
# 1292 5 20 1192 19280 13180


In [None]:
batches = []
for epoch_i in tqdm(range(num_epochs)):
    batch = next(loader)
    if (epoch_i + 1) % checkpoint_interval == 0:
        target = batch["target"]
        labels = np.argmax(target, axis=-1)

        num_relevant_contexts = np.sum(labels[:, :-1] == labels[:, [-1]], axis=-1)
        high_freq_classes = labels[:, -1] <= num_high_freq_class
        batches.append(dict(
            num_relevant_contexts=num_relevant_contexts.astype(np.uint8),
            targets=labels[:, -1].astype(np.uint16),
        ))

In [None]:
batches

In [18]:
# out_dir = os.path.join(os.path.dirname(results_dir), "training_info")
# os.makedirs(out_dir, exist_ok=True)
# pickle.dump(
#     batches,
#     open(
#         os.path.join(out_dir, "{}-seed_{}-dataset_size_{}-p_relevant_context_{}-input_noise_std_{}.pkl".format(
#             variant_name,
#             seed,
#             dataset_size,
#             p_relevant_context,
#             input_noise_std,
#         )),
#         "wb"
#     )
# )

In [27]:
batches = {key: [i[key].tolist() for i in batches] for key in batches[0]}

In [33]:
num_relevant_contexts = np.array(batches["num_relevant_contexts"])
targets = np.array(batches["targets"])

In [None]:
num_relevant_contexts.shape

In [39]:
low_freq_classes = targets > 20

In [40]:
num_low_freq_classes = np.sum(low_freq_classes, axis=0)

In [None]:
np.cumsum(num_low_freq_classes) / (np.arange(1, len(num_low_freq_classes) + 1) * batch_size)

In [None]:
targets

In [48]:
unique_targets, counts = np.unique(targets, return_counts=True)

In [49]:
low_freq_class_idxes = unique_targets > 20

In [None]:
counts[low_freq_class_idxes]