In [285]:
from dataclasses import dataclass
import dataclasses
from typing import Any, Callable, Iterable, Optional, Tuple, Type
from flax import nn
from flax.nn import initializers
from flax.core.scope import Scope
from flax.core import scope
from jax import numpy as jnp
import functools

from flax.core.frozen_dict import freeze

@dataclass
# TODO: Document that any class that extends from Module must add
#   name = Optional[None]
class Module:
  parent: Optional[Type["Module"]]
  
  # TODO: Use Dataclass "hidden" attributes that don't appear on __init__.
  # Then remove all use of hasattr
  
  @classmethod
  def toplevel(cls, *args, rngs=None, variables=None, mutable=False, **kwargs):
    # TODO: Think about the fact that `rngs` and `params` live on args
    # and kwargs
    if rngs is None:
      rngs = {}
    if variables is None:
      variables = {'param': {}}
    module = cls(None, *args, **kwargs)  # first argument is `parent` dataclass attr
    module._scope = Scope(variables, rngs=rngs)    

    # QUESTION: Make sure to unfreeze after calling _recurse so that you don't need
    # to set params as mutable during construction time...?
    new_variables = scope._unfreeze_variables(variables, mutable)
    module._scope.variables = new_variables

    return module

  def on_attached(self):
    pass
  
  def autonamed(self):
    if not hasattr(self, '_autonamed'):
      self._autonamed = {}
    return self._autonamed

  def submodules(self):
    if not hasattr(self, '_submodules'):
      self._submodules = {}
    return self._submodules

  def _ensure_has_name(self):
    if not hasattr(self, 'name') or self.name is None:
      if not hasattr(self.parent, '_in_autonames') or not self.parent._in_autonames:
        raise ValueError("In order to get autonames, must decorate method with @autonames")

      self.name = "{}/{}".format(
        self.__class__.__name__, 
        str(len(self.parent.autonamed())))
      self.parent.autonamed()[self.name] = self
  
  def __setattr__(self, name, value):
    # GOTCHA: This is very brittle.
    super().__setattr__(name, value)

    if name != 'parent':
      def _recurse(x):
        if isinstance(x, Module):
          x._ensure_has_name()
        elif isinstance(x, list):
          # TODO: Make this work on iterables?
          for submodule in x:
            _recurse(submodule)
        # TODO: Also support dicts?

      _recurse(value)

  
  def _init_scope(self):
    if self.parent is None:
      # NOTE: This error also happens if you try to initialize parameters
      # during __post_init__. Try to catch this.
      raise ValueError(
        'Trying to create a module instance at the top-level? '
        'Use, e.g. `MyModule.toplevel(...)`')

    self._ensure_has_name()

    self.parent.submodules()[self.name] = self    

    if not hasattr(self.parent, '_scope'):
      self.parent._init_scope()

    # TODO: Make scopes know of sublists, then don't call
    # push by name here.
    self._scope = self.parent._scope.push(self.name)
    
  def scope(self):
    if not hasattr(self, '_scope'):
      self._init_scope()
    return self._scope
    
  def variables(self):
    return self.scope().variables

  def param(self, name, init_fun, shape):
    return self.scope().param(name, init_fun, shape)


def autonames(fun, prefix=''):
  @functools.wraps(fun)
  def wrapped(self, *args, **kwargs):
    if hasattr(self, '_autonames_fun') and self._autonames_fun != fun:
      raise Error(
        "Can't only use @autonames on one method. "
        "Use @method_autonames for additional autonaming scopes.")      
    self._autonames_fun = fun

    # "Rewind" the autonaming process
    self.autonamed().clear()

    if hasattr(self, '_in_autonames') and self._in_autonames:
      raise Error("Can't nest `autonames`-decorated function calls")
    
    self._in_autonames = True
    try:
      return fun(self, *args, **kwargs)
    finally:
      self._in_autonames = False

  return wrapped

@dataclass
class Dense(Module):
  features: int
  bias: bool = True
  kernel_init: Callable = initializers.lecun_normal()
  bias_init: Callable = initializers.zeros
  name: str = None

  def __call__(self, x):
    kernel = self.param('kernel', self.kernel_init, (x.shape[-1], self.features))
    x = jnp.dot(x, kernel)
    if self.bias:
      x = x + self.param('bias', self.bias_init, (self.features,))
    return x



