In [1]:
from dataclasses import dataclass
import dataclasses
from typing import Any, Callable, Iterable, Optional, Tuple, Type
from flax import nn
from flax.nn import initializers
from flax.core.scope import Scope
from flax.core import scope
from jax import numpy as jnp
import functools
import numpy as np
import jax


from flax.core.frozen_dict import freeze

@dataclass
class Module:
  parent: Optional[Type["Module"]]
  
  @classmethod
  def toplevel(cls, *args, rngs=None, variables=None, mutable=False, **kwargs):
    # TODO: Think about the fact that `rngs` and `params` live on args
    # and kwargs
    if rngs is None:
      rngs = {}
    if variables is None:
      variables = {'param': {}}
    module = cls(None, *args, **kwargs)  # first argument is `parent` dataclass attr
    new_variables = scope._unfreeze_variables(variables, mutable)
    module._scope = Scope(new_variables, rngs=rngs)
    return module

  def autonamed(self, prefix=''):
    if not hasattr(self, '_autonamed'):
      self._autonamed = {}
    if not prefix in self._autonamed:
      self._autonamed[prefix] = {}
    return self._autonamed[prefix]

  def submodules(self):
    if not hasattr(self, '_submodules'):
      self._submodules = {}
    return self._submodules

  def _autoname(self):
    self.name = self.parent._dynamic_autoname_prefix + "{}/{}".format(
      self.__class__.__name__, 
      str(len(self.parent.autonamed())))
    self.parent.autonamed(self.parent._dynamic_autoname_prefix)[self.name] = self
  
  def _init_scope(self):
    if self.parent is None:
      raise ValueError(
        'Trying to create a module instance at the top-level? '
        'Use, e.g. `MyModule.toplevel(...)`')

    if not hasattr(self, 'name') or self.name is None:
      if hasattr(self.parent, '_dynamic_autoname_prefix') and self.parent._dynamic_autoname_prefix is not None:
        self._autoname()
      else:
        raise ValueError("To use automatically named submodules, wrap your method in `@autonames`.")

    self.parent.submodules()[self.name] = self    

    # TODO: Make scopes know of sublists, then don't call
    # push by name here.
    self._scope = self.parent._scope.push(self.name)
  
  def scope(self):
    if not hasattr(self, '_scope'):
      self._init_scope()
    return self._scope
    
  def variables(self):
    return self.scope().variables

  def param(self, name, init_fun, shape):
    return self.scope().param(name, init_fun, shape)


def _autonames(prefix=''):
  def _wrap(fun):
    @functools.wraps(fun)
    def wrapped(self, *args, **kwargs):
      if not hasattr(self, '_autonames_fun'):
        self._autonames_fun= {}
      if prefix in self._autonames_fun and self._autonames_fun[prefix] != fun:
        raise Error(
          "To use @autonames on more than one method, "
          "you must give each (other than one) a unique prefix "
          "via the `prefix` argument to @autonames.")
      self._autonames_fun[prefix] = fun

      # "Rewind" the autonaming process
      # NOTE: This might be worth documenting; if you store attributes on submodules during a call to, say, __call__()
      # and then modify them from the outside (if __call__ returned the module instances), then the next time
      # that you run that __call__, you will no longer have those attributes, because you're creating new instances of
      # submodules.
      
      # QUESTION: Should we bring the cursor back to 0 instead of clearing? Then we can re-use the same module instances
      # and it would presumably be faster (during jit time?)
      self.autonamed(prefix).clear()

      self._dynamic_autoname_prefix = prefix
      try:
        return fun(self, *args, **kwargs)
      finally:
        self._dynamic_autoname_prefix = None

    return wrapped
  return _wrap

autonames = _autonames()
autonames.prefix = _autonames

