In [None]:
%load_ext autoreload
%autoreload 2
from utils import load_model
from argparse import Namespace

from dataset.dataset import SetKnowledgeTrendingSinusoidsDistShift
from dataset.utils import get_dataloader
from evaluation.utils import get_summary_df
from models.loss import NLL

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import seaborn as sns

plt.style.use("science")
sns.set_style("whitegrid")
sns.set_palette("Dark2")
plt.rcParams["text.latex.preamble"] = (
    "\\usepackage{lmodern} \\usepackage{times} \\usepackage{amssymb}"
)

In [None]:
# Load the models
save_dirs = {
    "NP": "../saves/INPs_sinusoids/np_dist_shift_1",
    "INP": "../saves/INPs_sinusoids/inp_b_dist_shift_1",
    # Update the path below to the contrastive run you want to compare.
    "INP (contrastive)": "../saves/INPs_sinusoids/inp_b_dist_shift_contrastive_1",
}

models = list(save_dirs.keys())
model_dict = {}
config_dict = {}

for model_name, save_dir in save_dirs.items():
    model_dict[model_name], config_dict[model_name] = load_model(
        save_dir, load_it="best"
    )
    model_dict[model_name].eval()

model_names = list(model_dict.keys())


In [None]:
# Setup the dataloaders
config = Namespace(
    min_num_context=0,
    max_num_context=100,
    num_targets=100,
    noise=0.2,
    batch_size=25,
    x_sampler="uniform",
    test_num_z_samples=32,
    dataset="set-trending-sinusoids-dist-shift",
    device="cuda:0",
)

test_dataset = SetKnowledgeTrendingSinusoidsDistShift(
    root="../data/trending-sinusoids-dist-shift", split="test", knowledge_type="b"
)
test_data_loader = get_dataloader(test_dataset, config)

train_dataset = SetKnowledgeTrendingSinusoidsDistShift(
    root="../data/trending-sinusoids-dist-shift", split="train", knowledge_type="b"
)
train_data_loader = get_dataloader(train_dataset, config)

In [None]:
def uniform_sampler(num_targets, num_context):
    return np.random.choice(list(range(num_targets)), num_context, replace=False)


def _randperm_no_fixed(bs, device):
    perm = torch.randperm(bs, device=device)
    fixed = perm == torch.arange(bs, device=device)
    if fixed.any():
        perm[fixed] = (perm[fixed] + 1) % bs
    return perm


def _shuffle_knowledge(knowledge, perm):
    if isinstance(knowledge, torch.Tensor):
        return knowledge[perm]
    k_list = list(knowledge)
    perm_cpu = perm.detach().cpu().tolist()
    return [k_list[i] for i in perm_cpu]


def make_fully_mismatched_knowledge(knowledge, device):
    if knowledge is None:
        return None
    bs = len(knowledge) if not isinstance(knowledge, torch.Tensor) else knowledge.shape[0]
    if bs < 2:
        return knowledge
    perm = _randperm_no_fixed(bs, device=device)
    return _shuffle_knowledge(knowledge, perm)