In [286]:
# TODO: Can we make this a better error message?
# Dense(3)

In [287]:
# NOTE: It would be nice to make this throw an error,
# but how? I'd like to avoid requiring people to wrap /all/
# methods in a decorator (or the similar metaclass approach with
# hk.transparent).
#
# QUESTION: Can we resolve this by inspecting stack traces 
# when constucting modules, or when using them? Only during
# "DEBUG" runs
@dataclass
class TryReusingByNameCausesError(Module):
  name: Optional[str]

  def __call__(self, x):
    return Dense(self, 3, name="foo")(x) + Dense(self, 3, name="foo")(x)
  
try_reuse = TryReusingByNameCausesError.toplevel(3, rngs={'param': jax.random.PRNGKey(0)}, mutable=['param'])
try_reuse(np.ones((3, 3)))
try_reuse(np.ones((3, 3)))

DeviceArray([[-2.3335357,  1.6050829, -2.0810487],
             [-2.3335357,  1.6050829, -2.0810487],
             [-2.3335357,  1.6050829, -2.0810487]], dtype=float32)

In [288]:
import numpy as np
import jax

# init
d = Dense.toplevel(3, rngs={'param': jax.random.PRNGKey(0)}, mutable=['param'])
print(d(np.ones((3, 3))))
print(d.variables())

# Can call method twice on the same instance.
print(d(np.ones((3, 3))))



