In [1]:
import functools
from typing import Any, Tuple
import random

from absl import logging
import flax.linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import numpy as np
import optax

### Digit processing

turn strings like `1 + 2 = 3` into `id`s

In [2]:
NUM_DIGITS = 3


class CharacterTable(object):
    @property
    def pad_id(self):
        return 0
    
    @property
    def eos_id(self):
        return 1
    
    @property
    def vocab_size(self):
        return len(self.chars) + 2
    
    @property
    def max_input_len(self):
        return NUM_DIGITS ** 2 + 3
    
    @property
    def max_output_len(self):
        return NUM_DIGITS ** 2 + 3
    
    def __init__(self, chars):
        self.chars = sorted(set(chars))
        self.char2index = dict((c, i+2) for i, c in enumerate(self.chars))
        self.index2char = dict((i+2, c) for i, c in enumerate(self.chars))
    
    def encode(self, inputs):
        return np.array([self.char2index[c] for c in inputs] + [self.eos_id])
    
    def decode(self, inputs):
        chars = []
        for e in inputs.tolist():
            if e == self.eos_id:
                break
            if e not in self.index2char:
                continue
            
            chars.append(self.index2char[e])
        return ''.join(chars)


TABLE = CharacterTable('0123456789+-*= ')
example_ids = TABLE.encode('10 + 10 = 20')
print(example_ids)
example_text = TABLE.decode(example_ids)
print(example_text)

[ 7  6  2  4  2  7  6  2 16  2  8  6  1]
10 + 10 = 20


### Creating input data

generating input data

