In [1]:
import numpy as np
import os
from scipy.io import loadmat
from sklearn.datasets import get_data_home
from sklearn.neighbors import LargeMarginNearestNeighbor as LMNN

In [2]:
MNIST_DESKEWED_URL = 'https://www.dropbox.com/s/mhsnormwt5i2ba6/mnist-deskewed-pca164.mat?dl=1'
MNIST_DESKEWED_PATH = os.path.join(get_data_home(), 'mnist-deskewed-pca164.mat')

if not os.path.exists(MNIST_DESKEWED_PATH):
    from urllib import request
    print('Downloading deskewed MNIST from {} . . .'.format(MNIST_DESKEWED_URL), end='')
    request.urlretrieve(MNIST_DESKEWED_URL, MNIST_DESKEWED_PATH)
    print('done.')

mnist_mat = loadmat(MNIST_DESKEWED_PATH)

X_train = np.asarray(mnist_mat['X_train'], dtype=np.float64)
X_test = np.asarray(mnist_mat['X_test'], dtype=np.float64)
y_train = np.asarray(mnist_mat['y_train'], dtype=np.int).ravel()
y_test = np.asarray(mnist_mat['y_test'], dtype=np.int).ravel()

print('Loaded deskewed MNIST from {}.'.format(MNIST_DESKEWED_PATH))

Loaded deskewed MNIST from /work/chiotell/scikit_learn_data/mnist-deskewed-pca164.mat.


In [3]:
%load_ext memory_profiler

In [23]:
lmnn = LMNN(n_neighbors=3, store_opt_result=True, random_state=42, verbose=1, max_iter=35, n_jobs=-1)

In [24]:
%memit lmnn.fit(X_train, y_train)

Finding principal components... done in  0.46s.
Finding the target neighbors... done in 16.96s.
Computing static part of the gradient... done.

 Iteration      Objective Value    Time(s)
------------------------------------------
         1         1.044149e+07      11.08
         1         4.217784e+06      11.49
         2         3.781575e+06      11.01
         3         2.983367e+06      11.20
         4         2.504095e+06      10.95
         5         1.990753e+06      11.00
         6         1.393434e+06      11.11
         7         1.355413e+06      11.65
         7         1.128655e+06      10.97
         8         8.773327e+05      11.79
         9         7.368591e+05      11.30
        10         6.927031e+05      11.66
        11         6.746426e+05      11.83
        11         6.018972e+05      11.42
        12         5.397134e+05      11.92
        13         4.944894e+05      12.89
        14         4.488239e+05      12.21
        15         4.275367e+05      12

In [25]:
from sklearn.neighbors import KNeighborsClassifier as KNN

In [26]:
knn = KNN(n_neighbors=lmnn.n_neighbors_, n_jobs=-1)
knn.fit(lmnn.transform(X_train), y_train)
test_acc = knn.score(lmnn.transform(X_test), y_test)
print('LMNN accuracy on MNIST test set is {:5.2f}%.'.format(100*test_acc))

LMNN accuracy on MNIST test set is 98.68%.


In [27]:
lmnn.opt_result_.nfev

41

In [28]:
lmnn.opt_result_.nit

36