In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
import time

from mps_classifier import MPSclassifier

In [None]:
def save(filename, obj):
    with open(filename, 'wb') as f:
        pkl.dump(obj, f)

def load(filename):
    with open(filename, 'rb') as f:
        return pkl.load(f)

def load_data(filename):
    data = np.loadtxt(filename, delimiter=',')
    X, Y = data[:, 1:].reshape(-1, 28, 28), data[:, 0]
    return X, Y

In [None]:
X, Y = load_data('data/mnist_train.csv')
X_test, Y_test = load_data('data/mnist_test.csv')

In [None]:
chi_list = [10]#, 20, 30, 40, 50, 60]
for chi in chi_list:
    start = time.time()

    model = MPSclassifier(chi)
    model.train(X, Y)
    
    print('chi=%d'%chi, 'training time [min]: ', (time.time() - start)/60)
    save('results/chi%d_model.pkl'%(chi), model)

In [None]:
for chi in chi_list:
    model = load('results/chi%d_model.pkl'%(chi))
    start = time.time()

    accuracies = model.compute_accuracies(X_test, Y_test)
    
    print('chi=%d'%chi, 'evaluation time: ',(time.time() - start)/60)
    save('results/chi%d_accuracies.pkl'%(chi), accuracies)

In [None]:
acc = np.zeros((10, len(chi_list)))
for i, chi in enumerate(chi_list):
    acc[:, i] = load('results/chi%d_accuracies.pkl'%(chi))
weights = np.array([(Y_test==digit).sum() for digit in range(10)]) / len(Y_test)
acc_avg = (acc * weights.reshape(-1, 1)).sum(axis=0)

In [None]:
fig, ax = plt.subplots()
for digit in range(10):
    ax.plot(chi_list, acc[digit, :], '*', label='digit %d'%(digit))
ax.set_xticks(chi_list), ax.set_yticks(np.arange(0.6, 1+0.05, 0.05)), ax.set_xlabel('$\chi$'), ax.set_ylabel('accuracy per digit'), ax.grid(), ax.legend(bbox_to_anchor=(1.0, 1.0));
# fig.savefig('results/accuracy.pdf')

In [None]:
fig, ax = plt.subplots()
ax.plot(chi_list, acc_avg, label='average over all digits')
ax.set_xticks(chi_list), ax.set_yticks(np.arange(0.6, 1+0.05, 0.05)), ax.set_xlabel('$\chi$'), ax.set_ylabel('accuracy'), ax.grid(), ax.legend();
# fig.savefig('results/accuracy_averaged.pdf')