In [1]:
import torch
import torch.optim as optim

from utils.MNIST_loaders import MNIST_loaders

from functions.train import train
from functions.evaluate import evaluate

from models.HypernetClassifier import HyperNetClassifier
from models.Classifier import Classifier

In [None]:
# --- Hyperparameters ---
num_epochs = 5
batch_size = 64
weight_decay = 0.001
validation_split = 0.2  # Part of the training data to use for validation
random_seed = 42
n_bins = 10 # Number of bins for calculating ECE
input_size = 28 * 28
output_size = 10

learning_rate = 0.0003
optimizer = 'SGD'
hidden_size = [256, 128, 256]
zeroed_weights_in_baseline = 0.5
hypernet_ensemble_num = 10
use_previous_weights = False

In [3]:
torch.manual_seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
model = Classifier(input_size, output_size, zeroed_weights_fraction=zeroed_weights_in_baseline)
hypernet_model = HyperNetClassifier(input_size, output_size, hidden_sizes=hidden_size, device=device, use_previous_weights=use_previous_weights, ensemble_num=hypernet_ensemble_num)

if optimizer == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    hypernet_optimizer = optim.SGD(hypernet_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
else:
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    hypernet_optimizer = optim.Adam(hypernet_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
train_loader, val_loader, test_loader = MNIST_loaders()

print("Starting Training [A: Accuracy, L: Loss, ECE: Expected Calibration Error, H: Hypernet]")
for epoch in range(1, num_epochs + 1):
    
    print(f"\n=== Epoch: [{epoch}|{num_epochs}] ========================")
    
    avg_loss, avg_hypernet_loss, accuracy, hypernet_accuracy = train(model, hypernet_model, train_loader, optimizer, hypernet_optimizer, device=device)
    print(f'Train: L: {avg_loss:.4f}, H_L: {avg_hypernet_loss:.4f}, A: {accuracy:.2f}%, H_A: {hypernet_accuracy:.2f}%')
    
    val_accuracy, ece = evaluate(model, val_loader, n_bins=n_bins, device=device)
    hypernet_val_accuracy, hypernet_ece = evaluate(hypernet_model, val_loader, device=device)
    print(f'Val: A: {val_accuracy:.2f}%, A_H: {hypernet_val_accuracy:.2f}%, ECE: {ece:.5f}, ECE_H: {hypernet_ece:.5f}')

Starting Training [A: Accuracy, L: Loss, ECE: Expected Calibration Error, H: Hypernet]

Train: L: 1.5616, H_L: 11.3475, A: 57.81%, H_A: 60.53%
Val: A: 74.90%, A_H: 87.61%, ECE: 0.32406, ECE_H: 0.07129

Train: L: 0.9127, H_L: 6.6120, A: 77.96%, H_A: 77.39%
Val: A: 80.01%, A_H: 88.79%, ECE: 0.21491, ECE_H: 0.06259

Train: L: 0.7057, H_L: 6.4858, A: 81.57%, H_A: 78.15%


KeyboardInterrupt: 

In [None]:
# --- Testing the Model ---
print("\nStarting Testing...")

print("Normal Model:")
test_accuracy, ece = evaluate(model, test_loader, n_bins=n_bins, device=device)
print(f'Test set: Accuracy: {test_accuracy:.2f}%, ECE: {ece:.5f}')

print("Hypernet Model:")
test_accuracy, ece = evaluate(hypernet_model, test_loader, n_bins=n_bins, device=device)
print(f'Test set: Accuracy: {test_accuracy:.2f}%, ECE: {ece:.5f}')


Starting Testing...
Normal Model:
Test set: Accuracy: 91.31%, ECE: 0.04334
Hypernet Model:
Test set: Accuracy: 91.07%, ECE: 0.01071
