## Comparing the Performance of Gated Recurrent Units With and Without Attention
...by learning to predict the text of H.G. Wells' The Time Machine

### Dataset 
The Time Machine is a classic Science Fiction novel. 

In [1]:
import logging

logging.basicConfig(level=logging.INFO)

raw_text = open('../datasets/timemachine.txt').read()
words = raw_text.split(" ")
len(words)

51713

At around 51000 words, it's not a large dataset, but it is should be sufficient for traning small models with few parameters like a single layer GRU.

### Let's tokenize the text
We will be using a character based tokenization scheme and ignore all punctuation and special characters. This should allow the model to learn character sequences without the additional task of learning sentences boundaries, punctuation, formatting etc. 

In [2]:
import re
from collections import Counter

# Remove punctuation and convert to lowercase
text = raw_text.lower()
text = re.sub(r'[^a-z]+', ' ', text)[:3330] #[:333]

# Tokenize into characters
char_tokenized_text = list(text)
token_counts = Counter(char_tokenized_text)



In [3]:
text

'project gutenberg s the time machine by h g herbert george wells this ebook is for the use of anyone anywhere at no cost and with almost no restrictions whatsoever you may copy it give it away or re use it under the terms of the project gutenberg license included with this ebook or online at www gutenberg net title the time machine author h g herbert george wells release date october ebook last updated october language english start of this project gutenberg ebook the time machine the time machine by h g wells i the time traveller for so it will be convenient to speak of him was expounding a recondite matter to us his grey eyes shone and twinkled and his usually pale face was flushed and animated the fire burned brightly and the soft radiance of the incandescent lights in the lilies of silver caught the bubbles that flashed and passed in our glasses our chairs being his patents embraced and caressed us rather than submitted to be sat upon and there was that luxurious after dinner atmo

In [4]:
### Generate the Vocabulary
tok_to_idx = {}
idx_to_tok =  list(sorted(set([tok for tok,count in token_counts.items()])))
tok_to_idx = {tok: idx for idx, tok in enumerate(idx_to_tok)}

vocab_size = len(idx_to_tok)
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 26


### Now, let's generate the dataset. 
Just like in the Dinosaur Character RNN experiment, the dataset will consist of character sequences. However, this time we will use fixed-length sequences. This helps parallelize and thus speed up training.
Additionally, we will generate overalapping sequences to augment the size of the dataset.
To prevent overfitting, we will use a validation set during training and a test set to produce the final metrics.

In [5]:

SEQUENCE_LENGTH = 111
WINDOW_OVERLAP = 0

import jax.numpy as jnp

def sliding_window(seq, window_size, overlap):
    for i in range(0, len(seq) - window_size, window_size - overlap):
        yield [ tok_to_idx[tok] if tok in tok_to_idx else tok_to_idx['<unk>'] for tok in seq[i:i + window_size] ]


## Generate dataset 
def generate_data(text, seq_length, overlap):
    num_tokens = len(char_tokenized_text)
    return jnp.array([ seq for seq in sliding_window(char_tokenized_text, seq_length, overlap)])

In [6]:

data = generate_data(text, SEQUENCE_LENGTH, WINDOW_OVERLAP)
X = data[:,:-1]
Y = data[:,1:]


train_idxs = list(range(int(0.8*X.shape[0])))
valid_idxs = list(range(int(0.8*X.shape[0]), int(0.9*X.shape[0])))
test_idxs = list(range(int(0.9*X.shape[0]), X.shape[0]))
X_train = X[train_idxs,:]
Y_train = Y[train_idxs,:]
X_valid = X[valid_idxs,:]
Y_valid = Y[valid_idxs,:]
X_test = X[test_idxs,:]
Y_test = Y[test_idxs,:]

assert(X_train.shape[0] + X_valid.shape[0] + X_test.shape[0] == X.shape[0])
assert(all(X_valid[i,:].shape[0] == SEQUENCE_LENGTH-1 for i in range(X_valid.shape[0])))

print("X shape:", X.shape, "Y shape:", Y.shape)
print("Train Dataaset", X_train.shape, Y_train.shape)
print("Valid Dataset", X_valid.shape, Y_valid.shape)
print("Test Dataset", X_test.shape, Y_test.shape)

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)


X shape: (29, 110) Y shape: (29, 110)
Train Dataaset (23, 110) (23, 110)
Valid Dataset (3, 110) (3, 110)
Test Dataset (3, 110) (3, 110)


### Now, let's train a single-layer GRU as our baseline

In [7]:
import logging

from blinker import signal

import jax

from xjax.models import gru
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logger = logging.getLogger("GRU")

