# KNN for MNIST dataset

In [1]:
from sklearn import metrics
from sklearn.model_selection import train_test_split

In [2]:
from sklearn.datasets import fetch_openml
x, y = fetch_openml(name='mnist_784', return_X_y=True)

In [3]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.4, shuffle=False)

Intel Extension for Scikit-learn (previously known as daal4py) contains drop-in replacement functionality for the stock scikit-learn package. You can take advantage of the performance optimizations of Intel Extension for Scikit-learn by adding just two lines of code before the usual scikit-learn imports:

In [4]:
from sklearnex import patch_sklearn
patch_sklearn()

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


Choose parameter for alorithm, full list for Intel Extension for Scikit-learn can see [here](https://intel.github.io/scikit-learn-intelex/algorithms.html). If some parameters are not available, the original implementation of scikit-learn will be callback.

In [6]:
params = {
    'n_neighbors': 8,
    'algorithm': 'brute',
}

Train of the KNN classifier algorithm with Intel(R) Extension for Scikit-learn

In [7]:
%%time
from sklearn.neighbors import KNeighborsClassifier
classifier = KNeighborsClassifier(**params).fit(x_train, y_train)

CPU times: user 341 ms, sys: 293 ms, total: 633 ms
Wall time: 371 ms


Predict and get a result of the SVC algorithm with Intel(R) Extension for Scikit-learn

In [8]:
%%time
predicted = classifier.predict(x_test)
report = metrics.classification_report(y_test, predicted)
print(f"Classification report for KNN:\n{report}\n")

Classification report for SVC:
              precision    recall  f1-score   support

           0       0.97      0.99      0.98      2760
           1       0.94      1.00      0.97      3078
           2       0.98      0.95      0.97      2843
           3       0.96      0.97      0.96      2873
           4       0.98      0.96      0.97      2725
           5       0.96      0.96      0.96      2529
           6       0.98      0.99      0.98      2696
           7       0.96      0.96      0.96      2963
           8       0.99      0.92      0.95      2785
           9       0.95      0.95      0.95      2748

    accuracy                           0.97     28000
   macro avg       0.97      0.96      0.97     28000
weighted avg       0.97      0.97      0.97     28000


CPU times: user 36.2 s, sys: 395 ms, total: 36.6 s
Wall time: 1.91 s


In order to cancel optimizations, we use *unpatch_sklearn* and reimport the class SVC.

In [9]:
from sklearnex import unpatch_sklearn
unpatch_sklearn()

In [10]:
%%time
from sklearn.neighbors import KNeighborsClassifier
classifier = KNeighborsClassifier(**params).fit(x_train, y_train)

CPU times: user 91 ms, sys: 1.18 ms, total: 92.2 ms
Wall time: 91.6 ms


In [11]:
%%time
predicted = classifier.predict(x_test)
report = metrics.classification_report(y_test, predicted)
print(f"Classification report for SVC:\n{report}\n")

Classification report for SVC:
              precision    recall  f1-score   support

           0       0.97      0.99      0.98      2760
           1       0.94      1.00      0.97      3078
           2       0.98      0.95      0.97      2843
           3       0.96      0.97      0.96      2873
           4       0.98      0.96      0.97      2725
           5       0.96      0.96      0.96      2529
           6       0.98      0.99      0.98      2696
           7       0.96      0.96      0.96      2963
           8       0.99      0.92      0.95      2785
           9       0.95      0.95      0.95      2748

    accuracy                           0.97     28000
   macro avg       0.97      0.96      0.97     28000
weighted avg       0.97      0.97      0.97     28000


CPU times: user 1min 24s, sys: 2min 58s, total: 4min 23s
Wall time: 29.1 s


With scikit-learn-intelex patching you can:

- Use your scikit-learn code for training and inference without modification.
- Fast execution train and predict of scikit-learn models.
- Get the same quality of predictions.