<a href="https://colab.research.google.com/github/hhhezhang/jax-flax-learning/blob/main/flax_immutable_variables.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [66]:
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np

In [109]:
def complicated_values(shape):
    print("Calling complicated_values.")
    values = np.zeros(np.prod(shape))
    for i in range(values.size):
        values[i] = i ** 2
    return values.reshape(shape)

This example illustrate a expensive approach where the complicated init function get called multiple times.

In [180]:
class MyModelExpensive(nn.Module):

    @nn.compact
    def __call__(self, x, train=False):
        immutable_x = self.param("x", lambda rng, shape: complicated_values(shape), x.shape[1:])
        jax.lax.stop_gradient(immutable_x)
        return x * immutable_x

model_exp = MyModelExpensive()
vars = model_exp.init(jax.random.key(0), jnp.ones((3, 2)))
print(vars)
output = model_exp.apply(vars, jnp.ones((3, 2)))
output = model_exp.apply(vars, jnp.ones((3, 2)))

Calling complicated_values.
{'params': {'x': array([0., 1.])}}
Calling complicated_values.
Calling complicated_values.


This approach is better by compiling with jit.

In [177]:
@jax.jit
def apply_model(x):
    print("Calling jit model.")
    return model_exp.apply(vars, x)

output = apply_model(jnp.ones((3, 2)))
output = apply_model(jnp.ones((3, 2)))
output = apply_model(jnp.ones((3, 2)))

Calling jit model.
Calling complicated_values.


We try to use variable instead of param to initialize it.

In [187]:
class MyModel(nn.Module):

    @nn.compact
    def __call__(self, x, train=False):
        is_initialized = self.has_variable('immutable', 'x')
        immutable_x = self.variable("immutable", "x", lambda shape: jnp.zeros(shape), x.shape[1:])
        immutable_x
        if not is_initialized:
            immutable_x.value = complicated_values(immutable_x.value.shape)
        return x * immutable_x.value

model = MyModel()
vars = model.init(jax.random.key(0), jnp.ones((3, 2)))
print(vars)
output = model.apply(vars, jnp.ones((3, 2)))
output = model.apply(vars, jnp.ones((3, 2)))

Calling complicated_values.
{'immutable': {'x': array([0., 1.])}}


A final simplified version.

In [188]:
class MyModel(nn.Module):

    @nn.compact
    def __call__(self, x, train=False):
        immutable_x = self.variable("immutable", "x", complicated_values, x.shape[1:])
        return x * immutable_x.value

model = MyModel()
vars = model.init(jax.random.key(0), jnp.ones((3, 2)))
print(vars)
output = model.apply(vars, jnp.ones((3, 2)))
output = model.apply(vars, jnp.ones((3, 2)))

Calling complicated_values.
{'immutable': {'x': array([0., 1.])}}