@dataclass
class Dense(Module):
  features: int
  bias: bool = True
  kernel_init: Callable = initializers.lecun_normal()
  bias_init: Callable = initializers.zeros
  name: str = None

  def __call__(self, x):
    kernel = self.param('kernel', self.kernel_init, (x.shape[-1], self.features))
    x = jnp.dot(x, kernel)
    if self.bias:
      x = x + self.param('bias', self.bias_init, (self.features,))
    return x



In [2]:
# Make sure you can call a method twice that has named submodules
@dataclass
class TestCallTwiceWithNamedSubmodule(Module):
  name: Optional[str] = None
    
  def __call__(self, x):
    return Dense(self, 3, name="foo")(x)
  
try_reuse = TestCallTwiceWithNamedSubmodule.toplevel(3, rngs={'param': jax.random.PRNGKey(0)}, mutable=['param'])
try_reuse(np.ones((3, 3)))
try_reuse(np.ones((3, 3)))



DeviceArray([[-1.1667678 ,  0.80254143, -1.0405244 ],
             [-1.1667678 ,  0.80254143, -1.0405244 ],
             [-1.1667678 ,  0.80254143, -1.0405244 ]], dtype=float32)

In [3]:
# TODO: Can we make this a better error message?
# Dense(features=3)

In [4]:
# NOTE: It would be nice to make this throw an error,
# but how? I'd like to avoid requiring people to wrap /all/
# methods in a decorator (or the similar metaclass approach with
# hk.transparent).
#
# QUESTION: Can we resolve this by inspecting stack traces 
# when constucting modules, or when using them? Only during
# "DEBUG" runs
@dataclass
class TryReusingByNameCausesError(Module):
  name: Optional[str]

  def __call__(self, x):
    return Dense(self, 3, name="foo")(x) + Dense(self, 3, name="foo")(x)
  
try_reuse = TryReusingByNameCausesError.toplevel(3, rngs={'param': jax.random.PRNGKey(0)}, mutable=['param'])
try_reuse(np.ones((3, 3)))
try_reuse(np.ones((3, 3)))

DeviceArray([[-2.3335357,  1.6050829, -2.0810487],
             [-2.3335357,  1.6050829, -2.0810487],
             [-2.3335357,  1.6050829, -2.0810487]], dtype=float32)

In [5]:
# init
d = Dense.toplevel(3, rngs={'param': jax.random.PRNGKey(0)}, mutable=['param'])
print(d(np.ones((3, 3))))
print(d.variables())

# Can call method twice on the same instance.
print(d(np.ones((3, 3))))



