In [3]:
import numpy as np
import struct
import matplotlib.pyplot as plt

def load_data():
    with open('train-images-idx3-ubyte.gz', 'rb') as f:
        magic, num, rows, cols = struct.unpack(">IIII", f.read(16))
        X_train = np.fromfile(f, dtype=np.uint8).reshape(num, rows * cols)

    with open('train-labels-idx1-ubyte.gz', 'rb') as f:
        magic, num = struct.unpack(">II", f.read(8))
        y_train = np.fromfile(f, dtype=np.uint8)

    with open('t10k-images-idx3-ubyte.gz', 'rb') as f:
        magic, num, rows, cols = struct.unpack(">IIII", f.read(16))
        X_test = np.fromfile(f, dtype=np.uint8).reshape(num, rows * cols)

    with open('t10k-labels-idx1-ubyte.gz', 'rb') as f:
        magic, num = struct.unpack(">II", f.read(8))
        y_test = np.fromfile(f, dtype=np.uint8)

    return X_train, y_train, X_test, y_test

def distance(X, Y, p=2):
    if p == 1:
        return np.sum(np.abs(X - Y), axis=1)
    elif p == 2:
        return np.sqrt(np.sum(np.square(X - Y), axis=1))
    else:
        return np.max(np.abs(X - Y), axis=1)

def knn_predict(X_train, y_train, X_test, k=3, p=2):
    distances = distance(X_train, X_test[:, np.newaxis, :], p=p)
    nearest_neighbors = np.argpartition(distances, k, axis=1)[:, :k]
    labels = y_train[nearest_neighbors]
    y_pred = np.argmax(np.apply_along_axis(lambda x: np.bincount(x, minlength=10), axis=1, arr=labels), axis=1)
    return y_pred

def accuracy(y_pred, y_true):
    return np.mean(y_pred == y_true) * 100

# load the data
X_train, y_train, X_test, y_test = load_data()

# compute accuracy for varying values of k and p
ks = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
ps = [1, 2, 3]

accuracies = np.array([[accuracy(knn_predict(X_train, y_train, X_test, k=k, p=p), y_test) for p in ps] for k in ks])

# plot the accuracy as a function of k and p
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for j, p in enumerate(ps):
    ax.plot(ks, [p] * len(ks), accuracies[:, j], label='p={}'.format(p))
ax.set_xlabel('k')
ax.set_ylabel('p')
ax.set_zlabel('Accuracy (%)')
plt.legend()
plt.show()


ValueError: cannot reshape array of size 9912406 into shape (2055376946,370034809926666)