# Train the chatbot

This is remotely based on the amalgamation of [this](https://github.com/saltypaul/Seq2Seq-Chatbot)
and [that](https://github.com/marekrei/sequence-labeler) repositories. The images are taken from
the first one.

In [1]:
import os
import time
import json
import pickle

import numpy as np

import pandas as pd
import tqdm

We'll definitely need to plot something

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt

Uninterruptible section

In [3]:
import signal

class DelayedKeyboardInterrupt(object):
    def __enter__(self):
        self.signal_received = False
        self.old_handler = signal.getsignal(signal.SIGINT)
        signal.signal(signal.SIGINT, self.handler)

    def handler(self, sig, frame):
        self.signal_received = (sig, frame)

    def __exit__(self, type, value, traceback):
        signal.signal(signal.SIGINT, self.old_handler)
        if self.signal_received:
            self.old_handler(*self.signal_received)

Fix the random seed

In [4]:
random_state = np.random.RandomState(0x0BADC0DE)

Import Theano and Lasagne

In [5]:
# %env THEANO_FLAGS='device=cuda0,force_device=True,mode=FAST_RUN,floatX=float32'

import theano
theano.config.exception_verbosity = 'high'

import theano.tensor as tt

import lasagne
from lasagne.utils import floatX

Using cuDNN version 5103 on context None
Mapped name None to device cuda: GeForce GTX 980 Ti (0000:06:00.0)


Fix Lasagne's random seed.

In [6]:
lasagne.random.set_rng(np.random.RandomState(0xDEADC0DE))

Load the line lookup table. It was generated from the [Cornell Movie--Dialogs Corpus](https://people.mpi-sws.org/~cristian/Cornell_Movie-Dialogs_Corpus.html)

In [7]:
with open("../processed_lines.json", "r", encoding="utf-8") as fin:
    db_lines = {k: "\x02" + v + "\x03" for k, v in json.load(fin).items()}

Load the diaogues into Q&A pairs

In [8]:
with open("../processed_dialogues.json", "r", encoding="utf-8") as fin:
    db_dialogues = json.load(fin)

qa_pairs = []
for lines in db_dialogues:
    qa_pairs.extend(zip(lines[:-1], lines[1:]))
del db_dialogues

# for easier indexing
qa_pairs = np.array(qa_pairs)

Build the vocabulary

In [9]:
from collections import Counter
token_counts = Counter(c for l in db_lines.values() for c in l[1:-1])

## It is very important that these service characters be added first
vocab = ["\x02", "\x03"]
vocab += [c for c, f in token_counts.items()]

token_to_index = {w: i for i, w in enumerate(vocab)}

A function to lines into character id vectors.

In [10]:
def as_matrix(lines, max_len=None):
    if isinstance(lines, str):
        lines = [lines]

    length = max(map(len, lines))
    length = min(length, max_len or length)

    matrix = np.full((len(lines), length), -1, dtype='int32')
    for i, line in enumerate(lines):
        row_ix = [token_to_index.get(c, -1)
                  for c in line[:length]]

        matrix[i, :len(row_ix)] = row_ix

    return matrix

A function to convert History-Reply pairs into character id matrices

In [11]:
def retrieve_sentences(pairs):
    enc, dec = [], []
    for q, a in pairs:
        enc.append(db_lines[q])
        dec.append(db_lines[a])
    return enc, dec

def get_matrices(pairs, max_len=None):
    enc, dec = retrieve_sentences(pairs)
    return as_matrix(enc, max_len), as_matrix(dec, max_len)

A function to sample a batch from History-Reply pairs

In [12]:
def generate_batch(batch_size=32, max_len=None):
    n_batches = (len(qa_pairs) + batch_size - 1) // batch_size
    indices_ = random_state.permutation(len(qa_pairs))
    
    for i in range(n_batches):
        yield get_matrices(qa_pairs[indices_[i::n_batches]], max_len)

Define a simple seq2seq network (preferably in Lasagne)

In [13]:
from lasagne.layers import InputLayer, EmbeddingLayer
from lasagne.layers import GRULayer, DenseLayer
from lasagne.layers import NonlinearityLayer

from broadcast import BroadcastLayer, UnbroadcastLayer
from lasagne.layers import SliceLayer

from lasagne.layers.base import Layer

The architecture hyper parameters

In [14]:
model_file = "pickles/simple_mdl_epoch-15.pkl"
if os.path.exists(model_file):
    with open(model_file, "rb") as fin:
        ver, *rest = pickle.load(fin)
        assert (ver == "2.0") or (ver == "GRULayer")

    hyper, vocab, weights = rest

else:
    hyper = {
        "n_embed_char": 32,
        "n_hidden_decoder": 256,
        "n_hidden_encoder": 512,
        "n_recurrent_layers": 2,
        "b_xfeed": False,
        "b_project": True
    }

Set shortcuts

In [15]:
n_embed_char = hyper["n_embed_char"]              # 32
n_hidden_encoder = hyper["n_hidden_encoder"]      # 256
n_hidden_decoder = hyper["n_hidden_decoder"]      # 512
n_recurrent_layers = hyper["n_recurrent_layers"]  # 2
b_xfeed = hyper["b_xfeed"]                        # False
b_project = hyper["b_project"]                    # True

<img src="https://raw.githubusercontent.com/saltypaul/Seq2Seq-Chatbot/master/pics/Training%20Phase.jpg" />

### Embedding subgraph (pinkish)

A helper to create stacked RNN

In [16]:
from lasagne.layers.base import Layer

def gru_column(input, num_units, hidden, **kwargs):
    kwargs.pop("only_return_final", None)
    assert isinstance(hidden, (list, tuple))

    name = kwargs.pop("name", "default")
    column = [input]
    for i, l_hidden in enumerate(hidden):
        kwargs_ = kwargs.copy()
        if isinstance(l_hidden, Layer):
            kwargs_.pop("learn_init", None)
            kwargs_["hid_init"] = l_hidden

        layer = GRULayer(column[-1], num_units,
                         name=os.path.join(name, "gru_%02d" % i),
                         **kwargs_)
        column.append(layer)
    return column[1:]

Create readouts of the last hidden state

In [17]:
def gru_hidden_readout(column, indices):
    hidden = []
    for layer in column:
        name = os.path.join(layer.name, "slice")
        slice_ = SliceLayer(layer, indices, axis=1, name=name)
        hidden.append(slice_)
    return hidden

### Encoder

Tap into the common embedding layer

In [18]:
l_encoder_mask = InputLayer((None, None), name="encoder/mask")
l_encoder_embed = InputLayer((None, None, n_embed_char), name="encoder/input")

### Sentence representation

Construct layered GRU columns atop the embedding (we can also make parallel fwd / rev layers)

In [19]:
hidden = n_recurrent_layers * [None]
enc_rnn_layers = gru_column(l_encoder_embed, n_hidden_encoder, hidden,
                            mask_input=l_encoder_mask, learn_init=True,
                            backwards=False, name="encoder")

enc_rnn_layers_sliced = gru_hidden_readout(enc_rnn_layers, -1)

### Decoder

Tap into the common embedding layer but with decoder's own input.

In [20]:
l_decoder_mask = InputLayer((None, None), name="decoder/mask")
l_decoder_embed = InputLayer((None, None, n_embed_char), name="decoder/input")

Cross-feed is not currently supported

In [21]:
assert not b_xfeed

Project the hidden state of the encoder

In [22]:
if b_project or (n_hidden_encoder != n_hidden_decoder):
    dec_hid_inputs = []
    for layer in enc_rnn_layers_sliced:
        l_project = DenseLayer(layer, n_hidden_decoder, nonlinearity=None,
                               name=os.path.join(layer.name, "proj"))
        dec_hid_inputs.append(l_project)
else:
    dec_hid_inputs = enc_layers_sliced

Construct layers of GRU-s which recieve the final state of the encoder's network.

In [23]:
dec_rnn_layers = gru_column(l_decoder_embed, n_hidden_decoder, dec_hid_inputs,
                            mask_input=l_decoder_mask, learn_init=True,
                            backwards=False, name="decoder")

dec_rnn_layers_sliced = gru_hidden_readout(dec_rnn_layers, -1)

Read the output of the top layer of the RNN and re-embed into the character space

In [24]:
l_decoder_reembedder = DenseLayer(dec_rnn_layers[-1], num_units=len(vocab),
                                  nonlinearity=None, num_leading_axes=2,
                                  name="decoder/project")

Construct the softmax layer

In [25]:
l_bc = BroadcastLayer(l_decoder_reembedder, broadcasted_axes=(0, 1), name="decoder/bc")
l_softmax = NonlinearityLayer(l_bc, nonlinearity=lasagne.nonlinearities.softmax, name="decoder/softmax")
l_decoder_output = UnbroadcastLayer(l_softmax, l_bc, name="decoder/ub")

### Embedding layer 

The common embedding layer

In [26]:
l_input_char = InputLayer((None, None), name="char/input")
l_embed_char = EmbeddingLayer(l_input_char, len(vocab), n_embed_char, name="char/embed")

### Resume training

In [27]:
lasagne.layers.set_all_param_values(l_embed_char,
                                    weights["l_embed_char"])
lasagne.layers.set_all_param_values(l_decoder_reembedder,
                                    weights["l_decoder_reembedder"])

### Loss

Collect the encoder input

In [28]:
v_encoder_input = tt.imatrix(name="encoder/input")
v_encoder_embed = l_embed_char.get_output_for(v_encoder_input)

inputs = {l_encoder_embed: v_encoder_embed,
          l_encoder_mask: tt.ge(v_encoder_input, 0)}

And the decoder's inputs

In [29]:
v_decoder_input = tt.imatrix(name="decoder/input")
v_decoder_embed = lasagne.layers.get_output(l_embed_char, v_decoder_input)

inputs.update({l_decoder_embed: v_decoder_embed,
               l_decoder_mask: tt.ge(v_decoder_input, 0)})

Get the output of the decoder

In [30]:
v_decoder_output, v_decoder_mask = lasagne.layers.get_output(
    [l_decoder_output, l_decoder_mask], inputs, deterministic=False)

Slice the output to match the forward character-level language model

In [31]:
v_predicted = v_decoder_output[:, :-1].reshape(
    (-1, v_decoder_output.shape[-1]))

v_targets = v_decoder_input[:, 1:].reshape((-1,))

v_mask = v_decoder_mask[:, 1:].reshape((-1,))

Construct the cross-entropy loss

In [32]:
loss_ij = lasagne.objectives.categorical_crossentropy(v_predicted, v_targets)
loss = (loss_ij * v_mask).sum()
loss /= v_mask.sum()

It can be benefitial to project the character embeddings onto the unit sphere.
However we are going to project the embeddings into the unit $l^2$ ball instead.

In [33]:
W_emb = l_embed_char.get_params()[0]

op_project_embedding = theano.function([], updates={
    W_emb: W_emb / tt.maximum(W_emb.norm(2, axis=-1, keepdims=True), 1.0)
})

On the other hand we can always add $l^2$ regularization term.

In [34]:
if False:
    C_embed = 1e-1
    loss += C_embed * W_emb.norm(2, axis=-1).mean()

Collect all trainable parameters 

In [35]:
trainable = []
trainable.extend(lasagne.layers.get_all_params(l_embed_char, trainable=True))
trainable.extend(lasagne.layers.get_all_params(l_decoder_output, trainable=True))

Get the updates

In [36]:
learning_rate = theano.shared(floatX(1e-3), name="eta")

# updates = lasagne.updates.sgd(loss, trainable, learning_rate)
updates = lasagne.updates.adam(loss, trainable, learning_rate)

Create the ops

In [37]:
op_train = theano.function([v_decoder_input, v_encoder_input], loss,
                           updates=updates, givens={},
                           mode=theano.Mode(optimizer="fast_run"))

In [38]:
op_test_loss = theano.function([v_decoder_input, v_encoder_input], loss,
                               mode=theano.Mode(optimizer="fast_run"))

In [39]:
op_predict = theano.function([v_decoder_input, v_encoder_input],
                             v_decoder_output,
                             mode=theano.Mode(optimizer="fast_run"))

In [40]:
# inputs_ = {l_encoder_embed: v_encoder_embed,
#            l_encoder_mask: tt.ge(v_encoder_input, 0),
#            l_decoder_embed: v_decoder_embed,
#            l_decoder_mask: tt.ge(v_decoder_input, 0)}

v_decoder_logits = lasagne.layers.get_output(l_decoder_reembedder, inputs, deterministic=True)

op_predict_logits = theano.function([v_decoder_input, v_encoder_input],
                                     v_decoder_logits, mode=theano.Mode(optimizer="fast_run"))

In [41]:
trainable

[char/embed.W,
 encoder/gru_00.W_in_to_updategate,
 encoder/gru_00.W_hid_to_updategate,
 encoder/gru_00.b_updategate,
 encoder/gru_00.W_in_to_resetgate,
 encoder/gru_00.W_hid_to_resetgate,
 encoder/gru_00.b_resetgate,
 encoder/gru_00.W_in_to_hidden_update,
 encoder/gru_00.W_hid_to_hidden_update,
 encoder/gru_00.b_hidden_update,
 encoder/gru_00.hid_init,
 encoder/gru_00/slice/proj.W,
 encoder/gru_00/slice/proj.b,
 decoder/gru_00.W_in_to_updategate,
 decoder/gru_00.W_hid_to_updategate,
 decoder/gru_00.b_updategate,
 decoder/gru_00.W_in_to_resetgate,
 decoder/gru_00.W_hid_to_resetgate,
 decoder/gru_00.b_resetgate,
 decoder/gru_00.W_in_to_hidden_update,
 decoder/gru_00.W_hid_to_hidden_update,
 decoder/gru_00.b_hidden_update,
 encoder/gru_01.W_in_to_updategate,
 encoder/gru_01.W_hid_to_updategate,
 encoder/gru_01.b_updategate,
 encoder/gru_01.W_in_to_resetgate,
 encoder/gru_01.W_hid_to_resetgate,
 encoder/gru_01.b_resetgate,
 encoder/gru_01.W_in_to_hidden_update,
 encoder/gru_01.W_hid_to_hi

### The generator

<img src="https://raw.githubusercontent.com/saltypaul/Seq2Seq-Chatbot/master/pics/Eval.jpg" />

A handy slicer (copied and modified)

In [42]:
def slice_(x, i, n):
    s = x[..., slice(i, i + n)]
    return s if n > 1 else tt.addbroadcast(s, -1)

Define one step of the scan function

In [43]:
# Generator's one step update function
def generator_step_sm(x_tm1, h_tm1, m_tm1, tau, eps):
    """One step of the generative decoder version."""
    # x_tm1 is `BxT` one-hot, h_tm1 is `batch x ...`
    # m_tm1 is `batch`, tau, eps are scalars

    # collect the inputs
    inputs = {l_decoder_embed: x_tm1.dimshuffle(0, "x", 1),
              l_decoder_mask: m_tm1.dimshuffle(0, "x")}

    # Connect the prev variables to the the hidden and stack state feeds
    j = 0
    for layer in dec_rnn_layers:
        inputs[layer.hid_init] = slice_(h_tm1, j, layer.num_units)
        j += layer.num_units

    # Get the outputs
    outputs = [l_decoder_reembedder] + dec_rnn_layers_sliced

    # propagate through the decoder column
    logit_t, *h_t_list = lasagne.layers.get_output(outputs, inputs,
                                                   deterministic=True)

    # Pack the hidden states
    h_t = tt.concatenate(h_t_list, axis=-1)
    
    # Generate the next symbol: logit_t is `Bx1xV`
    logit_t = logit_t[:, 0]
    prob_t = tt.nnet.softmax(logit_t)

    # Gumbel-softmax sampling: Gumbel (e^{-e^{-x}}) distributed random noise
    gumbel = -tt.log(-tt.log(theano_random_state.uniform(size=logit_t.shape) + eps) + eps)
#     logit_t = theano.ifelse.ifelse(tt.gt(tau, 0), gumbel + logit_t, logit_t)
#     inv_temp = theano.ifelse.ifelse(tt.gt(tau, 0), 1.0 / tau, tt.constant(1.0))
    logit_t = tt.switch(tt.gt(tau, 0), gumbel + logit_t, logit_t)
    inv_temp = tt.switch(tt.gt(tau, 0), 1.0 / tau, tt.constant(1.0))

    # Get the softmax: x_t is `BxV`
    x_t = tt.nnet.softmax(logit_t * inv_temp)

    # Get the best symbol
    c_t = tt.cast(tt.argmax(x_t, axis=-1), "int8")

    # Get the estimated probability of the picked symbol.
    p_t = prob_t[tt.arange(c_t.shape[0]), c_t]

    # Compute the mask and inhibit the propagation on a stop symbol.
    # Recurrent layers return the previous state if m_tm1 is Fasle
    m_t = m_tm1 & tt.gt(c_t, vocab.index("\x03"))
    c_t = tt.switch(m_t, c_t, vocab.index("\x03"))

    # There is no need to freeze the states as they will be frozen by
    # the RNN passthrough according to the mask `m_t`.

    # Embed the current character.
    x_t = tt.dot(x_t, l_embed_char.W)

    return x_t, h_t, m_t, p_t, c_t

Create scalar inputs to the scan loop. Also initialize the random stream.

In [44]:
theano_random_state = tt.shared_randomstreams.RandomStreams(seed=42)

eps = tt.fscalar("generator/epsilon")
n_steps = tt.iscalar("generator/n_steps")
tau = tt.fscalar("generator/gumbel/tau")

Let's compile an autofeeding generator with softmax.

In [45]:
v_gen_input = tt.imatrix(name="generator/Q_input")

v_gen_embed = lasagne.layers.get_output(l_embed_char, v_gen_input)

Helper functions to freeze the GRULayer's hidden input's initialization,
if one is a parameter.

In [46]:
def GRULayer_freeze(layer, input):
    assert isinstance(layer, GRULayer)
    if isinstance(layer.hid_init, Layer):
        return layer

    assert not (layer.hid_init_incoming_index > 0)
    assert isinstance(layer.hid_init, theano.compile.SharedVariable)

    # Broadcast the fixed /learnt hidden init over the batch dimension
    hid_init = tt.dot(tt.ones((input.shape[0], 1)), layer.hid_init)

    # Create a fake Input Layer, which receives it as input
    layer._old_hid_init = layer.hid_init
    layer.hid_init = InputLayer((None, None), input_var=hid_init,
                                name=os.path.join(layer.name,
                                                  "hid_init_fix"))
    
    # Cache former values
    layer._old_input_layers = layer.input_layers
    layer._old_input_shapes = layer.input_shapes
    layer._old_hid_init_incoming_index = layer.hid_init_incoming_index
    
    # Emulate hidden layer input (is in GRULayer/MergeLayer.__init__())
    layer.input_layers.append(layer.hid_init)
    layer.input_shapes.append(layer.hid_init.output_shape)
    layer.hid_init_incoming_index = len(layer.input_layers) - 1

    layer._layer_frozen = True
    return layer

Freeze the hidden inputs of the decoder layers, which do not tap into the encoder.

In [47]:
for layer in dec_rnn_layers:
    GRULayer_freeze(layer, v_gen_input)

Readout the last state from the encoder.

In [48]:
inputs = {l_encoder_embed: v_gen_embed,
          l_encoder_mask: tt.ge(v_gen_input, 0)}
outputs = [l.hid_init for l in dec_rnn_layers]

dec_hid_inits = lasagne.layers.get_output(outputs, inputs,
                                          deterministic=True)

Prepare the initial values.

In [49]:
h_0 = tt.concatenate(dec_hid_inits, axis=-1)

x_0 = tt.fill(tt.zeros((v_gen_input.shape[0],), dtype="int32"),
              vocab.index("\x02"))
x_0 = lasagne.layers.get_output(l_embed_char, x_0)

m_0 = tt.ones((v_gen_input.shape[0],), 'bool')

Add a scan op and compile

In [50]:
result, updates = theano.scan(generator_step_sm, sequences=None, n_steps=n_steps,
                              outputs_info=[x_0, h_0, m_0, None, None],
                              strict=False, return_list=True,
                              non_sequences=[tau, eps], go_backwards=False,
                              name="generator/scan")
x_t, h_t, m_t, p_t, c_t = [r.swapaxes(0, 1) for r in result]

compile_mode = theano.Mode(optimizer="fast_run", linker="cvm")
op_generate = theano.function([v_gen_input, n_steps, tau],
                              [c_t, h_t, m_t, p_t],
                              updates=updates, givens={eps: floatX(1e-20)},
                              mode=compile_mode)

This function undoes the frrezing by the previous one

In [51]:
def GRULayer_unfreeze(layer):
    assert isinstance(layer, GRULayer)
    freeze_attr = ["_layer_frozen",
                   "_old_input_layers", "_old_input_shapes",
                   "_old_hid_init_incoming_index", "_old_hid_init"]
    if not all(hasattr(layer, a) for a  in freeze_attr):
        return layer

    assert layer._layer_frozen
    assert isinstance(layer.hid_init, Layer)
    assert layer.hid_init.name.endswith("/hid_init_fix")

    # Thawe the frozen hidden input
    layer.hid_init = layer._old_hid_init
    layer.input_layers = layer._old_input_layers
    layer.input_shapes = layer._old_input_shapes
    layer.hid_init_incoming_index = layer._old_hid_init_incoming_index
    
    for attr in freeze_attr:
        delattr(layer, attr)

    return layer

Unfreeze the decoder's layer, so that those which do not tap in to the encoder,
may continue to use / learn their own `hid_init` state.

In [52]:
for layer in dec_rnn_layers:
    GRULayer_unfreeze(layer)

A generator procedure, which automatically select the best replies (lowest perplexity).

In [53]:
def generate(questions, n_steps, n_samples=10, tau=0, seed=None):
    results = []
    for question in questions:
        # Replicate the query
        question = np.repeat(question[np.newaxis], n_samples, axis=0)
        if seed is not None:
            theano_random_state.seed(seed)

        x_t, h_t, m_t, p_t = op_generate(question, n_steps, tau)

        # may produce NaN, but they are shifted in the back by arsort
        perplexity, n_chars = (- np.log2(p_t) * m_t).sum(axis=-1), m_t.sum(axis=-1)
        perplexity /= n_chars

        result = []
        for i in perplexity.argsort():
            reply = "".join(map(vocab.__getitem__, x_t[i, :n_chars[i]]))
            result.append((reply, perplexity[i]))
        results.append(result)
    return results

<br/>

### Train the Bot

In [54]:
def sample_qa():
    sample = qa_pairs[random_state.choice(len(qa_pairs), 3)]
    enc, dec = retrieve_sentences(sample)

    replies = generate(as_matrix(enc), 140, tau=2**-5, n_samples=20)
    for e, d, r in zip(enc, dec, replies):
        tqdm.tqdm.write("|%-40.40s | %-30.30s | %-30.30s|" % (e[1:-1], d[1:-1], r[0][0]))

Set the batch size and the number of epochs.

In [None]:
batch_size, n_epochs = 160, 15
epoch, loss_val_hist = 0, []

model_path = os.path.join("pickles", time.strftime("%Y%m%d-%H%M%S"))
if not os.path.exists(model_path):
    os.makedirs(model_path)

filename_fmt_ = os.path.join(model_path, "simple_mdl_epoch-%02d%s.pkl")

Now let's train the shit!

In [None]:
progress_fmt_, interrupted = "%(loss).3f", False
n_batches = (len(qa_pairs) + batch_size - 1) // batch_size
while epoch < n_epochs:
    try:
        with tqdm.tqdm(total=n_batches) as progress_:
            for i, (be, bd) in enumerate(generate_batch(batch_size, max_len=512)):
                if (i % 100) == 0:
                    sample_qa()

                with DelayedKeyboardInterrupt():
                    loss_val_hist.append(op_train(bd, be))
                    # op_project_embedding()

                progress_.postfix = progress_fmt_ % {
                    "loss": np.mean(loss_val_hist[-100:]),
                }
                progress_.update(1)
            # end for

        # end with
        epoch += 1
    except KeyboardInterrupt:
        interrupted = True

    finally:
        # retrieve the parameters
        weights = {
            "l_embed_char": lasagne.layers.get_all_param_values(l_embed_char),
            "l_decoder_reembedder": lasagne.layers.get_all_param_values(l_decoder_reembedder)
        }
        filename = filename_fmt_ % (epoch, "_interrupted" if interrupted else "")
        with open(filename, "wb") as fin:
            pickle.dump(("GRULayer", hyper, vocab, weights), fin)

        if interrupted:
            break

<hr/>

### Trunk

In [None]:
plt.plot(np.log(np.log(loss_val_hist)))

In [None]:
generate(as_matrix(["\x02" + "Hi, there." + "\x03"]), 75, tau=1e-5, n_samples=200, seed=42)

In [None]:
def softmax(x_ij, axis=-1):
    x_ij = np.exp(x_ij - x_ij.max(axis=axis, keepdims=True))
    return x_ij / x_ij.sum(axis=axis, keepdims=True)

In [None]:
logits = op_predict_logits(be, bd)

rasta = np.random.RandomState(42)
gumbel = -np.log(-np.log(rasta.uniform(size=logits.shape)))


In [None]:
p = gumbel + logits
test = p[0, -10]
ttau = np.logspace(-20, 20, num=101, base=2)[:, np.newaxis]
test = softmax(test[np.newaxis] / ttau)

plt.plot(np.log2(ttau), test);

In [None]:
plt.hist(gumbel.flat)

In [None]:
for k, v, in zip(lasagne.layers.get_all_params(l_decoder_reembedder),
                 weights["l_decoder_reembedder"]):
    print (k, v.shape)

In [None]:
# list(zip(trainable[1:], weights["l_decoder_reembedder"]))

In [None]:
for be, bd in generate_batch(32, max_len=512):
    break

In [None]:
op_train(bd, be)

In [None]:
ass = h_0.eval({v_gen_input: be})

In [None]:
plt.plot(ass[0])

<hr/>

The older generator

In [None]:
# Generator's one step update function
def generator_step(x_tm1, h_tm1, m_tm1, tau, eps):
    """One step of the generative decoder version."""
    # x_tm1 is `batch` int8, h_tm1 is `batch x ...`
    # m_tm1 is `batch`, tau, eps are scalars

    # embed the previous character. x_t is `batch x embed`
    x_t = l_embed_char.get_output_for(x_tm1, deterministic=True)

    # collect the inputs
    inputs = {l_decoder_embed: x_t.dimshuffle(0, "x", 1),
              l_decoder_mask: m_tm1.dimshuffle(0, "x")}

    # Connect the prev variables to the the hidden and stack state feeds
    j = 0
    for layer in dec_rnn_layers:
        inputs[layer.hid_init] = slice_(h_tm1, j, layer.num_units)
        j += layer.num_units

    # Get the outputs
    outputs = [l_decoder_reembedder] + dec_rnn_layers_sliced

    # propagate through the decoder column
    logit_t, *h_t_list = lasagne.layers.get_output(outputs, inputs,
                                                   deterministic=True)

    logit_t = logit_t[:, 0]
    prob_t = tt.nnet.softmax(logit_t)

    # Gumbel-softmax sampling: Gumbel (e^{-e^{-x}}) distributed random noise
    gumbel = -tt.log(-tt.log(theano_random_state.uniform(size=logit_t.shape) + eps) + eps)
#     logit_t = theano.ifelse.ifelse(tt.gt(tau, 0), gumbel + logit_t, logit_t)
#     inv_temp = theano.ifelse.ifelse(tt.gt(tau, 0), 1.0 / tau, tt.constant(1.0))
    logit_t = tt.switch(tt.gt(tau, 0), gumbel + logit_t, logit_t)
    inv_temp = tt.switch(tt.gt(tau, 0), 1.0 / tau, tt.constant(1.0))

    # Pick one element
    x_t = tt.cast(tt.argmax(tt.nnet.softmax(logit_t * inv_temp), axis=-1), x_tm1.dtype)

    # Pack the hidden states
    h_t = tt.concatenate(h_t_list, axis=-1)

    # Compute the mask and inhibit the propagation on a stop symbol.
    # Recurrent layers return the previous state if m_tm1 is Fasle
    m_t = m_tm1 & tt.gt(x_t, vocab.index("\x03"))
    x_t = tt.switch(m_t, x_t, vocab.index("\x03"))

    # There is no need to freeze the states as they will be frozen by
    # the RNN passthrough according to the mask `m_t`.

    # Get the estimated probability of the picked symbol.
    p_t = prob_t[tt.arange(x_t.shape[0]), x_t]
    return x_t, h_t, m_t, p_t

h_0 = tt.concatenate(dec_hid_inits, axis=-1)

x_0 = tt.fill(tt.zeros((v_gen_input.shape[0],), dtype="int32"),
              vocab.index("\x02"))

m_0 = tt.ones((v_gen_input.shape[0],), 'bool')

result, updates = theano.scan(generator_step, sequences=None, n_steps=n_steps,
                              outputs_info=[x_0, h_0, m_0, None],
                              strict=False, return_list=True,
                              non_sequences=[tau, eps], go_backwards=False,
                              name="generator/scan")
x_t, h_t, m_t, p_t = [r.swapaxes(0, 1) for r in result]

compile_mode = theano.Mode(optimizer="fast_run", linker="cvm")
op_generate = theano.function([v_gen_input, n_steps, tau],
                              [x_t, h_t, m_t, p_t],
                              updates=updates, givens={eps: floatX(1e-20)},
                              mode=compile_mode)

def generate(questions, n_steps, n_samples=10, tau=0, seed=None):
    results = []
    for question in questions:
        # Replicate the query
        question = np.repeat(question[np.newaxis], n_samples, axis=0)
        if seed is not None:
            theano_random_state.seed(seed)
        x_t, h_t, m_t, p_t = op_generate(question, n_steps, tau)

        # may produce NaN, but they are shifted in the back by arsort
        perplexity, n_chars = (- np.log2(p_t) * m_t).sum(axis=-1), m_t.sum(axis=-1)
        perplexity /= n_chars

        result = []
        for i in perplexity.argsort():
            reply = "".join(map(vocab.__getitem__, x_t[i, :n_chars[i]]))
            result.append((reply, perplexity[i]))
        results.append(result)
    return results