In [5]:
import functools
from typing import Any, Dict, Tuple
from clu import metric_writers
from flax import linen as nn
from flax.training import train_state, checkpoints
import jax
import jax.numpy as jnp
import optax
import numpy as np

import models
from input_pipeline import CharacterTable as CTable
from input_pipeline import get_sequence_lengths
from input_pipeline import mask_sequences



In [2]:
Array = Any
PRNGKey = Any

flags = {
    "workdir": ".",
    "learning_rate": 0.003,
    "batch_size": 128,
    "hidden_size": 512,
    "num_training_steps": 2000,
    "decode_frequency": 200,
    "max_len_query_digit": 3,
}


In [3]:
def get_model(ctable: CTable, *, teacher_force: bool = False) -> models.Seq2seq:
    return models.Seq2seq(
        teacher_force=teacher_force,
        hidden_size=flags["hidden_size"],
        eos_id=ctable.eos_id,
        vocab_size=ctable.vocab_size,
    )


In [4]:
def get_initial_params(
    model: models.Seq2seq, rng: PRNGKey, ctable: CTable
) -> Dict[str, Any]:
    rng1, rng2 = jax.random.split(rng)
    variables = model.init(
        {"params": rng1, "lstm": rng2},
        jnp.ones(ctable.encoder_input_shape, jnp.float32),
        jnp.ones(ctable.decoder_input_shape, jnp.float32),
    )
    return variables["params"]


In [5]:
def get_train_state(rng: PRNGKey, ctable: CTable) -> train_state.TrainState:
    model = get_model(ctable)
    params = get_initial_params(model, rng, ctable)
    tx = optax.adam(flags["learning_rate"])
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    return state


In [6]:
def cross_entropy_loss(logits: Array, labels: Array, lengths: Array) -> float:
    xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1)
    masked_xe = jnp.mean(mask_sequences(xe, lengths))
    return -masked_xe

In [7]:
def compute_metrics(logits: Array, labels: Array, eos_id: int) -> Dict[str, float]:
    lengths = get_sequence_lengths(labels, eos_id)
    loss = cross_entropy_loss(logits, labels, lengths)
    token_accuracy = jnp.argmax(logits, -1) == jnp.argmax(labels, -1)
    sequence_accuracy = (
        jnp.sum(mask_sequences(token_accuracy, lengths), axis=-1) == lengths
    )
    accuracy = jnp.mean(sequence_accuracy)
    metrics = {"loss": loss, "accuracy": accuracy}

    return metrics


In [8]:
@jax.jit
def train_step(
    state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey, eos_id: int
) -> Tuple[train_state.TrainState, Dict[str, float]]:
    labels = batch["answer"][:, 1:]
    lstm_key = jax.random.fold_in(lstm_rng, state.step)

    def loss_fn(params):
        logits, _ = state.apply_fn(
            {"params": params}, batch["query"], batch["answer"], rngs={"lstm": lstm_key}
        )
        loss = cross_entropy_loss(logits, labels, get_sequence_lengths(labels, eos_id))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, labels, eos_id)

    return state, metrics


In [9]:
def log_decode(question: str, inferred: str, golden: str) -> None:
    suffix = (
        "(CORRECT)" if inferred == golden else (f"(INCORRECT) " f"correct={golden}")
    )
    print(f"DECODE: {question} = {inferred} {suffix}")


In [10]:
@functools.partial(jax.jit, static_argnums=3)
def decode(
    params: Dict[str, Any], inputs: Array, decode_rng: PRNGKey, ctable: CTable
) -> Array:
    init_decoder_input = ctable.one_hot(ctable.encode("=")[0:1])
    init_decoder_inputs = jnp.tile(
        init_decoder_input, (inputs.shape[0], ctable.max_output_len, 1)
    )
    model = get_model(ctable, teacher_force=False)
    _, predictions = model.apply(
        {"params": params}, inputs, init_decoder_inputs, rngs={"lstm": decode_rng}
    )
    return predictions


In [11]:
def decode_batch(state: train_state.TrainState, batch: Dict[str, Array], decode_rng: PRNGKey, ctable: CTable) -> None: 
    inputs, outputs = batch['query'], batch['answer'][:, 1:]
    decode_rng = jax.random.fold_in(decode_rng, state.step)
    inferred = decode(state.params, inputs, decode_rng, ctable)
    questions = ctable.decode_onehot(inputs)
    infers = ctable.decode_onehot(inferred)
    goldens = ctable.decode_onehot(outputs)

    for question, inferred, golden in zip(questions, infers, goldens):
        log_decode(question, inferred, golden)

