#  Gated recurrent unit (GRU) RNNs

This chapter requires some exposition. The GRU updates are fully implemented and the code appears to work properly.

In [None]:
from __future__ import print_function
import mxnet as mx
from mxnet import nd, autograd
import numpy as np
mx.random.seed(1)
ctx = mx.gpu(0)

## Dataset: "The Time Machine" 

In [None]:
with open("data/nlp/timemachine.txt") as f:
    time_machine = f.read()
time_machine = time_machine[:-38083]

## Numerical representations of characters

In [None]:
character_list = list(set(time_machine))
vocab_size = len(character_list)
character_dict = {}
for e, char in enumerate(character_list):
    character_dict[char] = e
time_numerical = [character_dict[char] for char in time_machine]

## One-hot representations

In [None]:
def one_hots(numerical_list, vocab_size=vocab_size):
    result = []
    for idx in numerical_list:
        array = nd.zeros(shape=(1, vocab_size), ctx=ctx)
        array[0, idx] = 1.0
        result.append(array)
    return result

In [None]:
def textify(vector_list):
    result = ""
    for vector in vector_list:
        vector = vector[0]
        result += character_list[int(nd.argmax(vector, axis=0).asscalar())]
    return result

## Preparing the data for training

In [None]:
seq_length = 64
dataset = [one_hots(time_numerical[i*seq_length:(i+1)*seq_length]) for i in range(int(np.floor((len(time_numerical)-1)/seq_length)))]
batch_size = 32
sequences_per_batch_row = int(np.floor(len(dataset))/batch_size)
data_rows = [dataset[i*sequences_per_batch_row:i*sequences_per_batch_row+sequences_per_batch_row] 
            for i in range(batch_size)]

In [None]:
def stack_the_datasets(datasets):
    full_dataset = []
    # iterate over the sequences
    for s in range(len(datasets[0])):
        sequence = []
        # iterate over the elements of the seqeunce
        for elem in range(len(datasets[0][0])):
            sequence.append(nd.concatenate([ds[s][elem].reshape((1,-1)) for ds in datasets], axis=0))
        full_dataset.append(sequence)
    return(full_dataset)     

In [None]:
training_data = stack_the_datasets(data_rows)

## Preparing our labels

In [None]:
labels = [one_hots(time_numerical[i*seq_length+1:(i+1)*seq_length+1]) for i in range(int(np.floor((len(time_numerical)-1)/seq_length)))]
label_rows = [labels[i*sequences_per_batch_row:i*sequences_per_batch_row+sequences_per_batch_row] for i in range(batch_size)]
training_labels = stack_the_datasets(label_rows)

## Gated recurrent units (GRU) RNNs

[Placeholder for explanation]

$$z_t = \sigma(X_t W_{xz} + h_{t-1} W_{hz} + b_z)$$
$$r_t = \sigma(X_t W_{xr} + h_{t-1} W_{hr} + b_r $$
$$ h_t = z_t \odot h_{t-1} + (1-z_t) \odot \text{tanh}(X_t W_{xh} + r_t \odot h_{t-1})W_{hh} + b_h )$$

<!--
$$i_t = \sigma(X_t W_{xi} + b_i)$$
$$f_t = \sigma(X_t W_{xf} + b_f)$$
$$o_t = \sigma(X_t W_{xo} + b_o)$$
$$c_t = f \odot c_{t-1} + g_t \odot i_t$$
$$h_t = \text{tanh}(c_t) \odot o_t $$ 
-->

## Allocate parameters

In [None]:
num_inputs = 77
num_hidden = 256
num_outputs = 77

########################
#  Weights connecting the inputs to the hidden layer
########################
Wxz = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01
Wxr = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01
Wxh = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01


########################
#  Recurrent weights connecting the hidden layer across time steps
########################
Whz = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01
Whr = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01
Whh = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01


########################
#  Bias vector for hidden layer
########################
bz = nd.random_normal(shape=num_hidden, ctx=ctx) * .01
br = nd.random_normal(shape=num_hidden, ctx=ctx) * .01
bh = nd.random_normal(shape=num_hidden, ctx=ctx) * .01


########################
# Weights to the output nodes
########################
Why = nd.random_normal(shape=(num_hidden,num_inputs), ctx=ctx) * .01
by = nd.random_normal(shape=num_inputs, ctx=ctx) * .01

## Attach the gradients

