In [348]:
from jax.nn import initializers

import jax
from flax import nn
import jax.numpy as jnp
from jax import jit

from collections.abc import Callable
from typing import Any, List

import numpy as np

import dataclasses
from dataclasses import dataclass
from dataclasses import field

from jax import tree_util

# TODO: Freeze parameters

@jax.tree_util.register_pytree_node_class
@dataclass
class Module:
  # TODO: Add 'kind' to support module lists
  def subparam(self, key, default_fn):
    if not hasattr(self, 'params'):
      self.params = {}
    if key not in self.params:
      self.params[key] = default_fn()
    return self.params[key]

  def param(self, key, shape, init_fn):
    # TODO: Check if prng_key exists. If not, error should say that you
    # can get a prng_key by making this a child of another module
    
    # TODO: split prng_key with name
    return self.subparam(key, lambda: init_fn(self.prng_key, shape))

  def child(self, key, module):
    return self.subparam(key, lambda: module)
  
  def is_initialized():
    return hasattr(self, 'params')

  def module_list(self, key='list'):
    return ModuleList(self.subparam(key, lambda: []))

  def tree_flatten(self):
    meta_dict = dataclasses.asdict(self)
    # Needed because we must return a tuple. Each element
    # in the tuple need only be a JAXable type
    meta = (meta_dict, )

    data_dict = {}
    for key in dir(self):
      val = getattr(self, key)
      # TODO: Don't duplicate params and ModuleLists (store ModuleLists somewhere else)
      # TODO(!!!): yeah this sucks, we really need to override setattr
      if key == 'params' or key == 'counter' or key == 'mlp':
        data_dict[key] = val
      elif isinstance(val, ModuleList):
        data_dict[key] = val.children
    # Needed because we must return a tuple. Each element
    # in the tuple need only be a JAXable type
    data = (data_dict, )

    return data, meta
  
  @classmethod
  def tree_unflatten(cls, meta, data):
    (meta_dict, ) = meta
    (data_dict, ) = data
    instance = cls(**meta_dict)
    for key in data_dict:
      # TODO: Eliminate this if/else and the equivalent one in `tree_flatten` by making
      # ModuleList JAX tree-able
      if key == 'params' or key == 'counter' or key == 'mlp':
        instance.__setattr__(key, data_dict[key])
      else:
        instance.__setattr__(key, ModuleList(data_dict[key]))
    return instance

  def __call__(self, *args, **kwargs):
    raise NotImplementedError()

  #############

@dataclass
class ModuleList():
  children: List[Module] = field(default_factory=list)
  cursor: int = 0

  def child(self, module):
    # TODO: split prng_key with index
    module.prng_key = jax.random.PRNGKey(0)
    if self.cursor >= len(self.children):
      self.children.append(module)
    child = self.children[self.cursor]
    self.cursor += 1
    return child
    

@tree_util.register_pytree_node_class
@dataclass
class Dense(Module):
  features: int
  bias: bool = True
  kernel_init: Callable = initializers.lecun_normal()
  bias_init: Callable = initializers.zeros

  def __call__(self, x):
    kernel = self.param('kernel', (x.shape[-1], self.features), self.kernel_init)
    x = jnp.dot(x, kernel)
    if self.bias:
      # QUESTION: Does += work in JAX?
      x = x + self.param('bias', (self.features,), self.bias_init)
    return x

@tree_util.register_pytree_node_class
@dataclass
class MLP(Module):
  depth: int = 3
  width: int = 32
  features: int = 10

  def __call__(self, x):
    self.layers = self.module_list()

    for i in range(self.depth):
      x = nn.relu(self.layers.child(Dense(self.width))(x))
    x = self.layers.child(Dense(self.width))(x)
    return x

@jit
def init():
  x = np.ones((3, 3))
  mlp = MLP(depth=3, width=32)
  mlp(x)
  return mlp

module = init()
dir(module)
# TODO: Consider making `module.layers` more of an actual list by extending it?
module.layers.children[0].params


