In [None]:
import comet_ml
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from torch.utils.data import DataLoader, random_split

from datasets import Poly
from models import ModelSmall, ModelMicro
from ensemble import predict, evaluate

load_dotenv()
comet_ml.login()
api = comet_ml.API()

# Ensemble Member Summary

In [None]:
def get_ensemble_df(name: str) -> pd.DataFrame:
    data = {"key": [], "train_loss": [], "val_loss": []}

    for experiment in api.get("nichlas-jacobs/pdl-hw2"):
        if experiment.name == name:
            data["key"].append(experiment.key)
            data["train_loss"].append(
                float(experiment.get_metrics("train_loss")[-1]["metricValue"])
            )
            data["val_loss"].append(
                float(experiment.get_metrics("val_loss")[-1]["metricValue"])
            )

    return pd.DataFrame(data)


df = get_ensemble_df("Poly_Trad_Ensemble_Micro_Member_1000_16_dim")
keys = df["key"].to_list()

print("All Ensembles")
print(df)

print("\ntrain_loss description:")
print(df["train_loss"].describe())

print("\nval_loss description:")
print(df["val_loss"].describe())

# Load Dataset

In [None]:
ds = Poly(10000)
train_ds, test_ds = random_split(ds, [0.8, 0.2])
train_loader = DataLoader(train_ds, batch_size=2000, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=2000, shuffle=False)

# Load the Models

In [None]:
device = torch.device("cuda:1")
models = []

for key in keys:
    model = ModelMicro(input_dim=1, hidden_dim=16)
    model.load_state_dict(torch.load(f"../checkpoints/{key}.pt"))
    models.append(model.to(device))

# Ensemble Performance

In [None]:
print("Train performance: MSE, average SD")
print(evaluate(train_ds, device, models))

print("Test performance: MSE, average SD")
print(evaluate(test_ds, device, models))

# Visualize Output

In [None]:
x_tensor = torch.linspace(-2, 2, 100).reshape(-1, 1).to(device)
x = x_tensor.to("cpu").squeeze().numpy()
results = predict(x_tensor, models)

In [None]:
plt.plot(x, results.mean.to("cpu").squeeze().numpy(), label="mean")
plt.plot(x, results.min.to("cpu").squeeze().numpy(), label="min")
plt.plot(x, results.max.to("cpu").squeeze().numpy(), label="max")
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title(f"Ensemble on Poly Dataset {len(keys)} Members")
plt.legend()

In [None]:
plt.plot(x, results.sd.to("cpu").squeeze().numpy())