## A Basic Character RNN Trained to Generate Novel Dinosaur Names

### Let's start by reading the dinosaur names dataset and extracting all the names

In [1]:

text = open("../datasets/dinos.txt", "r").read()
lines = text.split("\n")
lines = [ l.strip().lower() for l in lines ]

In [2]:
print(f"Looks like we have {len(lines)} dinosaur names in the dataset")

Looks like we have 1542 dinosaur names in the dataset


### Let's look at a few dinosaur names

In [3]:
lines[:10]

['aachenosaurus',
 'aardonyx',
 'abdallahsaurus',
 'abelisaurus',
 'abrictosaurus',
 'abrosaurus',
 'abydosaurus',
 'acanthopholis',
 'achelousaurus',
 'acheroraptor']

### How long are the dinosaur names? 

In [4]:
lengths = [ len(name) for name in lines ]
max_len = 0
longest_name = None
min_len = float('inf')
shortest_name = None

for line in lines:
    l = len(line)
    if l > max_len:
        max_len = l
        longest_name = line
        print(longest_name)
    if l < min_len:
        min_len = l 
        shortest_name = line

print(f"The longest name is {longest_name} which is {max_len} characters.")
print(f"The shortest name is {shortest_name} which is {min_len} characters.")


aachenosaurus
abdallahsaurus
acrocanthosaurus
archaeodontosaurus
carcharodontosaurus
micropachycephalosaurus
The longest name is micropachycephalosaurus which is 23 characters.
The shortest name is mei which is 3 characters.


### Since we'll be training a character RNN, let's tokenize the data by character

In [5]:
# First, generate a vocabulary
counts = {}
for line in lines:
    chars = list(line)
    for c in chars:
        if c not in counts:
            counts[c] = 1
        else:
            counts[c] += 1

vocab = sorted(list(counts.keys()))

VOCAB_SIZE = len(vocab) + 1

# Then create token lookup indices
tok_to_idx = {}
idx_to_tok = {}

for i,c in enumerate(vocab):
    tok_to_idx[c] = i
    idx_to_tok[i] = c


# Add a stop character to the token index lookup
tok_to_idx["@"] = VOCAB_SIZE - 1
idx_to_tok[VOCAB_SIZE - 1] = "@"


### Now let's generate the training data. 
Since this is a sequence to sequence model, it has to learn to predict the next token in the sequence. Each training example therefore consists of an input sequence of tokens. The label is just the same sequence shifted to the right by a single token.

In [6]:
import jax.numpy as jnp

#Generate the dataset 
X = []
Y = []
for line in lines:
    tokens = [ tok_to_idx[c] for c in line ]
    tokens += [VOCAB_SIZE]*(max_len - len(tokens))
    X.append(jnp.array([VOCAB_SIZE] + tokens))
    Y.append(jnp.array(tokens + [VOCAB_SIZE]))


### Next, let's train the model. 

This is a simple RNN model. Its key feature is a vector $ h_t $, called the "hidden state", which remembers the context from the previous tokens in the input sequence. Observe that the output $ y_t $ is a function of the previous hidden state $ h_{t-1} $. This is what enables it to learn to predict the next token in a sequence. 

The input $ x_t $ is a one-hot vector that encodes the input token. 


$$ \begin{align} h_t = tanh((W_{hx}x_{t} + W_{hh}h_{t-1}) + b_h) \\
y_t = softmax(W_{yh}h_t + b_y)  \end{align} $$

It can also be interpreted as a two layer neural network. The input to the network can now be seen as a single vector consisting of the current token $ x_t $ and the hidden state $ h_{t-1} $, representing the previous tokens in the sequence already seen by the model. 

$$ \begin{align} h_t = tanh( \begin{bmatrix} W_{hx} W_{hh} \end{bmatrix} \begin{bmatrix} x_t \\ h_{t-1} \end{bmatrix} + b_h) \\ 
y_t = softmax(W_{yh}h_t + b_y)  \end{align} $$

In [7]:
import logging

