In [1]:
import functools
import jax
from jax import random, lax, numpy as jnp
import numpy as np

import flax
from flax import nn
from flax import core
from flax.core import Scope, init, apply, Array

In [2]:
# def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,
#           kernel_init=nn.linear.default_kernel_init,
#           bias_init=nn.initializers.zeros):
#   kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))
#   y = jnp.dot(inputs, kernel)
#   if bias:
#     y += scope.param('bias', bias_init, (features,))
#   return y
dense = nn.dense

model_fn = functools.partial(dense, features=3)
x = jnp.ones((1, 2))
y, params = init(model_fn)(random.PRNGKey(0), x)
print(params)



FrozenDict({'param': FrozenDict({'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],
             [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)})})


In [3]:
def mlp(scope: Scope, inputs: Array, features: int):
  hscope = scope.push('hidden')
  hidden = dense(hscope, inputs, features)
  hidden = dense(hscope, inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[-0.00302252]], dtype=float32),
 FrozenDict({'param': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674,  0.33191404],
              [-0.7799348 ,  0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

In [4]:
def mlp(scope: Scope, inputs: Array, features: int):
  denselayer = functools.partial(dense, scope.push('hidden'))
  hidden = denselayer(inputs, features)
  hidden = denselayer(inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[-0.00302252]], dtype=float32),
 FrozenDict({'param': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674,  0.33191404],
              [-0.7799348 ,  0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

In [5]:
def mlp(scope: Scope, inputs: Array, features: int):
  denselayer = scope.child(dense)
  hidden = denselayer(inputs, features)
  hidden = denselayer(inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[0.44402555]], dtype=float32),
 FrozenDict({'param': FrozenDict({'dense__0': FrozenDict({'kernel': DeviceArray([[ 0.83161545, -0.2180224 ,  0.41811994],
              [ 0.17165233, -0.14596988,  1.1707549 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

In [6]:
def mlp(scope: Scope, inputs: Array, features: int):
  denselayer = scope.child(dense)
  hidden = denselayer(inputs, features)
  hidden = denselayer(inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[0.44402555]], dtype=float32),
 FrozenDict({'param': FrozenDict({'dense__0': FrozenDict({'kernel': DeviceArray([[ 0.83161545, -0.2180224 ,  0.41811994],
              [ 0.17165233, -0.14596988,  1.1707549 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

In [7]:
model_fn = functools.partial(nn.dense_general, features=3)
x = jnp.ones((1, 2))
y, params = init(model_fn)(random.PRNGKey(0), x)
print(params)

FrozenDict({'param': FrozenDict({'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],
             [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)})})


In [8]:
embedding, _ = init(nn.embedding)(random.PRNGKey(0), num_embeddings=2, features=3)
print(embedding.table)
print(embedding.lookup(1))
print(embedding.attend(jnp.ones((1, 3,))))

[[ 0.11575121 -0.51936364 -1.113899  ]
 [ 0.45569834 -0.5300623  -0.5873911 ]]
[ 0.45569834 -0.5300623  -0.5873911 ]
[[-1.5175114 -0.6617551]]


In [9]:
def mlp(scope: Scope, inputs: Array, features: int):
  hidden = dense(scope.push('hidden'), inputs, features)
  hidden = nn.relu(hidden)
  lnormd = nn.layer_norm(scope.push('lnorm'), hidden) 
  return dense(scope.push('out'), lnormd, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

(DeviceArray([[-0.2751354]], dtype=float32),
 FrozenDict({'param': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674,  0.33191404],
              [-0.7799348 ,  0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'lnorm': FrozenDict({'scale': DeviceArray([1., 1., 1.], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))

In [10]:
x = jnp.ones((1, 2, 4))
attn = functools.partial(nn.multi_head_dot_product_attention, num_heads=1, cache=True, causal_mask=True)
y, variables = init(attn)(random.PRNGKey(0), x, x)
params = variables['param']
cache = jax.tree_map(lambda fn: fn((1, 2)), variables['cache'])
variables = variables.copy(cache=cache)
apply(attn, mutable='cache')(variables, x[:, 0:1], x)

(DeviceArray([[[ 0.28011173, -0.55635864,  0.6146434 ,  0.5917805 ]]], dtype=float32),
 FrozenDict({'param': FrozenDict({'query': FrozenDict({'kernel': DeviceArray([[[-0.25566727,  0.24481767,  0.70468867,  0.10142951]],
 
              [[-0.24900651, -0.43240824, -0.07879252, -0.35603613]],
 
              [[ 0.04119987, -0.5303452 ,  0.5928286 , -0.5084808 ]],
 
              [[-0.01729889,  0.76639396, -0.4523484 ,  0.8094689 ]]],            dtype=float32), 'bias': DeviceArray([[0., 0., 0., 0.]], dtype=float32)}), 'key': FrozenDict({'kernel': DeviceArray([[[-0.27695838,  0.73972565, -0.00267717,  0.669741  ]],
 
              [[-0.6818866 ,  0.13959199, -0.3723538 , -0.35600063]],
 
              [[ 0.1109049 , -0.22054054,  0.5455148 ,  0.21493351]],
 
              [[-0.19008224,  0.06578335,  0.24833658,  0.00166761]]],            dtype=float32), 'bias': DeviceArray([[0., 0., 0., 0.]], dtype=float32)}), 'value': FrozenDict({'kernel': DeviceArray([[[ 0.7953513 , -0.31489673,  0.05