# One shot learning (UNFINISHED, expected to be finished before September 5th 2016).

Imagine we have a bunch of classes, say 10 classes, and we'd like to perform a classification task. However, we only have a good amount of data for 7 classes out of 10. For the other 3 classes, there's only a very limited number of examples (say 1 or 2 examples for instance). The idea of *one shot learning* is to train a network on the classes for which we have a lot of data and use this trained network to classify examples from the classes on which it wasn't trained on. Here, we mostly follow the approach described in *Siamese Neural Networks for One-shot Image Recognition* by Koch et al.

We use a siamese architecture that we train on the MNIST data set. More specifically, we only train the netowk on digits from 0 to 6. The network will take two images and answer the following question: **do the two inputs belong to the same class?** After the training has been completed, we try to classify the digits 7, 8 and 9 by comparing the testing examples to the very limited labeled data we have for these classes.

For more details on *siamese architecture*, I refer the interested reader to the implementation of a siamese network in the notebook **siamese**.

In [None]:
import sys
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
sys.path.insert(0, '../data_processing/')
from siamese_data import MNIST # load the data and process it
%matplotlib inline

We load the data.

In [None]:
data = MNIST()

## Learning a similarity metric with a siamese network

We are going to implement a similar architecture to the one described in the **siamese notebook**. If you've already read it, you can skip this part, the architecture and the training are exactly the same, the only difference being that we disregard completely the digits 7, 8, and 9 during training.

In [None]:
n_inputs = 28 # dimension of each of the input vectors
n_steps = 28 # sequence length
n_hidden = 128 # number of neurons of the bi-directional LSTM
n_classes = 2 # two possible classes, either `same` of `different`

In [None]:
x1 = tf.placeholder(tf.float32, shape=[None, n_steps, n_inputs]) # placeholder for the first network (image 1)
x2 = tf.placeholder(tf.float32, shape=[None, n_steps, n_inputs]) # placeholder for the second network (image 2)

# placeholder for the label. `[1, 0]` for `same` and `[0, 1]` for `different`.
y = tf.placeholder(tf.float32, shape=[None, n_classes])

# placeholder for dropout (we could use different dropouts for different part of the architecture)
keep_prob = tf.placeholder(tf.float32)

In [None]:
def reshape_input(x_):
    """
    Reshape the inputs to match the shape requirements of the function
    `tf.nn.bidirectional_rnn`
    
    Args:
        x_: a tensor of shape `(batch_size, n_steps, n_inputs)`
        
    Returns:
        A `list` of length `n_steps` with its elements being tensors of shape `(batch_size, n_inputs)`
    """
    x_ = tf.transpose(x_, [1, 0, 2]) # shape: (n_steps, batch_size, n_inputs)
    x_ = tf.split(0, n_steps, x_) # a list of `n_steps` tensors of shape (1, batch_size, n_steps)
    return [tf.squeeze(z, [0]) for z in x_] # remove size 1 dimension --> (batch_size, n_steps)    


