In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

from utils.MNIST_loaders import MNIST_loaders

from models.HypernetClassifier import HyperNetClassifier
from models.Classifier import Classifier

In [13]:
# --- Hyperparameters (things you can easily change!) ---
num_epochs = 10
learning_rate = 0.01
weight_decay = 0.001
batch_size = 64
validation_split = 0.2  # Percentage of the training data to use for validation
random_seed = 42      # For making sure our splits are the same each time

In [14]:
torch.manual_seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
def train(model, hypernet_model, train_loader, optimizer, hypernet_optimizer, epoch):
    model.train()  # Set the model to training mode
    hypernet_model.train()
    
    model = model.to(device)
    hypernet_model = hypernet_model.to(device)
    
    total_loss = 0
    correct = 0
    
    hypernet_total_loss = 0
    hypernet_correct = 0
    
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):

        optimizer.zero_grad()
        hypernet_optimizer.zero_grad()
        
        data = data.to(device)
        target = target.to(device)
        
        output = model(data)
        hypernet_output = hypernet_model(data)

        loss = nn.CrossEntropyLoss()(output, target)
        hypernet_loss = nn.CrossEntropyLoss()(hypernet_output, target)

        loss.backward()
        hypernet_loss.backward()

        optimizer.step()
        hypernet_optimizer.step()

        total_loss += loss.item()
        hypernet_total_loss += hypernet_loss.item()
        
        _, predicted = torch.max(output.data, 1)
        _, hypernet_predicted = torch.max(hypernet_output.data, 1)
        
        total += target.size(0)
        
        correct += (predicted == target).sum().item()
        hypernet_correct += (hypernet_predicted == target).sum().item()

        if (batch_idx + 1) % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}. Hypernet Loss: {hypernet_loss.item():.6f}')

    avg_loss = total_loss / len(train_loader)
    avg_hypernet_loss = hypernet_total_loss / len(train_loader)
    
    accuracy = 100. * correct / total
    hypernet_accuracy = 100. * hypernet_correct / total
    
    print(f'Train Epoch: {epoch} Average Loss: {avg_loss:.4f}, Average Hypernet Loss: {avg_hypernet_loss:.4f}, Accuracy: {accuracy:.2f}%, Hypernet Accuracy: {hypernet_accuracy:.2f}%')
    return avg_loss, avg_hypernet_loss, accuracy, hypernet_accuracy

In [16]:
def evaluate(model, data_loader):
    model.eval()
    model = model.to(device)

    total_loss = 0
    correct = 0
    total = 0

    # For ECE calculation
    n_bins = 10  # Since it's a 10-class classification task
    bin_boundaries = torch.linspace(0, 1, n_bins + 1).to(device)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    bin_corrects = torch.zeros(n_bins).to(device)
    bin_totals = torch.zeros(n_bins).to(device)
    bin_confidences = torch.zeros(n_bins).to(device)

    with torch.no_grad():  # Disable gradient calculations during evaluation
        for data, target in data_loader:
            data = data.to(device)
            target = target.to(device)

            output = model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            total_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            # Calculate probabilities and confidences for ECE
            probabilities = torch.softmax(output, dim=1)
            confidences, predictions = torch.max(probabilities, dim=1)
            accuracies = predictions.eq(target)

            for i in range(n_bins):
                in_bin = (confidences >= bin_lowers[i]) & (confidences < bin_uppers[i])
                bin_totals[i] += in_bin.sum().item()
                bin_corrects[i] += (accuracies[in_bin]).sum().item()
                bin_confidences[i] += (confidences[in_bin]).sum().item()

    avg_loss = total_loss / len(data_loader)
    accuracy = 100. * correct / total

    # Calculate ECE
    ece = torch.zeros(1, device=device)
    for i in range(n_bins):
        if bin_totals[i] > 0:
            bin_accuracy = bin_corrects[i] / bin_totals[i]
            avg_confidence = bin_confidences[i] / bin_totals[i]
            ece += torch.abs(avg_confidence - bin_accuracy) * bin_totals[i]
    ece = ece / total
    ece = ece.item()

    return avg_loss, accuracy, ece

In [17]:
def test(model, test_loader):
    test_loss, test_accuracy, ece = evaluate(model, test_loader)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}%, ECE: {ece:.5f}')

In [18]:
num_epochs = 3
input_size = 28 * 28
output_size = 10

In [None]:
model = Classifier(input_size, output_size, dropout=0.5)
hypernet_model = HyperNetClassifier(input_size, output_size, hidden_sizes=[1024, 512, 1024], device=device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
hypernet_optimizer = optim.Adam(hypernet_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

train_loader, val_loader, test_loader = MNIST_loaders()

print("Starting Training...")
for epoch in range(1, num_epochs + 1):
    _, _, _, _ = train(model, hypernet_model, train_loader, optimizer, hypernet_optimizer, epoch)
    
    val_loss, val_accuracy, ece = evaluate(model, val_loader)
    print(f'Validation Epoch: {epoch} Average Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%, ECE: {ece:.5f}')
    
    val_loss, val_accuracy, ece = evaluate(hypernet_model, val_loader)
    print(f'Hypernet Validation Epoch: {epoch} Average Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%, ECE: {ece:.5f}')

In [None]:
# --- Testing the Model ---
print("\nStarting Testing...")

print("Normal Model:")
test(model, test_loader)

print("Hypernet Model:")
test(hypernet_model, test_loader)


Starting Testing...
Normal Model:
Test set: Average loss: 0.5409, Accuracy: 87.87%, ECE: 0.06125
Hypernet Model:
Test set: Average loss: 48.8385, Accuracy: 9.91%, ECE: 0.00216
