In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# Load the datasets
train_dataset = datasets.MNIST(
    root = "data",
    train = True,
    transform = transforms.ToTensor(),
    download = True
)

test_dataset = datasets.MNIST(
    root = "data",
    train = False,
    transform = transforms.ToTensor(),
    download = True
)

In [3]:
# Wrap the datasets in dataloader for batching
train_loader = DataLoader(dataset = train_dataset, batch_size = 64, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size = 64, shuffle = False)

In [4]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128)  #nn.Linear(input_size, output_size)
        self.fc2 = nn.Linear(128,10)
    def forward(self,x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

**So the flow is:
Pixels (784) → Features (128) → Digit Scores (10).**

In [6]:
# Initialize model, loss, optimizer
model = SimpleNN()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [7]:
# Training loop
for epoch in range(10):   
    for images, labels in train_loader:
        outputs = model(images)         
        loss = loss_fn(outputs, labels)   

        optimizer.zero_grad()             
        loss.backward()                   
        optimizer.step()                  

    print(f"Epoch {epoch+1}, Loss = {loss.item():.4f}")

Epoch 1, Loss = 0.1303
Epoch 2, Loss = 0.1749
Epoch 3, Loss = 0.0333
Epoch 4, Loss = 0.1395
Epoch 5, Loss = 0.1421
Epoch 6, Loss = 0.0163
Epoch 7, Loss = 0.0027
Epoch 8, Loss = 0.1560
Epoch 9, Loss = 0.0148
Epoch 10, Loss = 0.0053


In [8]:
# Test accuracy
correct, total = 0, 0
with torch.no_grad():  # no gradients during testing
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy = {100 * correct / total:.2f}%")

Test Accuracy = 97.58%