[[ 0.3415428   0.73458356 -0.15859717]
 [ 0.3415428   0.73458356 -0.15859717]
 [ 0.3415428   0.73458356 -0.15859717]]
{'param': {'kernel': DeviceArray([[ 0.32717842,  0.05599118,  0.17998298],
             [-0.12294921,  0.7071209 ,  0.28972217],
             [ 0.1373136 , -0.02852853, -0.62830234]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}}
[[ 0.3415428   0.73458356 -0.15859717]
 [ 0.3415428   0.73458356 -0.15859717]
 [ 0.3415428   0.73458356 -0.15859717]]


In [6]:
# apply
d2 = Dense.toplevel(3, variables=d.variables())
d2(np.ones((3, 3)))



DeviceArray([[ 0.3415428 ,  0.73458356, -0.15859717],
             [ 0.3415428 ,  0.73458356, -0.15859717],
             [ 0.3415428 ,  0.73458356, -0.15859717]], dtype=float32)

In [7]:
@dataclass
class MLP(Module):
  widths: Tuple
  name: str = None

  @autonames
  def __call__(self, x):
    for width in self.widths[:-1]:
      x = nn.relu(Dense(self, width)(x))
    x = Dense(self, self.widths[-1])(x)
    return x

  def params_by_layer_index(self, layer_index):
    return list(self.autonamed().values())[layer_index].variables()['param']
  
  def params_by_layer_name(layer_name):
    return self.submodules()[layer_name].variables()['param']
    
@dataclass
class Sequential(Module):
  layers: Tuple[Module]
  name: str = None

  def __call__(self, x):
    for layer in layers:
      x = layer(x)
    return x

@dataclass
class MLP2(Module):
  widths: Tuple
  name: str = None

  # QUESTION: If you implement __init__ do you need to call super.__init__ with parent and
  # name_or_list
  
  @autonames
  def __post_init__(self):
    self.layers = [Dense(self, width) for width in self.widths]
    
  def __call__(self, x):
    for layer in self.layers[:-1]:
      x = nn.relu(layer(x))
    x = self.layers[-1](x)
    return x
    

@dataclass
class AutoEncoder(Module):
  encoder_widths: Iterable
  decoder_widths: Iterable
  in_shape: Tuple = None
  name: str = None

  def reconstruct(self, x):
    return self.decode(self.encode(x))
  
  @autonames.prefix('encoder:')
  def encode(self, x):
    self.in_shape = x.shape[1:]
    for width in self.encoder_widths[:-1]:
      x = nn.relu(Dense(self, width)(x))
    z = Dense(self, self.encoder_widths[-1])(x)
    return z

  @autonames.prefix('decoder:')
  def decode(self, z):
    for width in self.decoder_widths[:-1]:
      z = nn.relu(Dense(self, width)(z))
    x = Dense(self, self.decoder_widths[-1])(z)
    x = x.reshape(x.shape[:-1] + self.in_shape)
    return x

@dataclass
class AutoEncoder2(Module):
  encoder_widths: Iterable
  decoder_widths: Iterable
  in_shape: Tuple = None
  name: str = None
    
  @autonames
  def __post_init__(self):
    self.encoder = MLP2(self, 'encode', self.encoder_widths)
    self.decoder = MLP2(self, 'decode', self.decoder_widths)
    
  def encode(self, x):
    return self.encoder(x)
  
  def decode(self, x):
    x = self.decoder(x)
    x = x.reshape(x.shape[:-1] + self.in_shape)
    return x
  
@dataclass
class DenoisingAutoEncoder3(Module):
  encoder: Module
  decoder: Module
  
  def reconstruction_loss(self, x):
    rng = 'foo'
    return self.loss(self.apply_noise(rng, x), self.reconstruct(x))

  def reconstruct(self, x):
    return self.decoder(self.encoder(x))
  
  def apply_noise(self, rng, x):
    return x

  def loss(self, inputs, reconstruction):
    return np.mean(np.abs(inputs - reconstruction))
    
    

In [8]:
mlp = MLP.toplevel([3, 4, 5], rngs={'param': jax.random.PRNGKey(0)}, mutable=['params'])
print(mlp(np.ones((3, 3))))
print(mlp.variables())
 
# QUESTION: Can you point two models to the same parameter object?


[[-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]
 [-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]
 [-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]]
{'param': {'Dense/0': {'kernel': DeviceArray([[ 0.31915382, -0.76709324,  0.07335479],
             [ 0.3674749 ,  0.6602518 , -0.09766117],
             [ 0.84561044,  0.16911158,  0.42713115]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}, 'Dense/1': {'kernel': DeviceArray([[ 0.1547904 ,  0.00701489,  0.05388883,  0.0586841 ],
             [-0.6559397 , -0.5662961 , -0.9021673 ,  0.4098059 ],
             [ 0.07156377,  0.72754294, -0.01657445, -1.0881848 ]],            dtype=float32), 'bias': DeviceArray([0., 0., 0., 0.], dtype=float32)}, 'Dense/2': {'kernel': DeviceArray([[-0.532973  , -0.23218885,  0.38628116,  0.31261072,
              -0.55765074],
             [-0.14478445,  0.83220255, -0.52289605,  0.11223655,
               0.41507402],
             [-0.62332535,  0.1

In [9]:
mlp.variables()

{'param': {'Dense/0': {'kernel': DeviceArray([[ 0.31915382, -0.76709324,  0.07335479],
                [ 0.3674749 ,  0.6602518 , -0.09766117],
                [ 0.84561044,  0.16911158,  0.42713115]], dtype=float32),
   'bias': DeviceArray([0., 0., 0.], dtype=float32)},
  'Dense/1': {'kernel': DeviceArray([[ 0.1547904 ,  0.00701489,  0.05388883,  0.0586841 ],
                [-0.6559397 , -0.5662961 , -0.9021673 ,  0.4098059 ],
                [ 0.07156377,  0.72754294, -0.01657445, -1.0881848 ]],            dtype=float32),
   'bias': DeviceArray([0., 0., 0., 0.], dtype=float32)},
  'Dense/2': {'kernel': DeviceArray([[-0.532973  , -0.23218885,  0.38628116,  0.31261072,
                 -0.55765074],
                [-0.14478445,  0.83220255, -0.52289605,  0.11223655,
                  0.41507402],
                [-0.62332535,  0.15522319, -0.8609153 ,  0.1192041 ,
                 -0.84271395],
                [ 1.0219636 , -0.0413699 ,  0.39705953, -0.3570713 ,
                 -0.1

In [10]:
mlp.params_by_layer_index(1)

{'kernel': DeviceArray([[ 0.1547904 ,  0.00701489,  0.05388883,  0.0586841 ],
              [-0.6559397 , -0.5662961 , -0.9021673 ,  0.4098059 ],
              [ 0.07156377,  0.72754294, -0.01657445, -1.0881848 ]],            dtype=float32),
 'bias': DeviceArray([0., 0., 0., 0.], dtype=float32)}

In [11]:
mlp2 = MLP2.toplevel([3, 4, 5], variables=mlp.variables())
print(mlp2(np.ones((3, 3))))



ValueError: To use automatically named submodules, wrap your method in `@autonames`.

In [None]:
# TODO: Make a clear error if you call AutoEncoder(...) without a parent
ae = AutoEncoder.toplevel(
  encoder_widths=[3, 3], decoder_widths=[3, 3],
  rngs={'param': jax.random.PRNGKey(0)}, mutable=['params']
)
ae.reconstruct(np.ones((3, 3)))

In [None]:

# should this error? We're connecting submodules of another module into here.
# we should either think carefully about what kind of (both good and bad) behavior this
# may lead to. Or if we're not sure we can make it raise an Error.
dae = DAE.toplevel(
  encoder=ae2.encoder, decoder=ae2.decoder,
  prngs={'param': jax.random.PRNGKey(0)}, mutable=['param']
)


In [None]:
from jax import jit

X = np.ones((1, 10))
Y = np.ones((5, ))

@jit
def predict(params):
  # TODO: Think about the fact that you have to put the hyperparameters here  
  mlp = MLP.toplevel([3, 4, 5], variables={'param': params})
  return mlp(X)
  
@jit
def loss_fn(params):
  Yhat = predict(params)
  # TODO: Print in jit
  return jnp.mean(jnp.abs(Y - Yhat))

@jit
def init_params(rng):
  # TODO: Think about the fact that you have to put the hyperparameters here  
  mlp = MLP.toplevel([3, 4, 5], rngs={'param': rng}, mutable=['param'])
  mlp(X)
  return mlp.variables()['param']



In [None]:
loss_fn(init_params(jax.random.PRNGKey(42)))

In [None]:
jax.grad(loss_fn)(init_params(jax.random.PRNGKey(42)))

In [None]:
params = init_params(jax.random.PRNGKey(42))
for i in range(50):
  loss, grad = jax.value_and_grad(loss_fn)(params)
  print(i, "loss = ", loss, "Yhat = ", predict(params))
  lr = 0.03
  params = jax.tree_multimap(lambda x, d: x - lr * d, params, grad)
  