In [1]:
import joblib

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier

In [2]:
# load data
X, y = fetch_openml(
    "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)

X = X / 255.0

print(f"{X.shape = }")
print(f"{y.shape = }")

X.shape = (70000, 784)
y.shape = (70000,)


In [3]:
# split data into train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, test_size=0.7)

print(f"{X_train.shape = }")
print(f"{y_train.shape = }")

print(f"{X_test.shape = }")
print(f"{y_test.shape = }")

X_train.shape = (21000, 784)
y_train.shape = (21000,)
X_test.shape = (49000, 784)
y_test.shape = (49000,)


In [4]:
# build model
classifier = MLPClassifier(
    hidden_layer_sizes=(512, 512),
    verbose=True,
)

classifier

In [5]:
classifier.fit(X_train, y_train)

Iteration 1, loss = 0.41124624
Iteration 2, loss = 0.14635527
Iteration 3, loss = 0.09125043
Iteration 4, loss = 0.05712700
Iteration 5, loss = 0.03806716
Iteration 6, loss = 0.02284707
Iteration 7, loss = 0.02058435
Iteration 8, loss = 0.01111803
Iteration 9, loss = 0.00804372
Iteration 10, loss = 0.00430449
Iteration 11, loss = 0.00200957
Iteration 12, loss = 0.00125787
Iteration 13, loss = 0.00102021
Iteration 14, loss = 0.00089952
Iteration 15, loss = 0.00082325
Iteration 16, loss = 0.00076793
Iteration 17, loss = 0.00072050
Iteration 18, loss = 0.00068483
Iteration 19, loss = 0.00064728
Iteration 20, loss = 0.00062035
Iteration 21, loss = 0.00059855
Iteration 22, loss = 0.00057463
Iteration 23, loss = 0.00055814
Iteration 24, loss = 0.00054120
Iteration 25, loss = 0.00052776
Training loss did not improve more than tol=0.000100 for 10 consecutive epochs. Stopping.


In [6]:
# score
print(f"Train score: {classifier.score(X_train, y_train) * 100:.2f}%")
print(f"Test score: {classifier.score(X_test, y_test) * 100:.2f}%")

Train score: 100.00%
Test score: 97.40%


In [7]:
# dump
joblib.dump(classifier, "mnist_classifier.pkl")

['mnist_classifier.pkl']