In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import pickle

In [3]:
# Hyperparameters
input_size = 28 * 28  # MNIST images are 28x28
hidden_size = 128
num_classes = 10
epochs = 5
batch_size = 64
learning_rate = 0.001

# 1. Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 2. Define the Neural Network
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = x.view(-1, input_size)  # Flatten the input
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = MLP(input_size, hidden_size, num_classes)

# 3. Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 4. Train the model
def train():
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, targets) in enumerate(train_loader):
            # Move data to device (CPU/GPU)
            data, targets = data, targets

            # Forward pass
            outputs = model(data)
            loss = criterion(outputs, targets)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch [{epoch + 1}/{epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")

# 5. Evaluate the model
def evaluate():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data, targets
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    print(f"Accuracy of the model on the test images: {100 * correct / total:.2f}%")

# Save the model using pickle
def save_model_with_pickle(model, file_path):
    with open(file_path, 'wb') as f:
        pickle.dump(model.state_dict(), f)
    print(f"Model saved to {file_path}")

# Load the model using pickle
def load_model_with_pickle(model, file_path):
    with open(file_path, 'rb') as f:
        state_dict = pickle.load(f)
    model.load_state_dict(state_dict)
    print(f"Model loaded from {file_path}")

if __name__ == "__main__":
    train()
    save_model_with_pickle(model, "model.pkl")  # Save the trained model
    # To load later:
    # load_model_with_pickle(model, "model.pkl")
    evaluate()

Epoch [1/5], Step [100/938], Loss: 0.3209
Epoch [1/5], Step [200/938], Loss: 0.3948
Epoch [1/5], Step [300/938], Loss: 0.1543
Epoch [1/5], Step [400/938], Loss: 0.5316
Epoch [1/5], Step [500/938], Loss: 0.3191
Epoch [1/5], Step [600/938], Loss: 0.4533
Epoch [1/5], Step [700/938], Loss: 0.2584
Epoch [1/5], Step [800/938], Loss: 0.1022
Epoch [1/5], Step [900/938], Loss: 0.4521
Epoch [2/5], Step [100/938], Loss: 0.2239
Epoch [2/5], Step [200/938], Loss: 0.1722
Epoch [2/5], Step [300/938], Loss: 0.3737
Epoch [2/5], Step [400/938], Loss: 0.1432
Epoch [2/5], Step [500/938], Loss: 0.1029
Epoch [2/5], Step [600/938], Loss: 0.1490
Epoch [2/5], Step [700/938], Loss: 0.1694
Epoch [2/5], Step [800/938], Loss: 0.1282
Epoch [2/5], Step [900/938], Loss: 0.3129
Epoch [3/5], Step [100/938], Loss: 0.1195
Epoch [3/5], Step [200/938], Loss: 0.0809
Epoch [3/5], Step [300/938], Loss: 0.0965
Epoch [3/5], Step [400/938], Loss: 0.1396
Epoch [3/5], Step [500/938], Loss: 0.1435
Epoch [3/5], Step [600/938], Loss: