Example. As an example we apply least squares classification to the MNIST data set described in §4.3. The (training) data set contains 60,000 images of size 28 by 28. The number of examples per digit varies between 5421 (for digit five) and 6742 (for digit one). The pixel intensities are scaled to lie between 0 and 1. There is also a separate test set containing 10000 images.

In [3]:
import struct
import gzip
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.animation as animation
from IPython.display import HTML
%matplotlib notebook

ModuleNotFoundError: No module named 'pandas'

In [None]:
# create dictionaries to store the data
train = dict()
test = dict()

In [None]:
def get_images(filename):
    with gzip.GzipFile(Path('mnist', filename), 'rb') as f:
        magic, size, rows, cols = struct.unpack(">IIII", f.read(16))
        images = np.frombuffer(f.read(), dtype=np.dtype('B'))
    return images.reshape(size, rows,cols)

train['image'] = get_images('train-images-idx3-ubyte.gz')
test['image'] = get_images('t10k-images-idx3-ubyte.gz')

In [None]:
def get_labels(filename):
    with gzip.GzipFile(Path('mnist', filename), 'rb') as f:
        magic, num = struct.unpack(">II", f.read(8))
        labels = np.frombuffer(f.read(), dtype=np.dtype('B'))
    return labels

In [None]:
train['label'] = get_labels('train-labels-idx1-ubyte.gz')
test['label'] = get_labels('t10k-labels-idx1-ubyte.gz')

For each digit, we can define a Boolean classifier that distinguishes the digit from the other nine digits. Here we will consider classifiers to distinguish the digit zero. In a first experiment, we use the n = 28 × 28 = 784 pixel intensities as features in the least squares classifier (12.1). 

In [None]:
#scales images to 0-1
x = (train['image'].reshape(60000, -1)/255)
#train only on 0, code 0 as +1, >0 as 1
y = (train['label']>0).astype(int)*-2 + 1

In [None]:
from sklearn import linear_model as slm
lm = slm.LinearRegression()
lm.fit(x, y)
yhat = lm.predict(x)

In [None]:
fig, (ax, ax2) = plt.subplots(ncols=2, figsize=(10,5))
im = ax.imshow(lm.coef_.reshape(28,28), cmap="RdBu", vmin=-.3, vmax=.3)
cb = fig.colorbar(im, ax=ax, fraction=.045)
cb.set_ticks([-.3, -.15, 0, .15, .3])
cb.set_ticklabels([r"$\leq.3$", "-.15", "0", ".15", "$\geq.3$"])

im2 = ax2.imshow(x[1000].reshape(28,28), cmap='gray')
cb2 = fig.colorbar(im2, ax=ax2, fraction=.045)

In [None]:
alphas = np.arange(-1.1, 1.1, .1)
tpr = []
fpr = []
for i in alphas:
    yhat_roc = np.sign(yhat + i)
    tp = ((y==1) & (yhat_roc==1)).sum()
    tn = ((y==-1) & (yhat_roc==-1)).sum()
    fp = ((y==-1) & (yhat_roc==1)).sum()
    fn = ((y==1) & (yhat_roc==-1)).sum()
    
    tpr.append(tp/(tp+fn))
    fpr.append(fp/(fp+tn))

In [None]:
%%capture
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 4), 
                               gridspec_kw ={'width_ratios':[1.5,1]})

_, _, patches1 = ax1.hist([], label="y = +1", color="tab:blue", 
                           bins=50, density=True, alpha=.75)
_, _, patches2 = ax1.hist([], label="y = -1", color="tab:orange",
                           bins=50, density=True, alpha=.75)

_ = ax1.axvspan(0, 3, facecolor= 'lavender', label=r"$\hat{y}$ = +1", zorder=-6)
_ = ax1.axvspan(-3, 0, facecolor='cornsilk', label=r"$\hat{y}$ = -1", zorder=-6)
_ = ax1.axvline(x=0, color='k')
_ = ax1.set_xlim(-3,3)
_ = ax1.legend(ncol=2, loc=1, facecolor='white', framealpha=.95)

_ = ax2.plot(fpr, tpr, color='darkseagreen')
_ = ax2.set_xlabel("False Positive")
_ = ax2.set_ylabel("True Positive")

label = ax2.text([], [], "",  color='mediumseagreen')

def init():
    for p1, p2, in zip(patches1, patches2):
        p1.set_visible(False)
        p2.set_visible(False)
    return [patches1, patches2]

def animate(i):
    for p in ax1.patches:
        if isinstance(p, mpatches.Rectangle):
               p.set_visible(False)
    _, _, patches1 = ax1.hist(yhat[y==1]-alphas[i], color='tab:blue', bins=50, density=True, alpha=.75)
    _, _, patches2 = ax1.hist(yhat[y==-1]-alphas[i], color="tab:orange", bins=50, density=True, alpha=.75)
   

    ax2.collections = []   
    _ = ax2.scatter(fpr[i], tpr[i], s=30, c='seagreen', zorder=5)
    label.set_text(f'alpha: {alphas[i]:.2f}')
    offset = .015
    if alphas[i]<=.1:
        position = (fpr[i] + offset, tpr[i])
    if alphas[i]>.9:
        position = (fpr[i]-10*offset,tpr[i] - 5*offset)
    elif alphas[i]>=.6:
        position = (fpr[i]-4*offset, tpr[i] - 5*offset)
    elif alphas[i]>=.3:
        position = (fpr[i], tpr[i] - 4*offset)
    else:
        position = (fpr[i] + 2*offset, tpr[i] - 2*offset)
    label.set_position(position)
    return [patches1, patches2]

In [None]:
ani = animation.FuncAnimation(fig, animate, init_func = init, frames=len(alphas),
                             interval=500)
HTML(ani.to_jshtml())

In [None]:
ani.save('alpha.gif', dpi=80, writer='imagemagick')