In [2]:
%load_ext autoreload
%autoreload 2

In [34]:
import jax
import jax.numpy as jnp
import numpy as np

import flax

In [27]:
def func(x):
    return jax.numpy.vstack((x["a"], jax.numpy.ones((10, 2))))

In [28]:
jax.eval_shape(func, {"a": jax.numpy.array([2.0, 3.0])})

ShapeDtypeStruct(shape=(11, 2), dtype=float32)

In [33]:
jax.ShapedArray??

In [54]:
class Dummy(flax.nn.Module):
    def apply(self, x):
        y = self.param("y", shape=x.shape + (2,), initializer=flax.nn.initializers.ones)
        return x * y

In [63]:
class PyTreeDummy(flax.nn.Module):
    def apply(self, x):
        y = self.param("y", shape=x["a"].shape + (2,), initializer=flax.nn.initializers.ones)
        return x["a"] * y

In [68]:
model_def = Dummy.partial()

In [70]:
ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [()])

In [57]:
ys

ShapeDtypeStruct(shape=(2,), dtype=float32)

In [58]:
params

{'y': DeviceArray([1., 1.], dtype=float32)}

In [62]:
model_def.init(jax.random.PRNGKey(0), jnp.array(5))

(DeviceArray([5., 5.], dtype=float32),
 {'y': DeviceArray([1., 1.], dtype=float32)})

In [72]:
model_def = PyTreeDummy.partial()

In [73]:
model_def.init_by_shape(jax.random.PRNGKey(0), [{"a": ()}])

TypeError: Shapes must be 1D sequences of concrete values of integer type, got ('a',).

In [92]:
import flax
import jax
import jax.numpy as jnp

def f(params, x):
    return params["W"] @ x

class jax2flax(flax.nn.Module):
    def apply(self, x, f, init_params):
        params = {}
        for key, val in init_params.items():
            params[key] = self.param(key, shape=val.shape, initializer=flax.nn.initializers.ones)
            
        if self.is_initializing():
            for key, val in params.items():
                params[key] = val
                
        return f(params, x)
    
model_def = jax2flax.partial(f=f, init_params={"W": jnp.ones((2, 4))})
_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(4, 1)])
model = flax.nn.Model(model_def, params)
print("Evaluating model:", model(jnp.ones((4, 1))))

optim_def = flax.optim.GradientDescent(learning_rate=1)
optimizer = optim_def.create(model)

def loss(model, x, y):
    y_hat = model(x)
    return jnp.square(y - y_hat).mean()

loss, grad = jax.value_and_grad(loss)(optimizer.target, jnp.ones((4, 1)), jax.random.uniform(jax.random.PRNGKey(0), (2, 1)))
print("Loss, grad:", loss, grad)
optimizer = optimizer.apply_gradient(grad)
print("Updated model:", optimizer.target)

Evaluating model: [[4.]
 [4.]]
Loss, grad: 12.265022 Model(module=<class 'flax.nn.base.jax2flax'>, params={'W': DeviceArray([[3.7837048, 3.7837048, 3.7837048, 3.7837048],
             [3.195876 , 3.195876 , 3.195876 , 3.195876 ]], dtype=float32)})
Updated model: Model(module=<class 'flax.nn.base.jax2flax'>, params={'W': DeviceArray([[-2.7837048, -2.7837048, -2.7837048, -2.7837048],
             [-2.195876 , -2.195876 , -2.195876 , -2.195876 ]],            dtype=float32)})


In [None]:
import flax

class SubModule(flax.nn.Module):
    def apply(self, x, d):
        return x * d

class TopLevelModule(flax.nn.Module):
    def apply(self, x, y, a, b, c):
        e = SubModule()
        