In [3]:
def get_examples(num_examples):
    ops = [
        ('+', lambda x, y: x + y), 
        ('-', lambda x, y: x - y),
        ('*', lambda x, y: x * y),
        ('/', lambda x, y: x // y)]
    for _ in range(num_examples):
        max_digit = pow(10, NUM_DIGITS) -1
        op_code = random.randint(0, 1)
        key = tuple(sorted((random.randint(0, max_digit), random.randint(0, max_digit))))
        yield f'{key[0]}{ops[op_code][0]}{key[1]}', f'={ops[op_code][1](key[0], key[1])}'


print(next(get_examples(1)))

def encode_onehot(inputs):
    e = np.eye(TABLE.vocab_size)
    def encode_str(s):
        tokens = TABLE.encode(s)
        unpadded_len = len(tokens)
        tokens = np.pad(tokens, [(0, TABLE.max_input_len - len(tokens))], mode='constant')
        return e[tokens]
        # return jax.nn.one_hot(tokens, TABLE.vocab_size, dtype=jnp.float32)
    
    return np.array([encode_str(inp) for inp in inputs])


def get_batch(batch_size):
    inputs, outputs = zip(*get_examples(batch_size))
    return {
         'query': encode_onehot(inputs),
         'answer': encode_onehot(outputs)
    }


example_batch = get_batch(32)
print(example_batch['query'].shape)
print(example_batch['answer'].shape)
print(TABLE.decode(np.argmax(example_batch['query'][0, :], axis=-1)))
print(TABLE.decode(np.argmax(example_batch['answer'][0, :], axis=-1)))

('15+372', '=387')
(32, 12, 17)
(32, 12, 17)
674-834
=-160


In [4]:
%timeit get_batch(128)

6.86 ms ± 75.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Create Model

I don't really understand this

```python
@functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1,
    out_axes=1,
    split_rngs={'params': False})
```

but it seems like the example is all doing this for sequence data

In [5]:
key = jax.random.PRNGKey(666)


class EncoderLSTM(nn.Module):
    @functools.partial(
        nn.transforms.scan,
        variable_broadcast='params',
        in_axes=1,
        out_axes=1,
        split_rngs={'params': False})
    @nn.compact
    def __call__(self, carry, x):
        lstm_state, is_eos = carry
        new_lstm_state, y = nn.LSTMCell()(lstm_state, x)
        
        def select_carried_state(new_state, old_state):
            return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
        
        carried_lstm_state = tuple(select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
        
        is_eos = jnp.logical_or(is_eos, x[:, TABLE.eos_id])
        return (carried_lstm_state, is_eos), y
    
    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        return nn.LSTMCell.initialize_carry(key, (batch_size,), hidden_size)
    

class Encoder(nn.Module):
    hidden_size: int
    
    @nn.compact
    def __call__(self, inputs):
        batch_size = inputs.shape[0]
        lstm = EncoderLSTM(name='encoder_lstm')
        init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
        init_is_eos = jnp.zeros(batch_size, dtype=np.bool)
        init_carry = (init_lstm_state, init_is_eos)
        (final_state, _), _ = lstm(init_carry, inputs)
        return final_state


class DecoderLSTM(nn.Module):
    teacher_force: bool
    
    @functools.partial(
        nn.transforms.scan,
        variable_broadcast='params',
        in_axes=1,
        out_axes=1,
        split_rngs={'params': False})
    @nn.compact
    def __call__(self, carry, x):
        rng, lstm_state, last_prediction = carry
        carry_rng, categorical_rng = jax.random.split(rng, 2)
        if not self.teacher_force:
            x = last_prediction
        lstm_state, y = nn.LSTMCell()(lstm_state, x)
        logits = nn.Dense(features=TABLE.vocab_size)(y)
        predicted_token = jax.random.categorical(categorical_rng, logits)
        prediction = jax.nn.one_hot(predicted_token, TABLE.vocab_size, dtype=jnp.float32)
        return (carry_rng, lstm_state, prediction), (logits, prediction)


class Decoder(nn.Module):
    init_state: Tuple[Any]
    teacher_force: bool
    
    @nn.compact
    def __call__(self, inputs):
        lstm = DecoderLSTM(teacher_force=self.teacher_force)
        init_carry = (self.make_rng('lstm'), self.init_state, inputs[:, 0])
        _, (logits, prediction) = lstm(init_carry, inputs)
        return logits, prediction


class Seq2Seq(nn.Module):
    teacher_force: bool
    hidden_size: int
    
    @nn.compact
    def __call__(self, encoder_inputs, decoder_inputs):
        init_decoder_state = Encoder(hidden_size=self.hidden_size)(encoder_inputs)
        logits, prediction = Decoder(init_state=init_decoder_state, teacher_force=self.teacher_force)(decoder_inputs[:, :-1])
        return logits, prediction


HIDDEN_SIZE = 512
model = Seq2Seq(teacher_force=False, hidden_size=HIDDEN_SIZE)

In [6]:
key, init_key = jax.random.split(key)

def get_initial_params(model, key):
    encoder_shape = jnp.ones((1, TABLE.max_input_len, TABLE.vocab_size), jnp.float32)
    decoder_shape = jnp.ones((1, TABLE.max_output_len, TABLE.vocab_size), jnp.float32)
    return model.init({
        'params': key,
        'lstm': key,
    }, encoder_shape, decoder_shape)['params']


params = get_initial_params(model, init_key)
tx = optax.adam(0.001)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [7]:
test_batch = get_batch(32)

key, lstm_key = jax.random.split(key)

test_logits, test_prediction = model.apply(
    {'params': params},
    test_batch['query'],
    test_batch['answer'],
    rngs={'lstm': lstm_key})

print(test_logits.shape)
print(test_prediction.shape)

(32, 11, 17)
(32, 11, 17)


In [8]:
def get_sequence_lengths(sequence_batch, eos_id=TABLE.eos_id):
    eos_row = sequence_batch[:, :, eos_id]
    eos_index = jnp.argmax(eos_row, axis=-1)
    return jnp.where(eos_row[jnp.arange(eos_row.shape[0]), eos_index],
                     eos_index+1,
                     sequence_batch.shape[1])


test_label = test_batch['answer'][:, 1:]
seq_len = get_sequence_lengths(test_label)
print(seq_len)

[5 4 4 4 5 5 4 5 5 4 4 4 5 5 5 5 4 4 5 4 3 4 4 4 4 5 4 4 4 5 5 4]


### Loss function

In [9]:
def mask_sequences(sequence_batch, lengths):
    return sequence_batch * (lengths[:, np.newaxis] > np.arange(sequence_batch.shape[1])[np.newaxis])


def cross_entropy_loss(logits, labels, lengths):
    xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1)
    print(xe)
    print(mask_sequences(xe, lengths))
    masked_xe = jnp.mean(mask_sequences(xe, lengths))
    return -masked_xe


loss = cross_entropy_loss(test_logits, test_label, seq_len)
print(loss)

[[-2.8904777 -2.7868254 -2.7626944 -2.6934536 -2.7697418 -2.8016784
  -2.8280935 -2.8075247 -2.7960792 -2.8003495 -2.863701 ]
 [-2.9214404 -2.8051474 -2.923343  -2.862573  -2.8832521 -2.9184513
  -2.9227433 -3.0260184 -2.8678405 -2.8433309 -2.836333 ]
 [-2.8314075 -2.8569617 -2.9167051 -2.6773245 -2.8577676 -2.9421928
  -2.8451893 -2.8190424 -2.831006  -2.8373528 -2.8160846]
 [-2.870552  -2.7175772 -2.8592885 -2.629376  -2.8829362 -2.8722124
  -2.8933296 -2.896182  -2.882987  -2.903448  -2.842338 ]
 [-2.8889878 -2.858055  -2.8215797 -2.8070657 -2.704901  -2.907819
  -2.8423374 -2.8353229 -2.8649101 -2.902028  -2.871624 ]
 [-2.8629386 -2.7979014 -2.8074915 -2.7963216 -2.7594671 -2.800516
  -2.806804  -2.8525014 -2.8304312 -2.798877  -2.7914522]
 [-2.8836682 -2.7426717 -2.7585287 -2.8434894 -2.8093634 -2.891889
  -2.9059656 -2.8476896 -2.8394704 -2.8349142 -2.7896755]
 [-2.8837144 -2.8440342 -2.8400433 -2.8653557 -2.7386038 -2.828299
  -2.86696   -2.8940384 -2.9070866 -2.8749878 -2.96705

In [10]:
def compute_metrics(logits, labels):
    lengths = get_sequence_lengths(labels)
    loss = cross_entropy_loss(logits, labels, lengths)
    
    token_accuracy = jnp.argmax(logits, -1) == jnp.argmax(labels, -1)
    sequence_accuracy = (jnp.sum(mask_sequences(token_accuracy, lengths), axis=-1) == lengths)
    accuracy = jnp.mean(sequence_accuracy)
    return {
        'loss': loss,
        'accuracy': accuracy,
    }
    

@jax.jit
def train_step(state, batch, lstm_key):
    labels = batch['answer'][:, 1:]
    
    def loss_fn(params):
        logits, _ = state.apply_fn({'params': params},
                                   batch['query'],
                                   batch['answer'],
                                   rngs={'lstm': lstm_key})
        loss = cross_entropy_loss(logits, labels, get_sequence_lengths(labels))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, labels)
    return state, metrics


next_state, test_metrics = train_step(state, test_batch, lstm_key)
print(test_metrics)

Traced<ShapedArray(float32[32,11])>with<JVPTrace(level=2/1)>
  with primal = Traced<ShapedArray(float32[32,11])>with<DynamicJaxprTrace(level=0/1)>
       tangent = Traced<ShapedArray(float32[32,11]):JaxprTrace(level=1/1)>
Traced<ShapedArray(float32[32,11])>with<JVPTrace(level=2/1)>
  with primal = Traced<ShapedArray(float32[32,11])>with<DynamicJaxprTrace(level=0/1)>
       tangent = Traced<ShapedArray(float32[32,11]):JaxprTrace(level=1/1)>
Traced<ShapedArray(float32[32,11])>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(float32[32,11])>with<DynamicJaxprTrace(level=0/1)>
{'accuracy': DeviceArray(0., dtype=float32), 'loss': DeviceArray(1.1243855, dtype=float32)}


In [11]:
%timeit train_step(state, test_batch, lstm_key)

6.12 ms ± 142 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
BATCH_SIZE = 128
HIDDEN_SIZE = 512


def train_model(state, key, train_steps):
    for step in range(train_steps):
        key, lstm_key = jax.random.split(key)
        batch = get_batch(BATCH_SIZE)
        state, metrics = train_step(state, batch, lstm_key)
        if step % 300 == 0:
            key, lstm_key = jax.random.split(key)
            batch = get_batch(6)
            print(f"accuracy: {metrics['accuracy']}, loss: {metrics['loss']}")
            # test_decode(state.params, batch, lstm_key)
    return state


key = jax.random.PRNGKey(666)
model = Seq2Seq(teacher_force=False, hidden_size=HIDDEN_SIZE)
params = get_initial_params(model, init_key)
tx = optax.adam(1e-3)
state = train_state.TrainState.create(
    apply_fn=model.apply, params=params, tx=tx)
done_state = train_model(state, key, 30000)

Traced<ShapedArray(float32[128,11])>with<JVPTrace(level=2/1)>
  with primal = Traced<ShapedArray(float32[128,11])>with<DynamicJaxprTrace(level=0/1)>
       tangent = Traced<ShapedArray(float32[128,11]):JaxprTrace(level=1/1)>
Traced<ShapedArray(float32[128,11])>with<JVPTrace(level=2/1)>
  with primal = Traced<ShapedArray(float32[128,11])>with<DynamicJaxprTrace(level=0/1)>
       tangent = Traced<ShapedArray(float32[128,11]):JaxprTrace(level=1/1)>
Traced<ShapedArray(float32[128,11])>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(float32[128,11])>with<DynamicJaxprTrace(level=0/1)>
accuracy: 0.0, loss: 1.1941652297973633
accuracy: 0.0, loss: 0.5542354583740234
accuracy: 0.0234375, loss: 0.4907089173793793
accuracy: 0.0, loss: 0.46042659878730774
accuracy: 0.015625, loss: 0.4465414583683014
accuracy: 0.0078125, loss: 0.38888493180274963
accuracy: 0.0703125, loss: 0.3176511824131012
accuracy: 0.0859375, loss: 0.2880670428276062
accuracy: 0.1484375, loss: 0.26955246925354004
accuracy: 

In [13]:
@jax.jit
def decode(params, inputs, key):
    init_decoder_input = jax.nn.one_hot(TABLE.encode('=')[0:1], TABLE.vocab_size, dtype=jnp.float32)
    init_decoder_inputs = jnp.tile(init_decoder_input,
                                   (inputs.shape[0], TABLE.max_output_len, 1))
    model = Seq2Seq(teacher_force=False, hidden_size=HIDDEN_SIZE)
    _, prediction = model.apply({'params': params},
                                inputs,
                                init_decoder_inputs,
                                rngs={'lstm': key})
    return prediction


def decode_onehot(batch_inputs):
    decode_inputs = lambda inputs: TABLE.decode(inputs.argmax(axis=-1))
    return np.array(list(map(decode_inputs, batch_inputs)))


def decode_batch(params, batch, key):
    inputs, labels = batch['query'], batch['answer'][:, 1:]
    inferred = decode(params, inputs, key)
    questions = decode_onehot(inputs)
    infers = decode_onehot(inferred)
    answers = decode_onehot(labels)
    
    for q, i, a in zip(questions, infers, answers):
        suffix = '(CORRECT)' if i == a else f'(INCORRECT): correct={a}'
        print('DECODE: %s = %s %s' % (q, i, suffix))


decode_batch(done_state.params, test_batch, lstm_key)

DECODE: 285-692 = -407 (CORRECT)
DECODE: 26-47 = -21 (CORRECT)
DECODE: 159+252 = 411 (CORRECT)
DECODE: 111+437 = 548 (CORRECT)
DECODE: 369-930 = -561 (CORRECT)
DECODE: 483-856 = -373 (CORRECT)
DECODE: 730-768 = -38 (CORRECT)
DECODE: 362-754 = -392 (CORRECT)
DECODE: 877+963 = 1840 (CORRECT)
DECODE: 239+673 = 912 (CORRECT)
DECODE: 52+605 = 657 (CORRECT)
DECODE: 39+775 = 814 (CORRECT)
DECODE: 431-676 = -245 (CORRECT)
DECODE: 403+930 = 1333 (CORRECT)
DECODE: 616-735 = -119 (CORRECT)
DECODE: 678-864 = -186 (CORRECT)
DECODE: 45+248 = 293 (CORRECT)
DECODE: 268+325 = 593 (CORRECT)
DECODE: 536+884 = 1420 (CORRECT)
DECODE: 833-886 = -53 (CORRECT)
DECODE: 101-103 = -2 (CORRECT)
DECODE: 290+677 = 967 (CORRECT)
DECODE: 62+489 = 551 (CORRECT)
DECODE: 90+115 = 205 (CORRECT)
DECODE: 27+292 = 319 (CORRECT)
DECODE: 289-814 = -525 (CORRECT)
DECODE: 165+491 = 656 (CORRECT)
DECODE: 560-620 = -60 (CORRECT)
DECODE: 54+696 = 750 (CORRECT)
DECODE: 340+978 = 1318 (CORRECT)
DECODE: 296-964 = -668 (CORRECT)
DECOD