In [1]:
import jax
from jax import lax, random, numpy as jnp
import numpy as np
import flax
from flax import nn
import dataclasses

from typing import Any, Optional

In [5]:
class Dense(nn.base.Module):
  def __init__(self,           
               features,
               bias=True,
               dtype=jnp.float32,
               precision=None,
               kernel_init=nn.linear.default_kernel_init,
               bias_init=nn.initializers.zeros,
               name=None):
    self.features = features
    self.bias = bias
    self.dtype = dtype
    self.precision = precision
    self.kernel_init = kernel_init
    self.bias_init = bias_init
    self.name = name
  def apply(self, inputs):
    inputs = jnp.asarray(inputs, self.dtype)
    kernel = self.param('kernel', (inputs.shape[-1], self.features), self.kernel_init)
    kernel = jnp.asarray(kernel, self.dtype)
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),
                        precision=self.precision)
    if self.bias:
      bias = self.param('bias', (self.features,), self.bias_init)
      bias = jnp.asarray(bias, self.dtype)
      y = y + bias
    return y

class Test1(nn.Module):
  def __init__(self, features):
    self.features = features
  def apply(self, x):
    y = Dense(self.features, name='foo')(x)
    return y

#frame = nn.base._ModuleFrame(None, rng=random.PRNGKey(0), parent=None)
#with nn.base._module_stack.frame(frame):
out, ps = Test1(3).init(random.PRNGKey(0), jnp.zeros((2,10)))
  #print(nn.base._module_stack._stack.frames[0].params)
ps

{'foo': {'kernel': DeviceArray([[-0.10392965, -0.09891689,  0.41300794],
               [-0.02804299, -0.2828618 , -0.09221947],
               [ 0.529738  , -0.28473023, -0.29105985],
               [-0.06753982, -0.5128001 ,  0.44331315],
               [-0.13527706, -0.14729215,  0.23034579],
               [-0.1349464 , -0.41411245,  0.2732163 ],
               [-0.03944878, -0.3619894 ,  0.27715003],
               [-0.35726544,  0.1296376 , -0.17981628],
               [ 0.23691937, -0.19829887,  0.51437813],
               [ 0.03203585,  0.37968078,  0.38604495]], dtype=float32),
  'bias': DeviceArray([0., 0., 0.], dtype=float32)}}

In [6]:
@dataclasses.dataclass
class Dense2(nn.base.Module):
  features: Any
  bias: bool = True
  dtype: Any = jnp.float32
  precision: Any = None
  kernel_init: Any = nn.linear.default_kernel_init
  bias_init: Any = nn.initializers.zeros
  name: Optional[str] = None

  def apply(self, inputs):
    inputs = jnp.asarray(inputs, self.dtype)
    kernel = self.param('kernel', (inputs.shape[-1], self.features), self.kernel_init)
    kernel = jnp.asarray(kernel, self.dtype)
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),
                        precision=self.precision)
    if self.bias:
      bias = self.param('bias', (self.features,), self.bias_init)
      bias = jnp.asarray(bias, self.dtype)
      y = y + bias
    return y
  
class Test2(nn.Module):
  def __init__(self, features):
    self.features = features
  def apply(self, x):
    y = Dense2(self.features, name='foo')(x)
    return y

# frame = nn.base._ModuleFrame(None, rng=random.PRNGKey(0), parent=None)
# with nn.base._module_stack.frame(frame):
out, ps = Test2(5).init(random.PRNGKey(0), jnp.zeros((2,10)))
#print(nn.base._module_stack._stack.frames[0].params)
ps

