In [None]:
import comet_ml
import torch
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from torch.utils.data import DataLoader, random_split

from datasets import Poly
from models import PolyModelSmall
from ensemble import predict

load_dotenv()
comet_ml.login()
api = comet_ml.API()

In [None]:
def get_experiment_keys(name: str) -> list[str]:
    """
    Returns all experiment keys for a given name

    Args:
        name (str): experiment name

    Returns:
        list[str]: experiment keys
    """

    keys = []

    for experiment in api.get("nichlas-jacobs/pdl-hw2"):
        if experiment.name == name:
            keys.append(experiment.key)

    print(f"Retrieved {len(keys)} keys")
    return keys


keys = get_experiment_keys("Poly_Trad_Ensemble_Member_300_epochs")

In [None]:
# Load the models
device = torch.device("cuda:1")
models = []

for key in keys:
    model = PolyModelSmall(input_dim=1)
    model.load_state_dict(torch.load(f"../checkpoints/{key}.pt"))
    models.append(model.to(device))

In [None]:
ds = Poly(10000)
train_ds, test_ds = random_split(ds, [0.8, 0.2])
train_loader = DataLoader(train_ds, batch_size=2000, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=2000, shuffle=False)

In [None]:
x_tensor = torch.linspace(-2, 2, 100).reshape(-1, 1).to(device)
x = x_tensor.to("cpu").squeeze().numpy()
results = predict(x_tensor, models)

In [None]:
plt.plot(x, results.mean.to("cpu").squeeze().numpy(), label="mean")
plt.plot(x, results.min.to("cpu").squeeze().numpy(), label="min")
plt.plot(x, results.max.to("cpu").squeeze().numpy(), label="max")
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title(f"Ensemble on Poly Dataset {len(keys)} Members")
plt.legend()