In [1]:
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC

import utils

In [2]:
RANDOM_SEED = 42

In [3]:
(x_train, y_train), (x_test, y_test) = utils.load_mnist()

## SVC

In [4]:
svm_clf = SVC(gamma='auto', random_state=RANDOM_SEED)
svm_clf.fit(x_train[:1000], y_train[:1000])

SVC(gamma='auto', random_state=42)

In [5]:
svm_clf.classes_

array(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], dtype=object)

In [6]:
test_digit = x_test[0]
test_label = y_test[0]
print(f"test_digit.shape = {test_digit.shape}")
print(f"test_label = {test_label}")

test_digit.shape = (784,)
test_label = 7


In [7]:
svm_clf.predict([test_digit])

array(['7'], dtype=object)

In [8]:
scores = svm_clf.decision_function([test_digit])
print(f"scores = {scores}")
np.argmax(scores)

scores = [[ 3.93672053  8.19324105  4.98585718  1.87180719  7.1107916   0.85998001
   2.88509577  9.19744483 -0.18355263  6.01561605]]


7

## OneVsRestClassifier

In [9]:
ovr_clf = OneVsRestClassifier(SVC(gamma='auto', random_state=RANDOM_SEED))
ovr_clf.fit(x_train[:1000], y_train[:1000])

OneVsRestClassifier(estimator=SVC(gamma='auto', random_state=42))

In [10]:
len(ovr_clf.estimators_)

10

In [11]:
ovr_clf.predict([test_digit])

array(['7'], dtype='<U1')

## SGDClassifier

In [12]:
sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=RANDOM_SEED)
sgd_clf.fit(x_train[:1000], y_train[:1000])

SGDClassifier(random_state=42)

In [13]:
sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=RANDOM_SEED)
sgd_clf.fit(x_train[:1000], y_train[:1000])

SGDClassifier(random_state=42)

In [14]:
sgd_clf.predict([test_digit])

array(['7'], dtype='<U1')

In [15]:
scores = sgd_clf.decision_function([test_digit])
print(f"scores = {scores}")
np.argmax(scores)

scores = [[ -2523798.45804894 -11724109.03203037  -1435338.01340555
   -1949321.83249115  -3680354.32196714  -3946274.93960771
   -8801153.91738122   2872244.10888638  -1546346.83991382
   -1654867.61994809]]


7

In [16]:
cross_val_score(sgd_clf, x_train[:1000], y_train[:1000], cv=3, scoring='accuracy')

array([0.83233533, 0.82882883, 0.79279279])