In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, random_split
import glob

from datasets import Poly
from models import ModelSmall, ModelMicro
from ensemble import predict, evaluate, evaluate_ensemble_components

# Ensemble Member Summary

In [None]:
experiment_key = "4354b05f29f049fcaafbfcf236ba01e2"
checkpoints = glob.glob(f"../checkpoints/{experiment_key}_fge*", recursive=False)

print(f"Found {len(checkpoints)} FGE checkpoints")
for checkpoint in checkpoints:
    print(checkpoint)

# Load Dataset

In [None]:
ds = Poly(10000)
train_ds, test_ds = random_split(ds, [0.8, 0.2])
train_loader = DataLoader(train_ds, batch_size=2000, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=2000, shuffle=False)

# Load the Models

In [None]:
device = torch.device("cuda:1")
models = []

for checkpoint in checkpoints:
    model = ModelSmall(input_dim=1, hidden_dim=64)
    model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
    models.append(model.to(device))

# Performance of each model in ensemble

In [None]:
df_post = evaluate_ensemble_components(train_loader, test_loader, device, models)

print("All Models")
print(df_post)

print("\ntrain_loss description:")
print(df_post["train_loss"].describe())

print("\nval_loss description:")
print(df_post["test_loss"].describe())

# Ensemble Performance

In [None]:
print("Train performance: MSE_mean, MSE_median, average SD")
print(evaluate(train_loader, device, models))

print("Test performance: MSE_mean, MSE_median, average SD")
print(evaluate(test_loader, device, models))

In [None]:
def ensemble_performance_table(loader):
    data = {"Size": [], "MSE with Mean": [], "MSE with Median": [], "SD": []}

    for ensemble in [models, models[:5], models[:3]]:
        a, b, c = evaluate(loader, device, ensemble)
        data["Size"].append(len(ensemble))
        data["MSE with Mean"].append(a)
        data["MSE with Median"].append(b)
        data["SD"].append(c)

    return pd.DataFrame(data)


print(ensemble_performance_table(train_loader))
print(ensemble_performance_table(test_loader))

# Visualize Output

In [None]:
x_tensor = torch.linspace(-2, 2, 100).reshape(-1, 1).to(device)
x = x_tensor.to("cpu").squeeze().numpy()
results = predict(x_tensor, models)

In [None]:
plt.plot(x, results.mean.to("cpu").squeeze().numpy(), label="mean")
plt.plot(x, results.min.to("cpu").squeeze().numpy(), "--", label="min")
plt.plot(x, results.max.to("cpu").squeeze().numpy(), "--", label="max")
plt.plot(x, (x - 1) * (x + 1) * x, "k-", label="truth")
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title(f"Ensemble on Poly Dataset {len(checkpoints)} Members")
plt.legend()
plt.show()

In [None]:
plt.plot(x, results.sd.to("cpu").squeeze().numpy())
plt.xlabel("x")
plt.ylabel("Standard Deviation")
plt.show()