In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


In [6]:
transform = transforms.Compose([transforms.ToTensor()])



In [None]:
# prepare the datasets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

100%|██████████| 9.91M/9.91M [00:03<00:00, 2.94MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 94.8kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 971kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.46MB/s]


In [None]:
# Build a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [13]:
# model training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Training the model
for epoch in range (5):
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/5], Loss: {running_loss/len(train_loader):.4f}')

Epoch [1/5], Loss: 0.2037
Epoch [2/5], Loss: 0.0633
Epoch [3/5], Loss: 0.0462
Epoch [4/5], Loss: 0.0378
Epoch [5/5], Loss: 0.0283


In [14]:
# save the model
torch.save(model.state_dict(), 'mnist_cnn.pth')