[[ 0.3415428   0.73458356 -0.15859717]
 [ 0.3415428   0.73458356 -0.15859717]
 [ 0.3415428   0.73458356 -0.15859717]]
{'param': {'kernel': DeviceArray([[ 0.32717842,  0.05599118,  0.17998298],
             [-0.12294921,  0.7071209 ,  0.28972217],
             [ 0.1373136 , -0.02852853, -0.62830234]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}}
[[ 0.3415428   0.73458356 -0.15859717]
 [ 0.3415428   0.73458356 -0.15859717]
 [ 0.3415428   0.73458356 -0.15859717]]


In [281]:
# apply
d2 = Dense.toplevel(3, variables=d.variables())
d2(np.ones((3, 3)))



setattr features 3
setattr bias True
setattr kernel_init <function variance_scaling.<locals>.init at 0x132853cb0>
setattr bias_init <function zeros at 0x1118000e0>
setattr name None
setattr _scope <flax.core.scope.Scope object at 0x1342fb190>


DeviceArray([[ 0.3415428 ,  0.73458356, -0.15859717],
             [ 0.3415428 ,  0.73458356, -0.15859717],
             [ 0.3415428 ,  0.73458356, -0.15859717]], dtype=float32)

In [290]:
@dataclass
class MLP(Module):
  widths: Tuple
  name: str = None

  @autonames
  def __call__(self, x):
    for width in self.widths[:-1]:
      x = nn.relu(Dense(self, width)(x))
    x = Dense(self, self.widths[-1])(x)
    return x

  def params_by_layer_index(self, layer_index):
    return list(self.autonamed().values())[layer_index].variables()['param']
  
  def params_by_layer_name(layer_name):
    return self.submodules()[layer_name].variables()['param']
    
@dataclass
class Sequential(Module):
  layers: Tuple[Module]
  name: str = None

  def __call__(self, x):
    for layer in layers:
      x = layer(x)
    return x

@dataclass
class MLP2(Module):
  widths: Tuple
  name: str = None

  # QUESTION: If you implement __init__ do you need to call super.__init__ with parent and
  # name_or_list
  
  @autonames
  def __post_init__(self):
    self.layers = [Dense(self, width) for width in self.widths]
    
  def __call__(self, x):
    for layer in self.layers[:-1]:
      x = nn.relu(layer(x))
    x = self.layers[-1](x)
    return x
    

@dataclass
class AutoEncoder(Module):
  encoder_widths: Iterable
  decoder_widths: Iterable
  in_shape: Tuple = None
  name: str = None

  # QUESTION: Should only one method be allowed to be wrapped
  # in `autonames`?

  def __post_init__(self):
    self.encoder = self.Encoder(self, 'encoder')
    self.decoder = self.Decoder(self, 'decoder')
  
  def reconstruct(self, x):
    return self.decoder(self.encoder(x))
  
  @dataclass
  class Encoder(Module):
    name: str
    
    @autonames
    def __call__(self, x):
      self.in_shape = x.shape[1:]
      # QUESTION: Is this a legitimate use of `self.parent`?
      for width in self.parent.encoder_widths[:-1]:
        x = nn.relu(Dense(self, width)(x))
      z = Dense(self, self.parent.encoder_widths[-1])(x)
      return z
  
  @dataclass
  class Decoder(Module):
    name: str
    
    @autonames
    def __call__(self, z):
      for width in self.parent.decoder_widths[:-1]:
        z = nn.relu(Dense(self, width)(z))
      x = Dense(self, self.parent.encoder_widths[-1])(z)
      # QUESITON: Is this weird? Navigating up then into encoder?
      x = x.reshape(x.shape[:-1] + self.parent.encoder.in_shape)
      return x

@dataclass
class AutoEncoder1_5(Module):
  encoder_widths: Iterable
  decoder_widths: Iterable
  in_shape: Tuple = None
  name: str = None

  # QUESTION: Should only one method be allowed to be wrapped
  # in `autonames`?

  def encode(self, x):
    _ae = self
    @dataclass
    class Encoder(Module):
      name: str = 'encoder'

      @autonames
      def __call__(self, x):
        # QUESTION: Is this a legitimate use of `self.parent`?
        for width in _ae.encoder_widths[:-1]:
          x = nn.relu(Dense(self, width)(x))
        z = Dense(self, _ae.encoder_widths[-1])(x)
        return z
      
    self.in_shape = x.shape[1:]
    return Encoder(self)(x)
    
  def decode(self, x):
    _ae = self
    @dataclass
    class Decoder(Module):
      name: str = 'decode'

      @autonames
      def __call__(self, z):
        for width in _ae.decoder_widths[:-1]:
          z = nn.relu(Dense(self, width)(z))
        x = Dense(self, _ae.encoder_widths[-1])(z)
        # QUESITON: Is this weird? Navigating up then into encoder?
        x = x.reshape(x.shape[:-1] + _ae.in_shape)
        return x

    return Decoder(self)(x)

  def reconstruct(self, x):
    return self.decode(self.encode(x))
  
  
  @dataclass
  class Decoder(Module):
    name: str
    
    @autonames
    def __call__(self, z):
      for width in self.parent.decoder_widths[:-1]:
        z = nn.relu(Dense(self, width)(z))
      x = Dense(self, self.parent.encoder_widths[-1])(z)
      # QUESITON: Is this weird? Navigating up then into encoder?
      x = x.reshape(x.shape[:-1] + self.parent.encoder.in_shape)
      return x

    
@dataclass
class AutoEncoder2(Module):
  encoder_widths: Iterable
  decoder_widths: Iterable
  in_shape: Tuple = None
  name: str = None
    
  def __post_init__(self):
    self.encoder = MLP2(self, self.encoder_widths, name='encode')
    self.decoder = MLP2(self, self.decoder_widths, name='decode')
    
  def encode(self, x):
    self.in_shape = x.shape[1:]
    return self.encoder(x)
  
  def decode(self, x):
    x = self.decoder(x)
    x = x.reshape(x.shape[:-1] + self.in_shape)
    return x
  
  def reconstruct(self, x):
    return self.decode(self.encode(x))
  
@dataclass
class DAE(Module):
  encoder: Module
  decoder: Module
  
  def reconstruction_loss(self, x):
    rng = 'foo'
    return self.loss(self.apply_noise(rng, x), self.reconstruct(x))

  def reconstruct(self, x):
    return self.decoder(self.encoder(x))
  
  def apply_noise(self, rng, x):
    return x

  def loss(self, inputs, reconstruction):
    return np.mean(np.abs(inputs - reconstruction))
    
    

In [291]:
mlp = MLP.toplevel([3, 4, 5], rngs={'param': jax.random.PRNGKey(0)}, mutable=['params'])
print(mlp(np.ones((3, 3))))
print(mlp.variables())
 
# QUESTION: Can you point two models to the same parameter object?


[[-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]
 [-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]
 [-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]]
{'param': {'Dense/0': {'kernel': DeviceArray([[ 0.31915382, -0.76709324,  0.07335479],
             [ 0.3674749 ,  0.6602518 , -0.09766117],
             [ 0.84561044,  0.16911158,  0.42713115]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}, 'Dense/1': {'kernel': DeviceArray([[ 0.1547904 ,  0.00701489,  0.05388883,  0.0586841 ],
             [-0.6559397 , -0.5662961 , -0.9021673 ,  0.4098059 ],
             [ 0.07156377,  0.72754294, -0.01657445, -1.0881848 ]],            dtype=float32), 'bias': DeviceArray([0., 0., 0., 0.], dtype=float32)}, 'Dense/2': {'kernel': DeviceArray([[-0.532973  , -0.23218885,  0.38628116,  0.31261072,
              -0.55765074],
             [-0.14478445,  0.83220255, -0.52289605,  0.11223655,
               0.41507402],
             [-0.62332535,  0.1

In [284]:
mlp2 = MLP2.toplevel([3, 4, 5], rngs={'param': jax.random.PRNGKey(0)}, mutable=['params'])
print(mlp2(np.ones((3, 3))))
print(mlp2.variables())
 
# QUESTION: Can you point two models to the same parameter object?


setattr widths [3, 4, 5]
setattr name None
setattr _autonames_fun <function MLP2.__post_init__ at 0x132885680>
setattr _autonamed {}
setattr _in_autonames True
setattr features 3
setattr bias True
setattr kernel_init <function variance_scaling.<locals>.init at 0x132853cb0>
setattr bias_init <function zeros at 0x1118000e0>
setattr name None
setattr features 4
setattr bias True
setattr kernel_init <function variance_scaling.<locals>.init at 0x132853cb0>
setattr bias_init <function zeros at 0x1118000e0>
setattr name None
setattr features 5
setattr bias True
setattr kernel_init <function variance_scaling.<locals>.init at 0x132853cb0>
setattr bias_init <function zeros at 0x1118000e0>
setattr name None
setattr layers [Dense(parent=MLP2(parent=None, widths=[3, 4, 5], name=None), features=3, bias=True, kernel_init=<function variance_scaling.<locals>.init at 0x132853cb0>, bias_init=<function zeros at 0x1118000e0>, name=None), Dense(parent=MLP2(parent=None, widths=[3, 4, 5], name=None), features

In [276]:
mlp.variables()

{'param': {}}

In [93]:
mlp.params_by_layer_index(1)

{'kernel': DeviceArray([[ 0.1547904 ,  0.00701489,  0.05388883,  0.0586841 ],
              [-0.6559397 , -0.5662961 , -0.9021673 ,  0.4098059 ],
              [ 0.07156377,  0.72754294, -0.01657445, -1.0881848 ]],            dtype=float32),
 'bias': DeviceArray([0., 0., 0., 0.], dtype=float32)}

In [94]:
mlp2 = MLP2.toplevel([3, 4, 5], variables=mlp.variables())
print(mlp2(np.ones((3, 3))))



[[-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]
 [-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]
 [-0.17117555  0.17427535 -0.070427    0.10287903 -0.03070323]]


In [95]:
# TODO: Make a clear error if you call AutoEncoder(...) without a parent
ae = AutoEncoder.toplevel(
  encoder_widths=[3, 3], decoder_widths=[3, 3],
  rngs={'param': jax.random.PRNGKey(1)}, mutable=['params']
)
ae.reconstruct(np.ones((4, 3)))

DeviceArray([[-0.01497263, -0.805205  , -0.94335634],
             [-0.01497263, -0.805205  , -0.94335634],
             [-0.01497263, -0.805205  , -0.94335634],
             [-0.01497263, -0.805205  , -0.94335634]], dtype=float32)

In [96]:
# TODO: Make a clear error if you call AutoEncoder(...) without a parent
ae2 = AutoEncoder2.toplevel(
  encoder_widths=[3, 3], decoder_widths=[3, 3],
  rngs={'param': jax.random.PRNGKey(1)}, mutable=['params']
)
ae2.reconstruct(np.ones((4, 3)))

DeviceArray([[ 0.0273236 , -0.30911958, -0.9112693 ],
             [ 0.0273236 , -0.30911958, -0.9112693 ],
             [ 0.0273236 , -0.30911958, -0.9112693 ],
             [ 0.0273236 , -0.30911958, -0.9112693 ]], dtype=float32)

In [103]:
# TODO: Make a clear error if you call AutoEncoder(...) without a parent
ae1_5 = AutoEncoder1_5.toplevel(
  encoder_widths=[3, 3], decoder_widths=[3, 3],
  rngs={'param': jax.random.PRNGKey(1)}, mutable=['params']
)
ae1_5.reconstruct(np.ones((4, 3)))

DeviceArray([[-0.0291461 , -1.0122029 , -0.33748102],
             [-0.0291461 , -1.0122029 , -0.33748102],
             [-0.0291461 , -1.0122029 , -0.33748102],
             [-0.0291461 , -1.0122029 , -0.33748102]], dtype=float32)

In [104]:

# QUESTION: Does this work?!
# should this error? We're connecting submodules of another module into here.
# we should either think carefully about what kind of (both good and bad) behavior this
# may lead to. Or if we're not sure we can make it raise an Error.
dae = DAE.toplevel(
  encoder=ae.encoder, decoder=ae.decoder,
  rngs={'param': jax.random.PRNGKey(0)}, mutable=['param']
)
dae.reconstruct(np.ones((4, 3)))


DeviceArray([[-0.01497263, -0.805205  , -0.94335634],
             [-0.01497263, -0.805205  , -0.94335634],
             [-0.01497263, -0.805205  , -0.94335634],
             [-0.01497263, -0.805205  , -0.94335634]], dtype=float32)

In [105]:
from jax import jit

X = np.ones((1, 10))
Y = np.ones((5, ))

@jit
def predict(params):
  # TODO: Think about the fact that you have to put the hyperparameters here  
  mlp = MLP.toplevel([3, 4, 5], variables={'param': params})
  return mlp(X)
  
@jit
def loss_fn(params):
  Yhat = predict(params)
  # TODO: Print in jit
  return jnp.mean(jnp.abs(Y - Yhat))

@jit
def init_params(rng):
  # TODO: Think about the fact that you have to put the hyperparameters here  
  mlp = MLP.toplevel([3, 4, 5], rngs={'param': rng}, mutable=['param'])
  mlp(X)
  return mlp.variables()['param']



In [106]:
loss_fn(init_params(jax.random.PRNGKey(42)))

DeviceArray(1.4570823, dtype=float32)

In [107]:
jax.grad(loss_fn)(init_params(jax.random.PRNGKey(42)))

{'Dense/0': {'bias': DeviceArray([0.25291714, 0.        , 0.        ], dtype=float32),
  'kernel': DeviceArray([[0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ],
               [0.25291714, 0.        , 0.        ]], dtype=float32)},
 'Dense/1': {'bias': DeviceArray([ 0.04557207, -0.20579326,  0.5184991 ,  0.        ], dtype=float32),
  'kernel': DeviceArray([[ 0.08235972, -0.37191805,  0.9370529 ,  0.        ],
               [ 0.        , -0.        ,  0.        ,  0.        ],
               [ 0.        , -0.        ,  0.        ,  0.        ]],            dtype=float32)},
 'Dens

In [108]:
params = init_params(jax.random.PRNGKey(42))
for i in range(50):
  loss, grad = jax.value_and_grad(loss_fn)(params)
  print(i, "loss = ", loss, "Yhat = ", predict(params))
  lr = 0.03
  params = jax.tree_multimap(lambda x, d: x - lr * d, params, grad)
  

0 loss =  1.4570823 Yhat =  [[-0.65048295 -0.89209783  0.22747914 -0.8393059  -0.13100383]]
1 loss =  1.3724773 Yhat =  [[-0.58094865 -0.7692777   0.25747335 -0.68587863 -0.08375458]]
2 loss =  1.2976424 Yhat =  [[-0.5199375  -0.6605372   0.28393385 -0.5494852  -0.04218574]]
3 loss =  1.2308433 Yhat =  [[-0.46609223 -0.5635818   0.30782127 -0.42721057 -0.00515322]]
4 loss =  1.1706475 Yhat =  [[-0.41830716 -0.47650898  0.32992083 -0.31662038  0.02827771]]
5 loss =  1.1158575 Yhat =  [[-0.37567115 -0.39772335  0.3508847  -0.21565796  0.058881  ]]
6 loss =  1.0654583 Yhat =  [[-0.3374258  -0.32587245  0.37126485 -0.1225654   0.08730697]]
7 loss =  1.01858 Yhat =  [[-0.3029324  -0.2597958   0.39153802 -0.03582046  0.11411095]]
8 loss =  0.9744629 Yhat =  [[-0.27164698 -0.19848418  0.41212633  0.04591412  0.13977619]]
9 loss =  0.93243355 Yhat =  [[-0.24309914 -0.14104642  0.4334131   0.12383318  0.1647319 ]]
10 loss =  0.89188176 Yhat =  [[-0.21687557 -0.08668173  0.4557565   0.19902366  

In [216]:
@dataclass
class DenseExplicit(Module):
  in_features: int
  out_features: int
  with_bias: bool = True
  kernel_init: Callable = initializers.lecun_normal()
  bias_init: Callable = initializers.zeros
  name: str = None

  def on_attached(self):
    self.kernel = self.param('kernel', self.kernel_init, (self.in_features, self.out_features))

    if self.with_bias:
      self.bias = self.param('bias', self.bias_init, (self.out_features,))
  
  def __call__(self, x):
    x = jnp.dot(x, self.kernel)
    if self.with_bias:
      x = x + self.bias
    return x


In [217]:
dense_expl = DenseExplicit.toplevel(
  in_features=3, out_features=3,
  # TODO: This should have required 
  rngs={'param': jax.random.PRNGKey(1)}
)
print(dense_expl(np.ones((3, ))))
print(dense_expl(np.ones((3, ))))


[-0.6191777  0.8351118 -0.5551028]
[-0.6191777  0.8351118 -0.5551028]


In [221]:
class MLPExplicit(Module):
  def __init__(self, parent, features):
    self.parent = parent
    self.layers = [
      DenseExplicit(self, features[i], features[i+1])
      for i in range(len(features)-1)
    ]

  def on_attached(self):
    # NOTE: We can fix the need for this by changing __init__ to:
    # self.layers = self.module_list([...]) which would register the
    # module lists and then loop over those
    for l in self.layers:
      l.on_attached()

  def __call__(self, x):
    for l in self.layers[:-1]:
      x = nn.relu(l(x))
    return self.layers[-1](x)

In [222]:
mlp_expl = MLPExplicit.toplevel(
  features=[3, 4, 5, 6],
  rngs={'param': jax.random.PRNGKey(1)}
)
print(mlp_expl(np.ones((3, 3))))
print(mlp_expl(np.ones((3, 3))))

[[-0.3998984   0.47382313 -0.79696596  0.68610305 -0.5683795   0.09743354]
 [-0.3998984   0.47382313 -0.79696596  0.68610305 -0.5683795   0.09743354]
 [-0.3998984   0.47382313 -0.79696596  0.68610305 -0.5683795   0.09743354]]
[[-0.3998984   0.47382313 -0.79696596  0.68610305 -0.5683795   0.09743354]
 [-0.3998984   0.47382313 -0.79696596  0.68610305 -0.5683795   0.09743354]
 [-0.3998984   0.47382313 -0.79696596  0.68610305 -0.5683795   0.09743354]]


In [225]:
mlp_expl.layers[1].bias


DeviceArray([0., 0., 0., 0., 0.], dtype=float32)

In [None]:
def std_weight(module):
  @dataclass
  class StdWeight(Module):
    initialized: bool = False
    
    def __call__(self, x):
      if not self.params():
        # initialize parameters
        module(x)
      
      param = module.variables.param
      # TODO: Test that I would get an error if I directly modified `param`
      param = param.copy(kernel=std(param['kernel']))

      def with_vars(variables):
        # QUESTION: Can `with_vars` be implemented without assuming
        # that modules are dataclasses?
        module.__class__.toplevel(
          **dataclasses.asdict(module), variables=variables)
      return with_vars({'param': param})(x)
  return StdWeight