<a href="https://colab.research.google.com/github/hashimmg/jax_IB/blob/main/Flapping_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
%load_ext autoreload
%autoreload 2

In [26]:
import numpy as np
import tree_math as tm
import jax
import jax.numpy as jnp
from jax_ib.base import particle_class as pc
from jax_ib.base import grids, fast_diagonalization, boundaries, pressure, diffusion, advection, finite_differences, IBM_Force,convolution_functions,particle_motion, equations
import jax_cfd.base as cfd
import jax_ib.MD as MD
from jax import random
from jax_md import space, quantity
import jax_ib
import jax_ib.base as ib
from jax_ib.base import kinematics as ks
from jax.random import uniform as random_uniform
import matplotlib.pyplot as plt
import functools as fct
import scipy 
import tree_math
from jax_ib.base import array_utils

In [27]:
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
import numpy as np

In [28]:
mesh2d = jax.make_mesh((4, 2), ('i', 'j'))

In [29]:
density = 1.0    # fluid density
viscosity = 0.05 # fluid viscocity
dt=5e-4          # time step 
num_boundaries = 4
grid = grids.Grid((128,128), domain=((0,15.),(0,15.0)), device_mesh = mesh2d, periods = (15,15))
bc_fns = [lambda t: 0.0 for _ in range(4)]
vx_bc=((0.0, 0.0), (0.0, 0.0))
vy_bc=((0.0, 0.0), (0.0, 0.0))

velocity_bc = (boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=vx_bc,bc_fn=bc_fns,time_stamp=0.0),
               boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=vy_bc,bc_fn=bc_fns,time_stamp=0.0))

vx_fn = lambda x, y: jnp.zeros_like(x)
vy_fn = lambda x, y: jnp.zeros_like(x)
vx_0, vy_0 = tuple(
    [
       grids.GridVariable
        (
           grid.eval_on_mesh(fn = lambda x, y: jnp.ones_like(x), offset = offset), bc # initial values for fluid velocities are 0 both in x and y direction
        ) 
        for offset, bc in zip(grid.cell_faces,velocity_bc)
    ]
)
v0 = (vx_0, vy_0)
global_pressure = grids.GridVariable(
    grids.GridArray(jnp.zeros(grid.shape), grid.cell_center, grid), 
    boundaries.get_pressure_bc_from_velocity((vx_0, vy_0)))




subgrid=grid.subgrid((1,1),boundary_layer_widths=(1,1))

In [30]:
def add_halo_layer(array):
    I, J = jax.lax.psum(1, 'i'), jax.lax.psum(1, 'j')
    left_halo = jax.lax.ppermute(array[:,-1], 'j', [(j, (j + 1) % J) for j in range(J)])
    right_halo = jax.lax.ppermute(array[:,0], 'j', [(j, (j - 1) % J) for j in range(J)])
    upper_halo = jax.lax.ppermute(array[-1,:], 'i', [(i, (i + 1) % I) for i in range(I)])
    lower_halo = jax.lax.ppermute(array[0,:], 'i', [(i, (i - 1) % I) for i in range(I)])
    
    upper_halo = jnp.append(0.0, jnp.append(upper_halo, 0.0))
    lower_halo = jnp.append(0.0, jnp.append(lower_halo, 0.0))
    
    temp = jnp.concatenate([left_halo[:,None], array, right_halo[:,None]], axis=1)
    return  jnp.concatenate([upper_halo[None, :], temp, lower_halo[None,:]], axis=0)

def convect(v):
    return tuple(advection.advect_upwind(u, v, dt) for u in v)


# we assume all subgrids have identical shapes and steps
def get_subgrid_laplacians(grid, boundary_layer_widths):
    subgrid = grid.subgrid((1,1), boundary_layer_widths)
    return tuple([jnp.array(a) for a in list(map(array_utils.laplacian_matrix,subgrid.shape, subgrid.step))])

####################################################################
# this requires more investigation to be sure it actually runs distributed and isnt' causing any
# unwanted host-device comms
boundary_layer_widths = (1,1)
laplacians = get_subgrid_laplacians(grid, boundary_layer_widths)
pinv = fast_diagonalization.pseudoinverse(
    laplacians,global_pressure.array.data.dtype, 
    hermitian=True,circulant=True,implementation=None)
pinv(np.ones(laplacians[1].shape[0]).astype(np.float32))
####################################################################

    
@partial(shard_map, mesh=mesh2d, in_specs=(P('i','j'),(P('i','j'), P('i','j')), None,None,(None,None), None), out_specs=((P('i','j'),P('i','j')), P('i','j')))
def evolve_navier_stokes(pressure_field, velocities, grid, boundary_layer_widths, laplacians, num_steps):
    #create a subgrid for the current patch
    i = jax.lax.axis_index('i')
    j = jax.lax.axis_index('j')
    subgrid = grid.subgrid((i,j), boundary_layer_widths=boundary_layer_widths)
    
    explicit_update = equations.navier_stokes_explicit_terms(
        density=1.0, viscosity=1.0, dt=5E-4,grid=subgrid, convect=convect, diffuse=diffusion.diffuse, forcing=None)
    
    
 
      
    # hard coded boundary condition values in each direction
    bc_fns = [lambda t: 0.0 for _ in range(4)]
    ux_bc=((0.0, 0.0), (0.0, 0.0))
    uy_bc=((0.0, 0.0), (0.0, 0.0))
    pressure_bc = ((0.0, 0.0), (0.0, 0.0))
    
    pressure_bc = boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=pressure_bc,bc_fn=bc_fns,time_stamp=0.0)
    ux_bc, uy_bc = (boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=ux_bc,bc_fn=bc_fns,time_stamp=0.0),
                    boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=uy_bc,bc_fn=bc_fns,time_stamp=0.0))
    
    ux, uy = velocities
    for step in range(num_steps):
        local_pressure = add_halo_layer(pressure_field)
        local_ux = add_halo_layer(ux)
        local_uy = add_halo_layer(uy)
        
        #wrap into GridVariable to make jax_ib happy 
        local_pressure = grids.GridVariable(grids.GridArray(local_pressure, subgrid.cell_center, subgrid), pressure_bc)
        local_ux = grids.GridVariable(grids.GridArray(local_ux, subgrid.cell_faces[0], subgrid), velocity_bc[0])
        local_uy = grids.GridVariable(grids.GridArray(local_uy, subgrid.cell_faces[1], subgrid), velocity_bc[1])
    
        explicit = explicit_update((local_ux, local_uy))
        dP = finite_differences.forward_difference(local_pressure)

        local_u_star = tuple([u.array.data + dt * e.array.data - dp.data for u, e, dp in zip((local_ux, local_uy), explicit, dP)])
        local_u_star = tuple([grids.GridVariable(grids.GridArray(u, os, subgrid), bc) for os, u, bc in zip(subgrid.cell_faces, local_u_star, velocity_bc)])
   
        local_u_final, new_local_pressure= pressure.projection_and_update_pressure(local_pressure, local_u_star, pinv)
        ux = local_u_final[0].array.data[1:-1,1:-1]
        uy = local_u_final[0].array.data[1:-1,1:-1]
        pressure_field = new_local_pressure.array.data[1:-1,1:-1]
        
    return (ux, uy), pressure_field


In [None]:
subgrid_laplacians = get_subgrid_laplacians(grid, boundary_layer_widths)
y = evolve_navier_stokes(global_pressure.array.data, (vx_0.array.data, vy_0.array.data), grid,boundary_layer_widths, subgrid_laplacians, 4)
print('FINAL RESULT:\n', y)