<a href="https://colab.research.google.com/github/nahiim/colab/blob/main/06_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 1. Load and preprocess the MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),           # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=64, shuffle=False)

# 2. Define the Neural Network
class DigitClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)  # 10 output classes (0 to 9)
        )

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten image
        return self.model(x)

net = DigitClassifier()

# 3. Set loss function and optimizer
criterion = nn.CrossEntropyLoss()  # handles multi-class classification
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 4. Train the model
for epoch in range(5):
    total_loss = 0
    for images, labels in train_loader:
        preds = net(images)
        loss = criterion(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")

# 5. Evaluate the model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"\nAccuracy on test set: {100 * correct / total:.2f}%")


100%|██████████| 9.91M/9.91M [00:00<00:00, 11.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 345kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.20MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.62MB/s]


Epoch 1 | Loss: 374.7886
Epoch 2 | Loss: 170.3203
Epoch 3 | Loss: 122.9501
Epoch 4 | Loss: 99.7847
Epoch 5 | Loss: 84.0680

Accuracy on test set: 96.92%