def get_summary_df_shuffled(
    model_dict,
    config_dict,
    data_loader,
    eval_type_ls,
    model_names,
    sampler=uniform_sampler,
):
    loss = NLL(reduction="none")

    losses = {}
    outputs_dict = {}

    num_context_ls = [0, 1, 3, 5, 10, 15]
    for model_name in model_names:
        losses[model_name] = {}
        outputs_dict[model_name] = {}
        for eval_type in eval_type_ls:
            losses[model_name][eval_type] = {}
            outputs_dict[model_name][eval_type] = {}
            for num_context in num_context_ls:
                losses[model_name][eval_type][num_context] = []
                outputs_dict[model_name][eval_type][num_context] = []

    for model_name in model_names:
        model, config = model_dict[model_name], config_dict[model_name]

        for batch in data_loader:
            (x_context, y_context), (x_target, y_target), knowledge, extras = batch
            x_context = x_context.to(config.device)
            y_context = y_context.to(config.device)
            x_target = x_target.to(config.device)
            y_target = y_target.to(config.device)

            if isinstance(knowledge, torch.Tensor):
                knowledge = knowledge.to(config.device)

            shuffled_knowledge = make_fully_mismatched_knowledge(knowledge, config.device)

            for num_context in num_context_ls:
                for _ in range(3):
                    sample_idx = sampler(x_target.shape[1], max(num_context_ls))
                    x_context = x_target[:, sample_idx[:num_context], :]
                    y_context = y_target[:, sample_idx[:num_context], :]

                    for eval_type in eval_type_ls:
                        with torch.no_grad():
                            if eval_type == "raw":
                                outputs = model(
                                    x_context,
                                    y_context,
                                    x_target,
                                    y_target=y_target,
                                    knowledge=None,
                                )
                            elif config.use_knowledge and eval_type == "shuffled":
                                outputs = model(
                                    x_context,
                                    y_context,
                                    x_target,
                                    y_target=y_target,
                                    knowledge=shuffled_knowledge,
                                )
                            else:
                                continue
                            outputs = tuple(
                                [o.cpu() if isinstance(o, torch.Tensor) else o for o in outputs]
                            )
                            loss_value, _, _ = loss.get_loss(
                                outputs[0], outputs[1], outputs[2], outputs[3], y_target
                            )
                            losses[model_name][eval_type][num_context].append(loss_value)
                            outputs_dict[model_name][eval_type][num_context].append(
                                {
                                    "outputs": outputs,
                                    "x_context": x_context.cpu(),
                                    "y_context": y_context.cpu(),
                                    "x_target": x_target.cpu(),
                                    "y_target": y_target.cpu(),
                                    "knowledge": knowledge,
                                    "shuffled_knowledge": shuffled_knowledge,
                                }
                            )

    loss_summary = {}
    for model_name in model_names:
        loss_summary[model_name] = {}
        for eval_type in eval_type_ls:
            loss_summary[model_name][eval_type] = {}
            for num_context in num_context_ls:
                loss_summary[model_name][eval_type][num_context] = {}
                loss_values = losses[model_name][eval_type][num_context]
                if len(loss_values) == 0:
                    loss_summary[model_name][eval_type][num_context]["mean"] = np.nan
                    loss_summary[model_name][eval_type][num_context]["std"] = np.nan
                else:
                    loss_values = [lv[0] for lv in loss_values]
                    loss_values = torch.stack(loss_values, dim=0)
                    loss_summary[model_name][eval_type][num_context]["median"] = (
                        torch.median(loss_values).item()
                    )
                    loss_summary[model_name][eval_type][num_context]["mean"] = (
                        torch.mean(loss_values).item()
                    )
                    loss_summary[model_name][eval_type][num_context]["std"] = torch.std(
                        loss_values
                    ).item()

    loss_summary_df = pd.DataFrame.from_dict(
        {(i, j, k): loss_summary[i][j][k] for i in loss_summary for j in loss_summary[i] for k in loss_summary[i][j]},
        orient="index",
    )
    loss_summary_df["model_name"] = [idx[0] for idx in loss_summary_df.index]
    loss_summary_df["eval_type"] = [idx[1] for idx in loss_summary_df.index]
    loss_summary_df["num_context"] = [idx[2] for idx in loss_summary_df.index]
    loss_summary_df.reset_index(drop=True, inplace=True)

    return loss_summary_df, losses, outputs_dict


In [None]:
eval_type_ls = ["raw", "shuffled"]

train_summary_df, _, train_output_dict = get_summary_df_shuffled(
    model_dict, config_dict, train_data_loader, eval_type_ls, model_names
)
test_summary_df, _, test_output_dict = get_summary_df_shuffled(
    model_dict, config_dict, test_data_loader, eval_type_ls, model_names
)


In [None]:
train_summary_df["split"] = "train"
test_summary_df["split"] = "test"

plot_df = pd.concat([train_summary_df, test_summary_df])
plot_df = plot_df[
    ((plot_df.model_name == "INP") & (plot_df.eval_type == "shuffled"))
    | ((plot_df.model_name == "INP (contrastive)") & (plot_df.eval_type == "shuffled"))
    | ((plot_df.model_name == "NP") & (plot_df.eval_type == "raw"))
]

