## Comparing the Performance of Gated Recurrent Units With and Without Attention
...by learning to predict the text of H.G. Wells' The Time Machine

### Dataset

In [1]:
raw_text = open('../datasets/timemachine.txt').read()

### Let's tokenize the dataset


In [2]:
import re
from transformers import AutoTokenizer

# Remove punctuation and convert to lowercase
text = raw_text.lower()
text = re.sub(r'[^a-z]+', ' ', text)

# Tokenize into characters
char_tokenized_text = list(text)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Most common tokens
from collections import Counter
# Flatten the list of tokens
# Count the tokens
token_counts = Counter(char_tokenized_text)
print(f"Total tokens: {len(token_counts)}")
# Most common tokens
most_common_tokens = token_counts.most_common(10)
for token, count in most_common_tokens:
    print(f"{token} {count}")

Total tokens: 27
  35850
e 19667
t 15040
a 12700
i 11254
o 11082
n 10943
s 9242
r 8833
h 8786


In [4]:
MIN_FREQ = 1 ## Minimum token frequency to include in the vocab

### Generate the Vocabulary
tok_to_idx = {}
idx_to_tok =  list(sorted(set(['<unk>'] + [tok for tok,count in token_counts.items() if count >= MIN_FREQ])))
tok_to_idx = {tok: idx for idx, tok in enumerate(idx_to_tok)}


# Vocabulary size
vocab_size = len(idx_to_tok)
print(f"Vocabulary size: {vocab_size}")



Vocabulary size: 28


In [5]:
import jax.numpy as jnp

def sliding_window(seq, window_size, overlap):
    for i in range(0, len(seq) - window_size, window_size - overlap):
        yield [ tok_to_idx[tok] if tok in tok_to_idx else tok_to_idx['<unk>'] for tok in seq[i:i + window_size] ]


## Generate dataset 
def generate_data(text, seq_length, overlap):
    num_tokens = len(char_tokenized_text)
    return jnp.array([ seq for seq in sliding_window(char_tokenized_text, seq_length, overlap)])

### Train-Validate-Test Split

In [40]:
SEQUENCE_LENGTH = 64
WINDOW_OVERLAP = 10 

data = generate_data(text, SEQUENCE_LENGTH, WINDOW_OVERLAP)
X = data[:,:-1]
Y = data[:,1:]

print(X.shape, Y.shape)

train_idxs = list(range(int(0.8*X.shape[0])))
valid_idxs = list(range(int(0.8*X.shape[0]), int(0.9*X.shape[0])))
test_idxs = list(range(int(0.9*X.shape[0]), X.shape[0]))
X_train = X[train_idxs,:]
Y_train = Y[train_idxs,:]
print(X_train.shape, Y_train.shape)
X_valid = X[valid_idxs,:]
Y_valid = Y[valid_idxs,:]
print(X_valid.shape, Y_valid.shape)
X_test = X[test_idxs,:]
Y_test = Y[test_idxs,:]
print(X_test.shape, Y_test.shape)

assert(X_train.shape[0] + X_valid.shape[0] + X_test.shape[0] == X.shape[0])
assert(all(X_valid[i,:].shape[0] == SEQUENCE_LENGTH-1 for i in range(X_valid.shape[0])))


(3550, 63) (3550, 63)
(2840, 63) (2840, 63)
(355, 63) (355, 63)
(355, 63) (355, 63)


### Train a Basic GRU Model 

In [41]:
import logging

from blinker import signal

import jax

from xjax.models import gru
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HIDDEN_SIZE=128
EPOCHS=200
BATCH_SIZE=128
LEARNING_RATE = 2*10**(-3)
MAX_GRAD = 1

