In [None]:
# Nearest centroid model

# Import libraries
import keras
from keras.datasets import mnist
from numpy import mean
from numpy import std
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.neighbors import NearestCentroid

In [None]:
# load datasets
(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
X = x.reshape(-1, 28*28)/255.0
X_test = x_test.reshape(-1, 28*28)/255.0
# print(X.shape)
# print(y.shape)
# print(X_test.shape)
# print(y_test.shape)

In [None]:
metric_name = ['euclidean', 'manhattan']
res_dict = {}
for metric in metric_name:
  # define model
  model = NearestCentroid(metric=metric)

  # define model evaluation method
  cv = RepeatedStratifiedKFold(n_splits=60, n_repeats=5, random_state=1)

  # evaluate model
  scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)

  name = metric

  # summarize result
  print("Name: " + name)
  print('Mean Accuracy: %.3f (%.3f)' % (mean(scores), std(scores)))

  res_dict[name] = mean(scores)

Name: euclidean
Mean Accuracy: 0.807 (0.012)
Name: manhattan
Mean Accuracy: 0.741 (0.013)


In [None]:
# make predictions with a nearest centroid model on the test dataset
from sklearn.datasets import make_classification
from sklearn.neighbors import NearestCentroid
from sklearn.metrics import classification_report

# define model
model = NearestCentroid()

# fit model
model.fit(X, y)

# make a prediction
y_pred = model.predict(X_test)

# Printing Accuracy on Training and Test sets
print(f"Training Set Score : {model.score(X, y) * 100} %")
print(f"Test Set Score : {model.score(X_test, y_test) * 100} %")
 
# Printing classification report of classifier on the test set set data
print(f"Model Classification Report : \n{classification_report(y_test, model.predict(X_test))}")

Training Set Score : 80.79833333333333 %
Test Set Score : 82.03 %
Model Classification Report : 
              precision    recall  f1-score   support

           0       0.91      0.90      0.90       980
           1       0.77      0.96      0.86      1135
           2       0.88      0.76      0.81      1032
           3       0.77      0.81      0.78      1010
           4       0.80      0.83      0.81       982
           5       0.75      0.69      0.72       892
           6       0.88      0.86      0.87       958
           7       0.91      0.83      0.87      1028
           8       0.79      0.74      0.76       974
           9       0.77      0.81      0.79      1009

    accuracy                           0.82     10000
   macro avg       0.82      0.82      0.82     10000
weighted avg       0.82      0.82      0.82     10000

