In [None]:
import jax.numpy as jnp
from einops import rearrange
import numpy as np

import meshio

In [None]:
from cglbm.lbm import grid_eq_dist, eq_dist_phase_field
from cglbm.simulation import multi_step_simulation
from cglbm.environment import State, System

In [None]:
# from jax import config

# config.update("jax_debug_nans", True)
# config.update("jax_enable_x64", True)

# Simulation Code

## Density and Velocity update

### Notes about the functions

1. Every function is pure without side effects
2. They work for a single grid point
3. These functions will be vectorized to map over the entire grid


# Initial conditions


In [None]:
from cglbm.test_utils import load_config

system = load_config("params.ini")

## Phase Field

In [None]:
LX = system.LX
LY = system.LY
grid_shape = (LX, LY)

phase_field = jnp.zeros(grid_shape)

center = (LX/2, LY/2)
radius = LY/4
X, Y = jnp.meshgrid(jnp.arange(LX), jnp.arange(LY))


### Initialising drop

drop = (X - center[0]) ** 2 + (Y - center[1]) ** 2 < radius ** 2

phase_field = phase_field.at[drop].set(1.0)

### Initialising drop with tanh

In [None]:
#### change to code when running
distanceFromCenter = jnp.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)

phase_field = 1.0 + jnp.tanh((distanceFromCenter - radius) * 2.0 / system.width)

### Initialising square drop

#### change to code when running
side = min(LX, LY) / 2

square = jnp.logical_and(abs(X - center[0]) < side/2, abs(Y - center[1]) < side / 2)
phase_field = phase_field.at[square].set(1.0)

## Density, Velocity, Pressure

In [None]:
rho = system.density_one * phase_field + system.density_two * (1.0 - phase_field)
pressure = jnp.full(grid_shape, system.ref_pressure)

u_x = -system.uWallX + (Y - 2.5) * 2 * system.uWallX / (LY - 6)
u_y = jnp.zeros(grid_shape)
u = rearrange(jnp.stack([u_x, u_y]), "x i j -> i j x") #.transpose(1, 2, 0)

## Obstacle

In [None]:
obs = jnp.zeros(grid_shape, dtype=bool)
obs_velX = jnp.zeros(grid_shape)
obs_velY = jnp.zeros(grid_shape)

# cylinder = (X - center[0]) ** 2 + (Y - center[1]) ** 2 < radius ** 2
# obs = obs.at[cylinder].set(1.0)
# obs_velX = obs_velX.at[cylinder].set(system.uWallX)
# obs_velY = obs_velY.at[cylinder].set(system.uWallX)



# TODO: Find another way to set the obstacle
for i in range(system.LX):
    obs = obs.at[i, [0, 1, 2, -3, -2, -1]].set(True)
    obs_velX = obs_velX.at[i, [-3, -2, -1]].set(system.uWallX)
    obs_velX = obs_velX.at[i, [0, 1, 2]].set(-system.uWallX)

obs_vel = rearrange(jnp.stack([obs_velX, obs_velY]), "x i j -> i j x")

In [None]:
f = eq_dist_phase_field(system.cXYs, system.weights, phase_field, jnp.zeros((LX, LY, 2)))
N = grid_eq_dist(system.cXYs, system.weights, system.phi_weights, pressure, jnp.zeros((LX, LY, 2)))

In [None]:
state = State(
    rho=rho,
    pressure=pressure,
    u=u,
    obs=obs,
    obs_velocity=obs_vel,
    f=f,
    N=N
)

# Simulation

In [None]:
nr_iter = 10
nr_snapshots = 10
final_step = multi_step_simulation(system, state, nr_iter, nr_snapshots)

In [None]:
# %timeit multi_step_simulation(system, state, 1000, 100)

## Meshio

### Serialization

In [None]:
def create_mesh(Nx, Ny):
    """
    creates mesh filled with quad cells

    Args:
        Nx: int
        Ny: int

    Returns:
        mesh: meshio.Mesh
    """
    dim = 2
    x = np.linspace(0, Nx, Nx+1)
    y = np.linspace(0, Ny, Ny+1)

    xv, yv = np.meshgrid(x, y)
    points_xy = np.stack((xv, yv), axis=dim)
    points = points_xy.reshape(-1, 2)

    points_inds = np.arange(len(points))
    points_inds_xy = points_inds.reshape(Nx + 1, Ny + 1)

    top_left_points = points_inds_xy[:-1, :-1]
    bottom_left_points = points_inds_xy[1:, :-1]
    bottom_right_points = points_inds_xy[1:, 1:]
    top_right_points = points_inds_xy[:-1, 1:]
    quad_cells = np.stack((top_left_points, bottom_left_points, bottom_right_points, top_right_points),
                        axis=dim)
    quad_cells = quad_cells.reshape(-1,4)

    return meshio.Mesh(points, [("quad", quad_cells)])

In [None]:
mesh = create_mesh(LX, LY)

In [None]:
ux_final, uy_final = rearrange(final_step["u"], "t i y x -> x t i y")#.transpose(3, 0, 1, 2)

In [None]:
# Creating vtk visualization (creates vkt file)

# Storing 2nd frame in cell_data
mesh.cell_data["ux"] = ux_final[1,:,:].flatten()
mesh.cell_data["uy"] = uy_final[1,:,:].flatten()

mesh.write("lbm.vtk", file_format="vtk")

### Deserialization

In [None]:
# Assertion on data read

mesh = meshio.read("lbm.vtk", file_format="vtk")

cell_data_ux = mesh.cell_data["ux"][0].reshape((LX, LY))
cell_data_uy = mesh.cell_data["uy"][0].reshape((LX, LY))

# comparing only 2 frames, as there are NaNs after that
print(np.allclose(cell_data_ux, ux_final[1,:,:]))
print(np.allclose(cell_data_uy, uy_final[1,:,:]))

### Time-series Serialization

In [None]:
# Creating time-series visualization (creates xdmf file)

with meshio.xdmf.TimeSeriesWriter("lbm.xdmf") as writer:
    writer.write_points_cells(mesh.points, mesh.cells)
    for t in (range(len(ux_final))):
        writer.write_data(t, cell_data={"ux": ux_final[t], "uy": uy_final[t]})

### Time-series Deserialization

In [None]:
# Assertion on data read

cell_data_ux = []
cell_data_uy = []
with meshio.xdmf.TimeSeriesReader("lbm.xdmf") as reader:
    points, cells = reader.read_points_cells()
    for k in range(reader.num_steps):
        t, point_data, cell_data = reader.read_data(k)
        cell_data_ux.append(cell_data["ux"])
        cell_data_uy.append(cell_data["uy"])

cell_data_ux = np.stack(cell_data_ux).reshape(nr_snapshots+1, LX, LY)
cell_data_uy = np.stack(cell_data_uy).reshape(nr_snapshots+1, LX, LY)
# comparing only 2 frames, as there are NaNs after that
print(np.allclose(ux_final[0:2], cell_data_ux[0:2]))
print(np.allclose(uy_final[0:2], cell_data_uy[0:2]))
