In [1]:
import functools
import jax
from jax import numpy as jnp, random, lax
import numpy as np


In [2]:
from flax import nn, struct

In [3]:
from flax.core import Scope, init, apply, Array, lift, unfreeze

In [4]:
def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,
          kernel_init=nn.linear.default_kernel_init,
          bias_init=nn.initializers.zeros):
  kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))
  y = jnp.dot(inputs, kernel)
  if bias:
    y += scope.param('bias', bias_init, (features,))
  return y

model_fn = functools.partial(dense, features=3)

x = jnp.ones((1, 2))
y, params = init(model_fn)(random.PRNGKey(0), x)
print(params)

def mlp(scope: Scope, inputs: Array, features: int):
  hidden = scope.child(dense, 'hidden')(inputs, features)
  hidden = nn.relu(hidden)
  return dense(scope.push('out'), hidden, 1)

init(mlp)(random.PRNGKey(0), x, features=3)

FrozenDict({'params': {'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],
             [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})


(DeviceArray([[0.17045607]], dtype=float32),
 FrozenDict({'params': {'hidden': {'bias': DeviceArray([0., 0., 0.], dtype=float32), 'kernel': DeviceArray([[-0.22119394,  0.22075175, -0.0925657 ],
              [ 0.40571952,  0.27750877,  1.0542233 ]], dtype=float32)}, 'out': {'kernel': DeviceArray([[ 0.21448377],
              [-0.01530595],
              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)}}}))

In [5]:
@struct.dataclass
class Embedding:
  table: np.ndarray

  def lookup(self, indices):
    return self.table[indices]

  def attend(self, query):
    return jnp.dot(query, self.table.T)

# all the embedding module does is provide a convenient initializers

def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=nn.linear.default_embed_init) -> Embedding:
  table = scope.param('table', init_fn, (num_embeddings, features))
  return Embedding(table)

embedding, _ = init(embedding)(random.PRNGKey(0), num_embeddings=2, features=3)
print(embedding.table)
print(embedding.lookup(1))
print(embedding.attend(jnp.ones((1, 3,))))

[[ 0.11575121 -0.51936364 -1.113899  ]
 [ 0.45569834 -0.5300623  -0.5873911 ]]
[ 0.45569834 -0.5300623  -0.5873911 ]
[[-1.5175114 -0.6617551]]


In [6]:
def lstm(scope, carry, inputs,
         gate_fn=nn.activation.sigmoid, activation_fn=nn.activation.tanh,
         kernel_init=nn.linear.default_kernel_init,
         recurrent_kernel_init=nn.initializers.orthogonal(),
         bias_init=nn.initializers.zeros):
  r"""A long short-term memory (LSTM) cell.

  the mathematical definition of the cell is as follows
  .. math::
      \begin{array}{ll}
      i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\
      f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\
      g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\
      o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\
      c' = f * c + i * g \\
      h' = o * \tanh(c') \\
      \end{array}
  where x is the input, h is the output of the previous time step, and c is
  the memory.

  Args:
    carry: the hidden state of the LSTM cell,
      initialized using `LSTMCell.initialize_carry`.
    inputs: an ndarray with the input for the current time step.
      All dimensions except the final are considered batch dimensions.
    gate_fn: activation function used for gates (default: sigmoid)
    activation_fn: activation function used for output and memory update
      (default: tanh).
    kernel_init: initializer function for the kernels that transform
      the input (default: lecun_normal).
    recurrent_kernel_init: initializer function for the kernels that transform
      the hidden state (default: orthogonal).
    bias_init: initializer for the bias parameters (default: zeros)
  Returns:
    A tuple with the new carry and the output.
  """
  c, h = carry
  hidden_features = h.shape[-1]
  # input and recurrent layers are summed so only one needs a bias.
  dense_h = lambda name: scope.child(dense, name)(
      h, features=hidden_features, bias=True,
      kernel_init=recurrent_kernel_init, bias_init=bias_init)
  dense_i = lambda name: scope.child(dense, name)(
      inputs, features=hidden_features, bias=False,
      kernel_init=kernel_init)
  i = gate_fn(dense_i(name='ii') + dense_h(name='hi'))
  f = gate_fn(dense_i(name='if') + dense_h(name='hf'))
  g = activation_fn(dense_i(name='ig') + dense_h(name='hg'))
  o = gate_fn(dense_i(name='io') + dense_h(name='ho'))
  new_c = f * c + i * g
  new_h = o * activation_fn(new_c)
  return (new_c, new_h), new_h

def lstm_init_carry(batch_dims, size, init_fn=jnp.zeros):
  shape = batch_dims + (size,)
  return init_fn(shape), init_fn(shape)

x = jnp.ones((1, 2))
carry = lstm_init_carry((1,), 3)
y, variables = init(lstm)(random.PRNGKey(0), carry, x)
jax.tree_map(np.shape, (y, variables))

((((1, 3), (1, 3)), (1, 3)),
 FrozenDict({'params': {'hf': {'bias': (3,), 'kernel': (3, 3)}, 'hg': {'bias': (3,), 'kernel': (3, 3)}, 'hi': {'bias': (3,), 'kernel': (3, 3)}, 'ho': {'bias': (3,), 'kernel': (3, 3)}, 'if': {'kernel': (2, 3)}, 'ig': {'kernel': (2, 3)}, 'ii': {'kernel': (2, 3)}, 'io': {'kernel': (2, 3)}}}))

In [9]:
def simple_scan(scope: Scope, xs):
  init_carry = lstm_init_carry(xs.shape[:1], xs.shape[-1])
#   cell = scope.child(lstm, 'cell')
#   ys = []
#   for i in range(xs.shape[1]):
#       x = xs[:, i]
#       init_carry, y = cell(init_carry, x)
#       ys.append(y)
#   return init_carry, ys
  lstm_scan = lift.scan(lstm, in_axes=1, out_axes=1, variable_broadcast='params', split_rngs={'params': False})
  return lstm_scan(scope, init_carry, xs)

key1, key2 = random.split(random.PRNGKey(0), 2)
xs = random.uniform(key1, (1, 5, 2))


y, init_variables = init(simple_scan)(key2, xs)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))


initialized parameter shapes:
 {'params': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}


In [10]:
y = apply(simple_scan)(init_variables, xs)[0]
print('output:\n', y)

output:
 (DeviceArray([[-0.35626447,  0.25178757]], dtype=float32), DeviceArray([[-0.17885922,  0.13063088]], dtype=float32))
