In [25]:
import tkinter as tk
from tkinter import Canvas, Button
from PIL import Image, ImageDraw, ImageOps, ImageTk
import os
import random
import numpy as np
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import cv2

In [18]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=5):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(28800, 512)  # Adjust size according to your image dimension / pooling
        self.fc2 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = self.pool(self.relu(self.conv4(x)))
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
# Step 1: Load the saved model
state_dict = torch.load('model_k2_fold0.pth', map_location=device)
model = SimpleCNN()  # Instantiate your model class
model.load_state_dict(state_dict)
model.eval()
# Step 2: Prepare the input data (image tensor)
# Load your image
image_path = 'shape.png'
image = Image.open(image_path)

# Get std and mean
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = img.astype(np.float32) / 255.0
pixel_sum = np.sum(img)
pixel_count = img.size
mean = pixel_sum / pixel_count
variance = np.sum((img - mean) ** 2)
std = np.sqrt(variance / pixel_count)

# Step 3: Preprocess the input data if necessary
# Define transformations
transform = transforms.Compose([
    transforms.Resize((240, 240)),  # Resize to match model input size
    transforms.ToTensor(),           # Convert to tensor
    transforms.Normalize(mean=mean, std=std)  # Normalize
])

# Apply transformations to the image
input_image = transform(image)
input_image = input_image.unsqueeze(0)  # Add batch dimension

# Step 4: Pass the input data through the model
with torch.no_grad():
    output = model(input_image)

# Step 5: Interpret the output
# For example, if it's a classification model, you might want to get the predicted class
_, predicted_class = output.max(1)
labels_map = {
    "circle": 0,
    "square": 1,
    "triangle": 2,
    "pentagon": 3,
    "hexagon": 4,
}
reverse_labels_map = {value: key for key, value in labels_map.items()}
predicted_label = reverse_labels_map[predicted_class.item()]

print("Predicted class:", predicted_label)

Predicted class: triangle


In [29]:
labels_map = {
    "circle": 0,
    "square": 1,
    "triangle": 2,
    "pentagon": 3,
    "hexagon": 4,
}
reverse_labels_map = {value: key for key, value in labels_map.items()}

class DrawingApp():
    def __init__(self, master, width=240, height=240):

        self.mode = 0
        self.label = "circle"
        self.master = master
        self.width = width
        self.height = height
        
        self.canvas = tk.Canvas(master, width=self.width, height=self.height, bg="white")
        self.canvas.pack()
        
        self.canvas.bind("<B1-Motion>", self.draw)
        
        self.button_predict = tk.Button(master, text="Predict", command=self.predict, bg="cyan")
        self.button_predict.pack()

        self.label_text = tk.StringVar()
        self.label_text.set("Draw Shape")  # Initial label text
        self.label_widget = tk.Label(master, textvariable=self.label_text)
        self.label_widget.pack()
        
        self.new_canvas()  # Create a new drawing canvas
            
    def draw(self, event):
        x, y = event.x, event.y
        r = 3
        self.canvas.create_oval(x-r, y-r, x+r, y+r, fill="black")
        self.draw.ellipse([x-r, y-r, x+r, y+r], fill="black")
        
    def predict(self):
        filename = "your_shape.png"
        self.image.save(filename)

        image_path = filename
        image = Image.open(image_path)
        image = image.convert('L')

        # Get std and mean
        img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        img = img.astype(np.float32) / 255.0
        pixel_sum = np.sum(img)
        pixel_count = img.size
        mean = pixel_sum / pixel_count
        variance = np.sum((img - mean) ** 2)
        std = np.sqrt(variance / pixel_count)

        # Define transformations
        transform = transforms.Compose([
            transforms.Resize((240, 240)),  # Resize to match model input size
            transforms.ToTensor(),           # Convert to tensor
            transforms.Normalize(mean=mean, std=std)  # Normalize
        ])

        # Apply transformations to the image
        input_image = transform(image)
        input_image = input_image.unsqueeze(0)  # Add batch dimension
        
        state_dict = torch.load('model_k2_fold0.pth', map_location=device)
        model = SimpleCNN()  # Instantiate your model class
        model.load_state_dict(state_dict)
        model.eval()
        
        with torch.no_grad():
            output = model(input_image)

        # Step 5: Interpret the output
        # For example, if it's a classification model, you might want to get the predicted class
        _, predicted_class = output.max(1)
        predicted_label = reverse_labels_map[predicted_class.item()]

        print("Predicted class:", predicted_label)
        self.label_text.set(predicted_label)

        self.new_canvas()  # Create a new drawing canvas
        
    def new_canvas(self):
        self.canvas.delete("all")  # Clear the canvas
        self.image = Image.new("RGB", (self.width, self.height), "white")
        self.draw = ImageDraw.Draw(self.image)
    


def main():
    root = tk.Tk()
    root.title("Draw Shapes")
    app = DrawingApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()

Predicted class: square
Predicted class: circle
Predicted class: circle
Predicted class: triangle
Predicted class: square
Predicted class: pentagon
Predicted class: pentagon
Predicted class: pentagon
Predicted class: hexagon
Predicted class: circle
Predicted class: circle
Predicted class: circle
Predicted class: triangle
Predicted class: circle
Predicted class: triangle
Predicted class: triangle
Predicted class: square
Predicted class: triangle
Predicted class: circle
Predicted class: circle
Predicted class: hexagon
