## k−NN を使った手書き文字認識
    
### MNIST Data のダウンロード

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

from sklearn.utils import shuffle
from sklearn.cross_validation import train_test_split
from sklearn.metrics import f1_score

In [None]:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')

In [None]:
X, y = shuffle(mnist.data, mnist.target)
X = X / 255.0
train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.2)
train_X, dev_X , train_y, dev_y = train_test_split(train_X, train_y, test_size=0.2)

In [None]:
fig = plt.figure(figsize=(9,9))
fig.subplots_adjust(left=0, right=1, bottom=0, top=0.5, hspace=0.05, wspace=0.05)
for i in range(81):
    ax = fig.add_subplot(9, 9, i + 1, xticks=[], yticks=[])
    ax.imshow(train_X[i].reshape((28,28)), cmap='gray')

### Cosine 類似度

In [None]:
import numpy
norm = numpy.linalg.norm(train_X, ord=2, axis=1)
normalized_train_X = train_X / norm[:,numpy.newaxis]

In [None]:
sample_1 = normalized_train_X[0]
sample_2 = normalized_train_X[1]
print(numpy.dot(sample_1, sample_2))

### k-NN でテストデータに対する予測

In [None]:
def most_common(lst):
    return max(set(lst), key=lst.count)

In [None]:
def kNN(ranking, labels, k=10):
    return [most_common(labels[row].tolist()) for row in ranking[:, :k]]

#### ホールドアウトセットを利用して最適なkを求める

In [None]:
normalized_dev_X = dev_X / numpy.linalg.norm(dev_X, ord=2, axis=1)[:, numpy.newaxis]
ranking_dev_idx = numpy.argsort(-normalized_dev_X.dot(normalized_train_X.T))

In [None]:
best_k = -1
best_f1 = -1
for k in range(1, 100):
    pred_dev_y = kNN(ranking_dev_idx, train_y, k)
    f1 = f1_score(dev_y, pred_dev_y, average="macro")
    if best_f1 < f1:
        best_k = k
        best_f1 = f1

print(best_k)

#### 求めたkを用いて、テストセットのスコアを計算する

In [None]:
normalized_test_X = test_X / numpy.linalg.norm(test_X, ord=2, axis=1)[:, numpy.newaxis]
ranking_test_idx = numpy.argpartition(-normalized_test_X.dot(normalized_train_X.T), best_k)

In [None]:
pred_test_y = kNN(ranking_test_idx, train_y, best_k)

In [None]:
f1_score(test_y, pred_test_y, average="macro")