In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)  # 28x28 input
        self.fc2 = nn.Linear(128, 10)      # 10 output classes

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Train the network using SGD
def train_model(optimizer_name):
    model = SimpleNN()
    
    # Choose optimizer
    if optimizer_name == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=0.01)
    else:
        optimizer = optim.Adam(model.parameters(), lr=0.001)

    criterion = nn.CrossEntropyLoss()
    model.train()

    # Training loop
    for epoch in range(3):  # Train for 3 epochs
        for inputs, labels in train_loader:
            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Calculate loss
            loss.backward()  # Backpropagate
            optimizer.step()  # Update weights

    return model

# Evaluate the model's performance
def evaluate_model(model):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient calculation
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)  # Get the predicted class
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total * 100
    return accuracy

# Train and evaluate using SGD
sgd_model = train_model('SGD')
sgd_accuracy = evaluate_model(sgd_model)
print(f'SGD Accuracy: {sgd_accuracy:.2f}%')

# Train and evaluate using Adam
adam_model = train_model('Adam')
adam_accuracy = evaluate_model(adam_model)
print(f'Adam Accuracy: {adam_accuracy:.2f}%')


SGD Accuracy: 90.20%
Adam Accuracy: 96.80%
