In [1]:
from __future__ import print_function

import itertools
import random

import jax
from jax import jit, grad, vmap
import jax.numpy as np

import numpy as onp

import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from fastax import optimizers
from fastax.layers import Sigmoid, elementwise, serial, LSTM
# from fastax.activations import 
from fastax.losses import create_loss, crossentropy as cse
from fastax.initializers import glorot_uniform
from data import train_data, test_data

In [2]:
vocab = list(set([w for text in train_data.keys() for w in text.split(' ')]))
vocab_size = len(vocab)
print('%d unique words found' % vocab_size)

18 unique words found


In [3]:
# Assign indices to each word.
word_to_idx = { w: i for i, w in enumerate(vocab) }
idx_to_word = { i: w for i, w in enumerate(vocab) }

In [4]:
def createInputs(text):
    '''
      Returns an array of one-hot vectors representing the words in the input text string.
      - text is a string
      - Each one-hot vector has shape (vocab_size, 1)
    '''
    inputs = []
    for w in text.split(' '):
        v = onp.zeros((vocab_size, 1))
        v[word_to_idx[w]] = 1
        inputs.append(np.transpose(v))
    return inputs

In [5]:
# Softmax = elementwise(softmax)
init_random_params, net = serial(LSTM(2, W_init=glorot_uniform), Sigmoid)

loss_cse = create_loss(net, cse)

In [6]:
loss_grad = jit(grad(loss_cse))

def update(i, opt_state, batch):
    params = get_params(opt_state)
    x, y = batch
    grads = loss_grad(params, x, y)
#     print(grads, "\n\n\n")
    return opt_update(i, grads, opt_state)

In [7]:
def processData(data, i, opt_state, net, backprop=True):
    '''
    Returns the RNN's loss and accuracy for the given data.
    - data is a dictionary mapping text to True or False.
    - backprop determines if the backward phase should be run.
    '''
    items = list(data.items())
    random.shuffle(items)

    loss = 0
    num_correct = 0
    cnt = 0

    for x, y in items:
        inputs = createInputs(x)
        target = int(y)
        params = get_params(opt_state)
        
        # Forward
        probs = net(params, inputs)

        # Calculate loss / accuracy
        num_correct += int(np.argmax(probs) == target)
        if backprop:
            opt_state = update(i*len(data) + cnt, opt_state, (inputs, y))
            cnt += 1
    return opt_state, num_correct/len(data)
        

In [8]:
rng = jax.random.PRNGKey(0)

opt_init, opt_update, get_params = optimizers.sgd(0.02)
_, init_params = init_random_params(rng, (18, 1))
opt_state = opt_init(init_params)



In [9]:
for epoch in range(1000):
    opt_state, train_acc = processData(train_data, epoch, opt_state, net)

    if epoch % 100 == 99:
        print('--- Epoch %d' % (epoch + 1))
        print('Train Accuracy: %.3f' % (train_acc))

        _, test_acc = processData(test_data, epoch, opt_state, net, backprop=False)
        print('Test Accuracy: %.3f' % (test_acc))

--- Epoch 100
Train Accuracy: 0.552
Test Accuracy: 0.500
--- Epoch 200
Train Accuracy: 0.431
Test Accuracy: 0.350
--- Epoch 300
Train Accuracy: 0.776
Test Accuracy: 0.700
--- Epoch 400
Train Accuracy: 0.793
Test Accuracy: 0.700
--- Epoch 500
Train Accuracy: 0.828
Test Accuracy: 0.750
--- Epoch 600
Train Accuracy: 0.948
Test Accuracy: 0.950
--- Epoch 700
Train Accuracy: 0.983
Test Accuracy: 0.950
--- Epoch 800
Train Accuracy: 1.000
Test Accuracy: 1.000
--- Epoch 900
Train Accuracy: 1.000
Test Accuracy: 1.000
--- Epoch 1000
Train Accuracy: 1.000
Test Accuracy: 0.900
