In [1]:
import numpy as np
from keras.datasets import mnist
import cv2
import math
from scipy import ndimage




In [2]:
class NaiveBayes:
    def __init__(self):
        self.means = []
        self.variances = []
        self.priors = []

    def fit(self, x, y):
        self.classes = np.unique(y)

        for i in self.classes:
            self.priors.append(np.mean(y == i))
            x_i = x[y == i]
            self.means.append(np.mean(x_i, axis = 0))
            self.variances.append(np.var(x_i, axis = 0) + 0.01575)

    def predict(self, x):
        posteriors = []

        for i in self.classes:
            log_prior = np.log(self.priors[i])
            likelihood = np.sum(np.log(self.gaussian(x, self.means[i], self.variances[i])), axis = 1)
            posterior = likelihood + log_prior
            posteriors.append(posterior)

        return np.argmax(posteriors, axis = 0)

    def gaussian(self, x, mean, variance):
        numerator = np.exp(-((x - mean) ** 2) / (2 * variance))
        denominator = np.sqrt(2 * np.pi * variance)
        
        return numerator / denominator

In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], -1) / 255.0
x_test = x_test.reshape(x_test.shape[0], -1) / 255.0

In [5]:
model = NaiveBayes()
model.fit(x_train, y_train)
y_predicted = model.predict(x_test)
accuracy = np.mean(y_predicted == y_test)
print("Accuracy: ", accuracy)

Accuracy:  0.8156


In [5]:
class ProcessImage:
    def __init__(self, image_path):
        self.path = image_path

    def preprocess(self):
        # Read the image
        img = cv2.imread(self.path, cv2.IMREAD_GRAYSCALE)
        
        # Scale to 20x20, invert (like training)
        img = cv2.resize(255 - img, (20, 20), interpolation = cv2.INTER_AREA)

        # img = cv2.GaussianBlur(img,(5,5),0)

        # Make gray into black (uniform background like training)
        _, img = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

        # Remove completely black (empty) rows/cols on all sides
        img = self.trim(img)

        # Center digit
        shiftx, shifty = self.getBestShift(img)
        shifted = self.shift(img, shiftx, shifty)
        img = shifted

        # DEBUG
        # cv2.imwrite("output.png", img)

        # Normalize the image
        img = img / 255.0

        # Reshape to 1D match the input of the model
        img = img.reshape(-1)

        return img

    def trim(self, img):
        while np.sum(img[0]) == 0:
            img = img[1:]

        while np.sum(img[:, 0]) == 0:
            img = np.delete(img, 0, 1)

        while np.sum(img[-1]) == 0:
            img = img[:-1]

        while np.sum(img[:, -1]) == 0:
            img = np.delete(img, -1, 1)

        rows, cols = img.shape

        if rows > cols:
            factor = 20.0 / rows
            rows = 20
            cols = int(round(cols * factor))
            img = cv2.resize(img, (cols, rows))
        else:
            factor = 20.0 / cols
            cols = 20
            rows = int(round(rows * factor))
            img = cv2.resize(img, (cols, rows))

        colsPadding = (int(math.ceil((28 - cols) / 2.0)), int(math.floor((28 - cols) / 2.0)))
        rowsPadding = (int(math.ceil((28 - rows)/ 2.0)), int(math.floor((28 - rows) / 2.0)))
        img = np.pad(img, (rowsPadding, colsPadding), 'constant')

        return img

    def getBestShift(self, img):
        cy, cx = ndimage.center_of_mass(img)
        rows, cols = img.shape
        shiftx = np.round(cols / 2.0 - cx).astype(int)
        shifty = np.round(rows / 2.0 - cy).astype(int)

        return shiftx, shifty

    def shift(self, img, sx, sy):
        rows,cols = img.shape
        M = np.float32([[1, 0, sx], [0, 1, sy]])
        shifted = cv2.warpAffine(img, M, (cols, rows))

        return shifted

In [7]:
def digits():
    digits_imgs = []
    suffixes = ["1", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "2", "3", "4", "5", "6", "7", "8", "9"]
    
    for num in suffixes:
        img_processor = ProcessImage('digits/digit' + (num) + '.png')
        img = img_processor.preprocess()
        digits_imgs.append(img)

    predicted_digits = model.predict(digits_imgs)
    print("Predicted Digit: ", predicted_digits)

    actual_digits = [7, 7, 0, 5, 3, 2, 1, 0, 8, 7, 4, 2, 9, 8, 5, 1, 1, 1, 7]
    print("Actual Digits: ", actual_digits)

    accuracy = np.mean(predicted_digits == actual_digits)
    print("Accuracy: ", accuracy)

digits()

[[8.42044210e-167 0.00000000e+000 1.42765746e-169 9.05882739e-143
  2.14122870e-111 1.07652274e-168 7.71056334e-260 1.00000000e+000
  6.43436705e-127 2.24726607e-107]
 [3.72409627e-167 0.00000000e+000 1.87873468e-090 1.28198357e-075
  1.56512480e-043 1.70531158e-070 8.85588473e-130 1.00000000e+000
  1.40764176e-056 1.52182204e-021]
 [1.00000000e+000 0.00000000e+000 2.29963505e-077 5.08676794e-057
  1.53469950e-013 2.71354991e-028 1.30909334e-061 2.29096769e-036
  6.82442211e-035 1.93527933e-011]
 [9.98997319e-054 1.46921978e-142 7.83575554e-053 7.81475269e-059
  1.36150062e-110 1.00000000e+000 2.74032923e-060 4.31306605e-234
  3.65788631e-019 3.81343115e-205]
 [5.72559689e-133 2.23047668e-067 1.99132086e-082 1.00000000e+000
  1.43747958e-070 4.02119188e-047 1.43168164e-190 1.05929369e-111
  4.10931678e-033 4.00471302e-093]
 [6.13229426e-110 2.37955544e-177 1.00000000e+000 8.01914367e-080
  1.71158213e-110 2.59275345e-071 3.74745063e-033 2.74197112e-226
  1.34193367e-071 3.50948569e-157

In [43]:
def r_digits():
    digits_imgs = []
    suffixes = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "11", "22", "44", "88", "99", "111"]
    
    for num in suffixes:
        img_processor = ProcessImage('r_digits/r_image_' + (num) + '.png')
        img = img_processor.preprocess()
        digits_imgs.append(img)

    predicted_digits = np.argmax(model.predict(digits_imgs), axis = 1)
    print("Predicted Digit: ", predicted_digits)

    actual_digits = [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 4, 8, 9, 1]
    print("Actual Digits: ", actual_digits)

    accuracy = np.mean(predicted_digits == actual_digits)
    print("Accuracy: ", accuracy)

r_digits()

Predicted Digit:  [1 6 3 9 9 6 7 8 9 1 2 4 3 9 2]
Actual Digits:  [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 4, 8, 9, 1]
Accuracy:  0.6666666666666666
