In [1]:
from flax.training import train_state, checkpoints
from models import Seq2seq
import models
from input_pipeline import CharacterTable as CTable
from input_pipeline import get_sequence_lengths
from input_pipeline import mask_sequences
from absl import flags
from typing import Any, Dict, Tuple
import jax
import train
import functools
import jax.numpy as jnp
import jax.profiler



In [2]:
Array = Any
PRNGKey = Any

In [3]:
CKPT_DIR = "ckpts"
restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=None)

In [4]:
flags = {
    "workdir": ".",
    "learning_rate": 0.003,
    "batch_size": 128,
    "hidden_size": 512,
    "num_training_steps": 10000,
    "max_len_query_digit": 3,
}


In [5]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


In [6]:
flags = dotdict(flags)


In [7]:
ctable = CTable("0123456789+= ", flags.max_len_query_digit)


In [9]:
@functools.partial(jax.jit, static_argnums=3)
def decode(
    params: Dict[str, Any], inputs: Array, decode_rng: PRNGKey, ctable: CTable
) -> Array:
    """Decodes inputs."""
    init_decoder_input = ctable.one_hot(ctable.encode("=")[0:1])
    init_decoder_inputs = jnp.tile(
        init_decoder_input, (inputs.shape[0], ctable.max_output_len, 1)
    )
    model = train.get_model(ctable, teacher_force=False)
    _, predictions = model.apply(
        {"params": params}, inputs, init_decoder_inputs, rngs={"lstm": decode_rng}
    )
    return predictions

In [13]:
from models import Seq2seq

In [15]:
model = models.Seq2seq(
        teacher_force=False,
        hidden_size=512,
        eos_id=ctable.eos_id,
        vocab_size=ctable.vocab_size,
    )

In [21]:
rng = jax.random.PRNGKey(0)
rng1, rng2 = jax.random.split(rng)
variables = model.init(
    {"params": rng1, "lstm": rng2},
    jnp.ones(ctable.encoder_input_shape, jnp.float32),
    jnp.ones(ctable.decoder_input_shape, jnp.float32),
)

In [24]:
f = functools.partial(model.apply, variables, rngs={"lstm", rng})


TypeError: unhashable type: 'DeviceArray'

In [26]:
z = jax.xla_computation(f)(
    jnp.ones(ctable.encoder_input_shape, jnp.float32),
    jnp.ones(ctable.decoder_input_shape, jnp.float32),
    rngs={"lstm": rng}
)


In [27]:
with open("t2.dot", "w") as f:
    f.write(z.as_hlo_dot_graph())

In [30]:
from jax.tree_util import tree_structure

In [31]:
tree_structure(variables)

PyTreeDef(CustomNode(FrozenDict[()], [{'params': {'Decoder_0': {'DecoderLSTM_0': {'Dense_0': {'bias': *, 'kernel': *}, 'LSTMCell_0': {'hf': {'bias': *, 'kernel': *}, 'hg': {'bias': *, 'kernel': *}, 'hi': {'bias': *, 'kernel': *}, 'ho': {'bias': *, 'kernel': *}, 'if': {'kernel': *}, 'ig': {'kernel': *}, 'ii': {'kernel': *}, 'io': {'kernel': *}}}}, 'Encoder_0': {'encoder_lstm': {'LSTMCell_0': {'hf': {'bias': *, 'kernel': *}, 'hg': {'bias': *, 'kernel': *}, 'hi': {'bias': *, 'kernel': *}, 'ho': {'bias': *, 'kernel': *}, 'if': {'kernel': *}, 'ig': {'kernel': *}, 'ii': {'kernel': *}, 'io': {'kernel': *}}}}}}]))

In [40]:
jnp.DeviceArray

jaxlib.xla_extension.DeviceArrayBase

In [38]:
variables['params']

FrozenDict({
    Encoder_0: {
        encoder_lstm: {
            LSTMCell_0: {
                hf: {
                    bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0.

In [38]:
import jax.random as rn
import jax.numpy as jnp
import numpy as np
import string

import sys
sys.path.append('..')
from utils import get_random_string

In [25]:
digits = np.array([d for d in string.digits])

In [41]:
'qwertyuio'[rn.choice(rn.PRNGKey(0), jnp.arange(10), (5,), replace=False,)]

TypeError: only integer scalar arrays can be converted to a scalar index

In [42]:
s = get_random_string( np.array([ch for ch in string.ascii_lowercase]), 10, rn.PRNGKey(10))

In [43]:
np.array([ch for ch in s])[rn.choice(rn.PRNGKey(0), jnp.arange(10), (5,), replace=False,)]

array(['j', 'a', 'w', 'c', 'a'], dtype='<U1')