<a href="https://colab.research.google.com/github/hashimmg/jax_IB/blob/main/Flapping_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%load_ext autoreload
%autoreload 2

In [68]:
import numpy as np
import tree_math as tm
import jax
import jax.numpy as jnp
from jax_ib.base import particle_class as pc
from jax_ib.base import grids, fast_diagonalization, boundaries, pressure, diffusion, advection, finite_differences, IBM_Force,convolution_functions,particle_motion, equations
import jax_cfd.base as cfd
import jax_ib.MD as MD
from jax import random
from jax_md import space, quantity
import jax_ib
import jax_ib.base as ib
from jax_ib.base import kinematics as ks
from jax.random import uniform as random_uniform
import matplotlib.pyplot as plt
import functools as fct
import scipy 
import tree_math
from jax_ib.base import array_utils

In [3]:
from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

In [4]:
jax.devices()

[CudaDevice(id=0),
 CudaDevice(id=1),
 CudaDevice(id=2),
 CudaDevice(id=3),
 CudaDevice(id=4),
 CudaDevice(id=5),
 CudaDevice(id=6),
 CudaDevice(id=7)]

In [72]:
mesh2d = jax.make_mesh((4, 2), ('i', 'j'))

In [73]:
#grid = jax_ib.base.grids.Grid(shape=(16,16) ,domain = ((0,16),(0,16)), device_mesh=mesh2d)


In [90]:
density = 1.0    # fluid density
viscosity = 0.05 # fluid viscocity
dt=5e-4          # time step 
num_boundaries = 4
grid = grids.Grid((128,128), domain=((0,15.),(0,15.0)), device_mesh = mesh2d, periods = (15,15)) #gridpoints in 
bc_fns = [lambda t: 0.0 for _ in range(4)]
vx_bc=((0.0, 0.0), (0.0, 0.0))
vy_bc=((0.0, 0.0), (0.0, 0.0))

velocity_bc = (boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=vx_bc,bc_fn=bc_fns,time_stamp=0.0),
               boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=vy_bc,bc_fn=bc_fns,time_stamp=0.0))

vx_fn = lambda x, y: jnp.zeros_like(x)
vy_fn = lambda x, y: jnp.zeros_like(x)
vx_0, vy_0 = tuple(
    [
       grids.GridVariable
        (
           grid.eval_on_mesh(fn = lambda x, y: jnp.ones_like(x), offset = offset), bc # initial values for fluid velocities are 0 both in x and y direction
        ) 
        for offset, bc in zip(grid.cell_faces,velocity_bc)
    ]
)
v0 = (vx_0, vy_0)
global_pressure = grids.GridVariable(
    grids.GridArray(jnp.zeros(grid.shape), grid.cell_center, grid), 
    boundaries.get_pressure_bc_from_velocity((vx_0, vy_0)))




subgrid=grid.subgrid((1,1),boundary_layer_widths=(1,1))

In [122]:
def gather_boundary_layers(array):
    I, J = jax.lax.psum(1, 'i'), jax.lax.psum(1, 'j')
    left_halo = jax.lax.ppermute(array[:,-1], 'j', [(j, (j + 1) % J) for j in range(J)])
    right_halo = jax.lax.ppermute(array[:,0], 'j', [(j, (j - 1) % J) for j in range(J)])
    upper_halo = jax.lax.ppermute(array[-1,:], 'i', [(i, (i + 1) % I) for i in range(I)])
    lower_halo = jax.lax.ppermute(array[0,:], 'i', [(i, (i - 1) % I) for i in range(I)])
    
    upper_halo = jnp.append(0.0, jnp.append(upper_halo, 0.0))
    lower_halo = jnp.append(0.0, jnp.append(lower_halo, 0.0))
    
    temp = jnp.concatenate([left_halo[:,None], array, right_halo[:,None]], axis=1)
    return  jnp.concatenate([upper_halo[None, :], temp, lower_halo[None,:]], axis=0)

def convect(v):
    return tuple(advection.advect_upwind(u, v, dt) for u in v)


# we assume all subgrids have identical shapes and steps
def get_subgrid_laplacians(grid, boundary_layer_widths):
    subgrid = grid.subgrid((1,1), boundary_layer_widths)
    return tuple([jnp.array(a) for a in list(map(array_utils.laplacian_matrix,subgrid.shape, subgrid.step))])

####################################################################
# this requires more investigation to be sure it actually runs distributed and isnt' causing any
# unwanted host-device comms
boundary_layer_widths = (1,1)
laplacians = get_subgrid_laplacians(grid, boundary_layer_widths)
pinv = fast_diagonalization.pseudoinverse(
    laplacians,global_pressure.array.data.dtype, 
    hermitian=True,circulant=True,implementation=None)
pinv(np.ones(laplacians[1].shape[0]).astype(np.float32))
####################################################################


