# Wasserstein Example (Handwritten Digits)
In this notebook we use wasserstein (Earth Mover) distance to compare handwritten digits from the UCI Machine Learning repository

In [None]:
# Import modules

import numpy as np
import time
import matplotlib.pyplot as plt

from annchor import Annchor, BruteForce, compare_neighbor_graphs
from annchor.datasets import load_digits, load_digits_large


In [None]:
# View the data set, and set some parameters

k=25 # n_neighbours parameter (i.e. the k in k-NN)

data = load_digits()
X = data['X']
y = data['y']
neighbor_graph = data['neighbor_graph']
M = data['cost_matrix']

nx = X.shape[0]
print('Data set contains %d digits' % nx)

fig,axs = plt.subplots(2,5)
axs = axs.flatten()
for i,ax in enumerate(axs):
    ax.imshow(X[y==i][0].reshape(8,8))
    ax.axis('off')

plt.tight_layout(h_pad=0.1, w_pad=0.3)
plt.show()

As shown above, each image is an 8x8 grid of pixels, and there are 1797 digits in total.


## Using Annchor

Let's see how we use annchor to find the k-NN graph for this data set.

Specifically, we will use the wasserstein distance, which is a nice metric for comparing images. This metric requires a cost function, which is supplied as a keyword argument. We will use 25 anchor points, a sample size of 5000, and aim to use only 16% of the work required by the brute force the solution.



(Remember that the first time we run annchor will be longer than usual due to the numba.jit compile time overhead, so run this cell twice to get a good idea of timings)

In [None]:
start_time = time.time()

# Call ANNchor
ann = Annchor(X, # Input our data set
              'wasserstein', # Use the wasserstein metric
              func_kwargs = {'cost_matrix': M}, # Supply the cost function
              n_anchors=25,
              n_neighbors=k,
              n_samples=5000,
              p_work=0.16)

ann.fit()
print('ANNchor Time: %5.3f seconds' % (time.time()-start_time))


# Test accuracy
error = compare_neighbor_graphs(neighbor_graph,
                                ann.neighbor_graph,
                                25)
print('ANNchor Accuracy: %d incorrect NN pairs (%5.3f%%)' % (error,100*error/(k*nx)))

## Comparison with other techniques
Now compare this to Brute Force, or the pynndescent library (Annchor comes with a built in brute force option).

### Brute Force
The next cell uses annchors brute force implimentation (which is parallelised by default)

In [None]:
start_time = time.time()

bruteforce = BruteForce(X,
                        'wasserstein',
                        func_kwargs = {'cost_matrix': M}
                       )
bruteforce.fit()

print('Brute Force Time: %5.3f seconds' % (time.time()-start_time))

error = compare_neighbor_graphs(neighbor_graph,
                                bruteforce.neighbor_graph,
                                10)

print('Brute Force Accuracy: %d incorrect NN pairs (%5.3f%%)' % (error,100*error/(k*nx)))

### Pynndescent

In [None]:
from pynndescent import NNDescent
from pynndescent.distances import kantorovich
from numba import njit

@njit()
def wasserstein(x, y):
    return kantorovich(x,y,cost=M)

start_time = time.time()

# Call nearest neighbour descent
nndescent = NNDescent(X,n_neighbors=k,metric=wasserstein,random_state=1)
print('PyNND Time: %5.3f seconds' % (time.time()-start_time))

# Test accuracy
error = compare_neighbor_graphs(neighbor_graph,
                                nndescent.neighbor_graph,
                                25)
print('PyNND Accuracy: %d incorrect NN pairs (%5.3f%%)' % (error,100*error/(k*nx)))

## A Larger Example
This example uses the load_digits_large data set, which is similar to the previous data set but with more digits (5620 to be exact). We compare Annchor to Pynndescent.

In [None]:
# Load the data

k=25

data = load_digits_large()
X = data['X']
y = data['y']
neighbor_graph = data['neighbor_graph']
M = data['cost_matrix']
nx = X.shape[0]


# ANNchor


start_time = time.time()

ann = Annchor(X,
              wasserstein,
              n_anchors=30,
              n_neighbors=k,
              n_samples=5000,
              p_work=0.1)

ann.fit()
print('ANNchor Time: %5.3f seconds' % (time.time()-start_time))


# Test accuracy
error = compare_neighbor_graphs(neighbor_graph,
                                ann.neighbor_graph,
                                k)
print('ANNchor Accuracy: %d incorrect NN pairs (%5.3f%%)' % (error,100*error/(k*nx)))


## Pynndescent

start_time = time.time()

## Call nearest neighbour descent
nndescent = NNDescent(X,n_neighbors=k,metric=wasserstein,random_state=1)
print('PyNND Time: %5.3f seconds' % (time.time()-start_time))

## Test accuracy
error = compare_neighbor_graphs(neighbor_graph,
                                nndescent.neighbor_graph,
                                k)
print('PyNND Accuracy: %d incorrect NN pairs (%5.3f%%)' % (error,100*error/(k*nx)))


