In [1]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import SubsetRandomSampler
from PIL import Image
import cv2

## Model Architecture 

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # data has a single input channel
        
        
        self.conv1 = nn.Conv2d(1, 6, 5)
        # output is 6, 24x24 feature maps 
        
        self.pool = nn.MaxPool2d(2, 2)  
        #output is 6,  12x12 feature maps 
        
        self.conv2 = nn.Conv2d(6, 16, 5)
        # output is 16, 8x8 feature maps
        
        
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)   

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1) #second pooling layer
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Loading Trained Model

The model is loaded from the saved path and instantiated

In [3]:
PATH = "image_classifier.pth"
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=5, bias=True)
)

Similarly to the training process, the image is prepared and converted into a dataloader 

In [4]:
transform = transforms.Compose(
    [
     transforms.ToTensor()
    ])

In [5]:
class MSTAR_test(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        
        
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = int(index.item())
            
        x = self.data[index]
      

        if self.transform:
            x = Image.fromarray(self.data[index])
            x = self.transform(x)

        return x

    def __len__(self):
        return len(self.data)

## Function to Classify a Single Image
This fucntion reshapes and prepares the user uploaded image as required. The dataloader is put through the model and a prediction is returned.

In [6]:
def predict(img):
    img_cnn = img.reshape(img.shape[0], 28, 28)                                    #reshape to 28x28
    img_cnn_shape = img_cnn.shape                                                  #obtain original shape
    img_cnn = img_cnn.reshape(img_cnn_shape)                                       #reshape image
    img_cnn_dataset = MSTAR_test(img_cnn, transform = transform)                   #create a custom dataset for the image
    img_cnn_loader = DataLoader(img_cnn_dataset, batch_size=len(img_cnn_dataset))  #create DataLoader for single image

    it = iter(img_cnn_loader) 
    x = next(it)
                                                                                   #reshape image to expected format
    outputs = model(x)
    _, y_pred = torch.max(outputs, 1)
    
    
    classes=['Circle', 'Diamond', 'Square', 'Star', 'Triangle']
    x = classes[y_pred.item()]
    
    return x 

# GUIs
The code below is used to generate the GUIs 

In [7]:
from tkinter import ttk, colorchooser
from PIL import Image, ImageGrab
from tkinter import messagebox

class canvas:
    def __init__(self,master):
        self.master = master
        self.color_fg = 'black'
        self.color_bg = 'white'
        self.old_x = None
        self.old_y = None
        self.penwidth = 3
        self.drawWidgets()
        self.c.bind('<B1-Motion>',self.paint)      #drwaing the line 
        self.c.bind('<ButtonRelease-1>',self.reset)
        
    def capture(self):
        x0 = root.winfo_rootx() + self.controls.winfo_width() +  self.c.winfo_x()
        y0 = root.winfo_rooty() + self.controls.winfo_height() + self.c.winfo_y()
        x1= x0 + self.c.winfo_width() 
        y1= y0 + self.c.winfo_height()
        global path
        path = 'my_drawing.png' #dummy path while managing to save image 
        ImageGrab.grab().crop((x0,y0,x1,y1)).save(path)
        messagebox.showinfo( "Information","Image Saved Sucessfully",parent=self.master)
        

    
    def classify(self):
        x = ' '
        img = cv2.imread(path,0)
        img = cv2.resize(img, (28, 28))
        arr = np.array(img-255)
        arr = np.array(arr/255.)
        img = arr.reshape(1, 28, 28).astype('float32')
    
        x = predict(img)
        lbl3 = Label(self.controls, text="This image is a: ",font=("Helvetica", 14)).grid(row=4,column=0)
        lbl4 = Label(self.controls, text=x, fg='blue', font=("Helvetica", 16)).grid(row=4,column=1)      
    
    def paint(self,e):
        if self.old_x and self.old_y:
            self.c.create_line(self.old_x,self.old_y,e.x,e.y,width=self.penwidth,fill=self.color_fg,capstyle=ROUND,smooth=True)

        self.old_x = e.x
        self.old_y = e.y

    def reset(self,e):    #reseting or cleaning the canvas 
        self.old_x = None
        self.old_y = None      

    def changeW(self,e): #change Width of pen through slider
        self.penwidth = e
           
    def clear(self):
        self.c.delete(ALL)

    def change_fg(self):  #changing the pen color
        self.color_fg=colorchooser.askcolor(color=self.color_fg)[1]

    def change_bg(self):  #changing the background color canvas
        self.color_bg=colorchooser.askcolor(color=self.color_bg)[1]
        self.c['bg'] = self.color_bg

    
    def drawWidgets(self):
        self.controls = Frame(self.master,padx = 5,pady = 5)
        Label(self.controls, text='Pen Width:',font=("Helvetica",12)).grid(row=0,column=0)
        
        
        self.slider = ttk.Scale(self.controls,from_= 3, to = 100,command=self.changeW,orient=VERTICAL)
        self.slider.set(self.penwidth)
        self.slider.grid(row=0,column=1,ipadx=30)
        self.controls.pack(side=LEFT)
        
        
        self.c = Canvas(self.master,width=400,height=300,bg=self.color_bg,)
        self.c.pack(fill=BOTH,expand=True)
        btn1 = Button(self.controls, text='Save Image', command=self.capture).grid(row=1,column=2)
        btn2 = Button(self.controls, text='Classify Image', command=self.classify).grid(row=1,column=3)
        
        
        menu = Menu(self.master)
        self.master.config(menu=menu)
        filemenu = Menu(menu)
        colormenu = Menu(menu)
        menu.add_cascade(label='Colors',menu=colormenu)
        colormenu.add_command(label='Brush Color',command=self.change_fg)
        colormenu.add_command(label='Background Color',command=self.change_bg)
        optionmenu = Menu(menu)
        menu.add_cascade(label='Options',menu=optionmenu)
        optionmenu.add_command(label='Clear Canvas',command=self.clear)
        optionmenu.add_command(label='Exit',command=self.master.destroy)

In [11]:
from tkinter import*
from tkinter import filedialog
import tkinter as tk
from PIL import Image, ImageTk

def draw():
    window = Tk()
    canvas(window)
    window.title('Drawing Classifier')
    window.mainloop()


def select_image():
    global panelA
    global path 
    path = filedialog.askopenfilename()
    x = ''
    lbl2.config(text=" " , fg='black', font=("Helvetica", 16))
    
# ensure a file path was selected
    if len(path) > 0:
        # load the image from disk
        image = cv2.imread(path)
        image = cv2.resize(image, (200,200))
        #  swap channelse from BGR to RGB order
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # convert the images to PIL format
        image = Image.fromarray(image)
        #convert Image to ImageTk format
        image = ImageTk.PhotoImage(image)       
        
    if panelA is None:
        panelA = Label(image=image)
        panelA.image = image
        panelA.pack(side="left", padx=10, pady=10)
    else:
        panelA.configure(image=image)
        panelA.image = image

def classify_image():
    img = cv2.imread(path,0)
    img = cv2.resize(img, (28, 28))
    arr = np.array(img-255)
    arr = np.array(arr/255.)
    img = arr.reshape(1, 28, 28).astype('float32')
    
    x = predict(img)
    
    lbl=Label(text="This image is a: ", fg='black', font=("Helvetica", 16))
    lbl2.config(text=x, fg='blue', font=("Helvetica", 20))
    lbl.place(x=230, y=120)
    lbl2.place(x=245, y=150)
         

root = Tk()
frm = Frame(root)
frm.pack(side=BOTTOM, padx=15, pady=15)
panelA = None
path = None
lbl2=Label()

btn1 = Button(frm, text="Upload Image", command=select_image)
btn1.pack(side=tk.LEFT)

btn2 = Button(frm, text='Draw Image', command=draw)
btn2.pack(side=tk.LEFT, padx=10)

btn = Button(frm, text='Classify Image', command=classify_image)
btn.pack(side=tk.LEFT, padx=10)

btn = Button(frm, text='Exit', command=root.destroy)
btn.pack(side=tk.LEFT, padx=20)

root.title("Image Classifier")
root.geometry("450x300")
root.mainloop()