# Checkpointing with orbax

Demo of how checkpoints of the states are generated using [Orbax](https://orbax.readthedocs.io/en/latest/orbax_checkpoint_101.html)

## Importing modules

In [None]:
import jax.numpy as jnp
from einops import rearrange
import numpy as np

In [None]:
from cglbm.lbm import grid_eq_dist, eq_dist_phase_field
from cglbm.simulation import multi_step_simulation_with_checkpointing
from cglbm.environment import State
from cglbm.utils import validate_sim_params, restore_state

## Simulation Setup

### Loading pre-defined environment

In [None]:
from cglbm.config import load_sandbox_config

system = load_sandbox_config("stationary-drop-config.ini")

### Initial conditions of simulation

In [None]:
LX = system.LX
LY = system.LY
X, Y = jnp.meshgrid(jnp.arange(LX), jnp.arange(LY))

grid_shape = X.shape # this is taken from meshgrid, can also be Y.shape
phase_field = jnp.zeros(grid_shape)
center = (grid_shape[0]//2, grid_shape[1]//2)

radius = system.drop_radius

### Initializing drop

In [None]:
coordinates = rearrange(jnp.stack([Y,X]), "v y x -> y x v")
distanceFromCenter = jnp.sqrt(jnp.sum(jnp.square(coordinates - jnp.array(center)), axis=2))

phase_field = 0.5 * (1.0 + jnp.tanh((distanceFromCenter - radius) * 2.0 / system.width))

### Initializing Density, Velocity, Pressure

In [None]:
rho = system.density_one * phase_field + system.density_two * (1.0 - phase_field)
pressure = jnp.full(grid_shape, system.ref_pressure)

u_x = -system.uWallX + (Y - 2.5) * 2 * system.uWallX / (LY - 6)
u_y = jnp.zeros(grid_shape)
u = rearrange(jnp.stack([u_x, u_y]), "x i j -> i j x")

### Defining Obstacle

In [None]:
obs = jnp.zeros(grid_shape, dtype=bool)
obs_velX = jnp.zeros(grid_shape)
obs_velY = jnp.zeros(grid_shape)

obs = obs.at[:, [0, 1, -2, -1]].set(True)
obs_velX = obs_velX.at[:, [-2, -1]].set(system.uWallX)
obs_velX = obs_velX.at[:, [0, 1]].set(-system.uWallX)

obs_vel = rearrange(jnp.stack([obs_velX, obs_velY]), "x i j -> i j x")

### Initialising f and N

In [None]:
f = eq_dist_phase_field(system.cXYs, system.weights, phase_field, jnp.zeros(coordinates.shape))
N = grid_eq_dist(system.cXYs, system.weights, system.phi_weights, pressure, jnp.zeros(coordinates.shape))

### Initialising state of the simulation

In [None]:
state = State(
    rho=rho,
    pressure=pressure,
    u=u,
    phase_field=phase_field,
    obs=obs,
    obs_velocity=obs_vel,
    f=f,
    N=N
)

### Initializing and validating simulation parameters

In [None]:
nr_iter = 100
nr_snapshots = 20
nr_checkpoints = 5

validate_sim_params(nr_iter, nr_snapshots, nr_checkpoints)

### Initializing checkpoint manager

In [None]:
# Add path to your checkpoint directory here
checkpoint_dir= None

In [None]:
import orbax.checkpoint as ocp

mngr_options = ocp.CheckpointManagerOptions(save_interval_steps=(nr_iter // nr_checkpoints), max_to_keep=3)
mngr = ocp.CheckpointManager(checkpoint_dir, options=mngr_options)

## Running the Simulation

In [None]:
_, final_state = multi_step_simulation_with_checkpointing(system, state, mngr, nr_iter, nr_snapshots, nr_checkpoints)
mngr.wait_until_finished()

### Benchmarking

In [None]:
# %timeit -n 1 -r 1 multi_step_simulation_with_checkpointing(system, state, mngr, nr_iter, nr_snapshots, nr_checkpoints)

### Restoring checkpoint and checking if its correct

In [None]:
temp_state = State(rho=jnp.zeros(grid_shape),
                    pressure=jnp.zeros(grid_shape),
                    u=jnp.zeros((grid_shape[0],grid_shape[1],2)),
                    phase_field=jnp.zeros(grid_shape),
                    obs=jnp.zeros(grid_shape, dtype=bool),
                    obs_velocity=jnp.zeros((grid_shape[0],grid_shape[1],2)),
                    f=jnp.zeros((9,grid_shape[0],grid_shape[1])),
                    N=jnp.zeros((9,grid_shape[0],grid_shape[1])))

In [None]:
restored_state = restore_state(mngr, temp_state)
del temp_state

In [None]:
for key in final_state:
    assert np.any(restored_state[key] == final_state[key]), f"State not matching for key={key}"

assert mngr.latest_step() == nr_iter, "Number of iterations completed till now does not match"

## Re-running from last checkpoint

In [None]:
# Update number of iterations for which the next simulation will run
new_nr_iter = 100
nr_snapshots = 20
nr_checkpoints = 5

In [None]:
new_mngr_options = ocp.CheckpointManagerOptions(save_interval_steps=(new_nr_iter // nr_checkpoints), max_to_keep=2)
new_mngr = ocp.CheckpointManager(checkpoint_dir, options=new_mngr_options)

In [None]:
_, final_state = multi_step_simulation_with_checkpointing(system, state, new_mngr, new_nr_iter, nr_snapshots, nr_checkpoints, True)

In [None]:
assert new_mngr.latest_step() == (nr_iter + new_nr_iter ), "Number of iterations completed till now does not match"