In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

os.environ['THEANO_FLAGS'] = "device=cuda0"

In [None]:
import numpy as np
import theano
import theano.tensor as T
import lasagne
from lproc import rmap, subset, chunk_load
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve

from experiments.siamese_oneshot.a_data import durations, labels, transformations, \
    train_subset, val_subset
from experiments.siamese_oneshot.b_preprocess import feat_seqs
from experiments.siamese_oneshot.c_model import build_model

In [None]:
# Build model

max_time = 128
batch_size = 16

model = build_model(feat_seqs[0][0].shape, batch_size, max_time)
l_linout = model['l_linout']
l_in_left, l_in_right = model['l_in']
l_duration_left, l_duration_right = model['l_duration']
linout = lasagne.layers.get_output(l_linout)

# Build training routines
targets = T.vector('targets')
l_rate_var = T.scalar('l_rate')
loss = T.switch(targets > .1,
                .5 * linout ** 2,
                .5 * T.maximum(0, 1 - linout) ** 2).sum()
params = lasagne.layers.get_all_params(l_linout, trainable=True)
updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=l_rate_var)
update_fn = theano.function([l_in_left.input_var, l_duration_left.input_var,
                             l_in_right.input_var, l_duration_right.input_var,
                             targets, l_rate_var],
                            outputs=loss, updates=updates)

linout2 = lasagne.layers.get_output(l_linout, deterministic=True)
loss2 = T.switch(targets > .1,
                 .5 * linout2 ** 2,
                 .5 * T.maximum(0, 1 - linout2) ** 2)
predict_fn = theano.function([l_in_left.input_var, l_duration_left.input_var,
                              l_in_right.input_var, l_duration_right.input_var,
                              targets],
                             outputs=[linout2, loss2])

running_loss = 0

In [None]:
# Load dataset

def pad_seq(seq):
    return np.concatenate((seq[:max_time],
                           np.zeros((max(0, max_time - len(seq)),) + seq.shape[1:],
                                    dtype=seq.dtype)))

feat_seqs_train = subset(feat_seqs, train_subset)
feat_seqs_train = rmap(pad_seq, feat_seqs)
transformations_train = subset(transformations, train_subset)
labels_train = labels[train_subset].astype(np.int32)
durations_train = durations[train_subset]

feat_seqs_val = subset(feat_seqs, val_subset)
feat_seqs_val = rmap(pad_seq, feat_seqs_val)
transformations_val = subset(transformations, val_subset)
labels_val = labels[val_subset].astype(np.int32)
durations_val = durations[val_subset]

input_shape = feat_seqs[0][0].shape

In [None]:
# Run training iterations
def generate_pairs(transformations, labels, vocabulary, n, positive_ratio):
    # precompute where to look for positive pairs
    where_labels = {l: np.where(labels == l)[0] for l in vocabulary}
    where_not_labels = {l: np.where(labels != l)[0] for l in vocabulary}
    
    pairs = np.empty((n, 2), dtype=np.uint64)
    positive = np.empty((n,), dtype=np.bool)
    
    for k in range(n):
        i = k % len(labels)  # simply pick left items iteratively
        l = labels[i]
        p = np.random.random() < positive_ratio
        if p:
            j = np.random.choice(where_labels[l])
            while transformations[j][0] == transformations[i][0]:
                j = np.random.choice(where_labels[l])
        else:
            j = np.random.choice(where_not_labels[l])
        
        pairs[k] = i, j
        positive[k] = p
    
    return pairs, positive

In [None]:
l_rate = 1e-3
positive_ratio = 0.6

buffers = [np.zeros((4 * batch_size, max_time) + input_shape, dtype=np.float32),
           np.zeros((4 * batch_size,), dtype=np.int32),
           np.zeros((4 * batch_size, max_time) + input_shape, dtype=np.float32),
           np.zeros((4 * batch_size,), dtype=np.int32),
           np.zeros((4 * batch_size,), dtype=np.bool)]

for e in range(10):
    pairs, tgts = generate_pairs(transformations_train, labels_train, np.unique(labels), 
                                 len(train_subset) * 2, positive_ratio)
    x1 = rmap(lambda pair: feat_seqs_train[pair[0]], pairs)
    x2 = rmap(lambda pair: durations_train[pair[0]], pairs)
    x3 = rmap(lambda pair: feat_seqs_train[pair[1]], pairs)
    x4 = rmap(lambda pair: durations_train[pair[1]], pairs)
    for i, (xl, dl, xr, dr, tgt) in enumerate(chunk_load([x1, x2, x3, x4, tgts], buffers, 
                                                          chunk_size=batch_size, pad_last=False)):
        if len(tgt) != batch_size:
            continue
        batch_loss = update_fn(xl, dl, xr, dr, tgt, l_rate)
        running_loss = .98 * running_loss + .02 * batch_loss
        if i % 30 == 0:
            print("\rloss: {}".format(running_loss), end="", flush=True)

    print("\repoch {:3d} loss: {}".format(e, running_loss))