@partial(shard_map, mesh=mesh2d, in_specs=(P('i','j'),(P('i','j'), P('i','j')), None,None,(None,None)), out_specs=((P('i','j'),P('i','j')), P('i','j')))
def master(global_pressure, global_velocities, grid, boundary_layer_widths, laplacians):
    local_pressure = gather_boundary_layers(global_pressure)
    global_ux, global_uy = global_velocities

    
    local_ux = gather_boundary_layers(global_ux)
    local_uy = gather_boundary_layers(global_uy)
    i = jax.lax.axis_index('i')
    j = jax.lax.axis_index('j')
    subgrid = grid.subgrid((i,j), boundary_layer_widths=boundary_layer_widths)#, boundary_conditions = ('periodic','periodic'))
    bc_fns = [lambda t: 0.0 for _ in range(4)]
    ux_bc=((0.0, 0.0), (0.0, 0.0))
    uy_bc=((0.0, 0.0), (0.0, 0.0))
    pressure_bc = ((0.0, 0.0), (0.0, 0.0))
    boundary_condition = boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=pressure_bc,bc_fn=bc_fns,time_stamp=0.0)

    velocity_bc = (boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=ux_bc,bc_fn=bc_fns,time_stamp=0.0),
                   boundaries.new_periodic_boundary_conditions(ndim=2,bc_vals=uy_bc,bc_fn=bc_fns,time_stamp=0.0))

    local_pressure = grids.GridVariable(grids.GridArray(local_pressure, subgrid.cell_center, subgrid), boundary_condition)
    local_ux = grids.GridVariable(grids.GridArray(local_ux, subgrid.cell_faces[0], subgrid), velocity_bc[0])
    local_uy = grids.GridVariable(grids.GridArray(local_uy, subgrid.cell_faces[1], subgrid), velocity_bc[1])
    
    explicit_update = equations.navier_stokes_explicit_terms(
        density=1.0, viscosity=1.0, dt=5E-4,grid=subgrid, convect=convect, diffuse=diffusion.diffuse, forcing=None)
    explicit = explicit_update((local_ux, local_uy))
    dP = finite_differences.forward_difference(local_pressure)

    local_u_star = tuple([u.array.data + dt * e.array.data - dp.data for u, e, dp in zip((local_ux, local_uy), explicit, dP)])
    local_u_star = tuple([grids.GridVariable(grids.GridArray(u, os, subgrid), bc) for os, u, bc in zip(subgrid.cell_faces, local_u_star, velocity_bc)])
   
    local_u_final, new_local_pressure= pressure.projection_and_update_pressure(local_pressure, local_u_star, pinv)
    
    return tuple([u.array.data for u in local_u_final]), new_local_pressure.array.data

In [123]:
subgrid_laplacians = get_subgrid_laplacians(grid, boundary_layer_widths)
y = master(global_pressure.array.data, (vx_0.array.data, vy_0.array.data), grid,boundary_layer_widths, subgrid_laplacians)
print('FINAL RESULT:\n', y)

FINAL RESULT:
 ((Array([[0.5355639 , 0.8930345 , 0.9865329 , ..., 1.1325518 , 1.2622547 ,
        0.802761  ],
       [0.7995359 , 0.8930347 , 0.9540392 , ..., 1.0727541 , 1.0738642 ,
        0.90533894],
       [0.9048818 , 0.9255292 , 0.9540402 , ..., 1.0323198 , 1.0123001 ,
        0.947328  ],
       ...,
       [0.905082  , 1.0711942 , 1.0723054 , ..., 0.95433354, 0.8939764 ,
        0.801521  ],
       [0.8004292 , 1.2595847 , 1.1338688 , ..., 0.98643106, 0.8939762 ,
        0.53789556],
       [0.5355639 , 1.1570066 , 1.0918789 , ..., 1.0904495 , 1.1576016 ,
        0.53789556]], dtype=float32), Array([[0.5374064 , 0.8013785 , 0.9067246 , ..., 0.90692484, 0.80227196,
        0.5374064 ],
       [0.8948771 , 0.8948769 , 0.92737055, ..., 1.0730366 , 1.261427  ,
        1.1588492 ],
       [0.98837584, 0.9558815 , 0.9558804 , ..., 1.0741467 , 1.1357107 ,
        1.0937217 ],
       ...,
       [1.1343938 , 1.0745953 , 1.034161  , ..., 0.95617574, 0.98827404,
        1.0922922 ],
  

In [8]:

fmv = 6.79
sp = 1.72
price = 13.00
shares = 100000


strike = shares * 1.72
einkst = shares * (fmv - sp) * 0.5
kapst = shares * (price - fmv) * 0.275
print(shares * price - kapst - einkst-strike, kapst, einkst)




703725.0 170775.0 253500.0


In [None]:
*