plot_df["mean"] = -plot_df["mean"]
label_map = {
    "NP": r"NP: $\mathcal{K} = \varnothing$",
    "INP": r"INP: shuffled $\mathcal{K}$",
    "INP (contrastive)": r"INP (contrastive): shuffled $\mathcal{K}$",
}
plot_df["model_label"] = plot_df["model_name"].map(label_map)

fig, ax = plt.subplots(figsize=(3.6, 3.5))
sns.lineplot(
    plot_df,
    x="num_context",
    y="mean",
    hue="model_label",
    style="split",
    palette=["C2", "C4", "C1"],
    markers=True,
)

ax.set_ylabel("Negative Log-likelihood")
ax.set_xlabel("Number of context datapoints")

plt.tight_layout()
# plt.savefig("../figures/exp-2-shuffled.pdf", bbox_inches="tight")
plt.show()


In [None]:
loss = NLL()


def get_loss_bs(output_dict, model_name="INP", eval_type="informed", num_context=0):
    bs_ls = []
    this_loss_ls = []
    for batch_idx in range(len(output_dict[model_name][eval_type][num_context])):
        outputs = output_dict[model_name][eval_type][num_context][batch_idx]["outputs"]
        y_target = output_dict[model_name][eval_type][num_context][batch_idx][
            "y_target"
        ]
        knowledge = output_dict[model_name][eval_type][num_context][batch_idx][
            "knowledge"
        ].cpu()
        bs = knowledge[:, 0, 3]

        this_loss, _, _ = loss.get_loss(
            outputs[0], outputs[1].cuda(), outputs[2], outputs[3], y_target.cuda()
        )
        this_loss = this_loss.cpu()
        this_loss_ls.append(this_loss)
        bs_ls.append(bs)

    bs = torch.cat(bs_ls)
    this_loss = torch.cat(this_loss_ls)

    return this_loss, bs


train_loss_informed, train_bs_informed = get_loss_bs(
    train_output_dict, eval_type="informed"
)
test_loss_informed, test_bs_informed = get_loss_bs(
    test_output_dict, eval_type="informed"
)
test_loss_raw, test_bs_raw = get_loss_bs(
    test_output_dict, eval_type="raw", model_name="NP"
)


bins = np.linspace(-0.5, 6, 10)

raw_df = pd.DataFrame({"b": test_bs_raw, "loss": torch.log(test_loss_raw)})
raw_df["bin"] = pd.cut(raw_df["b"], bins=bins)
raw_df["eval_type"] = "raw"
informed_df = pd.DataFrame(
    {"b": test_bs_informed, "loss": torch.log(test_loss_informed)}
)
informed_df["bin"] = pd.cut(raw_df["b"], bins=bins)
informed_df["eval_type"] = "informed"
all_df = pd.concat([raw_df, informed_df]).reset_index(drop=True)


fig, axs = plt.subplots(2, 1, figsize=(3, 3.5), sharex=True, height_ratios=[2, 1])

sns.regplot(
    data=raw_df,
    label=r"NP: $\mathcal{K} = \varnothing$",
    ax=axs[0],
    x_ci="sd",
    x="b",
    y="loss",
    x_bins=bins,
    fit_reg=False,
    color="C2",
)
sns.regplot(
    data=informed_df,
    label=r"INP: $\mathcal{K} \neq \varnothing$",
    ax=axs[0],
    x_ci="sd",
    x="b",
    y="loss",
    x_bins=bins,
    fit_reg=False,
    color="C4",
)


axs[1].hist(
    train_bs_informed, color="grey", alpha=0.8, bins=bins, density=True, align="left"
)
axs[0].legend()
axs[0].legend(
    handletextpad=0.05,
    loc="lower right",
    facecolor="white",
    framealpha=0.8,
    frameon=True,
)
axs[0].set_ylabel("test log(loss)")
axs[1].set_ylabel("\% of training tasks")
axs[0].set_xlabel("")
axs[1].set_xlabel("Value of $b$")
plt.tight_layout()
# plt.savefig('../figures/ood_details.pdf', bbox_inches='tight')
plt.show()