In [None]:
import jax
import optax
from flax import nnx
from jax.experimental import mesh_utils
from jax.sharding import Mesh

from deeprte.model import modules
from deeprte.train_lib import checkpointing
from deeprte.train_lib import utils as train_utils

config = modules.DeepRTEConfig()
config.load_parameters_path = ""
config.load_full_state_path = ""
config.enable_single_replica_ckpt_restoring = False
config.dataset_type = "tfds"

rng = jax.random.PRNGKey(42)
rng, init_rng = jax.random.split(rng)

devices_array = mesh_utils.create_device_mesh((1, 2))
mesh = Mesh(devices_array, ("data", "fsdp"))


def constructor(config, key: jax.Array):
    return modules.DeepRTE(config, rngs=nnx.Rngs(params=key))

In [None]:
abstract_state, state_shardings = train_utils.get_abstract_state(
    constructor, optax.adam(1e-3), config, init_rng, mesh
)

len(jax.live_arrays())

In [None]:
abstract_state.params

In [None]:
jax.tree.map(lambda x: x.sharding, abstract_state)

In [None]:
init_state, state_sharding, data_iter = train_utils.setup_initial_state(
    constructor, None, optax.adam(1e-3), config, init_rng, mesh, None
)

In [None]:
checkpointing.save_params_to_path("/workspaces/deeprte/test", init_state.params)