In [78]:
import mxnet as mx
import numpy as np
from mxnet import nd, gluon, autograd
from mxnet.gluon import nn, Block

def one_hots(numerical_list, vocab_size=vocab_size):
    result = nd.zeros((len(numerical_list), vocab_size), ctx=ctx)
    for i, idx in enumerate(numerical_list):
        result[i, idx] = 1.0
    return result

def textify(embedding):
    result = ""
    indices = nd.argmax(embedding, axis=1).asnumpy()
    for idx in indices:
        result += character_list[int(idx)]
    return result

def load_time_machine(seq_length=64):
    # loading dataset
    path = "../../data/timemachine.txt"
    with open(path) as f:
        time_machine = f.read()
    time_machine = time_machine[:-38083] #hardcoded to remove crap
    character_dict, vocab_size = get_char_dict(time_machine)
    
    # -1 here so we have enough characters for labels later
    num_samples = (len(time_numerical) - 1) // seq_length
    dataset = one_hots(time_numerical[:seq_length*num_samples]).reshape((num_samples, seq_length, vocab_size))
    num_batches = len(dataset) // batch_size
    train_data = dataset[:num_batches*batch_size].reshape((batch_size, num_batches, seq_length, vocab_size))
    
    # swap batch_size and seq_length axis to make later access easier
    train_data = nd.swapaxes(train_data, 0, 1)
    train_data = nd.swapaxes(train_data, 1, 2)
    print('Shape of data set: ', train_data.shape)
    
    labels = one_hots(time_numerical[1:seq_length*num_samples+1])
    train_label = labels.reshape((batch_size, num_batches, seq_length, vocab_size))
    train_label = nd.swapaxes(train_label, 0, 1)
    train_label = nd.swapaxes(train_label, 1, 2)
    print('Shape of label set: ', train_label.shape)
    
    return train_data, train_label
    

def get_char_dict(data):
    # get character dictionary
    character_list = list(set(data))
    vocab_size = len(character_list)
    # get the character dictionary
    character_dict = {}
    for e, char in enumerate(character_list):
        character_dict[char] = e
    return character_dict, vocab_size

def decoder_RNN_helper(): 
    num_inputs = vocab_size
    num_outputs = vocab_size
    Wxh = nd.random_normal(shape=(num_inputs,100), ctx=ctx) * .01
    Whh = nd.random_normal(shape=(100,100), ctx=ctx) * .01
    bh = nd.random_normal(shape=100, ctx=ctx) * .01
    Why = nd.random_normal(shape=(100,num_outputs), ctx=ctx) * .01
    by = nd.random_normal(shape=num_outputs, ctx=ctx) * .01
    params = [Wxh, Whh, bh, Why, by]

    for param in params:
        param.attach_grad()
    return params

def decoder_RNN(steps, encoder_outputs, state, temperature=1.0):
    Wxh, Whh, bh, Why, by = decoder_RNN_helper()
    outputs = []
    h = state
    for i in range(steps):
        attention(h, encoder_outputs)
        h_linear = nd.dot(X, Wxh) + nd.dot(h, Whh) + bh
        h = nd.tanh(h_linear)
        yhat_linear = nd.dot(h, Why) + by
        yhat = softmax(yhat_linear, temperature=temperature) 
        outputs.append(yhat)
    return (outputs, h)

#attention mechanism
# input at t: every output of A, hidden state of B(t-1)    
def attention(decoder_hidden_t, encoder_output):
    #decoder_hidden_t: HBx1; encoder_output: HBxT
    net_input = nd.dot(decoder_hidden_t, encoder_output)
    return nd.sum(nd.dot(softmax(net_input), encoder_output),axis=1)

def softmax(y_linear):
    exp = nd.exp(y_linear-nd.max(y_linear))
    partition = nd.nansum(exp, axis=0, exclude=True).reshape((-1, 1))
    return exp / partition

#normal RNN from gluon
class encoder_RNN(Block):
    def __init__(self, mode, vocab_size, tie_weights=False, **kwargs):
        super(encoderRNN, self).__init__(**kwargs)
        with self.name_scope():
            self.rnn = rnn.RNN(100, 1, activation='relu', input_size=vocab_size)
            self.decoder = nn.Dense(vocab_size, in_units = 100)
    def forward(self, inputs, hidden):
        output, hidden = self.rnn(inputs, hidden) # make sure hidden is 100x1
        decoded = self.decoder(output.reshape((-1, 100)))
        return decoded, hidden

In [72]:
# context usage
ctx = mx.cpu()
data, labels = load_time_machine()