In [24]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Hyperparameters
input_size = 784  # 28x28 images
hidden_size = 500
output_size = 10  # 10 classes for MNIST
num_epochs = 50
batch_size = 64
learning_rate = 1e-3
momentum = 0.9

# MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# Split dataset into training and validation sets (8:2)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

# Define the neural network




In [25]:
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [26]:
model = NeuralNet(input_size, hidden_size, output_size).to("cuda")
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum) 
train_errors = []
val_errors = []  

In [27]:
def train(num_epochs):
    for epoch in range(num_epochs):
        # print(f"Epoch {epoch+1}/{num_epochs}")
        model.train()
        correct = 0
        total = 0
        for images, labels in train_loader:
            images = images.view(-1, 28*28).to("cuda")
            outputs = model(images)
            labels = labels.to("cuda")
            # print(outputs.device)
            # print(labels.device)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_error = 100 * (1 - correct / total)
        train_errors.append(train_error)
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.view(-1, 28*28).to("cuda")
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                labels = labels.to("cuda")
                # print(predicted.device)
                # print(labels.device)
                correct += (predicted == labels).sum().item()
        
        val_error = 100 * (1 - correct / total)
        val_errors.append(val_error)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Error: {train_error:.2f}%, Validation Error: {val_error:.2f}%')
    plt.plot(train_errors, label='Train Error')
    plt.plot(val_errors, label='Validation Error')
    plt.xlabel('Epoch')
    plt.ylabel('Error (%)')
    plt.legend()
    plt.savefig("./new_fullconnected.png")
    plt.show()


In [28]:
train(num_epochs)

Epoch [1/50], Train Error: 19.56%, Validation Error: 12.20%
Epoch [2/50], Train Error: 10.83%, Validation Error: 10.22%
Epoch [3/50], Train Error: 9.51%, Validation Error: 9.25%
Epoch [4/50], Train Error: 8.71%, Validation Error: 8.77%
Epoch [5/50], Train Error: 8.10%, Validation Error: 8.76%
Epoch [6/50], Train Error: 7.57%, Validation Error: 7.76%
Epoch [7/50], Train Error: 7.02%, Validation Error: 7.40%
Epoch [8/50], Train Error: 6.57%, Validation Error: 7.11%
Epoch [9/50], Train Error: 6.07%, Validation Error: 6.78%
Epoch [10/50], Train Error: 5.77%, Validation Error: 6.13%
Epoch [11/50], Train Error: 5.37%, Validation Error: 5.99%
