In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

os.environ['THEANO_FLAGS'] = "device=cuda1"

In [None]:
import shelve
import collections
import numpy as np
import theano
import theano.tensor as T
import lasagne
import seqtools
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist

from sltools.nn_utils import adjust_length
from sltools.models.siamese import build_predict_fn

from experiments.siamese_triplet.a_data import \
    durations, labels, recordings, \
    train_subset, val_subset
from experiments.siamese_triplet.b_preprocess import skel_feat_seqs

np.set_printoptions(linewidth=100)

In [None]:
from experiments.siamese_triplet.a_data import cachedir

report = shelve.open(os.path.join(cachedir, "rnn_report"))
report.clear()

# Load dataset

In [None]:
feat_seqs_train = [
    seqtools.gather(skel_feat_seqs, train_subset)
    ]
recordings_train = seqtools.gather(recordings, train_subset)
labels_train = labels[train_subset].astype(np.int32)
durations_train = durations[train_subset].astype(np.int32)

feat_seqs_val = [
    seqtools.gather(skel_feat_seqs, val_subset)
    ]
labels_val = labels[val_subset].astype(np.int32)
durations_val = durations[val_subset].astype(np.int32)

del recordings, labels, durations, skel_feat_seqs

In [None]:
def sample_triplets(vocabulary, labels, n, test=None):
    test = test or (lambda *_: True)
    where_labels = {l: np.where(labels == l)[0] for l in vocabulary}
    where_not_labels = {l: np.where(labels != l)[0] for l in vocabulary}

    triplets = np.empty((n, 3), dtype=np.uint64)

    i = 0
    while i < len(labels):
        left = i % len(labels)
        wl = where_labels[labels[left]]
        wn = where_not_labels[labels[left]]
        middle = np.random.choice(wl)
        right = np.random.choice(wn)

        while not test(left, middle, right):
            middle = np.random.choice(wl)
            right = np.random.choice(wn)

        triplets[i] = [left, middle, right]
        i += 1

    return np.random.permutation(triplets)


def triplet2minibatches(feat_seqs, durations, triplets):
    feat_seqs = [seqtools.smap(lambda s: adjust_length(s, max_time), f) 
                 for f in feat_seqs]
    durations = np.fmin(durations, max_time)

    feat_triplets = [
        seqtools.starmap(
            lambda i, j, k: np.stack([
                adjust_length(f[i], max_time), 
                adjust_length(f[j], max_time), 
                adjust_length(f[k], max_time)], axis=0),
            triplets)
        for f in feat_seqs]
    duration_triplets = seqtools.smap(lambda triplet: durations[triplet], triplets)

    feat_batches = [
        seqtools.batch(f, batch_size // 3, drop_last=True, collate_fn=np.concatenate)
        for f in feat_triplets]
    duration_batches = seqtools.batch(
        duration_triplets, 
        batch_size // 3,
        drop_last=True, collate_fn=np.concatenate)

    return seqtools.collate(feat_batches + [duration_batches])

# Build model

In [None]:
from experiments.siamese_triplet.c_model import skel_rnn

max_time = 128
batch_size = 12
encoder_kwargs = {
    "tconv_sz": 15,
    "filter_dilation": 1,
    "num_tc_filters": 256,
    "dropout": 0.2
}

assert batch_size % 3 == 0, "the model must take triplets"

model_dict = skel_rnn(
    *tuple(f[0][0].shape for f in feat_seqs_train), 
    batch_size=batch_size, max_time=max_time, 
    encoder_kwargs=encoder_kwargs)

l_linout = model_dict['l_linout']
l_in = model_dict['l_in']
l_duration = model_dict['l_duration']

# Run training iterations

In [None]:
from sltools.models.siamese import triplet_loss

l_rate = 1e-4
n_epoches = 10

l_rate_var = T.scalar('l_rate')
linout = lasagne.layers.get_output(l_linout, deterministic=False)
train_loss = triplet_loss(linout[0::3], linout[1::3], linout[2::3]).sum()
params = lasagne.layers.get_all_params(l_linout, trainable=True)
updates = lasagne.updates.adam(train_loss, params, learning_rate=l_rate_var)
update_fn = theano.function(
    [l.input_var for l in l_in] + [l_duration.input_var, l_rate_var],
    outputs=train_loss, updates=updates)

linout = lasagne.layers.get_output(l_linout, deterministic=True)
test_loss = triplet_loss(linout[0::3], linout[1::3], linout[2::3]).sum()
loss_fn = theano.function(
    [l.input_var for l in l_in] + [l_duration.input_var],
    outputs=test_loss)

In [None]:
running_train_loss = 2
running_val_loss = 2

for e in range(3, 6): 
    train_losses = []
    val_losses = []

    # Minibatch iterator
    triplets = sample_triplets(
        sorted(set(labels_train)), labels_train, len(labels_train),
        test=lambda i, j, k: recordings_train[i] != recordings_train[j])
    train_minibatches = triplet2minibatches(
        feat_seqs_train, durations_train, triplets)
    triplets = sample_triplets(
        sorted(set(labels_val)), labels_val, len(labels_val),
        test=lambda i, j, k: True)
    val_minibatches = triplet2minibatches(
        feat_seqs_val, durations_val, triplets)
    
    # Training iterations
    for i, minibatch in enumerate(seqtools.prefetch(train_minibatches, 2, max_buffered=20)):
        batch_loss = float(update_fn(*minibatch, l_rate))
        if np.isnan(batch_loss):
            raise ValueError()
        running_train_loss = .99 * running_train_loss + .01 * batch_loss

        if i % 3 == 0:
            train_losses.append(batch_loss)
        
        if i % 30 == 0:
            batch_losses = [loss_fn(*val_minibatches[j]) 
                            for j in np.random.choice(len(val_minibatches), 10)]
            val_losses.extend(batch_losses)
            running_val_loss = .91 * running_val_loss + .09 * np.mean(batch_losses)
            print("\rloss: {:>2.3f} / {:>2.3f}".format(running_train_loss, running_val_loss), 
                  end="", flush=True)

    l_rate *= 0.3
            
    report[str(e)] = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'epoch_loss': running_train_loss,
        'params': lasagne.layers.get_all_param_values(l_linout)
    }

