In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

from utils.MNIST_loaders import MNIST_loaders

from functions.train import train
from functions.evaluate import evaluate

from models.HypernetClassifier import HyperNetClassifier
from models.Classifier import Classifier

In [2]:
# --- Hyperparameters (things you can easily change!) ---
num_epochs = 5
batch_size = 64
weight_decay = 0.001
validation_split = 0.2  # Part of the training data to use for validation
random_seed = 42
n_bins = 10 # Number of bins for calculating ECE
input_size = 28 * 28
output_size = 10

learning_rate = 0.00001
optimizer = 'Adam'
hidden_size = [256, 128, 256]
zeroed_weights_in_baseline = 0.5
hypernet_ensemble_num = 10

In [3]:
torch.manual_seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
model = Classifier(input_size, output_size, zeroed_weights_fraction=zeroed_weights_in_baseline)
hypernet_model = HyperNetClassifier(input_size, output_size, hidden_sizes=hidden_size, device=device)

if optimizer == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    hypernet_optimizer = optim.SGD(hypernet_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
else:
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    hypernet_optimizer = optim.Adam(hypernet_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
train_loader, val_loader, test_loader = MNIST_loaders()

print("Starting Training [A: Accuracy, L: Loss, ECE: Expected Calibration Error, H: Hypernet]")
for epoch in range(1, num_epochs + 1):
    
    print(f"\n=== Epoch: [{epoch}|{num_epochs}] ========================")
    
    avg_loss, avg_hypernet_loss, accuracy, hypernet_accuracy = train(model, hypernet_model, train_loader, optimizer, hypernet_optimizer, device=device)
    print(f'Train: L: {avg_loss:.4f}, H_L: {avg_hypernet_loss:.4f}, A: {accuracy:.2f}%, H_A: {hypernet_accuracy:.2f}%')
    
    val_accuracy, ece = evaluate(model, val_loader, n_bins=n_bins, device=device)
    hypernet_val_accuracy, hypernet_ece = evaluate(hypernet_model, val_loader, device=device)
    print(f'Val: A: {val_accuracy:.2f}%, A_H: {hypernet_val_accuracy:.2f}%, ECE: {ece:.5f}, ECE_H: {hypernet_ece:.5f}')

Starting Training [A: Accuracy, L: Loss, ECE: Expected Calibration Error, H: Hypernet]

Train: L: 2.2141, H_L: 29.0413, A: 19.79%, H_A: 11.18%
Val: A: 29.06%, A_H: 13.77%, ECE: 0.11408, ECE_H: 0.62075

Train: L: 2.0028, H_L: 20.4021, A: 37.21%, H_A: 25.60%
Val: A: 44.71%, A_H: 42.48%, ECE: 0.25116, ECE_H: 0.39134

Train: L: 1.8275, H_L: 10.5531, A: 49.60%, H_A: 55.80%
Val: A: 54.50%, A_H: 63.50%, ECE: 0.31871, ECE_H: 0.22519

Train: L: 1.6736, H_L: 8.1213, A: 57.87%, H_A: 68.06%
Val: A: 60.85%, A_H: 70.17%, ECE: 0.35196, ECE_H: 0.17156

Train: L: 1.5394, H_L: 7.2429, A: 62.91%, H_A: 72.94%
Val: A: 64.29%, A_H: 73.99%, ECE: 0.35267, ECE_H: 0.14533


In [6]:
# --- Testing the Model ---
print("\nStarting Testing...")

print("Normal Model:")
test_accuracy, ece = evaluate(model, test_loader, n_bins=n_bins, device=device)
print(f'Test set: Accuracy: {test_accuracy:.2f}%, ECE: {ece:.5f}')

print("Hypernet Model:")
test_accuracy, ece = evaluate(hypernet_model, test_loader, n_bins=n_bins, device=device)
print(f'Test set: Accuracy: {test_accuracy:.2f}%, ECE: {ece:.5f}')


Starting Testing...
Normal Model:
Test set: Accuracy: 65.26%, ECE: 0.35863
Hypernet Model:
Test set: Accuracy: 75.23%, ECE: 0.13739
