In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

# Data preparation with updated validation split
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
total_size = len(dataset)
train_size = int(total_size * 0.8)  # 80% for training
val_size = total_size - train_size  # 20% for validation

train_set, val_set = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False)

# CNN architecture with additional layers and adjustments
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.fc_layer = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(128 * 3 * 3, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )
        
    def forward(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)  # Flatten the layer
        x = self.fc_layer(x)
        return x

# Move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Add training and validation code as needed to complete the setup
for epoch in range(10):  # loop over the dataset multiple times
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)  # Move data to GPU
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

    # Validation accuracy
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in val_loader:
            images, labels = data[0].to(device), data[1].to(device)  # Move data to GPU
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the network on the validation images: {100 * correct / total}%')

Epoch 1, Loss: 0.18612862056369583
Accuracy of the network on the validation images: 98.41666666666667%
Epoch 2, Loss: 0.0754811595318218
Accuracy of the network on the validation images: 98.81666666666666%
Epoch 3, Loss: 0.057543671413635214
Accuracy of the network on the validation images: 98.71666666666667%
Epoch 4, Loss: 0.050180432066942254
Accuracy of the network on the validation images: 98.93333333333334%
Epoch 5, Loss: 0.04162088596168906
Accuracy of the network on the validation images: 98.95%
Epoch 6, Loss: 0.037204658682419296
Accuracy of the network on the validation images: 98.95833333333333%
Epoch 7, Loss: 0.03421786321358134
Accuracy of the network on the validation images: 99.19166666666666%
Epoch 8, Loss: 0.032015219632109314
Accuracy of the network on the validation images: 99.18333333333334%
Epoch 9, Loss: 0.02943124442495173
Accuracy of the network on the validation images: 99.25%
Epoch 10, Loss: 0.027181052293046376
Accuracy of the network on the validation images

After 99.29%

In [None]:
# Use the original dataset without splitting
full_train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 10  # Set the number of epochs according to your previous experiments

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for images, labels in full_train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(full_train_loader)}')

Epoch [1/10], Loss: 0.1618996251936454
Epoch [2/10], Loss: 0.06262055852389369
Epoch [3/10], Loss: 0.05117280608036125
Epoch [4/10], Loss: 0.044390300056741466
Epoch [5/10], Loss: 0.03882307470615418
Epoch [6/10], Loss: 0.03424418139430795
Epoch [7/10], Loss: 0.032400465192048186
Epoch [8/10], Loss: 0.028498316897273085
Epoch [9/10], Loss: 0.026476696431741685
Epoch [10/10], Loss: 0.02643818386201477


In [None]:
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

model.eval()  # Set the model to evaluation mode
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the 10000 test images: {100 * correct / total}%')

Accuracy of the model on the 10000 test images: 99.31%


In [None]:
model_path = 'mnist_cnn_model.pth'

# Save the model
torch.save(model.state_dict(), model_path)