In [1]:
# check if python >= 3.5
import sys
assert sys.version_info >= (3, 5)

# check if scikit-learn >= 0.20
import sklearn
assert sklearn.__version__ >= '0.20'

# common import
import numpy as np
import os

# set random seed
np.random.seed(42)

# plot pretty figure
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# where to save figure
project_root_dir = '.'
chapter_name = 'classification'
images_path = os.path.join(project_root_dir, 'images', chapter_name)
os.makedirs(images_path, exist_ok=True)

def save_fig(fig_id, 
             tight_layout=True, 
             fig_extension='png',
             fig_resolution=300
             ):
    
    fig_path = os.path.join(images_path, fig_id + '.' + fig_extension)
    
    print('saving figure', '\t:', fig_id)
    
    if tight_layout:
        plt.tight_layout()
        
    plt.savefig(fig_path, format=fig_extension, dpi=fig_resolution)

# MNIST

The MNIST dataset is a set of 70,000 handwritten digits. 


In [1]:
from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784', version=1)
mnist.keys()

In [15]:
mnist_feature_names = mnist['feature_names']

cnt = 1
for feature in mnist_feature_names:
    print('feature', str(cnt),'\t:', feature)
    cnt += 1
    

In [3]:
mnist_descr = mnist['DESCR']
print(mnist_descr)

In [10]:
mnist_details = mnist['details']
#print(mnist_details)

for key, value in mnist_details.items():
    print(key, ':\n', value, '\n')
    

In [12]:
mnist_categories = mnist['categories']
print(mnist_categories)


In [13]:
mnist_url = mnist['url']
print(mnist_url)


In [16]:
# read in the data & target
X, y = mnist['data'], mnist['target']

print('X shape', '\t:', X.shape)
print('y shape', '\t:', y.shape)

In [20]:
# plot one digit
import matplotlib as mpl
import matplotlib.pyplot as plt

# get the first row in 'data'
some_digit = X[0]
# reshape the row into a 2D array of 28 rows and 28 columns
some_digit_image = some_digit.reshape(28, 28)
print(some_digit_image)

In [21]:
plt.imshow(some_digit_image,
           cmap='binary')
plt.axis('off')
plt.show()

In [24]:
# print the target value corresponding to X[0]
print('y[0] is:')
print(y[0])
print('which is ', type(y[0]))


In [26]:
# convert y into integer
y = y.astype(np.uint8)
print('after converting, y[0] is:')
print(y[0])
print('which is ', type(y[0]))


In [30]:
# function to plot digits

def plot_digits(instances, images_per_row=10, **options):
    image_size = 28
    images_per_row = min(len(instances), images_per_row)
    images = [instance.reshape(image_size, image_size) for instance in instances]
    n_rows = (len(instances) - 1) // images_per_row + 1
    row_images = []
    n_empty = n_rows * images_per_row - len(instances)
    images.append(np.zeros((image_size, image_size * n_empty)))
    
    for row in range(n_rows):
        r_images = images[row * images_per_row : (row + 1) * images_per_row]
        row_images.append(np.concatenate(r_images, axis=1))
    image = np.concatenate(row_images, axis=0)
    plt.imshow(image, cmap=mpl.cm.binary, **options)
    plt.axis('off')

# plot digits 

plt.figure(figsize=(9,9))

# get the first 200 rows (digits) from X
example_images = X[:200]

plot_digits(example_images, images_per_row=10)
save_fig('the_first_200_digits')
plt.show()

In [36]:
# get the training and test data set
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]


In [37]:
# print the last digit in X_train
the_last_digit_in_X_train = X_train[-1]
the_last_digit_in_X_train_image = the_last_digit_in_X_train.reshape(28, 28)
print('the last digit in X_train is:')
print(the_last_digit_in_X_train_image)
plt.imshow(the_last_digit_in_X_train_image, cmap='binary')
plt.axis('off')
plt.show()

In [39]:
# print the last value in y_train
the_last_value_in_y_train = y_train[-1]
print('the last digit in y_train is', '\t:', the_last_value_in_y_train)


# training a binary classifier