In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

os.environ['THEANO_FLAGS'] = "device=cuda1"

In [None]:
import shelve
from bisect import bisect
import numpy as np
import theano
import theano.tensor as T
import lasagne
import seqtools
import matplotlib.pyplot as plt

from sltools.nn_utils import adjust_length

from experiments.siamese_triplet.a_data import \
    durations, labels, transformations, \
    train_subset, val_subset
from experiments.siamese_triplet.b_preprocess import skel_feat_seqs

np.set_printoptions(linewidth=100)

In [None]:
from experiments.siamese_triplet.a_data import cachedir

report = shelve.open(os.path.join(cachedir, "rnn_report"))
report.clear()

# Load dataset

In [None]:
max_time = 128

skel_feat_seqs = seqtools.smap(lambda s: adjust_length(s, max_time), skel_feat_seqs)

feat_seqs_train = [
    seqtools.gather(skel_feat_seqs, train_subset)
    ]
transformations_train = seqtools.gather(transformations, train_subset)
labels_train = labels[train_subset].astype(np.int32)
durations_train = durations[train_subset].astype(np.int32)

feat_seqs_val = [
    seqtools.gather(skel_feat_seqs, val_subset)
    ]
transformations_val = seqtools.gather(transformations, val_subset)
labels_val = labels[val_subset].astype(np.int32)
durations_val = durations[val_subset].astype(np.int32)

del transformations, labels, durations, skel_feat_seqs

In [None]:
def sample_triplets(vocabulary, labels, n, test=None):
    test = test or (lambda *_: True)
    where_labels = {l: np.where(labels == l)[0] for l in vocabulary}
    where_not_labels = {l: np.where(labels != l)[0] for l in vocabulary}

    triplets = np.empty((n, 3), dtype=np.uint64)

    i = 0
    while i < len(labels):
        left = i % len(labels)
        wl = where_labels[labels[left]]
        wn = where_not_labels[labels[left]]
        middle = np.random.choice(wl)
        right = np.random.choice(wn)

        while not test(left, middle, right):
            middle = np.random.choice(wl)
            right = np.random.choice(wn)

        triplets[i] = [left, middle, right]
        i += 1

    return triplets


def triplet2minibatches(feat_seqs, labels, durations, transformations):
    triplets = np.array(sample_triplets(
        sorted(set(labels)), labels, len(labels),
        test=lambda i, j, k: transformations[i][0] != transformations[j][0]))
    triplets = np.random.permutation(triplets)
    feat_triplets = [
        seqtools.starmap(
            lambda i, j, k: np.stack([f[i], f[j], f[k]], axis=0),
            triplets)
        for f in feat_seqs]
    duration_triplets = seqtools.smap(lambda triplet: durations[triplet], triplets)

    feat_batches = [
        seqtools.batch(f, batch_size // 3, drop_last=True, collate_fn=np.concatenate)
        for f in feat_triplets]
    duration_batches = seqtools.batch(
        duration_triplets, 
        batch_size // 3,
        drop_last=True, collate_fn=np.concatenate)

    return seqtools.collate(feat_batches + [duration_batches])

# Build model

In [None]:
from experiments.siamese_triplet.c_model import skel_rnn

batch_size = 12
encoder_kwargs = {
    "tconv_sz": 15,
    "filter_dilation": 1,
    "num_tc_filters": 256,
    "dropout": 0.2
}
max_time = 128

assert batch_size % 3 == 0, "the model must take triplets"
model_dict = skel_rnn(
    *tuple(f[0][0].shape for f in feat_seqs_train), 
    batch_size=batch_size, max_time=max_time, 
    encoder_kwargs=encoder_kwargs)
l_linout = model_dict['l_linout']
l_in = model_dict['l_in']
l_duration = model_dict['l_duration']

# Run training iterations

In [None]:
from sltools.models.siamese import triplet_loss

l_rate = 1e-4
n_epoches = 10

l_rate_var = T.scalar('l_rate')
linout = lasagne.layers.get_output(l_linout, deterministic=False)
train_loss = triplet_loss(linout[0::3], linout[1::3], linout[2::3]).sum()
params = lasagne.layers.get_all_params(l_linout, trainable=True)
updates = lasagne.updates.adam(train_loss, params, learning_rate=l_rate_var)
update_fn = theano.function(
    [l.input_var for l in l_in] + [l_duration.input_var, l_rate_var],
    outputs=train_loss, updates=updates)

linout = lasagne.layers.get_output(l_linout, deterministic=True)
test_loss = triplet_loss(linout[0::3], linout[1::3], linout[2::3]).sum()
loss_fn = theano.function(
    [l.input_var for l in l_in] + [l_duration.input_var],
    outputs=test_loss)

In [None]:
running_train_loss = 10
running_val_loss = 10

for e in range(10): 
    train_losses = []
    val_losses = []

    # Minibatch iterator
    train_minibatches = triplet2minibatches(
        feat_seqs_train, labels_train, durations_train, transformations_train)
    val_minibatches = triplet2minibatches(
        feat_seqs_val, labels_val, durations_val, transformations_val)
    
    # Training iterations
    for i, minibatch in enumerate(seqtools.prefetch(train_minibatches, 2, max_buffered=20)):
        batch_loss = float(update_fn(*minibatch, l_rate))
        if np.isnan(batch_loss):
            raise ValueError()
        running_train_loss = .99 * running_train_loss + .01 * batch_loss

        if i % 30 == 0:
            train_losses.append(batch_loss)
            batch_losses = [loss_fn(*val_minibatches[j]) 
                            for j in np.random.choice(len(val_minibatches), 10)]
            val_losses.append(batch_losses[-1])
            running_val_loss = .91 * running_val_loss + .09 * np.mean(batch_losses)
            print("\rloss: {:>2.3f} / {:>2.3f}".format(running_train_loss, running_val_loss), 
                  end="", flush=True)

    report[str(e)] = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'epoch_loss': running_train_loss,
        'params': lasagne.layers.get_all_param_values(l_linout)
    }