In [None]:
# Preview results

# search for eer threshold
positive_ratio = 0.6
pairs, tgts = generate_pairs(transformations_train, labels_train, np.unique(labels), 
                             len(train_subset) * 2, positive_ratio)
x1 = rmap(lambda pair: feat_seqs_train[pair[0]], pairs)
x2 = rmap(lambda pair: durations_train[pair[0]], pairs)
x3 = rmap(lambda pair: feat_seqs_train[pair[1]], pairs)
x4 = rmap(lambda pair: durations_train[pair[1]], pairs)
all_preds = np.empty((len(tgts) - len(tgts) % batch_size,))
all_losses = np.empty((len(tgts) - len(tgts) % batch_size,))
all_targets = np.empty((len(tgts) - len(tgts) % batch_size,))
i = 0
for xl, dl, xr, dr, tgt in chunk_load([x1, x2, x3, x4, tgts], buffers, 
                                       chunk_size=batch_size, pad_last=False):
    if len(tgt) != batch_size:
        continue
    all_preds[i:i + len(tgt)], all_losses[i:i + len(tgt)] = \
        predict_fn(xl, dl, xr, dr, tgt)
    all_targets[i:i + len(tgt)] = tgt
    i += len(tgt)

eer_thres = 0
erdiff = 1
for t in np.linspace(0, max(all_preds), 100):
    erdiff_ = np.abs(np.mean(all_preds[all_targets > .5] > t) 
                     - np.mean(all_preds[all_targets < .5] < t))
    if erdiff_ < erdiff:
        eer_thres = t
        erdiff = erdiff_

print("training fnr = ", np.mean(all_preds[all_targets < .5] < eer_thres))
print("training fpr = ", np.mean(all_preds[all_targets > .5] > eer_thres))

# observe training results
plt.figure()
bins = np.linspace(0, 2, 40)
pos, _ = np.histogram(all_preds[all_targets == 1], bins=bins)
neg, _ = np.histogram(all_preds[all_targets == 0], bins=bins)
plt.bar(bins[:-1], pos, width=.025, color='red', alpha=.5)
plt.bar(bins[:-1], neg, width=.025, color='blue', alpha=.5)

# observe validation results
positive_ratio = 0.4
pairs, tgts = generate_pairs(transformations_val, labels_val, np.unique(labels),
                             len(val_subset) * 2, positive_ratio)
x1 = rmap(lambda pair: feat_seqs_val[pair[0]], pairs)
x2 = rmap(lambda pair: durations_val[pair[0]], pairs)
x3 = rmap(lambda pair: feat_seqs_val[pair[1]], pairs)
x4 = rmap(lambda pair: durations_val[pair[1]], pairs)
all_preds = np.empty((len(tgts) - len(tgts) % batch_size,))
all_losses = np.empty((len(tgts) - len(tgts) % batch_size,))
all_targets = np.empty((len(tgts) - len(tgts) % batch_size,))
i = 0
for xl, dl, xr, dr, tgt in chunk_load([x1, x2, x3, x4, tgts], buffers, 
                                       chunk_size=batch_size, pad_last=False):
    if len(tgt) != batch_size:
        continue
    all_preds[i:i + len(tgt)], all_losses[i:i + len(tgt)] = \
        predict_fn(xl, dl, xr, dr, tgt)
    all_targets[i:i + len(tgt)] = tgt
    i += len(tgt)

plt.figure()
bins = np.linspace(0, 2, 40)
pos, _ = np.histogram(all_preds[all_targets == 1], bins=bins)
neg, _ = np.histogram(all_preds[all_targets == 0], bins=bins)
plt.bar(bins[:-1], pos, width=.025, color='red', alpha=.5)
plt.bar(bins[:-1], neg, width=.025, color='blue', alpha=.5)

fpr, tpr, thres = roc_curve(all_targets > .5, all_preds)
plt.figure()
plt.plot(tpr, fpr)
plt.gca().set_aspect('equal')
plt.show()

print("eer thres = ", eer_thres)
print("fnr at eer = ", np.mean(all_preds[all_targets > .5] > eer_thres))
print("fpr at eer = ", np.mean(all_preds[all_targets < .5] < eer_thres))