<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#MNIST-Hand-Writing-Digits-Predictor" data-toc-modified-id="MNIST-Hand-Writing-Digits-Predictor-1">MNIST Hand Writing Digits Predictor</a></span></li></ul></div>

## MNIST Hand Writing Digits Predictor

In [None]:
!pip3 install ipycanvas

In [40]:
# model
import tensorflow as tf 
model = tf.keras.models.load_model('mnist_digit_model.h5')

def model_predict(data):
    """
    data: numpy.array shape: (784,)
    """
    data = np.expand_dims(data, axis=0)
    #return np.argmax(model.predict(data))
    return model.predict(data)

In [71]:
from ipycanvas import Canvas, hold_canvas
import numpy as np

In [91]:
_PIXEL_SIZE = 12
_NUM_PIXELS = 28
_WIDTH = _PIXEL_SIZE * _NUM_PIXELS
_HEIGHT = _PIXEL_SIZE * _NUM_PIXELS
_BG_COLOR = "#ffffff"
_FG_COLOR = "#000000"
_GRID_COLOR = "#999999"
_FRAME_COLOR = "#222299"

In [116]:
def create_canvas() -> Canvas:
    return Canvas(width=_WIDTH, height=_HEIGHT, sync_image_data=True)

def redraw(canvas: Canvas):
    canvas.clear()
    canvas.fill_style = _BG_COLOR
    canvas.fill_rect(0, 0, canvas.width, canvas.height)
    draw_grid(canvas)
    
def draw_grid(canvas: Canvas):
    with hold_canvas(canvas):
        canvas.stroke_style =_GRID_COLOR
        canvas.stroke_rect(0, 0, _WIDTH, _HEIGHT)    
        for i in range(_NUM_PIXELS):
            canvas.stroke_line(i * _PIXEL_SIZE, 0, i * _PIXEL_SIZE, _HEIGHT)
            canvas.stroke_line(0, i * _PIXEL_SIZE, _WIDTH, i * _PIXEL_SIZE)

        canvas.stroke_style =_FRAME_COLOR
        offset = 4 * _PIXEL_SIZE
        canvas.stroke_rect(offset, offset, _WIDTH - 2 * offset, _HEIGHT - 2 * offset)

def canvas_mouse_move(x, y):
    canvas.fill_style = _FG_COLOR
    with hold_canvas(canvas):
        #px = ((x // _PIXEL_SIZE) - 1) * _PIXEL_SIZE
        #py = ((y // _PIXEL_SIZE) - 1) * _PIXEL_SIZE
        #canvas.fill_rect(px, py, _PIXEL_SIZE * 3, _PIXEL_SIZE * 3)
        px = x // _PIXEL_SIZE * _PIXEL_SIZE
        py = y // _PIXEL_SIZE * _PIXEL_SIZE
        canvas.fill_rect(px, py, _PIXEL_SIZE, _PIXEL_SIZE)   
        blur_edges(px, py)

def blur_edges(px, py):
    dirs = [[-1,-1], [0,-1], [1,-1], [1,0], [1,1], [0,1], [-1,1], [-1,0]]
    for direction in dirs:
        x = px + direction[0]*_PIXEL_SIZE
        y = py + direction[1]*_PIXEL_SIZE
        canvas.fill_style = get_blur_color(x, y)
        canvas.fill_rect(x, y, _PIXEL_SIZE, _PIXEL_SIZE)   

def get_blur_color(px, py):
    canvas_data = canvas.get_image_data()    
    dirs = [[-1,-1], [0,-1], [1,-1], [1,0], [1,1], [0,1], [-1,1], [-1,0]]
    acc = 0
    for direction in dirs:
        x = px + direction[0]*_NUM_PIXELS + _NUM_PIXELS//2
        y = py + direction[1]*_NUM_PIXELS + _NUM_PIXELS//2
        acc += canvas_data[int(x), int(y), 0]
    acc = int(acc/18)
    return f"#{acc:02x}{acc:02x}{acc:02x}"
    
def canvas_mouse_down(x, y):
    pass
    
def canvas_mouse_up(x, y):
    pass    

In [117]:
from ipywidgets import Output, Button, Label

# widgets
out = Output()
bt_clear = Button(description='Clear')
bt_predict = Button(description='Predict')
lb_predict = Label(value="Prediction: ")

@out.capture()
def handle_mouse_down(x, y):
    global painting
    if not painting:
        painting = True        
    #print(f'{painting} Mouse down event:', x, y)
    #print(canvas.get_image_data(x,y,1,1))
        
@out.capture()
def handle_mouse_up(x, y):
    global painting
    if painting:
        painting = False
        #blur_edges()
    #print(f'{painting} Mouse up event:', x, y)

@out.capture()
def handle_mouse_move(x, y):
    global painting
    if painting:
        canvas_mouse_move(x, y)
        
def bt_clear_on_click(b):
    painting = False
    out.clear_output()    
    lb_predict.value = "Prediction:"
    redraw(canvas)
    data = canvas.get_image_data(0,0,_WIDTH,_HEIGHT)
    print(data)
    print(data.shape)
    
