In [2]:
#手書き文字のデータセット(mnist)をインポート
from keras.datasets import mnist
import numpy as np
#データプロットのライブラリをインポート
import matplotlib.pyplot as plt
from collections import Counter

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
class KNearestNeighbors(object):

    def __init__(self, k=1):
        self._train_data = None
        self._target_data = None
        self._k = k

    def fit(self, train_data, target_data):
        """訓練データを学習する"""
        # あらかじめ計算しておけるものが特にないので保存だけする
        self._train_data = train_data
        self._target_data = target_data

    def predict(self, x):
        """訓練データから予測する"""
        # 判別する点と教師データとのユークリッド距離を計算する
        distances = np.array([self._distance(p, x) for p in self._train_data])
        # ユークリッド距離の近い順でソートしたインデックスを得る
        nearest_indexes = distances.argsort()[:self._k]
        # 最も近い要素のラベルを返す
        nearest_labels = self._target_data[nearest_indexes]
        # 近傍のラベルで一番多いものを予測結果として返す
        c = Counter(nearest_labels)
        return c.most_common(1)[0][0]

In [4]:
#テストデータと教師データのデータとindexを取得
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
X_train = np.reshape(X_train, (len(X_train), 28, 28, 1))  # adapt this if using `channels_first` image data format
X_test = np.reshape(X_test, (len(X_test), 28, 28, 1))  # adapt this if using `channels_first` image data format

In [5]:
X_train = X_train.reshape(X_train.shape[0],-1)
X_test = X_test.reshape(X_test.shape[0],-1)

print(X_train.shape)
print(X_test.shape)

(60000, 784)
(10000, 784)


## sklearnを用いた実装

In [8]:
from sklearn.neighbors import KNeighborsClassifier


model = KNeighborsClassifier(n_neighbors=6)

In [9]:
model.fit(X_train, y_train)

KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=6, p=2,
           weights='uniform')

In [10]:
model.predict(X_test[:10])

array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], dtype=uint8)

## フルスクラッチ

In [6]:
class KNN(object):
    def __init__(self, k):
        self.k = k
                
    def fit(self, X, y):
        self.X = X
        self.y = y
        
    def predict(self,X_test):
        ret = []
        for d in X_test:
            distance = np.linalg.norm(self.X-d ,axis=1)

            #距離が近い順に配列を並び替える
            sorted_label = self.y[np.argsort(distance)]

            #距離が近い順にk個調べ多数決をとり、数字を決定する
            pred_y = Counter(sorted_label[:self.k]).most_common()[0][0]
            ret.append(pred_y)
        return ret


In [7]:
model=KNN(k=1)

In [8]:
model.fit(X_train[:10000],y_train[:10000])

In [9]:
pred_y = model.predict(X_test[:1000])
print(pred_y)

[7, 2, 1, 0, 9, 1, 9, 9, 5, 9, 0, 6, 9, 0, 1, 5, 4, 7, 3, 4, 9, 6, 6, 5, 9, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 1, 3, 5, 1, 2, 4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0, 2, 7, 1, 7, 3, 7, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 4, 3, 1, 4, 1, 7, 6, 9, 6, 0, 5, 4, 9, 9, 2, 1, 9, 4, 8, 7, 3, 9, 7, 9, 4, 4, 9, 2, 5, 4, 7, 6, 7, 9, 0, 5, 8, 5, 6, 6, 5, 7, 8, 1, 0, 1, 6, 4, 6, 7, 3, 1, 7, 1, 8, 2, 0, 1, 9, 8, 5, 5, 1, 5, 6, 0, 3, 4, 4, 6, 5, 4, 6, 5, 4, 5, 1, 4, 4, 7, 2, 3, 2, 1, 1, 8, 1, 8, 1, 8, 5, 0, 8, 9, 2, 5, 0, 1, 1, 1, 0, 4, 0, 1, 1, 6, 4, 2, 3, 6, 1, 1, 1, 3, 9, 5, 2, 9, 4, 5, 9, 3, 9, 0, 3, 6, 5, 3, 7, 2, 2, 7, 1, 2, 8, 4, 1, 7, 3, 3, 8, 7, 7, 9, 2, 2, 4, 1, 5, 8, 8, 7, 2, 6, 0, 6, 4, 2, 4, 1, 9, 5, 7, 7, 2, 1, 2, 6, 8, 5, 7, 7, 9, 1, 8, 1, 3, 0, 3, 0, 1, 9, 9, 4, 1, 8, 2, 1, 2, 9, 7, 5, 9, 2, 6, 4, 1, 5, 9, 2, 9, 2, 0, 4, 0, 0, 2, 8, 1, 7, 1, 2, 4, 0, 2, 7, 4, 3, 3, 0, 0, 3, 1, 9, 6, 5, 0, 5, 1, 7, 9, 3, 0, 4, 6, 0, 7, 1, 1, 2, 1, 

In [13]:
## precision, recall, f-scoreの計算
# accuracy

from sklearn.metrics import accuracy_score

accuracy_score(y_test[:1000], pred_y)

0.92

In [14]:
from sklearn.metrics import precision_recall_fscore_support

precision_recall_fscore_support(y_test[:1000], pred_y)

(array([0.93406593, 0.9057971 , 0.98076923, 0.91428571, 0.91509434,
        0.95238095, 0.93333333, 0.88461538, 0.93902439, 0.85416667]),
 array([1.        , 0.99206349, 0.87931034, 0.89719626, 0.88181818,
        0.91954023, 0.96551724, 0.92929293, 0.86516854, 0.87234043]),
 array([0.96590909, 0.9469697 , 0.92727273, 0.90566038, 0.89814815,
        0.93567251, 0.94915254, 0.90640394, 0.9005848 , 0.86315789]),
 array([ 85, 126, 116, 107, 110,  87,  87,  99,  89,  94]))