In [None]:
params = [Wxz, Wxr, Wxh, Whz, Whr, Whh, bz, br, bh]
params += [Why, by]

for param in params:
    param.attach_grad()
for param in params:
    param.attach_grad()

## Softmax Activation

In [None]:
def softmax(y_linear, temperature=1.0):
    lin = (y_linear-nd.max(y_linear)) / temperature
    exp = nd.exp(lin)
    partition =nd.sum(exp, axis=0, exclude=True).reshape((-1,1))
    return exp / partition

## Define the model

In [None]:
def gru_rnn(inputs, h, temperature=1.0):
    outputs = []
    for X in inputs:
        z = nd.sigmoid(nd.dot(X, Wxz) + nd.dot(h, Whz) + bz)
        r = nd.sigmoid(nd.dot(X, Wxr) + nd.dot(h, Whr) + br)
        h = z * h + (1 - z) * nd.tanh(nd.dot(X, Wxh) + nd.dot(r * h, Whh) + bh)
        
        yhat_linear = nd.dot(h, Why) + by
        yhat = softmax(yhat_linear, temperature=temperature) 
        outputs.append(yhat)
    return (outputs, h)

## Cross-entropy loss function

In [None]:
def cross_entropy(yhat, y):
    return - nd.mean(nd.sum(y * nd.log(yhat), axis=0, exclude=True))

## Averaging the loss over the sequence

In [None]:
def average_ce_loss(outputs, labels):
    assert(len(outputs) == len(labels))
    total_loss = nd.array([0.], ctx=ctx)
    for (output, label) in zip(outputs,labels):
        total_loss = total_loss + cross_entropy(output, label)
    return total_loss / len(outputs)

## Optimizer

In [None]:
def SGD(params, lr):    
    for param in params:
        param[:] = param - lr * param.grad

## Generating text by sampling

In [None]:
def sample(prefix, num_chars, temperature=1.0):
    #####################################
    # Initialize the string that we'll return to the supplied prefix
    #####################################
    string = prefix

    #####################################
    # Prepare the prefix as a sequence of one-hots for ingestion by RNN
    #####################################
    prefix_numerical = [character_dict[char] for char in prefix]
    input = one_hots(prefix_numerical)
    
    #####################################
    # Set the initial state of the hidden representation ($h_0$) to the zero vector
    #####################################    
    h = nd.zeros(shape=(1, num_hidden), ctx=ctx)
    c = nd.zeros(shape=(1, num_hidden), ctx=ctx)

    #####################################
    # For num_chars iterations,
    #     1) feed in the current input
    #     2) sample next character from from output distribution
    #     3) add sampled character to the decoded string
    #     4) prepare the sampled character as a one_hot (to be the next input)
    #####################################    
    for i in range(num_chars):
        outputs, h = gru_rnn(input, h, temperature=temperature)
        choice = np.random.choice(77, p=outputs[-1][0].asnumpy())
        string += character_list[choice]
        input = one_hots([choice])
    return string

In [None]:
epochs = 5
moving_loss = 0.

learning_rate = 2.0

# state = nd.zeros(shape=(batch_size, num_hidden), ctx=ctx)
for e in range(epochs):
    ############################
    # Attenuate the learning rate by a factor of 2 every 100 epochs.
    ############################
    if ((e+1) % 100 == 0):
        learning_rate = learning_rate / 2.0
    h = nd.zeros(shape=(batch_size, num_hidden), ctx=ctx)
#     c = nd.zeros(shape=(batch_size, num_hidden), ctx=ctx)
    for i, (data, label) in enumerate(zip(training_data, training_labels)):
        data_one_hot = data
        label_one_hot = label
        with autograd.record():
            outputs, h = gru_rnn(data_one_hot, h)
            loss = average_ce_loss(outputs, label_one_hot)
            loss.backward()
        SGD(params, learning_rate)

        ##########################
        #  Keep a moving average of the losses
        ##########################
        if (i == 0) and (e == 0):
            moving_loss = nd.mean(loss).asscalar()
        else:
            moving_loss = .99 * moving_loss + .01 * nd.mean(loss).asscalar()
      
    print("Epoch %s. Loss: %s" % (e, moving_loss)) 
#     print(sample("The Time Ma", 1024, temperature=.1))
#     print(sample("The Medical Man rose, came to the lamp,", 1024, temperature=.1))
            

## Conclusions

For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)