In [12]:
def train_and_evaluate(workdir: str) -> train_state.TrainState:
    ctable = CTable("0123456789+= ", flags["max_len_query_digit"])
    rng = jax.random.PRNGKey(0)
    state = get_train_state(rng, ctable)

    writer = metric_writers.create_default_writer(workdir)
    for step in range(flags["num_training_steps"]):
        rng, n_rng = jax.random.split(rng)
        batch = ctable.get_batch(flags["batch_size"], n_rng)
        state, metrics = train_step(state, batch, rng, ctable.eos_id)
        if step and step % flags["decode_frequency"] == 0:
            writer.write_scalars(step, metrics)
            rng, n_rng = jax.random.split(rng)
            batch = ctable.get_batch(5, n_rng)
            decode_batch(state, batch, n_rng, ctable)

    return state


In [13]:
state = train_and_evaluate('./logs')

DECODE: 11+812 = 834 (INCORRECT) correct=823
DECODE: 89+835 = 939 (INCORRECT) correct=924
DECODE: 49+804 = 835 (INCORRECT) correct=853
DECODE: 54+205 = 236 (INCORRECT) correct=259
DECODE: 6+73 = 71 (INCORRECT) correct=79
DECODE: 10+210 = 212 (INCORRECT) correct=220
DECODE: 54+634 = 681 (INCORRECT) correct=688
DECODE: 39+658 = 697 (CORRECT)
DECODE: 94+806 = 905 (INCORRECT) correct=900
DECODE: 57+862 = 925 (INCORRECT) correct=919
DECODE: 27+65 = 91 (INCORRECT) correct=92
DECODE: 89+169 = 264 (INCORRECT) correct=258
DECODE: 67+93 = 151 (INCORRECT) correct=160
DECODE: 73+659 = 735 (INCORRECT) correct=732
DECODE: 53+630 = 679 (INCORRECT) correct=683
DECODE: 73+442 = 514 (INCORRECT) correct=515
DECODE: 25+563 = 581 (INCORRECT) correct=588
DECODE: 54+440 = 490 (INCORRECT) correct=494
DECODE: 69+852 = 928 (INCORRECT) correct=921
DECODE: 68+244 = 312 (CORRECT)
DECODE: 24+322 = 345 (INCORRECT) correct=346
DECODE: 48+966 = 1013 (INCORRECT) correct=1014
DECODE: 61+711 = 773 (INCORRECT) correct=772

In [16]:
CKPT_DIR = 'ckpts'
checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=0)

'ckpts/checkpoint_0'

In [3]:
from input_pipeline import AdditionTaskCT
import jax

In [2]:
ctable = AdditionTaskCT("0123456789+= ", 3)

In [7]:
ctable.vocab_size

15

In [3]:
import string
string.ascii_lowercase

'abcdefghijklmnopqrstuvwxyz'

In [6]:
rnd = jax.random.PRNGKey(0)
np.array([ch for ch in string.ascii_lowercase])[jax.random.randint(rnd, shape=(1, 3), minval=0, maxval=len(string.ascii_lowercase))]

array([['o', 'z', 't']], dtype='<U1')

In [14]:
string.ascii_lowercase

TypeError: string indices must be integers

In [10]:
a = np.array([ch for ch in string.ascii_lowercase])[jax.random.randint(
                    rnd,
                    shape=(1, 3),
                    minval=0,
                    maxval=10,
                )]
b = np.array([ch for ch in string.ascii_lowercase])[[9, 0, 15]]

In [11]:
type(a)

numpy.ndarray

In [13]:
a[0]

array(['i', 'b', 'h'], dtype='<U1')

In [15]:
''.join(a[0])

'ibh'

In [1]:
from input_pipeline import WordReverseTaskCT
import jax

task = WordReverseTaskCT(3)
rnd = jax.random.PRNGKey(0)
g = task.get_batch(batch_size=1, rnd_key=rnd)

for i in g:
    print(i)

ValueError: Sequence is too long (5 > 4): '=ucq'

In [7]:
task.get_batch(10, rnd)

['q' 'c' 'u']
['z' 'u' 'g']
['t' 'f' 's']
['s' 'q' 'v']
['u' 'e' 'd']
['k' 'x' 'z']
['j' 'p' 'r']
['w' 't' 's']
['i' 't' 'f']
['c' 'u' 'q']
['j' 'n' 'g']
['e' 'a' 'w']
['z' 's' 'b']
['n' 'm' 'r']
['f' 'v' 'z']
['g' 'w' 'r']
['i' 'g' 'a']
['v' 'j' 's']
['y' 'n' 'z']
['z' 't' 'x']


KeyError: '['