In [30]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [7]:
# Define the transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [8]:
# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [9]:

# Define the CNN model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


In [10]:
# Initialize the model, loss function, and optimizer
model = Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [11]:
# Training loop
# epochs = 2
# for epoch in range(epochs):
#     train_loss = 0.0
#     for batch_idx, (data, target) in enumerate(train_loader):
#         optimizer.zero_grad()
#         output = model(data)
#         loss = criterion(output, target)
#         loss.backward()
#         optimizer.step()
#         train_loss += loss.item()
#         if batch_idx % 100 == 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                        100. * batch_idx / len(train_loader), loss.item()))
# 
#     # Evaluate the model on the test set
#     test_loss = 0
#     correct = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             output = model(data)
#             test_loss += criterion(output, target).item()
#             pred = output.max(1, keepdim=True)[1]
#             correct += pred.eq(target.view_as(pred)).sum().item()
# 
#     test_loss /= len(test_loader.dataset)
#     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#         test_loss, correct, len(test_loader.dataset),
#         100. * correct / len(test_loader.dataset)))

In [12]:
# Save the model
# torch.save(model.state_dict(), 'data/mnist_cnn.pt')

In [13]:
model.load_state_dict(torch.load('data/mnist_cnn.pt'))

<All keys matched successfully>

Loading Canvas
----

In [14]:

from torchvision.transforms import transforms

In [15]:
# Define the canvas size
canvas_width = 500
canvas_height = 500
# Define the preprocessing transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [31]:

def load_dataset(save_dir="saved_drawings"):
    images = []
    labels = []
    files = os.listdir(save_dir)
    for file in sorted(files):
        if file.endswith(".png"):
            img = cv2.imread(os.path.join(save_dir, file), cv2.IMREAD_GRAYSCALE)
            images.append(img)
        elif file.endswith(".txt"):
            with open(os.path.join(save_dir, file), "r") as file:
                labels.append(int(file.read().strip()))
    return images, labels

In [32]:
def save_drawing_and_maybe_retrain(canvas, model, save_dir="saved_drawings", retrain_threshold=50):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    file_count = len([name for name in os.listdir(save_dir) if name.endswith('.png')])
    file_path = os.path.join(save_dir, f"drawing_{file_count + 1}.png")

    cv2.imwrite(file_path, canvas)

    # Check if it's time to retrain
    if (file_count + 1) % retrain_threshold == 0:
        print("Initiating re-training...")
        train_model(model, save_dir)


In [33]:
import cv2


def process_image_for_model(canvas):
    """
    https://discuss.pytorch.org/t/image-processing-functions-on-pytorch-tensor/37203
    :param canvas: 
    :return: 
    """
    # Resize the image to the size expected by the model (e.g., 28x28 for MNIST)
    processed_img = cv2.resize(canvas, (28, 28), interpolation=cv2.INTER_AREA)

    # Convert to grayscale if the canvas image is in color
    if processed_img.shape[2] == 3:  # Check if the image has 3 channels
        processed_img = cv2.cvtColor(processed_img, cv2.COLOR_BGR2GRAY)

    # Invert the image colors (black-on-white to white-on-black)
    processed_img = cv2.bitwise_not(processed_img)

    # Convert the image to a PyTorch tensor and normalize
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Adjust as per your model's training
    ])

    return transform(processed_img)

In [37]:
# Function to handle mouse events
def draw(event, x, y, flags, param):
    global canvas, capturing, model

    if event == cv2.EVENT_LBUTTONDOWN:
        capturing = True

    elif event == cv2.EVENT_LBUTTONUP:
        capturing = False
        # Process the image for prediction
        canvas_img = cv2.resize(canvas, (28, 28), interpolation=cv2.INTER_AREA)
        canvas_img = cv2.cvtColor(canvas_img, cv2.COLOR_BGR2GRAY)
        canvas_img = cv2.bitwise_not(canvas_img)
        canvas_tensor = transform(canvas_img)
        canvas_tensor = canvas_tensor.unsqueeze(0)

        # Perform inference with your CNN model
        output = model(canvas_tensor)
        prediction = output.max(1, keepdim=True)[1]
        pred_text = f"{prediction.item()}"

        # Process the image to prepare it for the model
        canvas_img = process_image_for_model(canvas)
        # Save drawing and possibly retrain
        save_drawing_and_maybe_retrain(canvas_img, model)

        # Clear the canvas
        canvas = np.zeros((canvas_height, canvas_width, 3), np.uint8)

        # Calculate text size and position it in the center
        text_size = cv2.getTextSize(pred_text, cv2.FONT_HERSHEY_SIMPLEX, 2, 2)[0]
        text_x = (canvas_width - text_size[0]) // 2
        text_y = (canvas_height + text_size[1]) // 2

        # Write prediction in red back on the canvas
        cv2.putText(canvas, pred_text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX,
                    2, (0, 0, 255), 2, cv2.LINE_AA)

    if capturing:
        cv2.circle(canvas, (x, y), 10, (255, 255, 255), -1)

In [38]:
# Create a window for the canvas
cv2.namedWindow("Canvas", cv2.WINDOW_NORMAL)
canvas = np.zeros((canvas_height, canvas_width, 3), np.uint8)

# Define a flag to indicate when to capture the input
capturing = False

In [None]:
# Set the mouse callback function
cv2.setMouseCallback("Canvas", draw)

while True:
    cv2.imshow("Canvas", canvas)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cv2.destroyAllWindows()

In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
import cv2
import numpy as np


class DrawingDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        self.data = [f for f in os.listdir(folder_path) if f.endswith('.png')]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.folder_path, self.data[idx])
        image = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
        image = cv2.bitwise_not(image)  # Invert if necessary
        image = self.transform(image)
        return image


def train_model(model, folder_path, epochs=1, batch_size=10):
    dataset = DrawingDataset(folder_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()  # Assuming we have pseudo-labels

    model.train()
    for epoch in range(epochs):
        for data in dataloader:
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, torch.max(outputs, 1)[1])  # Using pseudo-labeling here
            loss.backward()
            optimizer.step()
    print("Training complete.")
    model.eval()

