A rough copy of https://jaketae.github.io/study/pytorch-rnn/

In [1]:
import random
import numpy as np
import jax.numpy as jnp
import optax
from flax import nnx
from data import fetch_names, name_to_array

In [2]:
rngs = nnx.Rngs(0)

In [3]:
train_set, test_set, char_to_idx, lang_to_label = fetch_names(jnp.array, jnp.array)
label_to_lang = {label: lang for lang, label in lang_to_label.items()}
num_letters = len(char_to_idx)
num_langs = len(lang_to_label)

In [4]:
class RNN(nnx.Module):
    def __init__(self, *, input_size, hidden_size, output_size, rngs):
        self.hidden_size = hidden_size
        self.linear1 = nnx.Linear(input_size + hidden_size, hidden_size, rngs=rngs)
        self.linear2 = nnx.Linear(input_size + hidden_size, output_size, rngs=rngs)

    def __call__(self, x, h):
        x = jnp.concat([x, h], axis=1)
        hidden = nnx.tanh(self.linear1(x))
        output = nnx.log_softmax(self.linear2(x))
        return output, hidden

    def init_hidden(self):
        return jnp.zeros((1, self.hidden_size))

In [5]:
learning_rate = 0.0005
momentum = 0.9

model = RNN(input_size=num_letters, hidden_size=256, output_size=num_langs, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average("loss"))

In [6]:
def loss_fn(model, name, label, hidden_state):
    for char in name:
        logits, hidden_state = model(char, hidden_state)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=label).mean()
    return loss, logits


@nnx.jit
def train_step(model, optimizer, metrics, name, label, hidden_state):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, name, label, hidden_state)
    metrics.update(loss=loss, logits=logits, labels=label)
    optimizer.update(grads)


@nnx.jit
def eval_step(model, metrics, name, label, hidden_state):
    loss, logits = loss_fn(model, name, label, hidden_state)
    metrics.update(loss=loss, logits=logits, labels=label)

In [7]:
eval_every = len(train_set) // 5
metrics_history = {"train_loss": [], "train_accuracy": [], "test_loss": [], "test_accuracy": []}

for _ in range(2):
    random.shuffle(train_set)
    for step, (name, label) in enumerate(train_set):
        hidden_state = model.init_hidden()
        train_step(model, optimizer, metrics, name, label, hidden_state)
        if step > 0 and step % eval_every == 0:
            for metric, value in metrics.compute().items():
                metrics_history[f"train_{metric}"].append(value)
            metrics.reset()
            for name, label in test_set:
                hidden_state = model.init_hidden()
                eval_step(model, metrics, name, label, hidden_state)
            for metric, value in metrics.compute().items():
                metrics_history[f"test_{metric}"].append(value)
            metrics.reset()
            print(
                f"[train] step: {step}, "
                f"loss: {metrics_history['train_loss'][-1]:.4f}, "
                f"accuracy: {metrics_history['train_accuracy'][-1] * 100:.2f}"
            )
            print(
                f"[test] step: {step}, "
                f"loss: {metrics_history['test_loss'][-1]:.4f}, "
                f"accuracy: {metrics_history['test_accuracy'][-1] * 100:.2f}"
            )

[train] step: 3612, loss: 1.5445, accuracy: 56.02
[test] step: 3612, loss: 1.3976, accuracy: 60.39
[train] step: 7224, loss: 1.2843, accuracy: 62.40
[test] step: 7224, loss: 1.2691, accuracy: 63.98
[train] step: 10836, loss: 1.1957, accuracy: 64.53
[test] step: 10836, loss: 1.1865, accuracy: 65.92
[train] step: 14448, loss: 1.0944, accuracy: 67.72
[test] step: 14448, loss: 1.1349, accuracy: 64.87
[train] step: 18060, loss: 1.0181, accuracy: 69.91
[test] step: 18060, loss: 1.0439, accuracy: 68.61
[train] step: 3612, loss: 0.9944, accuracy: 70.37
[test] step: 3612, loss: 1.0900, accuracy: 67.66
[train] step: 7224, loss: 0.9631, accuracy: 70.63
[test] step: 7224, loss: 1.0529, accuracy: 68.61
[train] step: 10836, loss: 0.9498, accuracy: 71.15
[test] step: 10836, loss: 1.0109, accuracy: 68.26
[train] step: 14448, loss: 0.9251, accuracy: 71.40
[test] step: 14448, loss: 0.9313, accuracy: 72.20
[train] step: 18060, loss: 0.9249, accuracy: 71.71
[test] step: 18060, loss: 0.9578, accuracy: 72.7

In [8]:
def predict(name: str) -> str:
    tensor_name = jnp.array(name_to_array(name, char_to_idx))
    hidden_state = model.init_hidden()
    for char in tensor_name:
        logits, hidden_state = model(char, hidden_state)
    pred = np.asarray(logits).argmax()
    return label_to_lang[pred]

In [9]:
for name in ("Jake", "Qin", "Fernando", "Demirkan"):
    print(f"{name}: {predict(name)}")

Jake: Russian
Qin: Chinese
Fernando: Italian
Demirkan: Russian
