In [11]:
import jax
from jax import lax, random, numpy as jnp
import numpy as np
import flax
from flax import nn
import dataclasses

from typing import Any, Optional

In [2]:
# class Dense(base.Module):
  
#   def apply(self,
#             inputs,
#             features,
#             bias=True,
#             dtype=jnp.float32,
#             precision=None,
#             kernel_init=default_kernel_init,
#             bias_init=initializers.zeros):
#     inputs = jnp.asarray(inputs, dtype)
#     kernel = self.param('kernel', (inputs.shape[-1], features), kernel_init)
#     kernel = jnp.asarray(kernel, dtype)
#     y = lax.dot_general(inputs, kernel,
#                         (((inputs.ndim - 1,), (0,)), ((), ())),
#                         precision=precision)
#     if bias:
#       bias = self.param('bias', (features,), bias_init)
#       bias = jnp.asarray(bias, dtype)
#       y = y + bias
#     return y


In [3]:
class Dense(nn.base.Module):
  def __init__(self,           
               features,
               bias=True,
               dtype=jnp.float32,
               precision=None,
               kernel_init=nn.linear.default_kernel_init,
               bias_init=nn.initializers.zeros,
               name=None):
    self.features = features
    self.bias = bias
    self.dtype = dtype
    self.precision = precision
    self.kernel_init = kernel_init
    self.bias_init = bias_init
    self.name = name
  def apply(self, inputs):
    inputs = jnp.asarray(inputs, self.dtype)
    kernel = self.param('kernel', (inputs.shape[-1], self.features), self.kernel_init)
    kernel = jnp.asarray(kernel, self.dtype)
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),
                        precision=self.precision)
    if self.bias:
      bias = self.param('bias', (self.features,), self.bias_init)
      bias = jnp.asarray(bias, self.dtype)
      y = y + bias
    return y

In [4]:
class Test1(nn.Module):
  def __init__(self, features):
    self.features = features
  def apply(self, x):
    y = Dense(self.features, name='foo')(x)
    return y

frame = nn.base._ModuleFrame(None, rng=random.PRNGKey(0), parent=None)
with nn.base._module_stack.frame(frame):
  out, ps = Test1(10).init(random.PRNGKey(0), jnp.zeros((2,10)))
  #print(nn.base._module_stack._stack.frames[0].params)



In [9]:
frame = nn.base._ModuleFrame(None, rng=random.PRNGKey(0), parent=None)
with nn.base._module_stack.frame(frame):
  out, ps = Test1(2).init_by_shape(random.PRNGKey(0), (((2,10), np.float32),))

In [10]:
ps

{'foo': {'bias': DeviceArray([0., 0.], dtype=float32),
  'kernel': DeviceArray([[-0.4599581 ,  0.20015173],
               [-0.27465212, -0.08700371],
               [-0.06366538, -0.09729218],
               [-0.12015642, -0.4313587 ],
               [-0.02870374,  0.10679117],
               [-0.32761663, -0.06075426],
               [ 0.03170891, -0.49054202],
               [ 0.3635096 ,  0.3704213 ],
               [-0.18155588,  0.52057534],
               [ 0.29326665,  0.04858123]], dtype=float32)}}

In [25]:
@dataclasses.dataclass
class Dense2(nn.base.Module):
  features: Any
  bias: bool = True
  dtype: Any = jnp.float32
  precision: Any = None
  kernel_init: Any = nn.linear.default_kernel_init
  bias_init: Any = nn.initializers.zeros
  name: Optional[str] = None
  def apply(self, inputs):
    inputs = jnp.asarray(inputs, self.dtype)
    kernel = self.param('kernel', (inputs.shape[-1], self.features), self.kernel_init)
    kernel = jnp.asarray(kernel, self.dtype)
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),
                        precision=self.precision)
    if self.bias:
      bias = self.param('bias', (self.features,), self.bias_init)
      bias = jnp.asarray(bias, self.dtype)
      y = y + bias
    return y
  
class Test2(nn.Module):
  def __init__(self, features):
    self.features = features
  def apply(self, x):
    y = Dense2(self.features, name='foo')(x)
    return y

frame = nn.base._ModuleFrame(None, rng=random.PRNGKey(0), parent=None)
with nn.base._module_stack.frame(frame):
  out, ps = Test2(5).init(random.PRNGKey(0), jnp.zeros((2,10)))
  #print(nn.base._module_stack._stack.frames[0].params)
ps

{'foo': {'kernel': DeviceArray([[-0.2529256 ,  0.26981813,  0.30638131, -0.08274508,
                -0.08477998],
               [-0.08519974, -0.12731075,  0.18377769,  0.12409567,
                 0.6234259 ],
               [-0.08067774, -0.35942018,  0.13299485,  0.23421109,
                 0.17342107],
               [ 0.14269167,  0.0055981 ,  0.27061704,  0.19526495,
                -0.5933554 ],
               [ 0.6213056 , -0.61469084,  0.04832337, -0.02115068,
                 0.1297735 ],
               [ 0.20668189, -0.03091113, -0.17350008, -0.12020119,
                -0.5258765 ],
               [-0.5250607 ,  0.29690045,  0.5325081 , -0.12721153,
                -0.05707901],
               [ 0.3932148 ,  0.02470948, -0.17474599, -0.4096663 ,
                -0.1654451 ],
               [ 0.5139211 , -0.00969588, -0.08668266,  0.39069012,
                 0.02172046],
               [-0.5894855 , -0.08789891,  0.12177308, -0.15530352,
                -0.53401613]], dt