In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import numpy as np
import tkinter as tk
from PIL import Image, ImageOps, ImageDraw

In [9]:
class MyCNN(nn.Module):
    def __init__(self):
        super(MyCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2)
        self.dropout = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(128 * 11 * 11, 128)
        self.bn4 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.dropout(x)
        x = x.view(-1, 128 * 11 * 11) 
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = MyCNN()
print(model)

MyCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout2d(p=0.25, inplace=False)
  (fc1): Linear(in_features=15488, out_features=128, bias=True)
  (bn4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [10]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

model = MyCNN()

#Cross entropy
criterion = nn.CrossEntropyLoss() 
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        outputs = model(inputs)
        
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i+1) % 100 == 0: 
            print('[Epoch %d, Mini-batch %5d] Loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0


[Epoch 1, Mini-batch   100] Loss: 0.489
[Epoch 1, Mini-batch   200] Loss: 0.190
[Epoch 1, Mini-batch   300] Loss: 0.157
[Epoch 1, Mini-batch   400] Loss: 0.129
[Epoch 1, Mini-batch   500] Loss: 0.125
[Epoch 1, Mini-batch   600] Loss: 0.119
[Epoch 1, Mini-batch   700] Loss: 0.104
[Epoch 1, Mini-batch   800] Loss: 0.112
[Epoch 1, Mini-batch   900] Loss: 0.093
[Epoch 2, Mini-batch   100] Loss: 0.085
[Epoch 2, Mini-batch   200] Loss: 0.090
[Epoch 2, Mini-batch   300] Loss: 0.078
[Epoch 2, Mini-batch   400] Loss: 0.088
[Epoch 2, Mini-batch   500] Loss: 0.075
[Epoch 2, Mini-batch   600] Loss: 0.068
[Epoch 2, Mini-batch   700] Loss: 0.069
[Epoch 2, Mini-batch   800] Loss: 0.072
[Epoch 2, Mini-batch   900] Loss: 0.059
[Epoch 3, Mini-batch   100] Loss: 0.064
[Epoch 3, Mini-batch   200] Loss: 0.052
[Epoch 3, Mini-batch   300] Loss: 0.066
[Epoch 3, Mini-batch   400] Loss: 0.052
[Epoch 3, Mini-batch   500] Loss: 0.069
[Epoch 3, Mini-batch   600] Loss: 0.051
[Epoch 3, Mini-batch   700] Loss: 0.052


In [11]:

val_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

model.eval()

correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print('Accuracy on the validation set: {:.2f}%'.format(accuracy))

Accuracy on the validation set: 99.30%


In [12]:
model_save_path = 'my_cnn_model.pth'

torch.save(model.state_dict(), model_save_path)
print(f'Model saved to {model_save_path}')


Model saved to my_cnn_model.pth


In [13]:

# Load trained model 
model = MyCNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()

class DrawApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Draw a digit")
        self.canvas = tk.Canvas(root, width=200, height=200, bg='white')
        self.canvas.pack()
        self.canvas.bind("<B1-Motion>", self.paint)
        self.button_clear = tk.Button(root, text="Clear", command=self.clear)
        self.button_clear.pack()
        self.button_predict = tk.Button(root, text="Predict", command=self.predict)
        self.button_predict.pack()
        self.label = tk.Label(root, text="", font=("Helvetica", 24))
        self.label.pack()
        self.image = Image.new("L", (200, 200), color=255)
        self.draw = ImageDraw.Draw(self.image)

    def paint(self, event):
        x1, y1 = (event.x - 8), (event.y - 8)
        x2, y2 = (event.x + 8), (event.y + 8)
        self.canvas.create_oval(x1, y1, x2, y2, fill="black", width=5)
        self.draw.ellipse([x1, y1, x2, y2], fill=0)

    def clear(self):
        self.canvas.delete("all")
        self.image = Image.new("L", (200, 200), color=255)
        self.draw = ImageDraw.Draw(self.image)
        self.label.config(text="")

    def predict(self):
        # Preprocess the image
        img = self.image.resize((28, 28))
        img = ImageOps.invert(img)
        img = np.array(img).astype(np.float32) / 255.0
        img = torch.tensor(img).unsqueeze(0).unsqueeze(0)

        # Make prediction
        with torch.no_grad():
            output = model(img)
            pred = output.argmax(dim=1, keepdim=True)

        # Display the prediction
        self.label.config(text=f'Predicted Digit: {pred.item()}')

root = tk.Tk()
app = DrawApp(root)
root.mainloop()
