# Model Rematerialization Demo

In [1]:
from etils import ecolab
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import optax

with ecolab.adhoc(reload='parallax'):
  import parallax
  from parallax.examples import models
  from parallax.examples import utils

In [3]:
model = models.SimpleMLP(16, 64, 16, rngs=nnx.Rngs(0))
rematted_model = parallax.remat_model(model)


@nnx.jit
def loss_fn(model):
  return optax.softmax_cross_entropy_with_integer_labels(
      logits=model(jnp.ones((1, 16), dtype=jnp.float32)),
      labels=jnp.array([1]),
  ).mean()


def cost_analysis(model):
  traced = nnx.jit(nnx.value_and_grad(loss_fn)).trace(model)
  compiled = traced.lower().compile()
  return compiled.cost_analysis()


print('Original model bytes accessed:', cost_analysis(model)['bytes accessed'])
print(
    'Rematted model bytes accessed:',
    cost_analysis(rematted_model)['bytes accessed'],
)

Original model bytes accessed: 156672.0
Rematted model bytes accessed: 2048.0
