## Comparing the Performance of Gated Recurrent Units With and Without Attention
...by learning to predict the text of H.G. Wells' The Time Machine

### Dataset 
The Time Machine is a classic Science Fiction novel. 

In [1]:
import logging

import jax

logging.basicConfig(level=logging.INFO)

rng = jax.random.key(42)


raw_text = open('../datasets/timemachine.txt').read()
words = raw_text.split(" ")
len(words)

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)


51713

At around 51000 words, it's not a large dataset, but it is should be sufficient for traning small models with few parameters like a single layer GRU.

### Let's tokenize the text
We will be using a character based tokenization scheme and ignore all punctuation and special characters. This should allow the model to learn character sequences without the additional task of learning sentences boundaries, punctuation, formatting etc. 

In [2]:
import re
from collections import Counter

# Remove punctuation and convert to lowercase
text = raw_text.lower()
text = re.sub(r'[^a-z]+', ' ', text) #[:333]

# Tokenize into characters
char_tokenized_text = list(text)
token_counts = Counter(char_tokenized_text)



In [3]:
### Generate the Vocabulary
tok_to_idx = {}
idx_to_tok =  list(sorted(set([tok for tok,count in token_counts.items()])))
tok_to_idx = {tok: idx for idx, tok in enumerate(idx_to_tok)}

vocab_size = len(idx_to_tok)
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 27


### Now, let's generate the dataset. 
Just like in the Dinosaur Character RNN experiment, the dataset will consist of character sequences. However, this time we will use fixed-length sequences. This helps parallelize and thus speed up training.
Additionally, we will generate overalapping sequences to augment the size of the dataset.
To prevent overfitting, we will use a validation set during training and a test set to produce the final metrics.

In [4]:

SEQUENCE_LENGTH = 64
WINDOW_OVERLAP = 10

import jax.numpy as jnp

def sliding_window(seq, window_size, overlap):
    for i in range(0, len(seq) - window_size, window_size - overlap):
        yield [ tok_to_idx[tok] if tok in tok_to_idx else tok_to_idx['<unk>'] for tok in seq[i:i + window_size] ]


## Generate dataset 
def generate_data(text, seq_length, overlap):
    num_tokens = len(char_tokenized_text)
    return jnp.array([ seq for seq in sliding_window(char_tokenized_text, seq_length, overlap)])

In [5]:

data = generate_data(text, SEQUENCE_LENGTH, WINDOW_OVERLAP)
X = data[:,:-1]
Y = data[:,1:]


train_idxs = list(range(int(0.8*X.shape[0])))
valid_idxs = list(range(int(0.8*X.shape[0]), int(0.9*X.shape[0])))
test_idxs = list(range(int(0.9*X.shape[0]), X.shape[0]))
X_train = X[train_idxs,:]
Y_train = Y[train_idxs,:]
X_valid = X[valid_idxs,:]
Y_valid = Y[valid_idxs,:]
X_test = X[test_idxs,:]
Y_test = Y[test_idxs,:]

assert(X_train.shape[0] + X_valid.shape[0] + X_test.shape[0] == X.shape[0])
assert(all(X_valid[i,:].shape[0] == SEQUENCE_LENGTH-1 for i in range(X_valid.shape[0])))

print("X shape:", X.shape, "Y shape:", Y.shape)
print("Train Dataaset", X_train.shape, Y_train.shape)
print("Valid Dataset", X_valid.shape, Y_valid.shape)
print("Test Dataset", X_test.shape, Y_test.shape)

X shape: (3550, 63) Y shape: (3550, 63)
Train Dataaset (2840, 63) (2840, 63)
Valid Dataset (355, 63) (355, 63)
Test Dataset (355, 63) (355, 63)


### Now, let's train a single-layer GRU as our baseline

In [6]:
import logging
from blinker import signal
from xjax.signals import train_epoch_started, train_epoch_completed

In [52]:
from xjax.models import gru

logger = logging.getLogger("GRU")

# Set up hyperparameters
gru_hparams = {
    "hidden_size":256
}
gru_train_hparams = {
    "num_epochs":80,
    "batch_size":256,
    "learning_rate":1E-2,
    "max_grad":1,
}

params, baseline_model = gru.gru(rng, vocab_size=vocab_size, **gru_hparams)

