In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [4]:
# Define the transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [5]:
# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:

# Define the CNN model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


In [8]:
# Initialize the model, loss function, and optimizer
model = Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [9]:
# Training loop
epochs = 20
for epoch in range(epochs):
    train_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

    # Evaluate the model on the test set
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


Test set: Average loss: 0.0068, Accuracy: 8678/10000 (87%)


Test set: Average loss: 0.0048, Accuracy: 9099/10000 (91%)


Test set: Average loss: 0.0040, Accuracy: 9186/10000 (92%)


Test set: Average loss: 0.0037, Accuracy: 9289/10000 (93%)


Test set: Average loss: 0.0033, Accuracy: 9366/10000 (94%)


Test set: Average loss: 0.0031, Accuracy: 9407/10000 (94%)


Test set: Average loss: 0.0028, Accuracy: 9473/10000 (95%)


Test set: Average loss: 0.0026, Accuracy: 9499/10000 (95%)


Test set: Average loss: 0.0027, Accuracy: 9469/10000 (95%)


Test set: Average loss: 0.0026, Accuracy: 9500/10000 (95%)


Test set: Average loss: 0.0025, Accuracy: 9505/10000 (95%)


Test set: Average loss: 0.0023, Accuracy: 9564/10000 (96%)


Test set: Average loss: 0.0023, Accuracy: 9569/10000 (96%)


Test set: Average loss: 0.0023, Accuracy: 9573/10000 (96%)


Test set: Average loss: 0.0024, Accuracy: 9565/10000 (96%)


Test set: Average loss: 0.0022, Accuracy: 9574/10000 (96%)


Test set: Average loss:

In [10]:
# Save the model
torch.save(model.state_dict(), 'data/mnist_cnn.pt')

In [11]:
model.load_state_dict(torch.load('data/mnist_cnn.pt'))

<All keys matched successfully>

Loading Canvas
----

In [12]:
import cv2
import numpy as np
import torch
from torchvision import transforms

In [13]:
# Define the canvas size
canvas_width = 500
canvas_height = 500
# Define the preprocessing transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
# Create a window for the canvas
cv2.namedWindow("Canvas", cv2.WINDOW_NORMAL)
canvas = np.zeros((canvas_height, canvas_width, 3), np.uint8)

# Define a flag to indicate when to capture the input
capturing = False

# Function to handle mouse events
def draw(event, x, y, flags, param):
    global canvas, capturing

    if event == cv2.EVENT_LBUTTONDOWN:
        capturing = True

    elif event == cv2.EVENT_LBUTTONUP:
        capturing = False
        # Capture the canvas and preprocess it
        canvas_img = cv2.resize(canvas, (28, 28), interpolation=cv2.INTER_AREA)
        canvas_img = cv2.bitwise_not(canvas_img)
        canvas_tensor = transform(canvas_img)
        canvas_tensor = canvas_tensor.unsqueeze(0)

        # Perform inference with your CNN model
        output = model(canvas_tensor)
        prediction = output.max(1, keepdim=True)[1]
        print(f"Predicted: {prediction.item()}")

        # Clear the canvas
        canvas = np.zeros((canvas_height, canvas_width, 3), np.uint8)

    elif event == cv2.EVENT_MOUSEMOVE:
        if capturing:
            cv2.circle(canvas, (x, y), 10, (255, 255, 255), -1)

# Set the mouse callback function
cv2.setMouseCallback("Canvas", draw)

# Load your trained CNN model
model.load_state_dict(torch.load('data/mnist_cnn.pt'))
model.eval()

while True:
    cv2.imshow("Canvas", canvas)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cv2.destroyAllWindows()

Predicted: 7
Predicted: 0
Predicted: 0




Predicted: 5
Predicted: 3
Predicted: 7
Predicted: 1
Predicted: 1
Predicted: 1
Predicted: 2
Predicted: 2
Predicted: 3