{'foo': {'kernel': DeviceArray([[-0.2529256 ,  0.26981813,  0.30638131, -0.08274508,
                -0.08477998],
               [-0.08519974, -0.12731075,  0.18377769,  0.12409567,
                 0.6234259 ],
               [-0.08067774, -0.35942018,  0.13299485,  0.23421109,
                 0.17342107],
               [ 0.14269167,  0.0055981 ,  0.27061704,  0.19526495,
                -0.5933554 ],
               [ 0.6213056 , -0.61469084,  0.04832337, -0.02115068,
                 0.1297735 ],
               [ 0.20668189, -0.03091113, -0.17350008, -0.12020119,
                -0.5258765 ],
               [-0.5250607 ,  0.29690045,  0.5325081 , -0.12721153,
                -0.05707901],
               [ 0.3932148 ,  0.02470948, -0.17474599, -0.4096663 ,
                -0.1654451 ],
               [ 0.5139211 , -0.00969588, -0.08668266,  0.39069012,
                 0.02172046],
               [-0.5894855 , -0.08789891,  0.12177308, -0.15530352,
                -0.53401613]], dt

In [12]:
@dataclasses.dataclass
class Dense3(nn.base.Module):
  features: Any
  bias: bool = True
  dtype: Any = jnp.float32
  precision: Any = None
  kernel_init: Any = nn.linear.default_kernel_init
  bias_init: Any = nn.initializers.zeros
  name: Optional[str] = None

  def apply(self, inputs):
    inputs = jnp.asarray(inputs, self.dtype)
    kernel = self.param('kernel', (inputs.shape[-1], self.features), self.kernel_init)
    kernel = jnp.asarray(kernel, self.dtype)
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),
                        precision=self.precision)
    if self.bias:
      bias = self.param('bias', (self.features,), self.bias_init)
      bias = jnp.asarray(bias, self.dtype)
      y = y + bias
    return y
  

class Test3(nn.Module):
  def __init__(self, features1, features2):
    self.features1 = features1
    self.features2 = features2
    self.dense1 = Dense3(features1, name='dense1')
    self.dense2 = Dense3(features2, name='dense2')
  def apply(self, x):
    y = self.dense1(x)
    z = self.dense2(y)
    return z

  
# class Test4(nn.Module):
#   def __init__(self, features1, features2):
#     self.features1 = features1
#     self.features2 = features2
#   def apply(self, x):
#     return Test3(self.features1, self.features2)(x)

out, ps = Test3(5,6).init(random.PRNGKey(0), jnp.ones((2,10)))
out, ps

(DeviceArray([[-1.9115574 ,  2.5983858 , -0.3795683 , -0.67614144,
                0.2463553 ,  1.2554313 ],
              [-1.9115574 ,  2.5983858 , -0.3795683 , -0.67614144,
                0.2463553 ,  1.2554313 ]], dtype=float32),
 {'dense1': {'kernel': DeviceArray([[-0.42626938, -0.06291714,  0.40380943,  0.33022466,
                 -0.44087473],
                [ 0.14540413, -0.16723464,  0.372912  , -0.14302   ,
                  0.33056796],
                [-0.13415512, -0.08430405,  0.33054024, -0.00531342,
                  0.49308527],
                [ 0.12483141,  0.5022159 ,  0.05051188, -0.47747606,
                 -0.08046041],
                [ 0.10952947,  0.3962143 ,  0.61411387, -0.06005407,
                 -0.26729408],
                [ 0.5076167 ,  0.0774442 ,  0.54077667, -0.26842186,
                 -0.24723075],
                [-0.05678061, -0.64803886,  0.14430496,  0.05134998,
                 -0.03412019],
                [-0.09557235,  0.45958194, -0

In [8]:
class Test4(nn.Module):
  def __init__(self, features1, features2):
    self.features1 = features1
    self.features2 = features2
  def apply(self, x):
    return Test3(self.features1, self.features2)(x)

#frame = nn.base._ModuleFrame(None, rng=random.PRNGKey(0), parent=None)
#with nn.base._module_stack.frame(frame):
out, ps = Test4(5,6).init(random.PRNGKey(0), jnp.zeros((2,10)))
ps

5
6
1 <flax.nn.base._ModuleFrame object at 0x1323e4d30>
2 <flax.nn.base._ModuleFrame object at 0x1321b9978>


{'Test3_0': {},
 'dense1': {'kernel': DeviceArray([[-0.42626938, -0.06291714,  0.40380943,  0.33022466,
                -0.44087473],
               [ 0.14540413, -0.16723464,  0.372912  , -0.14302   ,
                 0.33056796],
               [-0.13415512, -0.08430405,  0.33054024, -0.00531342,
                 0.49308527],
               [ 0.12483141,  0.5022159 ,  0.05051188, -0.47747606,
                -0.08046041],
               [ 0.10952947,  0.3962143 ,  0.61411387, -0.06005407,
                -0.26729408],
               [ 0.5076167 ,  0.0774442 ,  0.54077667, -0.26842186,
                -0.24723075],
               [-0.05678061, -0.64803886,  0.14430496,  0.05134998,
                -0.03412019],
               [-0.09557235,  0.45958194, -0.04055655,  0.6842091 ,
                 0.28904867],
               [ 0.34298617,  0.07379756, -0.5520766 , -0.4631825 ,
                 0.4163055 ],
               [-0.13763581,  0.14039777,  0.367441  , -0.47176048,
              

In [20]:
hasattr(Dense3, 'name')

True

In [5]:
frame = nn.base._ModuleFrame(None, rng=random.PRNGKey(0), parent=None)
with nn.base._module_stack.frame(frame):
  x = Test3(5,6)

In [15]:
for k in x.__dict__.keys(): 
  v = getattr(x, k)
  if isinstance(v, nn.base.Module):
    print(v._parent)

<flax.nn.base._ModuleFrame object at 0x131ebeda0>
<flax.nn.base._ModuleFrame object at 0x131ebeda0>