@train_epoch_completed.connect_via(baseline_model)
def collect_events(_, *, epoch, train_loss, valid_perplexity, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid perplexity={valid_perplexity:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
baseline_params = gru.train(baseline_model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            **gru_train_hparams)


INFO:GRU:Completed epoch=0, train loss=3.5608, valid perplexity=7.9802, elapsed=25.6233s
INFO:GRU:Completed epoch=20, train loss=1.9772, valid perplexity=3.8990, elapsed=126.3586s
INFO:GRU:Completed epoch=40, train loss=1.7083, valid perplexity=3.3421, elapsed=227.2008s
INFO:GRU:Completed epoch=60, train loss=1.5570, valid perplexity=3.1111, elapsed=328.7526s
INFO:GRU:Completed epoch=79, train loss=1.4594, valid perplexity=2.9869, elapsed=425.0190s


### Let's add a basic dot-product attention layer as the candidate model
We're going to use exactly the same hyperparameters as the baseline model to ensure an apples-to-apples comparison. If our theory is correct, then adding the attention layer should provide a performance boost without having to do any additional tuning

In [25]:
from xjax.models import gru_attn

logger = logging.getLogger("GRU+Attn")

# Set up hyperparameters
gru_attn_hparams = {
    "hidden_size":256
}
gru_attn_train_hparams = {
    "num_epochs":50,
    "batch_size":256,
    "learning_rate":1E-2,
    "max_grad":1
}

params, candidate_model = gru_attn.gru(rng, vocab_size=vocab_size, **gru_attn_hparams)

@train_epoch_completed.connect_via(candidate_model)
def collect_events(_, *, epoch, train_loss, valid_perplexity, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid_perplexity={valid_perplexity:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model with a dot-product attention layer on the data 
candidate_params = gru_attn.train(candidate_model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            **gru_attn_train_hparams)



INFO:GRU+Attn:Completed epoch=0, train loss=3.1477, valid_perplexity=7.4411, elapsed=66.6060s
INFO:GRU+Attn:Completed epoch=20, train loss=1.7620, valid_perplexity=3.4132, elapsed=173.1571s
INFO:GRU+Attn:Completed epoch=40, train loss=1.4238, valid_perplexity=2.9117, elapsed=281.0839s
INFO:GRU+Attn:Completed epoch=49, train loss=1.3379, valid_perplexity=2.8647, elapsed=330.9421s


### Let's generate some sample sentences from each of our models

In [49]:
import xjax

# I generate sequences from a prefix
prefix_str = "the time "
prefix = [ tok_to_idx[i] for i in list(prefix_str)]
baseline_results = []
candidate_results = []
for i in range(20):
    rng, sub_rng = jax.random.split(rng)
    y_baseline = gru.generate(rng=sub_rng, prefix=prefix, params=baseline_params, 
                 hidden_size=gru_hparams["hidden_size"], vocab_size=vocab_size, max_len=30) 
    baseline_results.append("".join([idx_to_tok[i] for i in y_baseline]))
    y_candidate = gru_attn.generate(rng=sub_rng, prefix=prefix, params=candidate_params, 
                 hidden_size=gru_attn_hparams["hidden_size"], vocab_size=vocab_size, max_len=30) 
    candidate_results.append("".join([idx_to_tok[i] for i in y_candidate]))
 


In [50]:
for b,c in zip(baseline_results, candidate_results):
    print(b,"      ", c)

the time e so face dazy the malaus wand        the time e long so dain there was littl
the time  ons the apposmed a jourta y a        the time  once it anough down sounaly a
the time imal vanage any darkness our i        the time e were began wy mattle balked 
the time ud any mach i rearn in rilavio        the time e to be vigiefte palloss i som
the time kot tertar ve was have eages g        the time  oth even his went and wells g
the time ed mach purle may a vast and a        the time et machine but to the sky but 
the time kless a and the over tient and        the time  leas that i sat who silptence
the time  that a just realize adjusing         the time  too faces three fill diffinin
the time ikn was a blauch of path by a         the time e of the little all footh year
the time naveridg straw purstable even         the time et this not after and took the
the time ed i don their munuca this was        the time ew in my eecimilurulion i pass
the time ed puzy little with do what i     

## Calculate Perplexity on the Test Set

In [51]:
baseline_perplexity = xjax.models.gru.perplexity(baseline_model, baseline_params, vocab_size, X_test, Y_test)
candidate_perplexity = xjax.models.gru_attn.perplexity(candidate_model, candidate_params, vocab_size, X_test, Y_test)
print(f"Baseline Perplexity: {baseline_perplexity:0.4f}\nCandidate Perplexity: {candidate_perplexity:0.4f}")

Baseline Perplexity: 4.1332
Candidate Perplexity: 3.9717