params, baseline_model = gru.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(baseline_model)
def collect_events(_, *, epoch, train_loss, valid_perplexity, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid perplexity={valid_perplexity:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
baseline_params = gru.train(baseline_model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



INFO:__main__:Completed epoch=0, train loss=6.4597, valid perplexity=1.1092, elapsed=28.2568s
INFO:__main__:Completed epoch=20, train loss=2.9397, valid perplexity=1.0963, elapsed=66.1678s
INFO:__main__:Completed epoch=40, train loss=2.5822, valid perplexity=1.0838, elapsed=103.3859s
INFO:__main__:Completed epoch=60, train loss=2.3959, valid perplexity=1.0777, elapsed=141.1152s
INFO:__main__:Completed epoch=80, train loss=2.2776, valid perplexity=1.0739, elapsed=179.3270s
INFO:__main__:Completed epoch=100, train loss=2.1850, valid perplexity=1.0707, elapsed=218.1402s
INFO:__main__:Completed epoch=120, train loss=2.1196, valid perplexity=1.0687, elapsed=256.2913s
INFO:__main__:Completed epoch=140, train loss=2.0657, valid perplexity=1.0672, elapsed=294.6254s
INFO:__main__:Completed epoch=160, train loss=2.0219, valid perplexity=1.0659, elapsed=334.0070s
INFO:__main__:Completed epoch=180, train loss=1.9789, valid perplexity=1.0649, elapsed=372.3579s


### Train a Basic GRU + Basic Dot-Product Self-Attention Model 

In [42]:
import logging

from blinker import signal

import jax
from xjax.models import gru_attn
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HIDDEN_SIZE=128
EPOCHS=200
BATCH_SIZE=128
LEARNING_RATE = 2*10**(-3)
MAX_GRAD = 1

params, candidate_model = gru_attn.gru(rng, vocab_size=vocab_size, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(candidate_model)
def collect_events(_, *, epoch, train_loss, valid_perplexity, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid_perplexity={valid_perplexity:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
candidate_params = gru_attn.train(candidate_model, rng=rng, params=params, 
                                            X_train=X_train, Y_train=Y_train, 
                                            X_valid=X_valid, Y_valid=Y_valid, 
                                            vocab_size=vocab_size, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



INFO:__main__:Completed epoch=0, train loss=5.5176, valid_perplexity=1.1142, elapsed=97.8642s
INFO:__main__:Completed epoch=20, train loss=2.8486, valid_perplexity=1.0929, elapsed=152.2923s
INFO:__main__:Completed epoch=40, train loss=2.4838, valid_perplexity=1.0805, elapsed=206.8317s
INFO:__main__:Completed epoch=60, train loss=2.2775, valid_perplexity=1.0736, elapsed=260.7903s
INFO:__main__:Completed epoch=80, train loss=2.1182, valid_perplexity=1.0687, elapsed=314.6506s
INFO:__main__:Completed epoch=100, train loss=2.0077, valid_perplexity=1.0653, elapsed=369.5693s
INFO:__main__:Completed epoch=120, train loss=1.9093, valid_perplexity=1.0627, elapsed=423.3096s
INFO:__main__:Completed epoch=140, train loss=1.8364, valid_perplexity=1.0609, elapsed=477.3755s
INFO:__main__:Completed epoch=160, train loss=1.7748, valid_perplexity=1.0595, elapsed=531.0542s
INFO:__main__:Completed epoch=180, train loss=1.7265, valid_perplexity=1.0585, elapsed=584.8786s


In [48]:
import xjax

# I generate sequences from a prefix
prefix_str = "the time"
prefix = [ tok_to_idx[i] for i in list(prefix_str)]
baseline_results = []
candidate_results = []
for i in range(20):
    rng, sub_rng = jax.random.split(rng)
    y_baseline = gru.generate(rng=sub_rng, prefix=prefix, params=baseline_params, 
                 hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, max_len=30) 
    baseline_results.append("".join([idx_to_tok[i] for i in y_baseline]))
    y_candidate = gru_attn.generate(rng=sub_rng, prefix=prefix, params=candidate_params, 
                 hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, max_len=30) 
    candidate_results.append("".join([idx_to_tok[i] for i in y_candidate]))
 


In [49]:
for b,c in zip(baseline_results, candidate_results):
    print(b,"      ", c)

the time was in it had had sale jight         the time was smart in the morlocking t
the time in the fam i wat the time thi        the time in the far in the psity all i
the time a ffike a mizid preet very di        the time and i singer i had survery di
the time and at with in an the dable s        the time at the with mire was earth he
the time the lices and were raallisuse        the time there you and were raw wiruse
the time the lagate the blich of i fou        the time refoll as this alf had no in 
the time sove the way enound of tole i        the time to deem come in my eople leve
the time were all it logk at as in the        the time were all it loggeesay mind th
the time some his all the dapreess rem        the time some of my everys becouss ope
the time to rer peplly fil the expelt         the time traver there of my coneas of 
the time and ruint discume whike it th        the time at the strain the fige with a
the time that that upon to fechey that        the time that some 

## Calculate Perplexity on the test set

In [50]:
baseline_perplexity = xjax.models.gru.perplexity(baseline_model, baseline_params, vocab_size, X_test, Y_test)
candidate_perplexity = xjax.models.gru_attn.perplexity(candidate_model, candidate_params, vocab_size, X_test, Y_test)
print(f"Baseline Perplexity: {baseline_perplexity:0.4f}\nCandidate Perplexity: {candidate_perplexity:0.4f}")

Baseline Perplexity: 1.0761
Candidate Perplexity: 1.0727
