# Building a Manual Classifier

In this next section, we are going to build a graphical interface that allows us to label images quickly for training data. Then the user will classify images manually one by one until they are ready to try to train a model.

In [None]:
# Create a function that generates N random image and asks the user to type the corresponding digit as a label.
# !pip install PyQt5
%matplotlib qt

import numpy as np
import matplotlib.pyplot as plt
from random import randint
from PIL import Image
import os
from skimage.transform import resize
from sklearn.svm import SVC


In [None]:
# The drawings are pretty big.  Let's crop away the white space and then resize them to 12x12 pixels.
from skimage.transform import resize

def findGroup(vec):
    # find the sub-vector that contains the largest number of sequential non-zero elements
    min = 0
    max = 0
    count = 1
    maxCount = 1
    for i in range(1,len(vec)):
        if vec[i] > 0 or vec[i-1] > 0:
            count += 1
            if count > maxCount:
                max = i
                min = i - count + 1
                maxCount = count
        else:
            count = 0
    return min, max
def cropAndResize(imgs, size=(12, 12)):
    imgsOut = []
    for i in range(len(imgs)):
        # Find the average of the first lines of all images.
        imgCopy = imgs[i].copy()
        imgCopy = np.mean(imgCopy, axis=2)
        
        # find which rows and columns are not white
        rows = np.min(imgCopy,axis=1) < 200
        min0, max0 = findGroup(rows)

        cols = np.min(imgCopy, axis=0) < 200
        min1, max1 = findGroup(cols)
        imgCopy = imgCopy[min0:max0, min1:max1]

        # pad image on left and right to make square
        pad = (imgCopy.shape[0] - imgCopy.shape[1]) // 3
        if pad > 0:
            imgCopy = np.pad(imgCopy, ((0, 0), (pad, pad)), mode='constant', constant_values=255)

        imgsOut.append(resize(imgCopy, size, anti_aliasing=True))
        
    return imgsOut

In [None]:
Labels = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
nLabels = len(Labels)

folder='handDigits/images'
imageNames = os.listdir(folder)
nImages = len(imageNames)

# Initialize lists of images and labels
allImageLabels = [] 
allDigitList = []
for i in range(nImages):
    img = Image.open(folder + '/' + imageNames[i])
    #convert to numpy array
    img = np.array(img)
    allDigitList.append(img)
    allImageLabels.append(None)

# Crop and resize images
allDigitListCAR = cropAndResize(allDigitList)

# Initialize lists for training
trainingDigitList = [None for i in range(nImages)]
trainingLabelList = [None for i in range(nImages)]

In [None]:
def getLabels(nImg2Show=25, labels=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], allImageLabels=allImageLabels, 
              allDigitList=allDigitList, allDigitListCAR=allDigitListCAR, trainingDigitList=trainingDigitList, 
              trainingLabelList=trainingLabelList):
    # Make an image grid
    nRows = int(np.ceil(np.sqrt(nImg2Show)))
    nCols = nRows
    
    # Create figure with clear background
    fig, axs = plt.subplots(nRows, nCols, figsize=(10, 8), facecolor='white')
    fig.suptitle('Left click: increment label | Right click: quit')
    
    # Ensure axs is always a flat array we can index into
    axs_flat = axs.flatten() if hasattr(axs, 'flatten') else [axs] if not hasattr(axs, '__len__') else axs
    
    # Load the images and display them in the grid
    # imageID -- sampled image index without replacement
    imageID = np.random.choice(nImages, nImg2Show, replace=False)

    for k in range(nImg2Show):
        idx = imageID[k]
        img = allDigitList[idx]
        axs_flat[k].imshow(img, cmap='gray')
        axs_flat[k].axis('off')
    
    # Store text annotations for highlighting and image label indexes
    highlight_texts = {}
    image_label_indexes = {}  # Keep track of label index for each image
    
    def highlight_selection(index, label=None, selected=True):
        if index < len(axs_flat):
            ax = axs_flat[index]
            
            # Remove existing annotation if present
            if index in highlight_texts:
                highlight_texts[index].remove()
                del highlight_texts[index]
            
            if selected and label is not None:
                # Add a label text in the top-left corner
                text = ax.text(0.1, 0.9, label, color='red', fontsize=20, 
                            transform=ax.transAxes, ha='center', va='center',
                            bbox=dict(facecolor='white', alpha=0.7, pad=0))
                highlight_texts[index] = text
            
            fig.canvas.draw_idle()
    
    # If we have fewer images than subplots, hide the extras. Also highlight prelabelled images.
    for k in range(len(axs_flat)):
        if k >= nImg2Show:
            axs_flat[k].axis('off')
        elif k < nImg2Show:
            # Initialize image label indexes
            if allImageLabels[imageID[k]] is not None:
                # Find the index of the pre-existing label
                label_idx = labels.index(allImageLabels[imageID[k]]) if allImageLabels[imageID[k]] in labels else 0
                image_label_indexes[k] = label_idx
                highlight_selection(k, label=labels[label_idx], selected=True)
                allImageLabels[imageID[k]] = label_idx
                trainingDigitList[imageID[k]] = allDigitListCAR[imageID[k]]
                trainingLabelList[imageID[k]] = labels[label_idx]
            else:
                # Start with no label (None)
                image_label_indexes[k] = None
                highlight_selection(k, label=None, selected=False)
    
    plt.tight_layout()
    
    # Mouse click event handler
    def on_click(event):
        nonlocal image_label_indexes
        
        if event.inaxes is None:
            return
            
        # Find which subplot was clicked
        clicked_idx = -1
        for idx, ax in enumerate(axs_flat):
            if event.inaxes == ax and idx < nImg2Show:
                clicked_idx = idx
                break
                
        if clicked_idx >= 0:
            if event.button == 1:  # Left click - increment label
                current_label_idx = image_label_indexes.get(clicked_idx)
                
                if current_label_idx is None:
                    # First time being labeled
                    new_label_idx = 0
                else:
                    # Increment label, cycling back to beginning if needed
                    new_label_idx = (current_label_idx + 1) % len(labels)

                # Update the label index for this image
                image_label_indexes[clicked_idx] = new_label_idx
  
                # Update the label display
                highlight_selection(clicked_idx, label=labels[new_label_idx], selected=True)
                
                # Update allImageLabels with the new label
                allImageLabels[imageID[clicked_idx]] = labels[new_label_idx]
                trainingDigitList[imageID[clicked_idx]] = allDigitListCAR[imageID[clicked_idx]]
                trainingLabelList[imageID[clicked_idx]] = labels[new_label_idx]
                
                # print(f"Image {clicked_idx} (ID: {imageID[clicked_idx]}) labeled as '{labels[new_label_idx]}'")
            
            elif event.button == 3:  # Right click - quit and save labels
                # Update allImageLabels for all labeled images before quitting
                for idx, label_idx in image_label_indexes.items():
                    if label_idx is not None:
                        allImageLabels[imageID[idx]] = labels[label_idx]
                
                plt.close(fig)
                # print("Classification complete!")
    
    # Connect the event handler
    cid = fig.canvas.mpl_connect('button_press_event', on_click)
    
    # Show instructions
    print("Click on an image to increment its label. Right-click anywhere to finish.")
    print(f"Available labels: {', '.join(labels)}")
    
    plt.show()

    return 

