In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
sys.path.insert(0, os.path.dirname(os.getcwd()))

os.environ['THEANO_FLAGS'] = "device=cuda0"

In [None]:
import pickle as pkl
import shelve
from pprint import pprint
import numpy as np
import lasagne
import theano
import theano.tensor as T
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.cm
from lproc import subset, rmap
from datasets import ch14dataset as dataset
from datasets.utils import gloss2seq, seq2gloss
from sltools.nn_utils import onehot, jaccard

In [None]:
from sltools.models.rnn import build_predict_fn

# from experiments.ch14_skel.a_data import durations, gloss_seqs, tmpdir, train_subset, val_subset, test_subset
# from experiments.ch14_skel.b_preprocess import feat_seqs
# from experiments.ch14_skel.c_models import build_lstm
# max_time = 128
# batch_size = 32
# multiple_inputs = False

from experiments.ch14_bgr.a_data import durations, gloss_seqs, tmpdir, train_subset, val_subset, test_subset
from experiments.ch14_bgr.b_preprocess import feat_seqs
from experiments.ch14_bgr.c_models import build_lstm
max_time = 128
batch_size = 32
multiple_inputs = False

# from experiments.ch14_fusion.a_data import durations, gloss_seqs, tmpdir, train_subset, val_subset, test_subset
# from experiments.ch14_fusion.b_preprocess import feat_seqs
# from experiments.ch14_fusion.c_models import build_lstm
# max_time = 128
# batch_size = 12
# multiple_inputs = True

nlabels = 21
# tmpdir = "/home/granger/.cache/ch14_skel_rnn_17_1_hinge"

In [None]:
report = shelve.open(os.path.join(tmpdir, 'rnn_report'))

In [None]:
feats_seqs_train = subset(feat_seqs, train_subset)
gloss_seqs_train = subset(gloss_seqs, train_subset)
durations_train = subset(durations, train_subset)

feats_seqs_val = subset(feat_seqs, val_subset)
gloss_seqs_val = subset(gloss_seqs, val_subset)
durations_val = subset(durations, val_subset)

feats_seqs_test = subset(feat_seqs, test_subset)
gloss_seqs_test = subset(gloss_seqs, test_subset)
durations_test = subset(durations, test_subset)

# Training

In [None]:
all_batch_losses = []
all_epoch_losses = []
n_epochs = []
for i in map(str, sorted(map(int, report.keys()))):
    r = report[i]
    all_batch_losses += r['batch_losses']
    all_epoch_losses.append(r['epoch_loss'])

In [None]:
plt.figure(figsize=(12, 3))
plt.plot(np.arange(len(all_epoch_losses)), all_epoch_losses, c='red')
n_batches = len(all_batch_losses) // len(all_epoch_losses)
error = np.array([np.std(all_batch_losses[i:i+n_batches]) 
                  for i in range(0, len(all_batch_losses), n_batches)])
plt.fill_between(np.arange(len(all_epoch_losses)), all_epoch_losses-error, all_epoch_losses+error)
# plt.semilogy([10 ** (i - 5) for i in range(5)])
# plt.yscale("log", nonposy='clip')

In [None]:
# pprint(report['0']['fit_args'])
best_epoch = sorted([(float(report[str(e)]['val_scores']['jaccard']), int(e))
                     for e in report.keys() if 'val_scores' in report[str(e)].keys()])[-1][1]
print("best epoch: {}".format(best_epoch))
r = report[str(best_epoch)]
pprint(r['train_scores']['jaccard'])
pprint(r['train_scores']['framewise'])
pprint(r['val_scores']['jaccard'])
pprint(r['train_scores']['framewise'] - r['val_scores']['framewise'])

In [None]:
epoch_report = report[str(best_epoch)]

if multiple_inputs:
    input_shape = tuple([x.shape[1:] for x in feats_seqs_train[0]])
else:
    input_shape = (feats_seqs_train[0].shape[1:],)

model = build_lstm(*input_shape,
                   batch_size=batch_size, max_time=max_time)

all_layers = lasagne.layers.get_all_layers(model['l_linout'])
with open(os.path.join(tmpdir, "rnn_it{:04d}.pkl".format(best_epoch)), 'rb') as f:
    params = pkl.load(f)
    lasagne.layers.set_all_param_values(all_layers, params)

predict_fn = build_predict_fn(model, batch_size, max_time, nlabels, model['warmup'])

# Score

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

