## Comparing the Performance of Gated Recurrent Units With and Without Attention
...by learning to predict the text of H.G. Wells' The Time Machine

### Dataset

In [1]:
raw_text = open('../datasets/timemachine.txt').read()

### Let's tokenize the dataset


In [2]:
import re
from transformers import AutoTokenizer

# Remove punctuation and convert to lowercase
text = raw_text.lower()
text = re.sub(r'[^a-z]+', ' ', text)

# Tokenize into characters
char_tokenized_text = list(text)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Most common tokens
from collections import Counter
# Flatten the list of tokens
# Count the tokens
token_counts = Counter(char_tokenized_text)
print(f"Total tokens: {len(token_counts)}")
# Most common tokens
most_common_tokens = token_counts.most_common(10)
for token, count in most_common_tokens:
    print(f"{token} {count}")

Total tokens: 27
  35850
e 19667
t 15040
a 12700
i 11254
o 11082
n 10943
s 9242
r 8833
h 8786


In [4]:
MIN_FREQ = 1 ## Minimum token frequency to include in the vocab

### Generate the Vocabulary
tok_to_idx = {}
idx_to_tok =  list(sorted(set(['<unk>'] + [tok for tok,count in token_counts.items() if count >= MIN_FREQ])))
tok_to_idx = {tok: idx for idx, tok in enumerate(idx_to_tok)}


# Vocabulary size
vocab_size = len(idx_to_tok)
print(f"Vocabulary size: {vocab_size}")



Vocabulary size: 28


In [5]:
import jax.numpy as jnp

def sliding_window(seq, window_size, overlap):
    for i in range(0, len(seq) - window_size, window_size - overlap):
        yield [ tok_to_idx[tok] if tok in tok_to_idx else tok_to_idx['<unk>'] for tok in seq[i:i + window_size] ]


## Generate dataset 
def generate_data(text, seq_length, overlap):
    num_tokens = len(char_tokenized_text)
    return jnp.array([ seq for seq in sliding_window(char_tokenized_text, seq_length, overlap)])

### Train-Validate-Test Split

In [6]:
SEQUENCE_LENGTH = 32
WINDOW_OVERLAP = 10 

data = generate_data(text, SEQUENCE_LENGTH, WINDOW_OVERLAP)
X = data[:,:-1]
Y = data[:,1:]

print(X.shape, Y.shape)

train_idxs = list(range(int(0.8*X.shape[0])))
valid_idxs = list(range(int(0.8*X.shape[0]), int(0.9*X.shape[0])))
test_idxs = list(range(int(0.9*X.shape[0]), X.shape[0]))
X_train = X[train_idxs,:]
Y_train = Y[train_idxs,:]
print(X_train.shape, Y_train.shape)
X_valid = X[valid_idxs,:]
Y_valid = Y[valid_idxs,:]
print(X_valid.shape, Y_valid.shape)
X_test = X[test_idxs,:]
Y_test = Y[test_idxs,:]
print(X_test.shape, Y_test.shape)

assert(X_train.shape[0] + X_valid.shape[0] + X_test.shape[0] == X.shape[0])
assert(all(X_valid[i,:].shape[0] == 31 for i in range(X_valid.shape[0])))


(8714, 31) (8714, 31)
(6971, 31) (6971, 31)
(871, 31) (871, 31)
(872, 31) (872, 31)


### Train a Basic GRU Model 

In [12]:
import logging

from blinker import signal

import jax

from xjax.models import gru
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HIDDEN_SIZE=128
EPOCHS=20 # 200
BATCH_SIZE=128
LEARNING_RATE = 2*10**(-3)
MAX_GRAD = 1

params, model = gru.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(model)
def collect_events(_, *, epoch, train_loss, valid_loss, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid loss={valid_loss:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
baseline_params = gru.train(model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



INFO:__main__:Completed epoch=0, train loss=11.2605, valid loss=0.1415, elapsed=8.4326s


### Train a Basic GRU + Basic Dot-Product Self-Attention Model 

In [21]:
import logging

from blinker import signal

import jax
from xjax.models import gru_attn
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HIDDEN_SIZE=128
EPOCHS=60
BATCH_SIZE=32
LEARNING_RATE = 2*10**(-3)
MAX_GRAD = 1

params, candidate_model = gru_attn.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(candidate_model)
def collect_events(_, *, epoch, train_loss, valid_loss, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid loss={valid_loss:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
candidate_params = gru_attn.train(candidate_model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



INFO:__main__:Completed epoch=0, train loss=33.2713, valid loss=0.1371, elapsed=19.9271s


In [24]:
import xjax

# I generate sequences from a prefix
prefix_str = "the time"
prefix = [ tok_to_idx[i] for i in list(prefix_str)]
generated = []
for i in range(20):
    rng, sub_rng = jax.random.split(rng)
    y = gru_attn.generate(rng=sub_rng, prefix=prefix, params=candidate_params, 
                 hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, max_len=30) 
    generated.append("".join([idx_to_tok[i] for i in y]))
 


H.shape (1, 9, 128)
H.shape (1, 10, 128)
H.shape (1, 11, 128)
H.shape (1, 12, 128)
H.shape (1, 13, 128)
H.shape (1, 14, 128)
H.shape (1, 15, 128)
H.shape (1, 16, 128)
H.shape (1, 17, 128)
H.shape (1, 18, 128)
H.shape (1, 19, 128)
H.shape (1, 20, 128)
H.shape (1, 21, 128)
H.shape (1, 22, 128)
H.shape (1, 23, 128)
H.shape (1, 24, 128)
H.shape (1, 25, 128)
H.shape (1, 26, 128)
H.shape (1, 27, 128)
H.shape (1, 28, 128)
H.shape (1, 29, 128)
H.shape (1, 30, 128)
H.shape (1, 31, 128)
H.shape (1, 32, 128)
H.shape (1, 33, 128)
H.shape (1, 34, 128)
H.shape (1, 35, 128)
H.shape (1, 36, 128)
H.shape (1, 37, 128)
H.shape (1, 38, 128)
H.shape (1, 9, 128)
H.shape (1, 10, 128)
H.shape (1, 11, 128)
H.shape (1, 12, 128)
H.shape (1, 13, 128)
H.shape (1, 14, 128)
H.shape (1, 15, 128)
H.shape (1, 16, 128)
H.shape (1, 17, 128)
H.shape (1, 18, 128)
H.shape (1, 19, 128)
H.shape (1, 20, 128)
H.shape (1, 21, 128)
H.shape (1, 22, 128)
H.shape (1, 23, 128)
H.shape (1, 24, 128)
H.shape (1, 25, 128)
H.shape (1, 26,

In [25]:
for g in generated:
    print(g)

the timeflond presofounnd ceresthrofow
the timehe bewhesthesphaiman med bvewr
the timeathxorlyouanthallesufwaralveme
the timelnclliver mnlleffila indeengho
the timefnelinthe tolle semetily wepll
the timeusf anoumoffonsr whevevengofof
the timethorome whellrengefllld nofoul
the timee rewhexthipevinghangheswhat i
the timeheneakenthorellenghistheereles
the timely mewn ngegenger mererined th
the timebchiteremaland e ormhanghorelw
the timeny ptetendmald fons angeateabe
the timer te fimfftereraffllllene wton
the timeceneghementheveshinhigefoleeth
the timevellvsevelingof len waanenghal
the timefas lly sengngrenghereverafofo
the timeke nghewhendsesmofandlllllened
the timeylfrfognspled wind ofinf warul
the timememessend blgrengrereshillerso
the timehenfily fesmfrighend mlfofwhil


## Calculate Perplexity on the test set