In [None]:
# Global imports
import random
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn
import mpld3
import numpy as np
import pescador

seaborn.set()
np.set_printoptions(precision=4, suppress=True)
mpld3.enable_notebook()

import optimus
import datatools
import models as M

pltargs = dict(interpolation='nearest', aspect='equal', 
               cmap=plt.cm.gray_r, origin='lower')

In [None]:
# let's load the data and take a look at some digits.
train, valid, test = datatools.load_mnist_npz("/Users/ejhumphrey/mnist/mnist.npz")
num_imgs = 5
fig = plt.figure(figsize=(num_imgs*2, 2))
for n, idx in enumerate(np.random.permutation(len(train[1]))[:num_imgs]):
    ax = fig.add_subplot(101 + 10*num_imgs + n)
    ax.imshow(train[0][idx, 0], **pltargs)
    ax.set_xlabel("{0}".format(train[1][idx]))
    ax.set_xticks([])    
    ax.set_yticks([]);
plt.tight_layout()

In [None]:
trainer, predictor = M.pwrank()

In [None]:
streams = [datatools.shuffle_stream(train[0][train[1] == cidx], cidx) 
           for cidx in range(10)]

In [None]:
cstream = datatools.comparative_stream(streams)
xs = next(cstream)

num_imgs = 3
fig = plt.figure(figsize=(num_imgs*2, 2))
for n, (name, x) in enumerate(xs.items()):
    ax = fig.add_subplot(101 + 10*num_imgs + n)
    ax.imshow(x.squeeze(), **pltargs)
    ax.set_xlabel("{0}".format(name))
    ax.set_xticks([])    
    ax.set_yticks([]);
plt.tight_layout()

In [None]:
batch = pescador.buffer_batch(cstream, 50)

In [None]:
driver = optimus.Driver(graph=trainer, name='test')
res = driver.fit(
    source=batch, 
    hyperparams=dict(learning_rate=0.02, margin=1, alpha=4), 
    max_iter=500, print_freq=25)

In [None]:
idx = np.random.permutation(len(valid[0]))[:500]
x_in = valid[0][idx]
y_true = valid[1][idx]
predictor.param_values = trainer.param_values
z_out = predictor(x_in=x_in)['embedding']

In [None]:
palette = seaborn.color_palette("Set2", 10)

fig = plt.figure(figsize=(8, 8))
ax = fig.gca()
for cidx in range(10):
    i = (y_true == cidx)
    ax.scatter(z_out[i].T[0], z_out[i].T[1], c=palette[cidx])

plt.tight_layout()