{'bias': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0.], dtype=float32),
 'kernel': DeviceArray([[ 0.20814848, -1.0710016 ,  0.6439947 , -0.5429579 ,
               -0.3246118 ,  1.208985  , -0.257471  , -0.32065013,
                0.72582203,  0.08994231,  0.3341397 , -0.22079287,
                0.69877183,  0.8303373 ,  0.3428869 ,  0.94781363,
                0.08524317, -0.45154864,  0.17062634,  0.8318948 ,
               -0.6483707 , -0.24204564, -0.04201688, -0.3538871 ,
               -0.22361928, -0.07344878, -0.20832072,  0.17268421,
               -0.64753693,  0.59740335, -0.07335723, -0.76099616],
              [-0.4295275 ,  1.2338059 ,  0.102417  , -0.13127546,
               -0.37081522,  0.18604888,  0.19740908, -0.3535717 ,
                0.7684898 ,  0.21345675,  0.5770429 , -0.328339  ,
                0.04007532,  0.15846986, -0.9928673 ,  0.

In [291]:

# NOTE: This is again like Flax's current `Model` -- you take gradients
# w.r.t. it.
def loss_fn(mlp):
  x = np.ones((1, 3))
  y_true = np.ones((1, 4))
  y_pred = mlp(x)
  return jnp.mean(jnp.abs(y_pred - y_true))

@jit
def opt_step(mlp):
  # TODO: Mark `mlp` as dead
  loss, grad = jax.value_and_grad(loss_fn)(mlp)
  lr = 1e-1
  return loss, jax.tree_multimap(lambda w, g: w - lr * g, mlp, grad)

mlp = MLP(depth=3, width=4)
for i in range(40):
  loss, mlp = opt_step(mlp)
  print(loss, mlp(np.ones((1, 3))))

1.0903456 [[-0.04370475 -0.14652367 -0.20754965  0.03639566]]
1.0903456 [[ 0.00516562 -0.04317067 -0.07185964  0.04282168]]
1.0167608 [[0.05 0.05 0.05 0.05]]
0.95 [[0.075 0.075 0.075 0.075]]
0.925 [[0.1 0.1 0.1 0.1]]
0.9 [[0.125 0.125 0.125 0.125]]
0.875 [[0.15 0.15 0.15 0.15]]
0.85 [[0.17500001 0.17500001 0.17500001 0.17500001]]
0.825 [[0.20000002 0.20000002 0.20000002 0.20000002]]
0.79999995 [[0.22500002 0.22500002 0.22500002 0.22500002]]
0.775 [[0.25000003 0.25000003 0.25000003 0.25000003]]
0.75 [[0.27500004 0.27500004 0.27500004 0.27500004]]
0.72499996 [[0.30000004 0.30000004 0.30000004 0.30000004]]
0.6999999 [[0.32500005 0.32500005 0.32500005 0.32500005]]
0.67499995 [[0.35000005 0.35000005 0.35000005 0.35000005]]
0.65 [[0.37500006 0.37500006 0.37500006 0.37500006]]
0.62499994 [[0.40000007 0.40000007 0.40000007 0.40000007]]
0.5999999 [[0.42500007 0.42500007 0.42500007 0.42500007]]
0.5749999 [[0.45000008 0.45000008 0.45000008 0.45000008]]
0.54999995 [[0.47500008 0.47500008 0.4750000

In [292]:
@jax.tree_util.register_pytree_node_class
@dataclass
class Counter:
  # QUESTION: register_buffer_variable? field(kind='param'|'counter'|'log', ...)
  # Need special wrapper that's not Module. More like Jonathan's dataclass. Then
  # Have Module use that decorator as well.
  value: int = 0

  def __call__(self):
    self.value = self.value + 1

  def tree_flatten(self):
    return (self.value, ), None
  
  @classmethod
  def tree_unflatten(cls, meta, data):
    value, = data
    return cls(value)

counter = Counter()
@jit
def inc3(c):
  c()
  c()
  c()
  return c

counter = inc3(counter)
print(counter)
counter = inc3(counter)
print(counter)



Counter(value=DeviceArray(3, dtype=int32))
Counter(value=DeviceArray(6, dtype=int32))


In [306]:
@tree_util.register_pytree_node_class
class WithCounter(Module):
  def __call__(self):
    # QUESTION: Use __setattr__ so that this becomes
    #     self.counter = Counter()?
    self.counter = self.child("counter", Counter())
    pass

# TODO: @cloned -- simulate jit but still being able to debug
  
@jit
def inc3_with(with_counter):
  with_counter.counter()
  with_counter.counter()
  with_counter.counter()
  return with_counter

with_counter = WithCounter()
with_counter()
# TODO: Somehow make this code fail -- it doesn't behave
# the same under a jit
# print(increment_twice(with_counter).counter)
# print(increment_twice(with_counter).counter)

with_counter = inc3_with(with_counter)
print(with_counter.counter)
with_counter = inc3_with(with_counter)
print(with_counter.counter)

# TODO: If we use __setattr__ then can we not place those things on `self.params`? Then
# this won't be possible to try.
# with_counter = inc3_with(with_counter)
# print(with_counter.params['counter'])
# with_counter = inc3_with(with_counter)
# print(with_counter.params['counter'])



Counter(value=DeviceArray(3, dtype=int32))
Counter(value=DeviceArray(6, dtype=int32))


In [349]:
@tree_util.register_pytree_node_class
class MLPAndCounter(Module):
  def __init__(self):
    self.mlp = self.child('mlp', MLP(depth=3, width=4))
    self.counter = self.child('counter', Counter())

def clone(x):
  return jax.tree_map(lambda v: v, x)
    
# TODO: Is there a general pattern for extracting just the 
# trainable parameters? Use `kind`
def opt_step(mlp_and_counter):
  mlp_and_counter = clone(mlp_and_counter)
  mlp_and_counter.counter()
  loss, grad = jax.value_and_grad(loss_fn)(mlp_and_counter.mlp)
  lr = 1e-1
  # TODO: This really implies we should probably override __setattr__
  # TODO(!!!): Why did I need to do this? Mental model breakdown!
  old_mlp = mlp_and_counter.mlp
  del mlp_and_counter.params['mlp']
  del mlp_and_counter.mlp
  mlp_and_counter.mlp = mlp_and_counter.child('mlp', jax.tree_multimap(lambda w, g: w - lr * g, old_mlp, grad))
  return loss, mlp_and_counter

@jit
def init(): 
  mlp_and_counter = MLPAndCounter()
  mlp_and_counter.mlp(np.ones((1, 3)))
  print(mlp_and_counter.mlp)
  return mlp_and_counter

mlp_and_counter = init()

for i in range(40):
  loss, mlp_and_counter = opt_step(mlp_and_counter)
  print(loss, mlp_and_counter.counter, mlp_and_counter.mlp(np.ones((1, 3))))

MLP(depth=3, width=4, features=10)
1.0903456 Counter(value=1) [[ 0.00516562 -0.04317067 -0.07185964  0.04282168]]
1.0167607 Counter(value=2) [[0.05 0.05 0.05 0.05]]
0.95 Counter(value=3) [[0.075 0.075 0.075 0.075]]
0.925 Counter(value=4) [[0.1 0.1 0.1 0.1]]
0.9 Counter(value=5) [[0.125 0.125 0.125 0.125]]
0.875 Counter(value=6) [[0.15 0.15 0.15 0.15]]
0.85 Counter(value=7) [[0.17500001 0.17500001 0.17500001 0.17500001]]
0.825 Counter(value=8) [[0.20000002 0.20000002 0.20000002 0.20000002]]
0.79999995 Counter(value=9) [[0.22500002 0.22500002 0.22500002 0.22500002]]
0.775 Counter(value=10) [[0.25000003 0.25000003 0.25000003 0.25000003]]
0.75 Counter(value=11) [[0.27500004 0.27500004 0.27500004 0.27500004]]
0.72499996 Counter(value=12) [[0.30000004 0.30000004 0.30000004 0.30000004]]
0.6999999 Counter(value=13) [[0.32500005 0.32500005 0.32500005 0.32500005]]
0.67499995 Counter(value=14) [[0.35000005 0.35000005 0.35000005 0.35000005]]
0.65 Counter(value=15) [[0.37500006 0.37500006 0.3750000

In [357]:
@tree_util.register_pytree_node_class
class MLPAndCounter2(Module):
  def __call__(self, x):
    self.counter = self.child('counter', Counter())
    self.counter()
    self.mlp = self.child('mlp', MLP(depth=3, width=4))
    return self.mlp(x)
  
def init2(): 
  mlp_and_counter = MLPAndCounter2()
  mlp_and_counter(np.ones((1, 3)))
  print(mlp_and_counter.mlp)
  return mlp_and_counter

mlp_and_counter2 = init2()

for i in range(40):
  loss, mlp_and_counter2 = opt_step(mlp_and_counter2)
  print(loss, mlp_and_counter2.counter.value, mlp_and_counter2.mlp(np.ones((1, 3))))

MLP(depth=3, width=4, features=10)
1.0903456 2 [[ 0.00516562 -0.04317067 -0.07185964  0.04282168]]
1.0167607 3 [[0.05 0.05 0.05 0.05]]
0.95 4 [[0.075 0.075 0.075 0.075]]
0.925 5 [[0.1 0.1 0.1 0.1]]
0.9 6 [[0.125 0.125 0.125 0.125]]
0.875 7 [[0.15 0.15 0.15 0.15]]
0.85 8 [[0.17500001 0.17500001 0.17500001 0.17500001]]
0.825 9 [[0.20000002 0.20000002 0.20000002 0.20000002]]
0.79999995 10 [[0.22500002 0.22500002 0.22500002 0.22500002]]
0.775 11 [[0.25000003 0.25000003 0.25000003 0.25000003]]
0.75 12 [[0.27500004 0.27500004 0.27500004 0.27500004]]
0.72499996 13 [[0.30000004 0.30000004 0.30000004 0.30000004]]
0.6999999 14 [[0.32500005 0.32500005 0.32500005 0.32500005]]
0.67499995 15 [[0.35000005 0.35000005 0.35000005 0.35000005]]
0.65 16 [[0.37500006 0.37500006 0.37500006 0.37500006]]
0.62499994 17 [[0.40000007 0.40000007 0.40000007 0.40000007]]
0.5999999 18 [[0.42500007 0.42500007 0.42500007 0.42500007]]
0.5749999 19 [[0.45000008 0.45000008 0.45000008 0.45000008]]
0.54999995 20 [[0.4750000

Counter(value=2)
Counter(value=4)


In [None]:
@dataclass
class AutoEncoder(Module):
  width: int = 32
  depth: int = 3

  def encode(x):
    self.input_shape = x.shape
    self.encoder_layers = self.module_list('encoder')

    z = x
    for i in range(self.depth):
      z = self.encoder_layers.child(Dense(self.width))(x)
      z = nn.relu(z)
    # final layer without relu
    z = self.decoder_layers.child(Dense(self.width))(x)
    return z
      
  def decode(z):
    assert hasattr(self, 'input_shape'), "Need to call `encode` to know the input shape"
    self.decoder_layers = self.module_list('decoder')

    x = z
    for i in range(self.depth):
      x = self.decoder_layers.child(Dense(self.width))(x)
      x = nn.relu(x)

    x = self.decoder_layers.child(Dense(np.prod(self.input_shape)))(x)
    x = x.reshape(self.input_shape)
    return x

  def __call__(x)
    return self.decode(self.encode(x))

In [None]:
class Logger(Module)