In [None]:
plt.figure(figsize=(12, 3))

batch_losses = np.concatenate([r['train_losses'] for r in report.values()])

x = np.arange(len(batch_losses))
y = np.array([np.mean(batch_losses[max(0, i-20):i+20]) for i in range(0, len(batch_losses))])
err = np.array([np.std(batch_losses[max(0, i-20):i+20]) for i in range(0, len(batch_losses))])
plt.plot(x, y)
plt.fill_between(x, y - err, y + err, alpha=.3)

batch_losses = np.concatenate([r['val_losses'] for r in report.values()])

x = np.arange(len(batch_losses))
y = np.array([np.mean(batch_losses[max(0, i-20):i+20]) for i in range(0, len(batch_losses))])
err = np.array([np.std(batch_losses[max(0, i-20):i+20]) for i in range(0, len(batch_losses))])
plt.plot(x, y)
plt.fill_between(x, y - err, y + err, alpha=.3)

plt.show()

In [None]:
np.argmin(y)

# Evaluate performances

In [None]:
from sltools.models.siamese import build_predict_fn

iteration = 9
lasagne.layers.set_all_param_values(l_linout, report[str(iteration)]['params'])
predict_fn = build_predict_fn(model_dict, batch_size)

In [None]:
train_vects = predict_fn(feat_seqs_train + [durations_train])
val_vects = predict_fn(feat_seqs_val + [durations_val])

In [None]:
from bisect import bisect_left
from scipy.spatial.distance import cdist, pdist, squareform


def episode(embeddings, labels, k, n_shots, n_adversaries, n_classes):
    vocabulary = np.sort(np.unique(labels))    
    voca_subset = np.random.permutation(vocabulary)[:n_classes]
    idx_where = {i: np.where(labels == l)[0] for i, l in enumerate(voca_subset)}

    ep_embeddings = np.empty([n_shots + 1 + n_adversaries, embeddings.shape[1]])
    ep_labels = np.empty([n_shots + 1 + n_adversaries], dtype=np.int)
    for i in range(n_adversaries):
        l = np.random.randint(n_classes - 1)
        m = np.random.choice(idx_where[l])
        ep_embeddings[i] = embeddings[m]
        ep_labels[i] = l
    for i in range(n_adversaries, n_adversaries + n_shots + 1):
        l = n_classes - 1
        m = np.random.choice(idx_where[l])
        ep_embeddings[i] = embeddings[m]
        ep_labels[i] = l
    
    dists = cdist(ep_embeddings[n_adversaries:], ep_embeddings)
    neighbours = np.argsort(dists, axis=1)[:, 1:k+1]
    counts = np.stack([
        np.histogram(ep_labels[n], bins=np.arange(-0.5, n_classes))[0]
        for n in neighbours])
    ranks = [np.argmax(voca_subset[o[::-1]] == voca_subset[-1])
             for o in np.argsort(counts, axis=1)]
    return ranks

In [None]:
k = 1
n_shots = 5
n_adversaries = 200
n_classes = 15

train_episodes = []
for i in range(1000):
    train_episodes.extend(episode(
        train_vects, labels_train,
        k, n_shots, n_adversaries, n_classes))

cutoff = 5
plt.figure()
ax = plt.subplot(2, 1, 1)
bins = list(np.arange(-0.5, cutoff + .5)) + [n_classes]
h, _ = np.histogram(train_episodes, bins=bins)
ax.bar(np.arange(cutoff + 1), h / h.sum())
ax.set_xlim((-.5, cutoff + .5))
ax.set_ylim((0, 1))
ax.set_title("{}-shot".format(n_shots))
ax.set_xlabel("rank")
ax.set_xticks([i for i in range(0, cutoff + 1, 2)] + [cutoff])
ax.set_xticklabels([i for i in range(0, cutoff + 1, 2)] + [">{}".format(cutoff)])

plt.tight_layout()
plt.show()

val_episodes = []
for i in range(1000):
    val_episodes.extend(episode(
        val_vects, labels_val,
        k, n_shots, n_adversaries, n_classes))

cutoff = 5
plt.figure()
ax = plt.subplot(2, 1, 2)
bins = list(np.arange(-0.5, cutoff + .5)) + [n_classes]
h, _ = np.histogram(val_episodes, bins=bins)
ax.bar(np.arange(cutoff + 1), h / h.sum())
ax.set_xlim((-.5, cutoff + .5))
ax.set_ylim((0, 1))
ax.set_title("{}-shot".format(n_shots))
ax.set_xlabel("rank")
ax.set_xticks([i for i in range(0, cutoff + 1, 2)] + [cutoff])
ax.set_xticklabels([i for i in range(0, cutoff + 1, 2)] + [">{}".format(cutoff)])

plt.tight_layout()
plt.show()