In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# --- Hyperparameters (things you can easily change!) ---
num_epochs = 10
learning_rate = 0.01
weight_decay = 0.001
batch_size = 64
validation_split = 0.2  # Percentage of the training data to use for validation
random_seed = 42      # For making sure our splits are the same each time

In [4]:
# Set the random seed for reproducibility
torch.manual_seed(random_seed)

<torch._C.Generator at 0x1af884cae50>

In [5]:
# --- Preparing the Dataset and DataLoaders ---

# Define the transformations to apply to the images
# Here, we convert the images to PyTorch tensors and normalize the pixel values
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Download the MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Split the training dataset into training and validation sets
train_size = int((1 - validation_split) * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Create DataLoaders. These help us load the data in batches during training.
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [6]:
# --- Implementing the Simple Classifier ---

class Classifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(Classifier, self).__init__()
        # This is a single linear layer (like a simple connection of all inputs to all outputs)
        self.linear = nn.Linear(input_size, output_size, bias=False)

    def forward(self, x):
        # The input images are 28x28 pixels, so we need to flatten them into a single vector of 784 elements
        x = x.view(-1, 28 * 28)
        # Pass the flattened vector through the linear layer
        x = self.linear(x)
        return x

In [7]:
class HyperNetClassifier(nn.Module):
    def __init__(self, input_size, output_size, hidden_sizes=[2400, 1200, 2400]):
        super(HyperNetClassifier, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        
        self.num_weights = input_size * output_size
        
        self.input_layer = nn.Linear(self.num_weights, hidden_sizes[0], bias=False)
        self.hidden_1 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.hidden_2 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.output_layer = nn.Linear(hidden_sizes[2], self.num_weights, bias=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)

        mask = (torch.rand(self.num_weights, requires_grad=False) >= 0.5).int().to(device)
        hypernet_input = torch.randn(self.num_weights, requires_grad=True).to(device) * mask
        
        hypernet_output = torch.relu(self.input_layer(hypernet_input))
        hypernet_output = torch.relu(self.hidden_1(hypernet_output))
        hypernet_output = torch.relu(self.hidden_2(hypernet_output))
        hypernet_output = self.output_layer(hypernet_output)

        weights = torch.reshape(hypernet_output, (self.input_size, self.output_size))

        return x @ weights

In [8]:
# --- Training Function ---

def train(model, hypernet_model, train_loader, optimizer, hypernet_optimizer, epoch):
    model.train()  # Set the model to training mode
    hypernet_model.train()
    
    model = model.to(device)
    hypernet_model = hypernet_model.to(device)
    
    total_loss = 0
    correct = 0
    
    hypernet_total_loss = 0
    hypernet_correct = 0
    
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):

        optimizer.zero_grad()
        hypernet_optimizer.zero_grad()
        
        data = data.to(device)
        target = target.to(device)
        
        output = model(data)
        hypernet_output = hypernet_model(data)

        loss = nn.CrossEntropyLoss()(output, target)
        hypernet_loss = nn.CrossEntropyLoss()(hypernet_output, target)

        loss.backward()
        hypernet_loss.backward()

        optimizer.step()
        hypernet_optimizer.step()

        total_loss += loss.item()
        hypernet_total_loss += hypernet_loss.item()
        
        _, predicted = torch.max(output.data, 1)
        _, hypernet_predicted = torch.max(hypernet_output.data, 1)
        
        total += target.size(0)
        
        correct += (predicted == target).sum().item()
        hypernet_correct += (hypernet_predicted == target).sum().item()

        if (batch_idx + 1) % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}. Hypernet Loss: {hypernet_loss.item():.6f}')

    avg_loss = total_loss / len(train_loader)
    avg_hypernet_loss = hypernet_total_loss / len(train_loader)
    
    accuracy = 100. * correct / total
    hypernet_accuracy = 100. * hypernet_correct / total
    
    print(f'Train Epoch: {epoch} Average Loss: {avg_loss:.4f}, Average Hypernet Loss: {avg_hypernet_loss:.4f}, Accuracy: {accuracy:.2f}%, Hypernet Accuracy: {hypernet_accuracy:.2f}%')
    return avg_loss, avg_hypernet_loss, accuracy, hypernet_accuracy

In [9]:
# --- Evaluating Function ---

def evaluate(model, data_loader):
    model.eval()
    model = model.to(device)
    
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient calculations during evaluation
        for data, target in data_loader:
            data = data.to(device)
            target = target.to(device)
            
            output = model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            total_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    avg_loss = total_loss / len(data_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [10]:
# --- Testing Function ---

def test(model, test_loader):
    test_loss, test_accuracy = evaluate(model, test_loader)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}%')

In [14]:
num_epochs = 10

In [None]:
# Initialize the model
input_size = 28 * 28  # 784 input features (28x28 pixels)
output_size = 10     # 10 output classes (digits 0-9)

model = Classifier(input_size, output_size)
hypernet_model = HyperNetClassifier(input_size, output_size, hidden_sizes=[1024, 512, 1024])

# --- Initialize the Optimizer ---
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
hypernet_optimizer = optim.Adam(hypernet_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# --- Training and Evaluation Loop ---
print("Starting Training...")
for epoch in range(1, num_epochs + 1):
    _, _, _, _ = train(model, hypernet_model, train_loader, optimizer, hypernet_optimizer, epoch)
    
    val_loss, val_accuracy = evaluate(model, val_loader)
    print(f'Validation Epoch: {epoch} Average Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')
    
    val_loss, val_accuracy = evaluate(hypernet_model, val_loader)
    print(f'Hypernet Validation Epoch: {epoch} Average Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')

Starting Training...
Train Epoch: 1 Average Loss: 0.5199, Average Hypernet Loss: 44.6137, Accuracy: 87.22%, Hypernet Accuracy: 75.58%
Validation Epoch: 1 Average Loss: 0.5403, Accuracy: 87.55%
Hypernet Validation Epoch: 1 Average Loss: 1.1493, Accuracy: 64.64%
Train Epoch: 2 Average Loss: 0.4854, Average Hypernet Loss: 592.7261, Accuracy: 88.84%, Hypernet Accuracy: 68.01%
Validation Epoch: 2 Average Loss: 0.6323, Accuracy: 86.94%
Hypernet Validation Epoch: 2 Average Loss: 183.1502, Accuracy: 83.24%
Train Epoch: 3 Average Loss: 0.5059, Average Hypernet Loss: 356.5517, Accuracy: 88.58%, Hypernet Accuracy: 69.76%
Validation Epoch: 3 Average Loss: 0.6392, Accuracy: 86.86%
Hypernet Validation Epoch: 3 Average Loss: 0.4109, Accuracy: 87.96%
Train Epoch: 4 Average Loss: 0.5204, Average Hypernet Loss: 13.4896, Accuracy: 88.54%, Hypernet Accuracy: 77.33%
Validation Epoch: 4 Average Loss: 0.7421, Accuracy: 86.82%
Hypernet Validation Epoch: 4 Average Loss: 0.7468, Accuracy: 77.01%


In [None]:
# --- Testing the Model ---
print("\nStarting Testing...")

print("Normal Model:")
test(model, test_loader)

print("Hypernet Model:")
test(hypernet_model, test_loader)


Starting Testing...
Normal Model:
Test set: Average loss: 0.5011, Accuracy: 89.05%
Hypernet Model:
Test set: Average loss: 1172.1804, Accuracy: 73.84%
