# Ch3.1 MNIST

In [1]:
from sklearn.datasets import fetch_openml
import numpy as np

mnist = fetch_openml('mnist_784', version=1)

mnist.keys()


dict_keys(['data', 'target', 'feature_names', 'DESCR', 'details', 'categories', 'url'])

In [2]:
mnist['details']

{'id': '554',
 'name': 'mnist_784',
 'version': '1',
 'format': 'ARFF',
 'upload_date': '2014-09-29T03:28:38',
 'licence': 'Public',
 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff',
 'file_id': '52667',
 'default_target_attribute': 'class',
 'tag': ['AzurePilot',
  'OpenML-CC18',
  'OpenML100',
  'study_1',
  'study_123',
  'study_41',
  'study_99',
  'vision'],
 'visibility': 'public',
 'status': 'active',
 'processing_date': '2018-10-03 21:23:30',
 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}

In [3]:
X, y = mnist['data'], mnist['target']


In [4]:
X.shape

(70000, 784)

In [5]:
y.shape

(70000,)

In [6]:
# reshape one row of data to a 28x28 image
import matplotlib as mpl
import matplotlib.pyplot as plt

for i in range(5):
    some_digit = X[i]
    some_digit_image = some_digit.reshape(28,28)
    
    plt.imshow(some_digit_image, cmap='binary')
    plt.axis('off')
    plt.show()
    
    print(y[i])


<Figure size 640x480 with 1 Axes>

5


<Figure size 640x480 with 1 Axes>

0


<Figure size 640x480 with 1 Axes>

4


<Figure size 640x480 with 1 Axes>

1


<Figure size 640x480 with 1 Axes>

9


In [7]:
y[0]

'5'

In [8]:
# cast string to integer
y = y.astype(np.uint8)

In [9]:
y[0]

5

In [10]:
# get the training and test sets
X_train, X_test, y_train, y_test = X[:6000], X[6000:], y[:6000], y[6000:]

# 3.2 training a binary classifier

In [11]:
y_train_5 = (y_train == 5) # True for all 5s, False for all other digits
y_test_5  = (y_test  == 5)

In [12]:
y_train_5

array([ True, False, False, ..., False, False, False])

In [13]:
y_train_5[0]

True

In [14]:
y_train_5.shape

(6000,)

In [15]:
# apply SGD classifier
from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)




SGDClassifier(alpha=0.0001, average=False, class_weight=None,
       early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,
       l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=None,
       n_iter=None, n_iter_no_change=5, n_jobs=None, penalty='l2',
       power_t=0.5, random_state=42, shuffle=True, tol=None,
       validation_fraction=0.1, verbose=0, warm_start=False)

In [16]:
some_digit = X[0]
sgd_clf.predict([some_digit])

array([ True])

# 3.3 performance measures

# 3.4 measuring accuracy using cross-validation

In [17]:
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone

skfolds = StratifiedKFold(n_splits=3, random_state=42)

for train_index, test_index in skfolds.split(X_train, y_train_5):
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = y_train_5[train_index]
    
    X_test_fold = X_train[test_index]
    y_test_fold = y_train_5[test_index]
    
    clone_clf.fit(X_train_folds, y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    
    n_correct = sum(y_pred == y_test_fold)
    
    print(n_correct / len(y_pred))




0.9565217391304348
0.969
0.9459729864932466




In [18]:
from sklearn.model_selection import cross_val_score

cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring='accuracy')




array([0.95652174, 0.969     , 0.94597299])

In [19]:
from sklearn.base import BaseEstimator

class Never5Classifier(BaseEstimator):
    def fit(sel, X, y=None):
        pass
    def predict(self, X):
        return np.zeros((len(X), 1), dtype=bool)


In [20]:
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring='accuracy')


array([0.91 , 0.915, 0.918])

# 3.5 confusion matrix

In [21]:
from sklearn.model_selection import cross_val_predict

y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)




In [22]:
from sklearn.metrics import confusion_matrix

confusion_matrix(y_train_5, y_train_pred)


array([[5374,  112],
       [ 145,  369]])