def analyse(*data):
    fig, axs = plt.subplots(1, len(data), figsize=(15, 3))
    cmap = matplotlib.cm.viridis
    cmap.set_bad(cmap(0.001))
    
    for ax, (feat_seqs, gloss_seqs, durations) in zip(axs, data):
        labels = [gloss2seq(g_, d_, 0)
                  for g_, d_ in zip(gloss_seqs, durations)]
        pred = [np.argmax(p, axis=1) for p in predict_fn(feat_seqs)]

        score = np.mean([jaccard(onehot(l, np.arange(1, 20)), onehot(p, np.arange(1, 20)))
                         for l, p in zip (labels, pred)])

        pred = np.concatenate(pred)
        labels = np.concatenate(labels)

        confusion = confusion_matrix(labels, pred).astype(np.double)
        confusion /= np.sum(confusion, axis=1)[:, None]

        print("Jaccard index: {:0.3f}".format(score))
        print("Framewise: {:0.3f}".format(np.mean(pred == labels)))

        im = ax.matshow(confusion, interpolation='none', cmap=cmap, 
                        clim=(0.001, 1), norm=LogNorm(vmin=0.001, vmax=1))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im, cax=cax, orientation='vertical')
        

# 21 distinct colors
cmap = np.array([[113,204,0], [209,73,251], [243,255,52], [223,119,255], 
         [139,255,150], [255,66,189], [1,222,201], [255,77,30], 
         [0,149,225], [137,106,0], [0,43,105], [255,230,180], 
         [111,0,66], [0,113,63], [251,177,255], [56,96,0], 
         [160,218,255], [74,0,6], [255,170,172], [0,62,95], 
         [93,43,0]]) / 255

In [None]:
analyse((feats_seqs_train, gloss_seqs_train, durations_train),
        (feats_seqs_val, gloss_seqs_val, durations_val))

# Preview prediction

In [None]:
# def preview_seq(proba):        
#     plt.plot(np.arange(len(proba)), proba[:, 0], ls=':', c=cmap[0])
#     for c in range(1, 21):
#         plt.plot(np.arange(len(proba)), proba[:, c], c=cmap[c])
def preview_seq(proba):
    x = proba[:, 0] > 0.1
    for g, start, stop in seq2gloss(x):
        start = max(0, start - 1)
        stop = min(len(x), stop + 1)
        if g:
            plt.plot(np.arange(start, stop), proba[start:stop, 0], ls=':', c=cmap[0])
    for c in range(1, 21):
        x = proba[:, c] > 0.1
        for g, start, stop in seq2gloss(x):
            if g:
                start = max(0, start - 1)
                stop = min(len(x), stop + 1)
                plt.plot(np.arange(start, stop), proba[start:stop, c], c=cmap[c])
    
    plt.gca().set_ylim((0.1, 1.05))

In [None]:
s = 2

proba = predict_fn([feats_seqs_val[s]])[0]
labels = onehot(gloss2seq(gloss_seqs_val[s], durations_val[s], 0), 
                np.arange(0, 21))

plt.figure(figsize=(13, 4))
plt.subplot(2, 1, 1)
preview_seq(proba)
plt.subplot(2, 1, 2)
preview_seq(labels * 1.0)

# print(transformations[val_subset_augmented[s]])

# Analyse errors

In [None]:
labels = [gloss2seq(g_, d_, 0)
          for g_, d_ in zip(gloss_seqs_val, durations_val)]
preds = [np.argmax(p, axis=1) for p in predict_fn(feats_seqs_val)]

score = [jaccard(onehot(l, np.arange(1, 20)), onehot(p, np.arange(1, 20)))
         for l, p in zip (labels, preds)]

plt.hist(score, np.linspace(0.5, 1, 40))

list(zip(range(len(score)), score))[:15]

In [None]:
# nb of false positives out of sequence vocabulary

np.mean([len(set(p_) - set(l_)) for p_, l_ in zip(preds, labels)], axis=0)

In [None]:
# confusion types
preds_cat = np.concatenate(preds)
labels_cat = np.concatenate(labels)

confusion = confusion_matrix(labels_cat, preds_cat)

cum_err = np.sum(confusion, axis=1) - np.diag(confusion)

print("false pos: {}  false neg: {}, mis-class: {}".format(
    cum_err[0], np.sum(confusion[1:, 0]), np.sum(cum_err[1:]) - np.sum(confusion[1:, 0])))

In [None]:
# correlate error with predicted gloss duration

plt.figure(figsize=(12, 4))

validity = [np.sum(l[start:stop] == g) 
            for p, l in zip(preds, labels)
            for (g, start, stop) in seq2gloss(p)
            if g != 0]
