In [1]:
import functools
from importlib import reload
from pprint import pprint
from typing import Callable
import jax
from jax import random, lax, numpy as jnp
import numpy as np

import flax
from flax import nn
from flax.nn import initializers
from flax import core
from flax.core import Scope, init, apply, Array, unfreeze, freeze

from typing import Union, Any, Callable, Optional
import dataclasses

In [6]:
@dataclasses.dataclass
class Module:
  parent: Union['Module', Scope]
  
  def __post_init__(self):
    self.parent = self.parent if isinstance(self.parent, Scope) else self.parent._scope
    if hasattr(self, 'name') and self.name is not None:
      self._scope = self.parent.push(name=self.name)
    else:
      self._scope = self.parent.push(name_prefix=self.__class__.__name__)
    if hasattr(self, 'init') and callable(self.init):
      self.init()

  def param(self, name, init_fn, *args):
    return self._scope.param(name, init_fn, *args)

In [10]:
@dataclasses.dataclass  
class Dense(Module):
  features: int
  kernel_init: Callable = nn.linear.default_kernel_init
  bias_init: Callable = nn.initializers.zeros
  name: Optional[str] = None

  def __call__(self, x):
    kernel = self.param('kernel', self.kernel_init, (x.shape[-1], self.features))
    bias = self.param('bias', self.bias_init, (self.features,))
    return jnp.dot(x, kernel) + bias


@dataclasses.dataclass
class DenseTest1(Module):
  features: int
  name: Optional[str] = None

  def __call__(self, x):
    x = Dense(self, self.features)(x)
    x = nn.relu(x)
    x = Dense(self, self.features)(x)
    return nn.relu(x)

x = jnp.ones((4,))
y, params = init(lambda scope, x: DenseTest1(parent=scope, features=4, name='mlp')(x))(random.PRNGKey(0), x)
pprint(jax.tree_map(jnp.shape, unfreeze(params)))

{'param': {'mlp': {'Dense_0': {'bias': (4,), 'kernel': (4, 4)},
                   'Dense_1': {'bias': (4,), 'kernel': (4, 4)}}}}


In [11]:
@dataclasses.dataclass  
class Dense(Module):
  features: int
  kernel_init: Callable = nn.linear.default_kernel_init
  bias_init: Callable = nn.initializers.zeros
  name: Optional[str] = None

  def __call__(self, x):
    kernel = self.param('kernel', self.kernel_init, (x.shape[-1], self.features))
    bias = self.param('bias', self.bias_init, (self.features,))
    return jnp.dot(x, kernel) + bias


@dataclasses.dataclass
class DenseTest1B(Module):
  features: int
  name: Optional[str] = None

  def __call__(self, x):
    denselyr = Dense(self, self.features)
    x = denselyr(x)
    x = nn.relu(x)
    x = denselyr(x)
    return nn.relu(x)

x = jnp.ones((4,))
y, params = init(lambda scope, x: DenseTest1B(parent=scope, features=4, name='mlp')(x))(random.PRNGKey(0), x)
pprint(jax.tree_map(jnp.shape, unfreeze(params)))

{'param': {'mlp': {'Dense_0': {'bias': (4,), 'kernel': (4, 4)}}}}


In [8]:
@dataclasses.dataclass
class DenseTest2A(Module):
  features: int
  name: Optional[str] = None

  def init(self):
    self.hidden = Dense(self, self.features, name='hidden')

  def fwd(self, x):
    x = self.hidden(x)
    return x

  def rev(self, x):
    x = self.hidden(x)
    return x

@dataclasses.dataclass
class DenseTest2B(Module):
  features: int
  name: Optional[str] = None

  def init(self):
    self.wrapper = DenseTest2A(self, self.features, name='wrapper')

  def __call__(self, x):
    x = self.wrapper.fwd(x)
    x = self.wrapper.fwd(x)
    return x

x = jnp.ones((4,))
y, params = init(lambda scope, x: DenseTest2B(parent=scope, features=4, name='mlp')(x))(random.PRNGKey(0), x)
pprint(jax.tree_map(jnp.shape, unfreeze(params)))


{'param': {'mlp': {'wrapper': {'hidden': {'bias': (4,), 'kernel': (4, 4)}}}}}


In [9]:
@dataclasses.dataclass  
class Dense(Module):
  features: int
  kernel_init: Callable = nn.linear.default_kernel_init
  bias_init: Callable = nn.initializers.zeros
  name: Optional[str] = None

  def __call__(self, x):
    self.kernel = self.param('kernel', self.kernel_init, (x.shape[-1], self.features))
    self.bias = self.param('bias', self.bias_init, (self.features,))
    return jnp.dot(x, self.kernel) + self.bias


@dataclasses.dataclass
class DenseTest1(Module):
  features: int
  name: Optional[str] = None

  def __call__(self, x):
    x = Dense(self, self.features)(x)
    x = nn.relu(x)
    x = Dense(self, self.features)(x)
    return nn.relu(x)

x = jnp.ones((4,))
y, params = init(lambda scope, x: DenseTest1(parent=scope, features=4, name='mlp')(x))(random.PRNGKey(0), x)
pprint(jax.tree_map(jnp.shape, unfreeze(params)))


{'param': {'mlp': {'Dense_0': {'bias': (4,), 'kernel': (4, 4)},
                   'Dense_1': {'bias': (4,), 'kernel': (4, 4)}}}}


In [12]:
# BREAKS

@dataclasses.dataclass
class DenseTest2A(Module):
  features: int
  name: Optional[str] = None

  def fwd(self, x):
    x = Dense(self, self.features, name='hidden')(x)
    return x

  def rev(self, x):
    x = Dense(self, self.features, name='hidden')(x)
    return x

@dataclasses.dataclass
class DenseTest2B(Module):
  features: int
  name: Optional[str] = None

  def init(self):
    self.wrapper = DenseTest2A(self, self.features, name='wrapper')

  def __call__(self, x):
    x = self.wrapper.fwd(x)
    x = self.wrapper.rev(x)
    return x

x = jnp.ones((4,))
y, params = init(lambda scope, x: DenseTest2B(parent=scope, features=4, name='mlp')(x))(random.PRNGKey(0), x)
pprint(jax.tree_map(jnp.shape, unfreeze(params)))


AssertionError: 

In [13]:
@dataclasses.dataclass  
class Embed(Module):
  features: int
  depth: int
  kernel_init: Callable = nn.linear.default_kernel_init
  name: Optional[str] = None

  def init(self):
    self.table = self.param('kernel', self.kernel_init, (self.features, self.depth))

  def f1(self, idx):
    return self.table[idx]
  
  def f2(self, q):
    return jnp.dot(q, self.table.T)


@dataclasses.dataclass  
class EmbedTest(Module):
  features: int
  depth: int
  name: Optional[str] = None

  def init(self):
    self.hidden = Dense(self, self.features, name='hidden')

  def __call__(self, x):
    lyr = Embed(self, self.features, self.depth, kernel_init=lambda k,s: jnp.eye(*s), name='embed')
    x = lyr.f1(x)
    x = lyr.f2(x)
    return x

x = jnp.array([0,2,3,1])
y, params = init(lambda scope, x: EmbedTest(parent=scope, name='test', features=4, depth=5)(x))(random.PRNGKey(0), x)
print(y)
pprint(jax.tree_map(jnp.shape, unfreeze(params)))

[[1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]
 [0. 1. 0. 0.]]
{'param': {'test': {'embed': {'kernel': (4, 5)}}}}
