## Comparing the Performance of Gated Recurrent Units With and Without Attention
...by learning to predict the text of H.G. Wells' The Time Machine

### Dataset

In [23]:
raw_text = open('../datasets/timemachine.txt').read()

### Let's tokenize the dataset


In [24]:
import re
from transformers import AutoTokenizer

# Remove punctuation and convert to lowercase
text = raw_text.lower()
text = re.sub(r'[^a-z]+', ' ', text)

# Tokenize into characters
char_tokenized_text = list(text)

In [25]:
",".join(char_tokenized_text[200:300])

's,e, ,i,t, ,u,n,d,e,r, ,t,h,e, ,t,e,r,m,s, ,o,f, ,t,h,e, ,p,r,o,j,e,c,t, ,g,u,t,e,n,b,e,r,g, ,l,i,c,e,n,s,e, ,i,n,c,l,u,d,e,d, ,w,i,t,h, ,t,h,i,s, ,e,b,o,o,k, ,o,r, ,o,n,l,i,n,e, ,a,t, ,w,w,w, ,g,u,t'

In [26]:
# Most common tokens
from collections import Counter
# Flatten the list of tokens
# Count the tokens
token_counts = Counter(char_tokenized_text)
print(f"Total tokens: {len(token_counts)}")
# Most common tokens
most_common_tokens = token_counts.most_common(10)
for token, count in most_common_tokens:
    print(f"{token} {count}")

Total tokens: 27
  35850
e 19667
t 15040
a 12700
i 11254
o 11082
n 10943
s 9242
r 8833
h 8786


In [27]:
MIN_FREQ = 1 ## Minimum token frequency to include in the vocab

### Generate the Vocabulary
tok_to_idx = {}
idx_to_tok =  list(sorted(set(['<unk>'] + [tok for tok,count in token_counts.items() if count >= MIN_FREQ])))
tok_to_idx = {tok: idx for idx, tok in enumerate(idx_to_tok)}


# Vocabulary size
vocab_size = len(idx_to_tok)
print(f"Vocabulary size: {vocab_size}")



Vocabulary size: 28


In [28]:
import jax.numpy as jnp

def sliding_window(seq, window_size, overlap):
    for i in range(0, len(seq) - window_size, window_size - overlap):
        yield [ tok_to_idx[tok] if tok in tok_to_idx else tok_to_idx['<unk>'] for tok in seq[i:i + window_size] ]


## Generate dataset 
def generate_data(text, seq_length, overlap):
    num_tokens = len(char_tokenized_text)
    return jnp.array([ seq for seq in sliding_window(char_tokenized_text, seq_length, overlap)])

### Train-Validate-Test Split

In [29]:
SEQUENCE_LENGTH = 32
WINDOW_OVERLAP = 10 

data = generate_data(text, SEQUENCE_LENGTH, WINDOW_OVERLAP)
X = data[:,:-1]
Y = data[:,1:]

print(X.shape, Y.shape)

train_idxs = list(range(int(0.8*X.shape[0])))
valid_idxs = list(range(int(0.8*X.shape[0]), int(0.9*X.shape[0])))
test_idxs = list(range(int(0.9*X.shape[0]), X.shape[0]))
X_train = X[train_idxs,:]
Y_train = Y[train_idxs,:]
print(X_train.shape, Y_train.shape)
X_valid = X[valid_idxs,:]
Y_valid = Y[valid_idxs,:]
print(X_valid.shape, Y_valid.shape)
X_test = X[test_idxs,:]
Y_test = Y[test_idxs,:]
print(X_test.shape, Y_test.shape)

assert(X_train.shape[0] + X_valid.shape[0] + X_test.shape[0] == X.shape[0])
assert(all(X_valid[i,:].shape[0] == 31 for i in range(X_valid.shape[0])))


(8714, 31) (8714, 31)
(6971, 31) (6971, 31)
(871, 31) (871, 31)
(872, 31) (872, 31)


In [30]:
import logging

from blinker import signal

import jax

from xjax.models import gru
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HIDDEN_SIZE=128
EPOCHS=100
BATCH_SIZE=32
LEARNING_RATE = 10**(-3)
MAX_GRAD = 1

params, model = gru.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(model)
def collect_events(_, *, epoch, train_loss, valid_loss, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid loss={valid_loss:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
trained_params = gru.train(model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



INFO:__main__:Completed epoch=0, train loss=37.3494, valid loss=0.1404, elapsed=13.0957s
INFO:__main__:Completed epoch=20, train loss=21.2666, valid loss=0.0975, elapsed=85.4223s
INFO:__main__:Completed epoch=40, train loss=18.7488, valid loss=0.0879, elapsed=158.1976s


In [None]:
import xjax

# I generate sequences from a prefix
prefix_str = "i saw"
prefix = [ tok_to_idx[i] for i in list(prefix_str)]
generated = []
for i in range(20):
    rng, sub_rng = jax.random.split(rng)
    y = gru.generate(rng=sub_rng, prefix=prefix, params=trained_params, 
                 hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, max_len=20) 
    generated.append("".join([idx_to_tok[i] for i in y]))
 


In [None]:
for g in generated:
    print(g)

i saw and down the perhap
i saw no solute you the p
i saw the time as in the 
i saw was is a potward of
i saw the thing to my min
i saw the running but was
i saw must was the bit my
i saw the interest of fin
i saw the actual and diff
i saw was that sowe went 
i saw at ground and white
i saw that the down i saw
i saw the time stread has
i saw you know the than h
i saw the sky little pazz
i saw a perfect no potent
i saw increass the adreac
i saw was fire for rested
i saw was other human sen
i saw in the stunly the d


## Calculate Perplexity on the test set