def siamese_model(x1_, x2_, keep_prob):
    """
    Create the siamese network.
    
    Args:
        x1_: a tensor of shape `(batch_size, n_steps, n_inputs)` containing a batch of images
            for the first network.
        x2_: a tensor of shape similar to `x1_` containing a batch of images for the second network.
        
    Returns:
        A tensor of shape `(batch_size, n_classes)` containing the unscaled predictions.
    """

    # We reshape the inputs to match the shape requirements of `tf.nn.bidirectinal_rnn`
    x1_, x2_ = reshape_input(x1_), reshape_input(x2_)
    
    # A bidirectional RNN consists of a forward cell and a backward cell. The two cells are independent
    lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, state_is_tuple=True) # Forwward cell
    lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, state_is_tuple=True) # Backward cell
    
    # We add dropout to the LSTM's cells
    lstm_fw_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_fw_cell, output_keep_prob=keep_prob)
    lstm_bw_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_bw_cell, output_keep_prob=keep_prob)
    
    with tf.variable_scope('siamese_network') as scope:
        with tf.name_scope('Bi_LSTM_1'):
            outputs1, last_state_fw1, last_state_bw1 = tf.nn.bidirectional_rnn(
                                            lstm_fw_cell, lstm_bw_cell, x1_,
                                            dtype=tf.float32)
        with tf.name_scope('Bi_LSTM_2'):
            scope.reuse_variables() # tied weights (reuse the weights from `Bi_LSTM_1` for `Bi_LSTM_2`)
            outputs2, last_state_fw2, last_state_bw2 = tf.nn.bidirectional_rnn(
                                            lstm_fw_cell, lstm_bw_cell, x2_,
                                            dtype=tf.float32)
    
    
    # Weights and biases for the layer that connects the outputs from the two networks
    weights = tf.get_variable('weigths_out', shape=[4 * n_hidden, n_classes],
                    initializer=tf.random_normal_initializer(stddev=1.0/float(n_hidden)))
    biases = tf.get_variable('biases_out', shape=[n_classes])
    
    # We concatenate the different states of the cells for the first network and
    # for the second network independently, and we compute the absolute difference
    # between the states from the first network and the second network.
    last_state1 = tf.concat(1, [last_state_fw1[0], last_state_bw1[0],
                                  last_state_fw1[1], last_state_bw1[1]])
    last_state2 = tf.concat(1, [last_state_fw2[0], last_state_bw2[0],
                                  last_state_fw2[1], last_state_bw2[1]])
    last_states_diff = tf.abs(last_state1 - last_state2)
    return tf.matmul(last_states_diff, weights) + biases

In [None]:
logits = siamese_model(x1, x2, keep_prob) # Unscaled logits. They are scaled using softmax in the loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, y))
optimizer = tf.train.AdamOptimizer().minimize(loss)

correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [None]:
init = tf.initialize_all_variables()

max_iter = 1000 # maximum number of iterations for training
batch_train = 128 # batch size for training
batch_test = 512 # batch size for testing
display = 50 # display the training loss and accuracy every `display` step
n_test = 200 # test the network every `n_test` step

saver = tf.train.Saver() # to save the trained model

with tf.Session() as sess:
    sess.run(init) # initialize all variables
    print('Network training begins.')
    for i in range(1, max_iter + 1):
        # We retrieve a batch of data from the training set (digits between 0 and 6)
        batch_x1, batch_x2, batch_y = data.get_next_batch(batch_train, phase='train', one_shot=True)
        # We feed the data to the network for training
        feed_dict = {x1: batch_x1, x2: batch_x2, y: batch_y, keep_prob: .9}
        _, loss_, accuracy_ = sess.run([optimizer, loss, accuracy], feed_dict=feed_dict)
        
        if i % display == 0:
            print('step %i, training loss: %.5f, training accuracy: %.3f' % (i, loss_, accuracy_))
        
        # Testing the network
        if i % n_test == 0:
            # Retrieving data from the test set
            batch_x1, batch_x2, batch_y = data.get_next_batch(batch_test, phase='test', one_shot=True)
            feed_dict = {x1: batch_x1, x2: batch_x2, y: batch_y, keep_prob: 1.0}
            accuracy_test = sess.run(accuracy, feed_dict=feed_dict)
            print('testing step %i, accuracy %.3f' % (i, accuracy_test))
        
        # We save a snapshot of the weights
        if i % snapshot_n == 0:
            save_path = saver.save(sess, 'snapshot_' + str(i) + '.ckpt')
            logger.info('Snapshot saved in file: %s', save_path)

    print('********************************')
    print('Training finished.')

## One shot learning: using the pretrained similarity metric on new classes

We now want to see how the network performs on images from unseen classes, i.e. sevens, eights and nines.

Following the approach described by Koch et al., we chose 10 images ($i_0, i_1,...,i_9$), one per class. We then classify an image by comparing it pairwise with the images $i_0,...,i_9$.

But first, let's chose 10 reference images.

In [None]:
one_example_per_class = {}
for digit in data.digits:
    one_example_per_class[digit] = getattr(data, digit)[np.random.randint(len(getattr(data, digit)))]

We retrieve the model trained above, and we classify images of 7, 8 and 9 by comparing them with the benchmark images $i_0,...,i_9$. We report the accuracy of the classifcation on the *unseen* classes and on the *seen* classes.

In [None]:
checkpoint_dir = 'models/one_shot_learning/'
with tf.Session() as sess:
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir)
    saver.restore(sess, latest_checkpoint)
    ...