In [None]:
%matplotlib inline
import math, sys, os
import numpy as np
from numpy.linalg import norm
from PIL import Image
from matplotlib import rcParams, rc, pyplot as plt
from scipy.ndimage import imread
from skimage.measure import block_reduce

from scipy.ndimage.filters import correlate, convolve
from ipywidgets import interact, interactive, fixed
from ipywidgets.widgets import *
rc('animation', html='html5')
rcParams['figure.figsize'] = 3, 6
%precision 4
np.set_printoptions(precision=4, linewidth=100)

In [None]:
'''
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')
images, labels = mnist.train.images, mnist.train.labels
images = images.reshape((5500, 28, 28))
np.savez_compressed('MNIST_data/train', images=images, labels=labels)
'''
1

In [None]:
def plots(ims, interp=False, titles=None):
    ims = np.array(ims)
    mn, mx = ims.min(), ims.max()
    f = plt.figure(figsize = (12, 24))
    for i in range(len(ims)):
        sp = f.add_subplot(1, len(ims), i+1)
        if not titles is None:
            sp.set_title(titles[i], fontsize=18)
        ply.imshow(ims[i], interpolation=None if interp else 'none', vmin=mn, vmax=mx)
    
def plot(im, interp=False):
    f = plt.figure(figsize=(3,6), frameon=True)
    plt.imshow(im, interpolation=None if interp else 'none')

plt.gray()
plt.close()


In [None]:
data = np.load('MNIST_data/train.npz')
images = data['images']
labels = data['labels']
n = len(images)
images.shape

In [None]:
plot(images[0])

In [None]:
labels[0]

In [None]:
plots(images[:5], titles=labels[:5])

In [None]:
top = [[-1,-1,-1], [1, 1, 1], [0, 0, 0]]
plot(top)

In [None]:
r = (0, 28)
def zoomin(x1=0, x2=28, y1=0, y2=28):
    plot(images[0, y1:y2, x1:x2])
w=interactive(zoomin, x1=r, x2=r, y1=r, y2=r)
w

In [None]:
k = w.kwargs
dims = np.index_exp[k['y1']:k['y2']:1,k['x1']:k['x2']]
print(k, dims)

In [None]:
images[0][dims]

In [None]:
corrtop = correlate(images[0], top)
corrtop[dims]

In [None]:
plot(corrtop[dims])

In [None]:
plot(corrtop)

In [None]:
np.rot90(top, 1)

In [None]:
convtop = convolve(images[0], np.rot90(top, 2))
plot(convtop)
np.allclose(convtop, convtop)

In [None]:
straights = [np.rot90(top, i) for i in range(4)]
plots(straights)

In [None]:
br = [[0,0,1], [0,1,-1.5], [1,-1.5,0]]
diags = [np.rot90(br, i) for i in range(4)]
plots(diags)

In [None]:
rots = straights + diags
corrs = [correlate(images[0], rot) for rot in rots]
plots(corrs)

In [None]:
def pool(im):
    return block_reduce(im, (7,7), np.max)
plots([pool(im) for im in corrs])

In [None]:
eights=[images[i] for i in range(n) if labels[i] == 8]
ones = [images[i] for i in range(n) if labels[i] == 1]


In [None]:
plots(eights[:5])
plots(ones[:5])

In [None]:
pool8 = [np.array([pool(correlate(im, rot)) for im in eights]) for rot in rots]

In [None]:
len(pool8), pool8[0].shape

In [None]:
plots(pool8[0][0:5])

In [None]:
def normalize(arr):
    return (arr - arr.mean())/arr.std()


In [None]:
filt8 = np.array([ims.mean(axis=0) for ims in pool8])
filt8 = normalize(filt8)

In [None]:
plots(filt8)

In [None]:
pool1 = [np.array([pool(correlate(im, rot)) for im in ones]) for rot in rots]
filt1 = np.array([ims.mean(axis=0) for ims in pool1])
filt1 = normalize(filt1)
plots(filt1)

In [None]:
def pool_corr(im): 
    return np.array([pool(correlate(im, rot)) for rot in rots])


In [None]:
plots(pool_corr(eights[0]))

In [72]:
def see(a, b):
    return ((a-b)**2).sum()

def is8_n2(im):
    return 1 if see(pool_corr(im), filt1) > see(pool_corr(im), filt8) else 0


In [73]:
see(pool_corr(eights[0]), filt8), see(pool_corr(eights[0]), filt1)

(126.77776, 181.26105)

In [74]:
[np.array([is8_n2(im) for im in ims]).sum() for ims in [eights, ones]]

[5223, 287]

In [None]:
[np.array([(1-is8_n2(im)) for im in ims]).sum() for ims in [eights, ones]]

def n1(a, b):
    return (np.fabs(a - b)).sum()

def is8_n1(im):
    return 1 if n1(pool_corr(im), filt1) > n1(pool_corr(im), filt8) else 0

[np.array([is8_n1(im) for im in ims]).sum() for ims in [eights, ones]]
[np.array([(1 - is8_n1(im)) for im in ims]).sum() for ims in [eights, ones]]