## Comparing the Performance of Gated Recurrent Units With and Without Attention
...by learning to predict the text of H.G. Wells' The Time Machine

### Dataset

In [1]:
raw_text = open('../datasets/timemachine.txt').read()

### Let's tokenize the dataset


In [2]:
import re
from transformers import AutoTokenizer

# Remove punctuation and convert to lowercase
text = raw_text.lower()
text = re.sub(r'[^a-z]+', ' ', text)

# Tokenize into characters
char_tokenized_text = list(text)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Most common tokens
from collections import Counter
# Flatten the list of tokens
# Count the tokens
token_counts = Counter(char_tokenized_text)
print(f"Total tokens: {len(token_counts)}")
# Most common tokens
most_common_tokens = token_counts.most_common(10)
for token, count in most_common_tokens:
    print(f"{token} {count}")

Total tokens: 27
  35850
e 19667
t 15040
a 12700
i 11254
o 11082
n 10943
s 9242
r 8833
h 8786


In [4]:
MIN_FREQ = 1 ## Minimum token frequency to include in the vocab

### Generate the Vocabulary
tok_to_idx = {}
idx_to_tok =  list(sorted(set(['<unk>'] + [tok for tok,count in token_counts.items() if count >= MIN_FREQ])))
tok_to_idx = {tok: idx for idx, tok in enumerate(idx_to_tok)}


# Vocabulary size
vocab_size = len(idx_to_tok)
print(f"Vocabulary size: {vocab_size}")



Vocabulary size: 28


In [5]:
import jax.numpy as jnp

def sliding_window(seq, window_size, overlap):
    for i in range(0, len(seq) - window_size, window_size - overlap):
        yield [ tok_to_idx[tok] if tok in tok_to_idx else tok_to_idx['<unk>'] for tok in seq[i:i + window_size] ]


## Generate dataset 
def generate_data(text, seq_length, overlap):
    num_tokens = len(char_tokenized_text)
    return jnp.array([ seq for seq in sliding_window(char_tokenized_text, seq_length, overlap)])

### Train-Validate-Test Split

In [6]:
SEQUENCE_LENGTH = 32
WINDOW_OVERLAP = 10 

data = generate_data(text, SEQUENCE_LENGTH, WINDOW_OVERLAP)
X = data[:,:-1]
Y = data[:,1:]

print(X.shape, Y.shape)

train_idxs = list(range(int(0.8*X.shape[0])))
valid_idxs = list(range(int(0.8*X.shape[0]), int(0.9*X.shape[0])))
test_idxs = list(range(int(0.9*X.shape[0]), X.shape[0]))
X_train = X[train_idxs,:]
Y_train = Y[train_idxs,:]
print(X_train.shape, Y_train.shape)
X_valid = X[valid_idxs,:]
Y_valid = Y[valid_idxs,:]
print(X_valid.shape, Y_valid.shape)
X_test = X[test_idxs,:]
Y_test = Y[test_idxs,:]
print(X_test.shape, Y_test.shape)

assert(X_train.shape[0] + X_valid.shape[0] + X_test.shape[0] == X.shape[0])
assert(all(X_valid[i,:].shape[0] == 31 for i in range(X_valid.shape[0])))


(8714, 31) (8714, 31)
(6971, 31) (6971, 31)
(871, 31) (871, 31)
(872, 31) (872, 31)


### Train a Basic GRU Model 

In [7]:
import logging

from blinker import signal

import jax

from xjax.models import gru
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HIDDEN_SIZE=128
EPOCHS=200
BATCH_SIZE=128
LEARNING_RATE = 2*10**(-3)
MAX_GRAD = 1

params, model = gru.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(model)
def collect_events(_, *, epoch, train_loss, valid_loss, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid loss={valid_loss:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
baseline_params = gru.train(model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



INFO:__main__:Completed epoch=0, train loss=11.2605, valid loss=0.1415, elapsed=8.7842s


### Train a Basic GRU + Basic Dot-Product Self-Attention Model 

In [21]:
import logging

from blinker import signal

import jax
from xjax.models import gru_attn
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HIDDEN_SIZE=128
EPOCHS=200
BATCH_SIZE=128
LEARNING_RATE = 2*10**(-3)
MAX_GRAD = 1

params, candidate_model = gru_attn.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(candidate_model)
def collect_events(_, *, epoch, train_loss, valid_loss, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid loss={valid_loss:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
candidate_params = gru_attn.train(candidate_model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



INFO:__main__:Completed epoch=0, train loss=10.4208, valid loss=0.1434, elapsed=17.8352s
INFO:__main__:Completed epoch=20, train loss=5.8326, valid loss=0.1068, elapsed=60.2208s
INFO:__main__:Completed epoch=40, train loss=5.0064, valid loss=0.0928, elapsed=102.4101s
INFO:__main__:Completed epoch=60, train loss=4.5790, valid loss=0.0868, elapsed=147.8452s
INFO:__main__:Completed epoch=80, train loss=4.3150, valid loss=0.0839, elapsed=191.5985s
INFO:__main__:Completed epoch=100, train loss=4.1582, valid loss=0.0822, elapsed=234.5921s
INFO:__main__:Completed epoch=120, train loss=4.0348, valid loss=0.0814, elapsed=278.7876s
INFO:__main__:Completed epoch=140, train loss=3.9394, valid loss=0.0809, elapsed=321.4549s
INFO:__main__:Completed epoch=160, train loss=3.8632, valid loss=0.0807, elapsed=364.4056s


In [24]:
import xjax

# I generate sequences from a prefix
prefix_str = "the time"
prefix = [ tok_to_idx[i] for i in list(prefix_str)]
generated = []
for i in range(20):
    rng, sub_rng = jax.random.split(rng)
    y = gru_attn.generate(rng=sub_rng, prefix=prefix, params=candidate_params, 
                 hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, max_len=30) 
    generated.append("".join([idx_to_tok[i] for i in y]))
 


Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1, 1, 28)
Y_pred.shape (1,

In [25]:
for g in generated:
    print(g)

the time machipl soft destless the gre
the time for when i sphins upon the wr
the time then sometion and up and very
the time transed the same lations of t
the time new other vidial face was pla
the time traveller was but i was not t
the time world about me you cannot yet
the time breaded no evide becous struc
the time machine i remember startressi
the time man was good the repread disc
the time caught in a black that had he
the time yet east match and a stared i
the time had from the appear little op
the time traveller against high i had 
the time traveller of a little very wa
the time traveller man showed it was a
the time dargurating little which thei
the time large no brived question and 
the time hands and i was first be brig
the time hapitation from by an inferin


## Calculate Perplexity on the test set