In [None]:
nImg2Show=4
getLabels(nImg2Show, Labels)

In [None]:
trainingDigitListCAR = [img for img in trainingDigitList if img is not None]
trainingLabelListCAR = [label for label in trainingLabelList if label is not None]

# Loop through the label list and create a matrix where each row is the flattened version of the corresponding image in the label list.
dataMatrix = np.zeros((len(trainingDigitListCAR),len(trainingDigitListCAR[0].flatten())))
for i in range(len(trainingDigitListCAR)):
    dataMatrix[i] = trainingDigitListCAR[i].flatten()

# Display the first 20 images in the label list in a 3x12 grid, with the corresponding label as the title.
fig, axs = plt.subplots(4, 5, figsize=(6, 5))
for i in range(4):
    for j in range(5):
        axs[i, j].imshow(trainingDigitListCAR[(i * 5 + j)], cmap='gray')
        axs[i, j].set_title(trainingLabelListCAR[(i * 5 + j)])
        axs[i, j].axis('off')

In [None]:
# Associate the labelMatrix with the labels list, and build a classifier using the images in the training set.
from sklearn.neighbors import KNeighborsClassifier
classifier = KNeighborsClassifier(n_neighbors=1)
nTrain = len(trainingLabelListCAR)
classifier.fit(dataMatrix[:nTrain], trainingLabelListCAR[:nTrain])

# Use the classifier to predict the labels of the last 24 images in the label list.
predictedLabels = classifier.predict(dataMatrix)

print('Training Accuracy:', np.mean(predictedLabels == trainingLabelListCAR))

In [None]:
# Let's combine the previous steps (sorting training data, cropping & resizing, and training the classifier) into a single function.
def updateClassifier(nImg2Show=25, classifier=None, allImageLabels=allImageLabels):
    # If classifier exists, use it to update label guesses. Otherwise, create a new one.
    if classifier is None:
        # classifier = KNeighborsClassifier(n_neighbors=1)
        classifier = SVC()
    else:
        # Create a data matrix
        dataMatrixAll = np.zeros((len(allDigitListCAR), len(allDigitListCAR[0].flatten())))
        for i in range(len(allDigitListCAR)):
            dataMatrixAll[i] = allDigitListCAR[i].flatten()
            allImageLabels = classifier.predict(dataMatrixAll)
    
    # Get the labels for new images
    getLabels(nImg2Show, Labels, allImageLabels=allImageLabels)    
    
    # Crop and resize the images
    trainingDigitListCAR = [img for img in trainingDigitList if img is not None]
    trainingLabelListCAR = [label for label in trainingLabelList if label is not None]

    # Create a data matrix
    dataMatrix = np.zeros((len(trainingDigitListCAR), len(trainingDigitListCAR[0].flatten())))
    for i in range(len(trainingDigitListCAR)):
        dataMatrix[i] = trainingDigitListCAR[i].flatten()

    # Train the classifier
    if classifier is None:
        classifier = KNeighborsClassifier(n_neighbors=1)
    else:
        pass        
    
    print(f'Training classifier using {dataMatrix.shape[0]} images.')
    classifier.fit(dataMatrix, trainingLabelListCAR)

    # Print the training accuracy
    predictedLabels = classifier.predict(dataMatrix)
    print('Training Accuracy:', np.mean(predictedLabels == trainingLabelListCAR))

    # Return the classifier
    return classifier   

In [None]:
classifier = updateClassifier(nImg2Show=4, classifier=None, allImageLabels=allImageLabels)

In [None]:
classifier = updateClassifier(nImg2Show=16, classifier=classifier, allImageLabels=allImageLabels)

In [None]:
len([img for img in trainingDigitList if img is not None])