In [None]:
%load_ext autoreload
%autoreload 2
from argparse import Namespace

import numpy as np
import pandas as pd
import torch

from config import Config
from dataset.utils import setup_dataloaders
from evaluation.utils import load_model
from models.loss import NLL


In [None]:
# Load the models (update paths to your runs)
save_dirs = {
    "NP": "../saves/INPs_atom3d_lba_poc/np_atom3d_0",
    "INP": "../saves/INPs_atom3d_lba_poc/inp_atom3d_0",
    "cINP": "../saves/INPs_atom3d_lba_poc/clinp_atom3d_0",
}

model_dict = {}
config_dict = {}

for model_name, save_dir in save_dirs.items():
    model_dict[model_name], config_dict[model_name] = load_model(
        save_dir, load_it="best"
    )
    model_dict[model_name].eval()


In [None]:
# Build the evaluation dataloader
meta = torch.load("../data/atom3d-lba-pocket-poc/tasks.pt")["meta"]
eval_cfg = Namespace(
    dataset="atom3d-lba-pocket-poc",
    batch_size=16,
    min_num_context=0,
    max_num_context=10,
    x_sampler="uniform",
    noise=0.0,
    input_dim=meta["x_dim"],
    output_dim=1,
    knowledge_type="none",
)
_, _, test_loader, _ = setup_dataloaders(eval_cfg)


In [None]:
loss_fn = NLL(reduction="none")
num_context_ls = [0, 1, 3, 5, 10]

def eval_model(model, data_loader, num_context, use_knowledge):
    nll_vals = []
    mse_vals = []
    for batch in data_loader:
        (x_context, y_context), (x_target, y_target), knowledge, _ = batch
        x_target = x_target.to(model.config.device)
        y_target = y_target.to(model.config.device)
        if num_context > 0:
            idx = np.random.choice(x_target.shape[1], num_context, replace=False)
            x_context = x_target[:, idx, :]
            y_context = y_target[:, idx, :]
        else:
            x_context = x_target[:, :0, :]
            y_context = y_target[:, :0, :]

        if use_knowledge:
            if isinstance(knowledge, torch.Tensor):
                knowledge = knowledge.to(model.config.device)
            outputs = model(
                x_context, y_context, x_target, y_target=y_target, knowledge=knowledge
            )
        else:
            outputs = model(
                x_context, y_context, x_target, y_target=y_target, knowledge=None
            )

        loss_val, _, _ = loss_fn.get_loss(
            outputs[0], outputs[1], outputs[2], outputs[3], y_target
        )
        nll_vals.append(loss_val.mean().item())

        pred_mean = outputs[0].mean.mean(dim=0)
        mse_vals.append(((pred_mean - y_target) ** 2).mean().item())

    return float(np.mean(nll_vals)), float(np.mean(mse_vals))

rows = []
for num_context in num_context_ls:
    np_nll, np_mse = eval_model(model_dict["NP"], test_loader, num_context, False)
    inp_nll, inp_mse = eval_model(model_dict["INP"], test_loader, num_context, True)
    cinp_nll, cinp_mse = eval_model(model_dict["cINP"], test_loader, num_context, True)
    rows.append({
        "num_context": num_context,
        "NP_NLL": np_nll,
        "NP_MSE": np_mse,
        "INP_NLL": inp_nll,
        "INP_MSE": inp_mse,
        "cINP_NLL": cinp_nll,
        "cINP_MSE": cinp_mse,
    })

results_df = pd.DataFrame(rows)
results_df
