# Serialization with Meshio

This notebook contains the steps to initialise the state, run the simulation and serialize using Meshio.

In this notebook, Meshio is used to serialize and deserialize into VTK and XDMF files.

The states of the system are stored as VTK files.

The time series data is stored in XDMF file format.

Both VTK and XDMF can be visualised using Paraview.

## Installing dependencies and Importing modules

In [None]:
%pip install h5py meshio

In [None]:
import jax.numpy as jnp
from einops import rearrange
import numpy as np

import meshio

In [None]:
from cglbm.lbm import grid_eq_dist, eq_dist_phase_field
from cglbm.simulation import multi_step_simulation
from cglbm.environment import State, System

## Simulation Setup

### Loading pre-defined environment

In [None]:
from cglbm.config import load_sandbox_config

system = load_sandbox_config("stationary-drop-config.ini")

### Initial conditions of simulation

In [None]:
LX = system.LX
LY = system.LY
X, Y = jnp.meshgrid(jnp.arange(LX), jnp.arange(LY))

grid_shape = X.shape # this is taken from meshgrid, can also be Y.shape
phase_field = jnp.zeros(grid_shape)
center = (grid_shape[0]//2, grid_shape[1]//2)

radius = system.drop_radius

### Initializing drop

In [None]:
coordinates = rearrange(jnp.stack([Y,X]), "v y x -> y x v")
distanceFromCenter = jnp.sqrt(jnp.sum(jnp.square(coordinates - jnp.array(center)), axis=2))

phase_field = 0.5 * (1.0 + jnp.tanh((distanceFromCenter - radius) * 2.0 / system.width))

### Initializing Density, Velocity, Pressure

In [None]:
rho = system.density_one * phase_field + system.density_two * (1.0 - phase_field)
pressure = jnp.full(grid_shape, system.ref_pressure)

u_x = -system.uWallX + (Y - 2.5) * 2 * system.uWallX / (LY - 6)
u_y = jnp.zeros(grid_shape)
u = rearrange(jnp.stack([u_x, u_y]), "x i j -> i j x")

### Defining Obstacle

In [None]:
obs = jnp.zeros(grid_shape, dtype=bool)
obs_velX = jnp.zeros(grid_shape)
obs_velY = jnp.zeros(grid_shape)

obs = obs.at[:, [0, 1, -2, -1]].set(True)
obs_velX = obs_velX.at[:, [-2, -1]].set(system.uWallX)
obs_velX = obs_velX.at[:, [0, 1]].set(-system.uWallX)

obs_vel = rearrange(jnp.stack([obs_velX, obs_velY]), "x i j -> i j x")

### Initialising f and N

In [None]:
f = eq_dist_phase_field(system.cXYs, system.weights, phase_field, jnp.zeros(coordinates.shape))
N = grid_eq_dist(system.cXYs, system.weights, system.phi_weights, pressure, jnp.zeros(coordinates.shape))

### Initialising state of the simulation

In [None]:
state = State(
    rho=rho,
    pressure=pressure,
    u=u,
    phase_field=phase_field,
    obs=obs,
    obs_velocity=obs_vel,
    f=f,
    N=N
)

## Running the Simulation

In [None]:
nr_iter = 10
nr_snapshots = 10
final_results, _ = multi_step_simulation(system, state, nr_iter, nr_snapshots)

In [None]:
ux_final, uy_final = rearrange(final_results["u"], "t i y x -> x t i y")
phase_field_final = final_results["phase_field"]

## Meshio Serialization

In [None]:
def create_mesh(Nx, Ny):
    """
    creates mesh filled with quad cells

    Args:
        Nx: int
        Ny: int

    Returns:
        mesh: meshio.Mesh
    """
    dim = 2
    x = np.linspace(0, Nx, Nx+1)
    y = np.linspace(0, Ny, Ny+1)

    xv, yv = np.meshgrid(x, y)
    points_xy = np.stack((xv, yv), axis=dim)
    points = points_xy.reshape(-1, 2)

    points_inds = np.arange(len(points))
    points_inds_xy = points_inds.reshape(Nx + 1, Ny + 1)

    top_left_points = points_inds_xy[:-1, :-1]
    bottom_left_points = points_inds_xy[1:, :-1]
    bottom_right_points = points_inds_xy[1:, 1:]
    top_right_points = points_inds_xy[:-1, 1:]
    quad_cells = np.stack((top_left_points, bottom_left_points, bottom_right_points, top_right_points),
                        axis=dim)
    quad_cells = quad_cells.reshape(-1,4)

    return meshio.Mesh(points, [("quad", quad_cells)])

### Serialization to VTK file

In [None]:
mesh = create_mesh(LX, LY)

# Storing 2nd frame in cell_data
frame = 1

mesh.cell_data["ux"] = ux_final[frame,:,:].flatten()
mesh.cell_data["uy"] = uy_final[frame,:,:].flatten()

mesh.write("lbm.vtk", file_format="vtk")

### Deserialization from VTK file

In [None]:
mesh = meshio.read("lbm.vtk", file_format="vtk")

# Reading data
cell_data_ux = mesh.cell_data["ux"][0].reshape((LX, LY))
cell_data_uy = mesh.cell_data["uy"][0].reshape((LX, LY))

# Assertions
# comparing only 2 frames, as there are NaNs after that
print(np.allclose(cell_data_ux, ux_final[1,:,:]))
print(np.allclose(cell_data_uy, uy_final[1,:,:]))

### Serialization of time-series data to XDMF file

In [None]:
mesh = create_mesh(LX, LY)

with meshio.xdmf.TimeSeriesWriter("lbm.xdmf") as writer:
    writer.write_points_cells(mesh.points, mesh.cells)
    for t in (range(len(ux_final))):
        writer.write_data(t, cell_data={"ux": ux_final[t], "uy": uy_final[t]})

### Deserialization of time-series data from XDMF file

In [None]:
cell_data_ux = []
cell_data_uy = []

# Reading data
with meshio.xdmf.TimeSeriesReader("lbm.xdmf") as reader:
    points, cells = reader.read_points_cells()
    for k in range(reader.num_steps):
        t, point_data, cell_data = reader.read_data(k)
        cell_data_ux.append(cell_data["ux"])
        cell_data_uy.append(cell_data["uy"])

cell_data_ux = np.stack(cell_data_ux).reshape(nr_snapshots+1, LX, LY)
cell_data_uy = np.stack(cell_data_uy).reshape(nr_snapshots+1, LX, LY)

# Assertions
# comparing only 2 frames, as there are NaNs after that
print(np.allclose(ux_final[0:2], cell_data_ux[0:2]))
print(np.allclose(uy_final[0:2], cell_data_uy[0:2]))