gloss_d = [stop - start
           for p in preds 
           for (g, start, stop) in seq2gloss(p)
           if g != 0]

scores = np.zeros((int(np.ceil(max(gloss_d) / 5 + 0.0001)),))
total_d = np.zeros((int(np.ceil(max(gloss_d) / 5 + 0.0001)),))
for v, d in zip(validity, gloss_d):
    idx = int(d / 5)
    scores[idx] += v
    total_d[idx] += d

plt.gca().bar(np.arange(5, int(np.ceil(max(gloss_d) + 5.0001)), 5), scores / total_d)


validity = [np.sum(l[start:stop] == 0)
            for p, l in zip(preds, labels)
            for (g, start, stop) in seq2gloss(p)
            if g != 0]
gloss_d = [stop - start
           for p in preds
           for (g, start, stop) in seq2gloss(p)
           if g != 0]

scores = np.zeros((int(np.ceil(max(gloss_d) / 5 + 0.0001)),))
total_d = np.zeros((int(np.ceil(max(gloss_d) / 5 + 0.0001)),))
for v, d in zip(validity, gloss_d):
    idx = int(d / 5)
    scores[idx] += v
    total_d[idx] += d

plt.gca().bar(np.arange(7, int(np.ceil(max(gloss_d) + 7.0001)), 5), scores / total_d)

In [None]:
# Score with filtered short segments

ji = np.mean([jaccard(onehot(l, np.arange(1, 20)), onehot(p, np.arange(1, 20)))
              for l, p in zip (labels, preds)])
preds_cat = np.concatenate(preds)
labels_cat = np.concatenate(labels)

print("Jaccard index: {:0.3f}".format(ji))
print("Framewise: {:0.3f}".format(np.mean(preds_cat == labels_cat)))

thresholds = np.arange(10, 30)
jis = np.empty((len(thresholds),))
for i, t in enumerate(thresholds):
    preds2 = [gloss2seq([(g, start, stop)
                         for (g, start, stop) in seq2gloss(p) 
                         if stop - start > t],
                        len(p), 0)
              for p in preds]
    jis[i] = np.mean([jaccard(onehot(l, np.arange(1, 20)), onehot(p, np.arange(1, 20)))
                      for l, p in zip (labels, preds2)])

thres1 = thresholds[np.argmax(jis)]

thresholds = np.arange(100, 150, 5)
jis = np.empty((len(thresholds),))
for i, t in enumerate(thresholds):
    preds2 = [gloss2seq([(g, start, stop)
                         for (g, start, stop) in seq2gloss(p) 
                         if thres1 < stop - start < t],
                        len(p), 0)
              for p in preds]
    jis[i] = np.mean([jaccard(onehot(l, np.arange(1, 20)), onehot(p, np.arange(1, 20)))
                      for l, p in zip (labels, preds2)])
    
thres2 = thresholds[np.argmax(jis)]
    
preds2 = [gloss2seq([(g, start, stop)
                     for (g, start, stop) in seq2gloss(p) 
                     if thres2 > stop - start > thres1],
                    len(p), 0)
          for p in preds]
ji = np.mean([jaccard(onehot(l, np.arange(1, 20)), onehot(p, np.arange(1, 20)))
              for l, p in zip (labels, preds2)])
preds_cat = np.concatenate(preds2)
labels_cat = np.concatenate(labels)
print("Optimal thresholds: {} - {}".format(thres1, thres2))
print("Jaccard index: {:0.3f}".format(ji))
print("Framewise: {:0.3f}".format(np.mean(preds_cat == labels_cat)))

# Testing

In [None]:
# thres = <-set above

labels = [gloss2seq(g_, d_, 0)
          for g_, d_ in zip(gloss_seqs_test, durations_test)]

# Complete model
preds = [np.argmax(p, axis=1) for p in predict_fn(feats_seqs_test)]
preds2 = [gloss2seq([(g, start, stop)
                     for (g, start, stop) in seq2gloss(p) 
                     if thres1 < stop - start < thres2],
                    len(p), 0)
          for p in preds]

score = np.mean([jaccard(onehot(l, np.arange(1, 20)), onehot(p, np.arange(1, 20)))
                 for l, p in zip(labels, preds2)])

print("testing score: {}".format(score))

# Analyse model

## Filters

In [None]:
from sklearn.manifold import TSNE
from sltools.tconv import TemporalConv

tc_l = None
layers = lasagne.layers.get_all_layers(model['l_linout'])

for l in layers:
    if isinstance(l, TemporalConv):
        tc_l = l
        break

