In [1]:
import functools
from importlib import reload

import jax
from jax import random, lax, numpy as jnp
import numpy as np

import flax
from flax import nn
from flax import core
from flax.core import Scope, init, apply, Array

In [2]:
# def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,
#           kernel_init=nn.linear.default_kernel_init,
#           bias_init=nn.initializers.zeros):
#   kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))
#   y = jnp.dot(inputs, kernel)
#   if bias:
#     y += scope.param('bias', bias_init, (features,))
#   return y
dense = nn.dense

model_fn = functools.partial(dense, features=3)
x = jnp.ones((1, 2))
y, params = init(model_fn)(random.PRNGKey(0), x)
print(params)



FrozenDict({'param': FrozenDict({'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],
             [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)})})


In [3]:
def mlp(scope: Scope, inputs: Array, features: int):
  hscope = scope.push('hidden')
  hidden = dense(hscope, inputs, features)
  hidden = dense(hscope, inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[-0.00302252]], dtype=float32),
 FrozenDict({'param': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674,  0.33191404],
              [-0.7799348 ,  0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

In [4]:
def mlp(scope: Scope, inputs: Array, features: int):
  denselayer = functools.partial(dense, scope.push('hidden'))
  hidden = denselayer(inputs, features)
  hidden = denselayer(inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[-0.00302252]], dtype=float32),
 FrozenDict({'param': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674,  0.33191404],
              [-0.7799348 ,  0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

In [5]:
def mlp(scope: Scope, inputs: Array, features: int):
  denselayer = scope.child(dense)
  hidden = denselayer(inputs, features)
  hidden = denselayer(inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[0.44402555]], dtype=float32),
 FrozenDict({'param': FrozenDict({'dense__0': FrozenDict({'kernel': DeviceArray([[ 0.83161545, -0.2180224 ,  0.41811994],
              [ 0.17165233, -0.14596988,  1.1707549 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

In [6]:
def mlp(scope: Scope, inputs: Array, features: int):
  denselayer = scope.child(dense)
  hidden = denselayer(inputs, features)
  hidden = denselayer(inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[0.44402555]], dtype=float32),
 FrozenDict({'param': FrozenDict({'dense__0': FrozenDict({'kernel': DeviceArray([[ 0.83161545, -0.2180224 ,  0.41811994],
              [ 0.17165233, -0.14596988,  1.1707549 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

### dense

In [7]:
model_fn = functools.partial(nn.dense_general, features=3)
x = jnp.ones((1, 2))
y, params = init(model_fn)(random.PRNGKey(0), x)
print(params)

FrozenDict({'param': FrozenDict({'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],
             [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)})})


### embedding

In [33]:
embedding, v = init(nn.embedding)(random.PRNGKey(0), num_embeddings=2, features=3)
print(embedding.table)
print(embedding.lookup(1))
print(embedding.attend(jnp.ones((1, 3,))))
v

[[ 0.11575121 -0.51936364 -1.113899  ]
 [ 0.45569834 -0.5300623  -0.5873911 ]]
[ 0.45569834 -0.5300623  -0.5873911 ]
[[-1.5175114 -0.6617551]]


FrozenDict({'param': FrozenDict({'table': DeviceArray([[ 0.11575121, -0.51936364, -1.113899  ],
             [ 0.45569834, -0.5300623 , -0.5873911 ]], dtype=float32)})})

### layernorm

In [9]:
def mlp(scope: Scope, inputs: Array, features: int):
  hidden = dense(scope.push('hidden'), inputs, features)
  hidden = nn.relu(hidden)
  lnormd = nn.layer_norm(scope.push('lnorm'), hidden) 
  return dense(scope.push('out'), lnormd, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[-0.2751354]], dtype=float32),
 FrozenDict({'param': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674,  0.33191404],
              [-0.7799348 ,  0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'lnorm': FrozenDict({'scale': DeviceArray([1., 1., 1.], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

### attn

In [10]:
x = jnp.ones((1, 2, 4))
attn = functools.partial(nn.multi_head_dot_product_attention, num_heads=1, cache=True, causal_mask=True)
y, variables = init(attn)(random.PRNGKey(0), x, x)
params = variables['param']
cache = jax.tree_map(lambda fn: fn((1, 2)), variables['cache'])
variables = variables.copy(cache=cache)
apply(attn, mutable='cache')(variables, x[:, 0:1], x)

(DeviceArray([[[ 0.28011173, -0.55635864,  0.6146434 ,  0.5917805 ]]], dtype=float32),
 FrozenDict({'param': FrozenDict({'query': FrozenDict({'kernel': DeviceArray([[[-0.25566727,  0.24481767,  0.70468867,  0.10142951]],
 
              [[-0.24900651, -0.43240824, -0.07879252, -0.35603613]],
 
              [[ 0.04119987, -0.5303452 ,  0.5928286 , -0.5084808 ]],
 
              [[-0.01729889,  0.76639396, -0.4523484 ,  0.8094689 ]]],            dtype=float32), 'bias': DeviceArray([[0., 0., 0., 0.]], dtype=float32)}), 'key': FrozenDict({'kernel': DeviceArray([[[-0.27695838,  0.73972565, -0.00267717,  0.669741  ]],
 
              [[-0.6818866 ,  0.13959199, -0.3723538 , -0.35600063]],
 
              [[ 0.1109049 , -0.22054054,  0.5455148 ,  0.21493351]],
 
              [[-0.19008224,  0.06578335,  0.24833658,  0.00166761]]],            dtype=float32), 'bias': DeviceArray([[0., 0., 0., 0.]], dtype=float32)}), 'value': FrozenDict({'kernel': DeviceArray([[[ 0.7953513 , -0.31489673,  0.05

# test model

In [4]:
cd examples/wmt

/Users/levskaya/repos/alflax/examples/wmt


In [5]:
import models

In [6]:
models=reload(models)

In [5]:
x = jnp.ones((1, 2, 4))
y, variables = init(models.add_position_embs)(random.PRNGKey(0), x)
y, variables



(DeviceArray([[[1.       , 1.       , 2.       , 2.       ],
               [1.841471 , 1.0001   , 1.5403023, 2.       ]]],            dtype=float32),
 FrozenDict({}))

In [6]:
x = jnp.ones((1, 2, 4))
mlp_block = functools.partial(models.mlp_block, mlp_dim=6)
y, variables = init(mlp_block)({'param': random.PRNGKey(0), 'dropout': random.PRNGKey(0)}, x)
y, variables

(DeviceArray([[[-0.14761418, -0.12177372,  0.00229907, -0.25567493],
               [ 0.00163177,  0.0016789 , -0.00219643,  0.        ]]],            dtype=float32),
 FrozenDict({'param': FrozenDict({'dense__0': FrozenDict({'kernel': DeviceArray([[ 0.01424811,  0.05006719, -0.3585709 , -0.14347874,
               -0.21958077,  0.06684854],
              [ 0.73106635, -0.49544278,  0.42408884, -0.4566572 ,
                0.13028866, -0.24688083],
              [-0.6033098 ,  0.29077   ,  0.15498346, -0.55903727,
                0.10362869,  0.40677932],
              [-0.20636722,  0.15732443, -0.00088572,  0.2146615 ,
               -0.22184326,  0.16878225]], dtype=float32), 'bias': DeviceArray([-1.0963597e-06, -1.1821694e-06,  3.7060437e-07,
              -3.0882313e-07, -1.3826963e-06,  3.3113665e-06],            dtype=float32)}), 'dense__1': FrozenDict({'kernel': DeviceArray([[-1.2714542e-01,  2.1636164e-02,  4.6816525e-01,
                2.8025460e-01],
              [ 4.860620

In [23]:
x = jnp.ones((1, 2, 4))
block = functools.partial(models.encoder_1d_block, qkv_dim=6, mlp_dim=24, num_heads=2)
y, variables = init(block)({'param': random.PRNGKey(0), 'dropout': random.PRNGKey(0)}, x)
#y, variables
y, jax.tree_map(jnp.shape, core.unfreeze(variables))

(DeviceArray([[[1.0000001 , 0.99999934, 0.99999803, 1.0000008 ],
               [1.0000001 , 0.99999934, 0.99999803, 1.0000008 ]]],            dtype=float32),
 {'param': {'layer_norm__0': {'bias': (4,), 'scale': (4,)},
   'layer_norm__1': {'bias': (4,), 'scale': (4,)},
   'mlp_block__0': {'dense__0': {'bias': (24,), 'kernel': (4, 24)},
    'dense__1': {'bias': (4,), 'kernel': (24, 4)}},
   'multi_head_dot_product_attention__0': {'key': {'kernel': (4, 2, 3)},
    'out': {'kernel': (2, 3, 4)},
    'query': {'kernel': (4, 2, 3)},
    'value': {'kernel': (4, 2, 3)}}}})

In [24]:
models=reload(models)
x = jnp.ones((1, 2))
block = functools.partial(models.encoder, qkv_dim=6, mlp_dim=24, num_heads=2, vocab_size=256)
y, variables = init(block)({'param': random.PRNGKey(0), 'dropout': random.PRNGKey(0)}, x)

y, jax.tree_map(jnp.shape, core.unfreeze(variables))

(DeviceArray([[[-1.4978968 , -0.51244444,  0.12619986, ...,  0.9570363 ,
                 1.6442691 , -0.27091283],
               [ 0.6670244 ,  0.6246195 ,  0.07661428, ...,  0.932581  ,
                -0.6971056 ,  0.0879988 ]]], dtype=float32),
 {'param': {'embedding__0': {'table': (256, 512)},
   'encoder_norm': {'bias': (512,), 'scale': (512,)},
   'encoderblock_0': {'layer_norm__0': {'bias': (512,), 'scale': (512,)},
    'layer_norm__1': {'bias': (512,), 'scale': (512,)},
    'mlp_block__0': {'dense__0': {'bias': (24,), 'kernel': (512, 24)},
     'dense__1': {'bias': (512,), 'kernel': (24, 512)}},
    'multi_head_dot_product_attention__0': {'key': {'kernel': (512, 2, 3)},
     'out': {'kernel': (2, 3, 512)},
     'query': {'kernel': (512, 2, 3)},
     'value': {'kernel': (512, 2, 3)}}},
   'encoderblock_1': {'layer_norm__0': {'bias': (512,), 'scale': (512,)},
    'layer_norm__1': {'bias': (512,), 'scale': (512,)},
    'mlp_block__0': {'dense__0': {'bias': (24,), 'kernel': (512,

In [26]:
models=reload(models)
encoded = jnp.ones((1, 2, 6))
x = jnp.ones((1, 2))
src_padding_mask = (x > 0)[..., None]
block = functools.partial(models.decoder, qkv_dim=6, mlp_dim=24, num_heads=2, output_vocab_size=256)
y, variables = init(block)({'param': random.PRNGKey(0), 'dropout': random.PRNGKey(0)}, encoded, src_padding_mask, x)

y.shape, jax.tree_map(jnp.shape, core.unfreeze(variables))

((1, 2, 256),
 {'param': {'embedding__0': {'table': (256, 512)},
   'encoderdecoder_norm': {'bias': (512,), 'scale': (512,)},
   'encoderdecoderblock_0': {'layer_norm__0': {'bias': (512,),
     'scale': (512,)},
    'layer_norm__1': {'bias': (512,), 'scale': (512,)},
    'layer_norm__2': {'bias': (512,), 'scale': (512,)},
    'mlp_block__0': {'dense__0': {'bias': (24,), 'kernel': (512, 24)},
     'dense__1': {'bias': (512,), 'kernel': (24, 512)}},
    'multi_head_dot_product_attention__0': {'key': {'kernel': (512, 2, 3)},
     'out': {'kernel': (2, 3, 512)},
     'query': {'kernel': (512, 2, 3)},
     'value': {'kernel': (512, 2, 3)}},
    'multi_head_dot_product_attention__1': {'key': {'kernel': (6, 2, 3)},
     'out': {'kernel': (2, 3, 512)},
     'query': {'kernel': (512, 2, 3)},
     'value': {'kernel': (6, 2, 3)}}},
   'encoderdecoderblock_1': {'layer_norm__0': {'bias': (512,),
     'scale': (512,)},
    'layer_norm__1': {'bias': (512,), 'scale': (512,)},
    'layer_norm__2': {'bi

In [45]:
models=reload(models)
x = jnp.ones((1, 2))
block = functools.partial(models.transformer, qkv_dim=6, mlp_dim=24, num_heads=2, vocab_size=256, share_embeddings=True, train=True)
y, variables = init(block)({'param': random.PRNGKey(0), 'dropout': random.PRNGKey(0)}, x, x)

y.shape, jax.tree_map(jnp.shape, core.unfreeze(variables))


((1, 2, 256),
 {'param': {'decoder': {'encoderdecoder_norm': {'bias': (512,),
     'scale': (512,)},
    'encoderdecoderblock_0': {'layer_norm__0': {'bias': (512,),
      'scale': (512,)},
     'layer_norm__1': {'bias': (512,), 'scale': (512,)},
     'layer_norm__2': {'bias': (512,), 'scale': (512,)},
     'mlp_block__0': {'dense__0': {'bias': (24,), 'kernel': (512, 24)},
      'dense__1': {'bias': (512,), 'kernel': (24, 512)}},
     'multi_head_dot_product_attention__0': {'key': {'kernel': (512, 2, 3)},
      'out': {'kernel': (2, 3, 512)},
      'query': {'kernel': (512, 2, 3)},
      'value': {'kernel': (512, 2, 3)}},
     'multi_head_dot_product_attention__1': {'key': {'kernel': (512, 2, 3)},
      'out': {'kernel': (2, 3, 512)},
      'query': {'kernel': (512, 2, 3)},
      'value': {'kernel': (512, 2, 3)}}},
    'encoderdecoderblock_1': {'layer_norm__0': {'bias': (512,),
      'scale': (512,)},
     'layer_norm__1': {'bias': (512,), 'scale': (512,)},
     'layer_norm__2': {'bias'

In [60]:
models=reload(models)
x = jnp.ones((1, 2))

encode_fn = functools.partial(models.transformer_encode, qkv_dim=6, mlp_dim=24, num_heads=2, vocab_size=256, share_embeddings=True)
decode_fn = functools.partial(models.transformer_decode, qkv_dim=6, mlp_dim=24, num_heads=2, vocab_size=256, share_embeddings=True)

encoded = apply(encode_fn)(variables, x, rngs={'dropout': random.PRNGKey(24)})

In [61]:
models=reload(models)
x = jnp.ones((1, 2))
src_padding_mask = (x > 0)[..., None]


apply(decode_fn)(variables, encoded, src_padding_mask, x, rngs={'dropout': random.PRNGKey(24)}).shape

(1, 2, 256)

In [98]:
#y, variables = init(model)(random.PRNGKey(0), jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32), cache=True, train=True)
input_shape = (1,10)
target_shape = (1,10)
block = functools.partial(models.transformer, qkv_dim=6, mlp_dim=24, num_heads=2, vocab_size=256, share_embeddings=True, train=True)
y, variables = init(block)({'param': random.PRNGKey(0), 'dropout': random.PRNGKey(0)},  jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32), cache=True, train=True)


In [65]:
variables['cache']

FrozenDict({'decoder': FrozenDict({'posembed_output': FrozenDict({'idx': DeviceArray(0, dtype=uint32)}), 'encoderdecoderblock_0': FrozenDict({'multi_head_dot_product_attention__0': FrozenDict({'entry': <function multi_head_dot_product_attention.<locals>.init_fn at 0x138d179d8>})}), 'encoderdecoderblock_1': FrozenDict({'multi_head_dot_product_attention__0': FrozenDict({'entry': <function multi_head_dot_product_attention.<locals>.init_fn at 0x138d17620>})}), 'encoderdecoderblock_2': FrozenDict({'multi_head_dot_product_attention__0': FrozenDict({'entry': <function multi_head_dot_product_attention.<locals>.init_fn at 0x138d176a8>})}), 'encoderdecoderblock_3': FrozenDict({'multi_head_dot_product_attention__0': FrozenDict({'entry': <function multi_head_dot_product_attention.<locals>.init_fn at 0x139045268>})}), 'encoderdecoderblock_4': FrozenDict({'multi_head_dot_product_attention__0': FrozenDict({'entry': <function multi_head_dot_product_attention.<locals>.init_fn at 0x1390451e0>})}), 'enco

In [95]:
cache = jax.tree_map(lambda fn: fn((1, 2)) if callable(fn) else fn, variables['cache'])
variables = variables.copy(cache=cache)

In [74]:
from flax import jax_utils


In [76]:
jax_utils.replicate(1)



ShardedDeviceArray([1], dtype=int32)

In [78]:
jax.pmap(lambda x: x)(jax_utils.replicate(1))

ShardedDeviceArray([1], dtype=int32)

In [88]:
variables = variables.copy(param={})

In [92]:
type(variables)

flax.core.frozen_dict.FrozenDict

In [93]:
core.FrozenDict(param={})

FrozenDict({'param': {}})

In [99]:
list(variables.keys())

['param', 'cache']