In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
# Define the transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [3]:
# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:

# Define the CNN model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


In [5]:
# Initialize the model, loss function, and optimizer
model = Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [6]:
# Training loop
# epochs = 2
# for epoch in range(epochs):
#     train_loss = 0.0
#     for batch_idx, (data, target) in enumerate(train_loader):
#         optimizer.zero_grad()
#         output = model(data)
#         loss = criterion(output, target)
#         loss.backward()
#         optimizer.step()
#         train_loss += loss.item()
#         if batch_idx % 100 == 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                        100. * batch_idx / len(train_loader), loss.item()))
# 
#     # Evaluate the model on the test set
#     test_loss = 0
#     correct = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             output = model(data)
#             test_loss += criterion(output, target).item()
#             pred = output.max(1, keepdim=True)[1]
#             correct += pred.eq(target.view_as(pred)).sum().item()
# 
#     test_loss /= len(test_loader.dataset)
#     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#         test_loss, correct, len(test_loader.dataset),
#         100. * correct / len(test_loader.dataset)))

In [7]:
# Save the model
# torch.save(model.state_dict(), 'data/mnist_cnn.pt')

In [8]:
model.load_state_dict(torch.load('data/mnist_cnn.pt'))

<All keys matched successfully>

Loading Canvas
----

In [9]:
import numpy as np
import cv2
import torch
from torchvision.transforms import transforms

In [10]:
# Define the canvas size
canvas_width = 500
canvas_height = 500
# Define the preprocessing transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [14]:
# Function to handle mouse events
def draw(event, x, y, flags, param):
    global canvas, capturing, model

    if event == cv2.EVENT_LBUTTONDOWN:
        capturing = True

    elif event == cv2.EVENT_LBUTTONUP:
        capturing = False
        # Capture the canvas and preprocess it
        canvas_img = cv2.resize(canvas, (28, 28), interpolation=cv2.INTER_AREA)
        canvas_img = cv2.cvtColor(canvas_img, cv2.COLOR_BGR2GRAY)
        canvas_img = cv2.bitwise_not(canvas_img)
        canvas_tensor = transform(canvas_img)
        canvas_tensor = canvas_tensor.unsqueeze(0)

        # Perform inference with your CNN model
        output = model(canvas_tensor)
        prediction = output.max(1, keepdim=True)[1]
        pred_text = f"{prediction.item()}"
        print(pred_text)

        # Clear the canvas
        canvas = np.zeros((canvas_height, canvas_width, 3), np.uint8)

        # Calculate text size and position it in the center
        text_size = cv2.getTextSize(pred_text, cv2.FONT_HERSHEY_SIMPLEX, 2, 2)[0]
        text_x = (canvas_width - text_size[0]) // 2
        text_y = (canvas_height + text_size[1]) // 2

        # Write prediction in red back on the canvas
        cv2.putText(canvas, pred_text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 
                    2, (0, 0, 255), 2, cv2.LINE_AA)

    elif event == cv2.EVENT_MOUSEMOVE:
        if capturing:
            cv2.circle(canvas, (x, y), 10, (255, 255, 255), -1)

In [15]:
# Create a window for the canvas
cv2.namedWindow("Canvas", cv2.WINDOW_NORMAL)
canvas = np.zeros((canvas_height, canvas_width, 3), np.uint8)

# Define a flag to indicate when to capture the input
capturing = False

In [16]:
# Set the mouse callback function
cv2.setMouseCallback("Canvas", draw)

while True:
    cv2.imshow("Canvas", canvas)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cv2.destroyAllWindows()

1
0
0
8
5
5
0
8
5
7
0
3
4
8
0
0
0
6
9
1
6
5
9
6
8
2
7
7
0
0
2
1
5
5
6
3


KeyboardInterrupt: 