import jax
import xjax
from xjax.signals import train_epoch_completed

logging.basicConfig(level=logging.INFO)

# Module logger
logger = logging.getLogger("Char RNN")

# Define hyperparameters
HIDDEN_SIZE = 64
VOCAB_SIZE = len(tok_to_idx)
learning_rate = 1E-2
max_grad = 10
epochs = 300

rng = jax.random.key(42)

# I define a character-rnn model
params, model = xjax.models.char_rnn.char_rnn(rng, VOCAB_SIZE, HIDDEN_SIZE)

# I log events
@train_epoch_completed.connect_via(model)
def collect_events(_, *, epoch, loss, elapsed, **__):
    logger.info(f"epoch={epoch}, loss={loss:0.4f}, elapsed={elapsed:0.4f}")

# I train a character RNN model on the data 
trained_params = xjax.models.char_rnn.train(model, rng=rng, params=params, 
                                            X_train=X, Y_train=Y, 
                                            vocab_size=VOCAB_SIZE, 
                                            epochs=epochs, 
                                            learning_rate=learning_rate,
                                            max_grad=max_grad)



INFO:Char RNN:epoch=0, loss=5101.4146, elapsed=1.7701
INFO:Char RNN:epoch=10, loss=2159.6357, elapsed=11.1020
INFO:Char RNN:epoch=20, loss=1812.7123, elapsed=20.4835
INFO:Char RNN:epoch=30, loss=1953.0304, elapsed=29.7877
INFO:Char RNN:epoch=40, loss=1881.3508, elapsed=39.0794
INFO:Char RNN:epoch=50, loss=1878.3793, elapsed=48.3536
INFO:Char RNN:epoch=60, loss=1782.6591, elapsed=57.6883
INFO:Char RNN:epoch=70, loss=1772.7905, elapsed=67.0791
INFO:Char RNN:epoch=80, loss=1757.2007, elapsed=76.3859
INFO:Char RNN:epoch=90, loss=1774.6650, elapsed=86.0161
INFO:Char RNN:epoch=100, loss=1717.6869, elapsed=95.5716
INFO:Char RNN:epoch=110, loss=1778.9965, elapsed=105.1843
INFO:Char RNN:epoch=120, loss=1684.2278, elapsed=115.7771
INFO:Char RNN:epoch=130, loss=1680.6908, elapsed=125.9145
INFO:Char RNN:epoch=140, loss=1701.4093, elapsed=135.4562
INFO:Char RNN:epoch=150, loss=1704.2050, elapsed=147.0271
INFO:Char RNN:epoch=160, loss=1710.7242, elapsed=157.2308
INFO:Char RNN:epoch=170, loss=1703.76

### Finally, let's look at what the model learned
Ideally, the model should have learned to generate plausible-sounding dinosaur names. 

In [8]:
prefix_str = "a"
prefix = [tok_to_idx[c] for c in prefix_str]
generated = []
for i in range(30):
    rng, sub_rng = jax.random.split(rng)
    y = xjax.models.char_rnn.generate(rng=sub_rng, prefix= prefix, params=trained_params, hidden_size=HIDDEN_SIZE, vocab_size=VOCAB_SIZE) 
    generated.append(y)

In [9]:
for g in generated:
    print("".join([idx_to_tok[i] for i in g[:-1]]))

agraton
agrososaurus
agryinta
aucastosaurus
agrasstoiltesaurus
aucanopus
agrasoctomus
austosus
agroryonskosaurus
aucastoreritochengon
agrapikus
aucanatopus
auchttimti
anotosaurus
ageacingorosaurus
aspatarosaurus
aniloctor
anintrantosaurucan
agertoncttoptingonguandyn
agramosaurus
aucrenosaurusaurus
agrandosaurus
agrangosaurus
agrontochuos
auclachrasaurus
abroplos
agrodan
agronongosaurus
aucactodorodox
adrangynotron


Some of these look plausible. It seems to have figured out that dinosaur names tend to end with an 'rus'