In [None]:
plt.figure()

batch_losses = np.concatenate([r['train_losses'] for r in report.values()])

x = np.arange(len(batch_losses))
y = np.array([np.mean(batch_losses[max(0, i-50):i+50]) for i in range(0, len(batch_losses))])
err = np.array([np.std(batch_losses[max(0, i-50):i+50]) for i in range(0, len(batch_losses))])
plt.plot(x, y)
plt.fill_between(x, y - err, y + err, alpha=.3)

batch_losses = np.concatenate([r['val_losses'] for r in report.values()])

x = np.arange(len(batch_losses))
y = np.array([np.mean(batch_losses[max(0, i-50):i+50]) for i in range(0, len(batch_losses))])
err = np.array([np.std(batch_losses[max(0, i-50):i+50]) for i in range(0, len(batch_losses))])
plt.plot(x, y)
plt.fill_between(x, y - err, y + err, alpha=.3)

plt.show()

# Evaluate performances

In [None]:
# iteration = 9
# lasagne.layers.set_all_param_values(l_linout, report[str(iteration)]['params'])

predict_fn = build_predict_fn(model_dict, batch_size, max_time)
embeddings_train = predict_fn(feat_seqs_train, durations_train)
embeddings_val = predict_fn(feat_seqs_val, durations_val)

In [None]:
def episode(embeddings, labels, voca_size, shots, k):    
    vocabulary = np.sort(np.unique(labels))
    
    # sample vocabulary subset
    ep_vocabulary = np.random.choice(vocabulary, size=voca_size, replace=False)
    
    ep_train_subset = []
    ep_test_subset = []
    for l in ep_vocabulary:
        where_label = np.random.permutation(np.where(labels == l)[0])
        ep_train_subset.extend(where_label[:shots])
        ep_test_subset.extend(where_label[shots:])

    # run knn
    dists = cdist(embeddings[ep_test_subset], embeddings[ep_train_subset])
#     plt.figure()
#     plt.imshow(dists)
#     plt.show()
    
    neighbours = np.argsort(dists, axis=1)[:, :k]
    
    neighbours_labels = labels[ep_train_subset][None, neighbours][0]
    neighbours_dists = dists[np.arange(len(ep_test_subset))[:, None], neighbours]
    
    stats = np.empty((len(ep_test_subset), voca_size), dtype=[('freq', 'i4'), ('dist_score', 'f4'), ('class', 'i4')])
    for i, l in enumerate(ep_vocabulary):
        stats['freq'][:, i] = np.sum(neighbours_labels == l, axis=1)
        stats['dist_score'][:, i] = -np.sum(neighbours_dists * (neighbours_labels == l), axis=1)
        stats['class'][:, i] = l
    
    stats = np.sort(stats, axis=1)
    
    ranks = voca_size - 1 - np.argmax(labels[ep_test_subset, None] == stats['class'], axis=1)
    
    return ranks

In [None]:
neighbours = 1  # k
shots = 3
voca_size = 20

train_ranks = []
train_dists = []
train_stats = []
for i in range(50):
    _, unique_indices = np.unique(recordings_train, return_index=True)
    
    ranks = episode(
        embeddings_train[unique_indices], labels_train[unique_indices],
        voca_size, shots, neighbours)
    train_ranks.extend(ranks)

cutoff = neighbours + 2
plt.figure()
ax = plt.subplot(2, 1, 1)
bins = list(np.arange(-0.5, cutoff + .5)) + [voca_size]
h, _ = np.histogram(train_ranks, bins=bins)
ax.bar(np.arange(cutoff + 1), h / h.sum())
ax.set_xlim((-.5, cutoff + .5))
ax.set_ylim((0, 1))
ax.set_title("{}-shot".format(shots))
ax.set_xlabel("rank")
ax.set_xticks([i for i in range(0, cutoff + 1, 2)] + [cutoff])
ax.set_xticklabels([i for i in range(0, cutoff + 1, 2)] + [">{}".format(cutoff)])

plt.tight_layout()
plt.show()

val_ranks = []
val_dists = []
for i in range(50):
    ranks = episode(
        embeddings_val, labels_val,
        voca_size, shots, neighbours)
    val_ranks.extend(ranks)

cutoff = neighbours + 2
plt.figure()
ax = plt.subplot(2, 1, 1)
bins = list(np.arange(-0.5, cutoff + .5)) + [voca_size]
h, _ = np.histogram(val_ranks, bins=bins)
ax.bar(np.arange(cutoff + 1), h / h.sum())
ax.set_xlim((-.5, cutoff + .5))
ax.set_ylim((0, 1))
ax.set_title("{}-shot".format(shots))
ax.set_xlabel("rank")
ax.set_xticks([i for i in range(0, cutoff + 1, 2)] + [cutoff])
ax.set_xticklabels([i for i in range(0, cutoff + 1, 2)] + [">{}".format(cutoff)])

plt.tight_layout()
plt.show()