@out.capture()    
def bt_predict_on_click(b):
    canvas_data = canvas.get_image_data(0,0,_WIDTH,_HEIGHT)
    # nos quedamos SÓLO el pixel central de cada celda de la cuadrícula
    # nos quedamos SÓLO con la primera componente del color (grayscale)
    # tenemos que invertir los valores    
    pixel_data = np.array([255 - (canvas_data[(i*_PIXEL_SIZE) + _PIXEL_SIZE//2, (j*_PIXEL_SIZE) + _PIXEL_SIZE//2, 0]) for i in range(_NUM_PIXELS) for j in range(_NUM_PIXELS)])
    
    print(f"\nData:")
    for i in range(_NUM_PIXELS):
        for j in range(_NUM_PIXELS):
            print(f"{pixel_data[i*_NUM_PIXELS + j]:^5d}", end=" ")
            #print(f"{255 - canvas_data[(i*_PIXEL_SIZE) + _PIXEL_SIZE//2, (j*_PIXEL_SIZE) + _PIXEL_SIZE//2, 0]:^5d}", end=" ")
        print()
        
    prediction = model_predict(pixel_data)
    lb_predict.value = f"Prediction: {np.argmax(prediction)}"
    print(f"Prediction: {prediction}")

"""
def blur_edges():
    canvas_data = canvas.get_image_data(0,0,_WIDTH,_HEIGHT)
    pixel_data = np.array([canvas_data[(i*_PIXEL_SIZE) + _PIXEL_SIZE//2, (j*_PIXEL_SIZE) + _PIXEL_SIZE//2, 0] \
                           for i in range(_NUM_PIXELS) for j in range(_NUM_PIXELS)])
    pixel_data = np.reshape(pixel_data,(-1, _NUM_PIXELS))
    new_image = np.zeros((_NUM_PIXELS, _NUM_PIXELS), dtype=int)
    for i in range(_NUM_PIXELS):
        for j in range(_NUM_PIXELS):
            if i==0 or j==0 or i==(_NUM_PIXELS-1) or j==(_NUM_PIXELS-1):
                continue
            new_image[i, j] = pixel_data[i-1:i+2, j-1:j+2].sum()//9
    
    with hold_canvas(canvas):
        for i in range(_NUM_PIXELS):
            for j in range(_NUM_PIXELS):
                canvas.fill_style = f"#{new_image[i,j]:02x}{new_image[i,j]:02x}{new_image[i,j]:02x}"
                canvas.fill_rect(i*_PIXEL_SIZE, j*_PIXEL_SIZE, _PIXEL_SIZE, _PIXEL_SIZE)
""" 
    
# Painting flag
painting = False

canvas = create_canvas()
redraw(canvas)

canvas.on_mouse_down(handle_mouse_down)
canvas.on_mouse_up(handle_mouse_up)
canvas.on_mouse_move(handle_mouse_move)
bt_clear.on_click(bt_clear_on_click)
bt_predict.on_click(bt_predict_on_click)

display(canvas, bt_clear, bt_predict, lb_predict, out)

Canvas(height=336, sync_image_data=True, width=336)

Button(description='Clear', style=ButtonStyle())

Button(description='Predict', style=ButtonStyle())

Label(value='Prediction: ')

Output()

[[[165 165 165 255]
  [178 178 178 255]
  [178 178 178 255]
  ...
  [178 178 178 255]
  [178 178 178 255]
  [178 178 178 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 ...

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [204 204 204 255]
  [204 204 204 255]
  ...
  [204 204 204 255]
  [204 204 204 255]
  [204 204 204 255]]]
(336, 336, 4)
[[[165 165 165 255]
  [178 178 178 255]
  [178 178 178 255]
  ...
  [178 178 178 255]
  [178 178 178 255]
  [178 178 178 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255

[[[165 165 165 255]
  [178 178 178 255]
  [178 178 178 255]
  ...
  [178 178 178 255]
  [178 178 178 255]
  [178 178 178 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 ...

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [204 204 204 255]
  [204 204 204 255]
  ...
  [204 204 204 255]
  [204 204 204 255]
  [204 204 204 255]]]
(336, 336, 4)
[[[165 165 165 255]
  [178 178 178 255]
  [178 178 178 255]
  ...
  [178 178 178 255]
  [178 178 178 255]
  [178 178 178 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255

[[[165 165 165 255]
  [178 178 178 255]
  [178 178 178 255]
  ...
  [178 178 178 255]
  [178 178 178 255]
  [178 178 178 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 ...

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255]
  [255 255 255 255]
  [204 204 204 255]]

 [[178 178 178 255]
  [204 204 204 255]
  [204 204 204 255]
  ...
  [204 204 204 255]
  [204 204 204 255]
  [204 204 204 255]]]
(336, 336, 4)
[[[165 165 165 255]
  [178 178 178 255]
  [178 178 178 255]
  ...
  [178 178 178 255]
  [178 178 178 255]
  [178 178 178 255]]

 [[178 178 178 255]
  [255 255 255 255]
  [255 255 255 255]
  ...
  [255 255 255 255