## Comparing a Char-RNN with a GRU for Generating Dinosaur Names

### Let's start by reading the dinosaur names dataset and extracting all the names

In [1]:

text = open("../datasets/dinos.txt", "r").read()
lines = text.split("\n")
lines = [ l.strip().lower() for l in lines ]

In [2]:
print(f"Looks like we have {len(lines)} dinosaur names in the dataset")

Looks like we have 1542 dinosaur names in the dataset


### Let's look at a few dinosaur names

In [3]:
lines[:10]


['aachenosaurus',
 'aardonyx',
 'abdallahsaurus',
 'abelisaurus',
 'abrictosaurus',
 'abrosaurus',
 'abydosaurus',
 'acanthopholis',
 'achelousaurus',
 'acheroraptor']

### How long are the dinosaur names? 

In [4]:
lengths = [ len(name) for name in lines ]
max_len = 0
longest_name = None
min_len = float('inf')
shortest_name = None

for line in lines:
    l = len(line)
    if l > max_len:
        max_len = l
        longest_name = line
        print(longest_name)
    if l < min_len:
        min_len = l 
        shortest_name = line

print(f"The longest name is {longest_name} which is {max_len} characters.")
print(f"The shortest name is {shortest_name} which is {min_len} characters.")


aachenosaurus
abdallahsaurus
acrocanthosaurus
archaeodontosaurus
carcharodontosaurus
micropachycephalosaurus
The longest name is micropachycephalosaurus which is 23 characters.
The shortest name is mei which is 3 characters.


### Since we'll be training a character RNN, let's tokenize the data by character

In [5]:
# First, generate a vocabulary
counts = {}
for line in lines:
    chars = list(line)
    for c in chars:
        if c not in counts:
            counts[c] = 1
        else:
            counts[c] += 1

vocab = sorted(list(counts.keys()))

VOCAB_SIZE = len(vocab) + 1

# Then create token lookup indices
tok_to_idx = {}
idx_to_tok = {}

for i,c in enumerate(vocab):
    tok_to_idx[c] = i
    idx_to_tok[i] = c


# Add a stop character to the token index lookup
tok_to_idx["@"] = VOCAB_SIZE - 1
idx_to_tok[VOCAB_SIZE - 1] = "@"


### Now let's generate the Dataset. 
Since this is a sequence to sequence model, it has to learn to predict the next token in the sequence. Each training example therefore consists of an input sequence of tokens. The label is just the same sequence shifted to the right by a single token.

In [6]:
import jax.numpy as jnp

#Generate the dataset 
X = []
Y = []
for line in lines:
    tokens = [ tok_to_idx[c] for c in line ] 
    tokens += [VOCAB_SIZE]*(max_len - len(tokens))
    X.append(jnp.array([VOCAB_SIZE] + tokens))
    Y.append(jnp.array(tokens + [VOCAB_SIZE]))



### Next, let's train the Char RNN model. 


In [7]:
import logging

import jax
import xjax
from xjax.signals import train_epoch_completed

# Module logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("Char RNN")

# Define hyperparameters
HIDDEN_SIZE = 64
VOCAB_SIZE = len(tok_to_idx)
learning_rate = 1E-2
max_grad = 10
epochs = 300

rng = jax.random.key(42)

# I define a character-rnn model
params, model = xjax.models.char_rnn.char_rnn(rng, 27, HIDDEN_SIZE)

# I log events
@train_epoch_completed.connect_via(model)
def collect_events(_, *, epoch, loss, elapsed, **__):
    logger.info(f"epoch={epoch}, loss={loss:0.4f}, elapsed={elapsed:0.4f}")

# I train a character RNN model on the data 
trained_params = xjax.models.char_rnn.train(model, rng=rng, params=params, 
                                            X_train=X, Y_train=Y, 
                                            vocab_size=VOCAB_SIZE, 
                                            epochs=epochs, 
                                            learning_rate=learning_rate,
                                            max_grad=max_grad)



INFO:Char RNN:epoch=0, loss=5101.4146, elapsed=1.9660


KeyboardInterrupt: 

### Finally, let's look at what the model learned
Ideally, the model should have learned to generate plausible-sounding dinosaur names. 

In [None]:
prefix_str = "bac"
prefix = [tok_to_idx[c] for c in prefix_str]
generated = []
for i in range(25):
    rng, sub_rng = jax.random.split(rng)
    y = xjax.models.char_rnn.generate(rng=sub_rng, prefix= prefix, params=trained_params, hidden_size=HIDDEN_SIZE, vocab_size=VOCAB_SIZE) 
    generated.append(y)

In [None]:
for g in generated:
    print("".join([idx_to_tok[i] for i in g[:-1]]))

baceoromossus
baceorodor
bacron
bacudochus
baceoron
baceurobrg
bacuoran
baceog
bacurrurus
bacroh
bachoreruchus
bacurnos
baceon
bacuon
bacroh
bacuhe
bacueris
bacyuropops
baceorynn
bacuho
bacurnerus
bacuhe
bachorerolosaurus
bacuinhia
bacudoraurus


Some of these look plausible. It seems to have figured out that dinosaur names tend to end with an 'rus'

### Now, let's train a single-layer GRU as our baseline

In [8]:
import logging

from blinker import signal

import jax

from xjax.models import gru
from xjax.signals import train_epoch_started, train_epoch_completed

rng = jax.random.key(42)

logger = logging.getLogger("GRU")

# Set up hyperparameters
HIDDEN_SIZE=128
EPOCHS=200
BATCH_SIZE=128
LEARNING_RATE = 5*10**(-3)
MAX_GRAD = 1

params, baseline_model = gru.gru(rng, vocab_size=VOCAB_SIZE, hidden_size=HIDDEN_SIZE)

@train_epoch_completed.connect_via(baseline_model)
def collect_events(_, *, epoch, train_loss, valid_perplexity, elapsed, **__):
    logger.info(f"Completed epoch={epoch}, train loss={train_loss:0.4f}, valid perplexity={valid_perplexity:0.4f}, elapsed={elapsed:0.4f}s")


# I train a GRU model on the data 
baseline_params = gru.train(baseline_model, rng=rng, params=params, 
                                            X_train=X, Y_train=Y, 
                                            vocab_size=VOCAB_SIZE, 
                                            batch_size=BATCH_SIZE,
                                            num_epochs=EPOCHS, 
                                            learning_rate=LEARNING_RATE,
                                            max_grad=MAX_GRAD)



TypeError: Only integer scalar arrays can be converted to a scalar index.