# Set up hyperparameters
HIDDEN_SIZE=128
EPOCHS=10000
BATCH_SIZE=8
LEARNING_RATE = 1E-2
MAX_GRAD = 1

params, baseline_model = gru.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(baseline_model)
def collect_events(_, *, epoch, train_loss, valid_perplexity, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid perplexity={valid_perplexity:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
baseline_params = gru.train(baseline_model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



2024-09-13 08:26:54.282655: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit__unnamed_wrapped_function_] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-09-13 08:28:35.125839: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3m40.849545s

********************************
[Compiling module jit__unnamed_wrapped_function_] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
INFO:GRU:Completed epoch=0, train loss=1.2798, valid perplexity=1.0918, elapsed=240.0473s
INFO:GRU:Completed epoch=20, train loss=1.2061, valid perplexity=1.0845, elapsed=240.2089s
INFO:GRU:Completed epoch=40, train loss=1.0883, valid perplexity=1.0706, elapsed=240.3667s
INFO:GRU:Completed epoch=60, train loss=1.0055, valid per

### Let's add a basic dot-product attention layer as the candidate model
We're going to use exactly the same hyperparameters as the baseline model to ensure an apples-to-apples comparison. If our theory is correct, then adding the attention layer should provide a performance boost without having to do any additional tuning

In [10]:
from xjax.models import gru_attn

logger = logging.getLogger("GRU+Attn")

# Set up hyperparameters
HIDDEN_SIZE=8
EPOCHS=1 #200
BATCH_SIZE=1024
LEARNING_RATE = 5*10**(-3)
MAX_GRAD = 1

params, candidate_model = gru_attn.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(candidate_model)
def collect_events(_, *, epoch, train_loss, valid_perplexity, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid_perplexity={valid_perplexity:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model with a dot-product attention layer on the data 
candidate_params = gru_attn.train(candidate_model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



INFO:GRU+Attn:Completed epoch=0, train loss=0.0000, valid_perplexity=1.6250, elapsed=5.2727s


### Let's generate some sample sentences from each of our models

In [12]:
import xjax

# I generate sequences from a prefix
prefix_str = "project gutenberg"
prefix = [ tok_to_idx[i] for i in list(prefix_str)]
baseline_results = []
candidate_results = []
for i in range(20):
    rng, sub_rng = jax.random.split(rng)
    y_baseline = gru.generate(rng=sub_rng, prefix=prefix, params=baseline_params, 
                 hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, max_len=30) 
    baseline_results.append("".join([idx_to_tok[i] for i in y_baseline]))
    y_candidate = gru_attn.generate(rng=sub_rng, prefix=prefix, params=candidate_params, 
                 hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, max_len=30) 
    candidate_results.append("".join([idx_to_tok[i] for i in y_candidate]))
 


In [13]:
for b,c in zip(baseline_results, candidate_results):
    print(b,"      ", c)

project gutenbergxjyajy jajej jajaj jajej jajaj        project gutenbergkgmhyyhyz iymgmjzgfr xfv vpygx
project gutenbergxjej jajaj jajajyaj jaj jajej         project gutenbergshwfecpcvyzewohgfekkxvcimcxpyc
project gutenbergxjaj jajy jajy aj jajajyaj jaj        project gutenbergfnandyfotgfrpypffpchfwulfuc wo
project gutenbergxjy jaj jajej jaj jejajy aj ja        project gutenbergsfuynpdfn  epfyez dewwyv afuhf
project gutenbergxjyajajajajajajyaj jajej jajya        project gutenbergfufmygwjzgeaupgiaedfnveunesgkk
project gutenbergxjyjajyajy jajaj jajajajej jaj        project gutenbergypmfaayavkcwjapdclrffymgmfw tr
project gutenbergxj jajyajajy aj jajajej jajajy        project gutenbergfp zpzyuwsurffgihmysnpw taidet
project gutenbergxjy jajajy jajajajy jajajyaj j        project gutenbergyywlfutbfkftdwgwkpsyfweloxaxwt
project gutenbergxjejaj jajajej jajyajej jajej         project gutenbergjgxxvy yjgpyvrmycsgfwelxmdzrtn
project gutenbergxjej jajajajy aj jajy aj jajya        project gutenbergy

## Calculate Perplexity on the Test Set

In [None]:
baseline_perplexity = xjax.models.gru.perplexity(baseline_model, baseline_params, vocab_size, X_test, Y_test)
candidate_perplexity = xjax.models.gru_attn.perplexity(candidate_model, candidate_params, vocab_size, X_test, Y_test)
print(f"Baseline Perplexity: {baseline_perplexity:0.4f}\nCandidate Perplexity: {candidate_perplexity:0.4f}")

Baseline Perplexity: 1.0000
Candidate Perplexity: 1.6516