W = np.asarray(tc_l.W.eval())
tsne = TSNE(n_components=1, n_iter=5000, n_iter_without_progress=100, verbose=True)
filter_order = np.argsort(tsne.fit_transform(W)[:, 0])

f = plt.figure(figsize=(15, 8))

ax = None
for i in range(4):
    if i > 0:
        ax = plt.subplot(1, 4, i+1, sharey=ax)
    else:
        ax = plt.subplot(1, 4, i+1)
    ax.pcolor(W[filter_order[i * 200:(i + 1) * 200]], 
              clim=(-np.abs(W).max(), np.abs(W).max()), 
              cmap='RdBu')

In [None]:
# Show activated filters for a category

act_l = layers[layers.index(tc_l) + 2]
X = feats_seqs_val
X = rmap(lambda x_: (x_,), X)
y = [gloss2seq(g_, len(r_), 0)
     for g_, r_ in zip(gloss_seqs_val, feats_seqs_val)]

# Chunking
step = recognizer.max_len - 2 * recognizer.warmup
durations = [len(seq[0]) for seq in X]
chunks = [(i, k, min(k + recognizer.max_len, d))
          for i, d in enumerate(durations)
          for k in range(0, d - recognizer.warmup, step)]
grads = [np.zeros((d, tc_l.output_shape[2]), dtype=theano.config.floatX)
         for d in durations]

# Functions
X_buffers = [np.zeros(shape=(recognizer.batch_size, recognizer.max_len) + shape,
                      dtype=theano.config.floatX)
             for shape in recognizer.input_shapes]
y_buffer = np.zeros(shape=(recognizer.batch_size, recognizer.max_len), dtype=np.int32)
d_buffer = np.zeros((recognizer.batch_size,), dtype=np.int32)
c_buffer = np.zeros((recognizer.batch_size, 3), dtype=np.int32)
tgt_var = T.imatrix()
  
activations, predictions = lasagne.layers.get_output(
    [act_l, recognizer.l_raw], deterministic=True)
g = theano.grad(predictions[T.arange(recognizer.batch_size)[:, None], :, tgt_var].sum(), 
                wrt=activations)
g_fn = theano.function([recognizer.l_in[0].input_var, recognizer.durations_var, tgt_var], g)

j = 0
for i, (seq, start, stop) in enumerate(chunks):
    for b, x in zip(X_buffers, X[seq]):
        b[j][:stop - start] = x[start:stop]
    y_buffer[j][:stop - start] = y[seq][start:stop]
    d_buffer[j] = stop - start
    c_buffer[j] = (seq, start, stop)

    if j + 1 == recognizer.batch_size or i == len(chunks) - 1:
        batch_predictions = g_fn(*X_buffers, d_buffer, y_buffer)[:j + 1]
        for (seq_, start_, stop_), grad in zip(c_buffer, batch_predictions):
            warmup = recognizer.warmup if start_ > 0 else 0
            grads[seq_][start_ + warmup:stop_] = \
                grad[warmup:stop_ - start_]

    j = (j + 1) % recognizer.batch_size
    
all_grads = np.concatenate(grads)
all_labels = np.concatenate(y)

In [None]:
rng = .5 # max(-X.min(), X.max())

ashes = []
for l in range(recognizer.nlabels):
    where = (all_labels == l)
    h = np.stack([np.histogram(X[where, i], bins=np.linspace(-rng, rng, 16))[0] / where.sum()
                  for i in range(X.shape[1])])
    ashes.append(h)

meanh = np.mean(ashes, axis=0)

In [None]:
fig = plt.figure(figsize=(15, 50))

ax = plt.subplot2grid((1, 6), (0, 0), colspan=2)
ax.pcolormesh(W[filter_order, :], clim=(-np.abs(W).max(), np.abs(W).max()), cmap='bwr')
ax.set_yticks(np.arange(0, W.shape[0], 5))
# ax.set_yticks([])
ax.set_yticklabels([])
ax.grid(True)

for p, i in enumerate([0, 1, 2, -1]):
    h = ashes[i] - meanh
#     h = meanh
    ax = plt.subplot2grid((1, 6), (0, 2 + p), colspan=1)
    ax.pcolormesh(h, clim=(-1, 1), cmap='bwr')
    ax.set_yticks(np.arange(0, W.shape[0], 5))
    ax.set_xticks(np.arange(0, 16, 3))
    ax.set_xticklabels(np.linspace(-rng, rng, 6))
    ax.grid(True)
    ax.set_ylim((0, W.shape[0]))
    ax.set_yticklabels([])

fig.subplots_adjust(hspace=0, wspace=0.1)