In [None]:
import os
os.environ[ 'JAX_PLATFORM_NAME' ] = 'cpu'

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import scipy.signal as spsig
from scipy.ndimage import binary_dilation, binary_erosion
import h5py as h5
import Pulse as pulse
import glob

# logging
from logzero import logger

# jwave imports
from jwave.geometry import Domain, Medium, TimeAxis, Sources, sphere_mask, circ_mask
import jax.numpy as jnp
from jax import jit
from jwave import FourierSeries
from jwave.utils import show_field
import jax

from jax import jit
from jax import numpy as jnp
from jaxdf import FourierSeries
from jax.sharding import Mesh, PartitionSpec, NamedSharding

from jwave.acoustics import simulate_wave_propagation
from jwave.geometry import *

# pyvista/meshing imports
import pyvista as pv

In [None]:
jax.config.update( 'jax_enable_x64', False )
%matplotlib widget

# User edit
You should not have to change anyting beyond the next cell. 

In [None]:
############## user edit #######################

meshfile = './AcrylicWedge.stl'
maskfile = './AcrylicWedgeHires3D_adjusted.h5'
string_handle = 'Axial'
padding_fraction = 0.1
num_points_along_largest_dimension = 400
pml_fraction = 0.5 # fraction of the buffer region outside the object used as PML

# material properties
v_acrylic = 2730. # m/s, speed of sound
v_air = 341.
rho_acrylic = 1.18e3    # kg/m^3, acrylic density
rho_air = 1.293         # kg/m^3, air density

# pulse characteristics
pulse_hg_order = 3 # order of the Hermite-Gauss function modeling the pulse
center_frequency = 2.e6 # Hz, the dominant frequency of the input pressure wave
simulation_time = 30.e-6 # we simulate only 30 us for now, a little beyond what we currently measure.
z_loc = TEMPLATE

# hardware parameters
# num_desired_devices = None # if None, uses all available GPUs

################################################

# Stop!
Don't edit anything beyond this cell, unless you know what you're doing. 

In [None]:
# if not num_desired_devices:
#     num_desired_devices = jax.device_count()
# else: 
#     assert num_desired_devices <= jax.device_count(), f'Number of requested devices sbould be at most {jax.device_count()}. '
# while num_points_along_largest_dimension % num_desired_devices != 0: 
#     num_points_along_largest_dimension += 1
# print( f'Adjusted number of points along largest dimension: {num_points_along_largest_dimension}' )

# Acoustic impedance and reflectivity

In [None]:
Z_acrylic = v_acrylic * rho_acrylic
Z_air = v_air * rho_air
R = ( Z_acrylic - Z_air ) / ( Z_acrylic + Z_air )
logger.info( f'Reflectivity of acrylic-air interface: {R:.4f}' )

# Grid details

In [None]:
# dimensions and grid size calculations
mesh = pv.read( meshfile )
bounds = np.array( mesh.bounds ).reshape( -1, 2 )
largest_dimension = np.diff( bounds, axis=1 ).max()
excess_padding = largest_dimension * padding_fraction / 2. 
dx_mm = ( largest_dimension + 2*excess_padding ) / num_points_along_largest_dimension # spatial step size
pml_size = np.round( excess_padding * pml_fraction / dx_mm ).astype( int )
grid_ranges = [ 
    np.arange( mn, mx, dx_mm )
    for mn, mx in list( 
        bounds + np.array( [ -excess_padding, excess_padding ] )[np.newaxis,:].repeat( 3, axis=0 )
    )
]
for n in range( len( grid_ranges ) ):
    if grid_ranges[n].size%2 != 0: 
        grid_ranges[n] = np.concatenate( ( grid_ranges[n], np.array( [ grid_ranges[n][-1] + dx_mm ] ) ) )
        
shp = tuple( ar.shape[0] for ar in grid_ranges )
shp = ( shp[1], shp[0], shp[2] )
logger.info( f'Domain size: {shp} pixels. ' )
logger.info( f'PML buffer size: {pml_size} pixels. ')


In [None]:
if not glob.glob( maskfile ):
    logger.info( 'Mask not found. Creating...' )
    pts = np.concatenate( 
        [ 
            arr.ravel()[np.newaxis,:]
            for arr in np.meshgrid( 
                *( 
                    gr for gr in grid_ranges
                )
            )
        ], 
        axis=0
    ).T
    point_cloud = pv.PolyData( pts )
    selected = point_cloud.select_enclosed_points( mesh, progress_bar=True )
    mask = np.array( selected[ 'SelectedPoints' ] ).reshape( *shp )
    mask = binary_erosion( binary_dilation( mask.astype( bool ) ) ).astype( float )
    with h5.File( maskfile, 'w' ) as fid: 
        fid.create_dataset( 'mask', data=mask )
else: 
    logger.info( 'Mask found. Loading...' )
    with h5.File( maskfile, 'r' ) as fid: 
        mask = fid[ 'mask' ][:].reshape( shp ).astype( float )

In [None]:
plt.figure()
plt.pcolormesh( grid_ranges[1], grid_ranges[2], mask[:,mask.shape[0]//2,:].T )
plt.plot( [ -11. ], [ 10 ], 'or', label='Source location' )
plt.axis( 'equal' )
plt.xlabel( '$y~(mm)~\\longrightarrow$' )
plt.ylabel( '$z~(mm)~\\longrightarrow$' )
plt.colorbar()
plt.title( f'{string_handle} cross-section of wedge mask\n$\\Delta x = {dx_mm:.4f}$ mm', weight='bold' )
plt.legend()

# plt.imshow( mask[:,:,mask.shape[-1]//2], origin='lower' )

# `jwave` definitions

## Sharding

In [None]:
# msh = Mesh( jax.devices(), ( 'x', ) )

## Cross section simulation
### 2D Spatial domain

In [None]:
mask_axial = mask[:,mask.shape[1]//2,:].T
# mask_axial = jax.device_put( 
#     jnp.array( mask_axial_cpu, dtype=jnp.float32 ), 
#     NamedSharding( msh, PartitionSpec( 'x' ) )
# )

dx = ( float( dx_mm )*1.e-3, float( dx_mm )*1.e-3 )
domain_axial = Domain( mask_axial.shape, dx )

sound_speed_map = mask_axial*v_acrylic + ( 1. - mask_axial )*v_air
sound_speed_field = FourierSeries( jnp.expand_dims( sound_speed_map, axis=-1 ), domain=domain_axial )

density_map = mask_axial*rho_acrylic + ( 1. - mask_axial )*rho_air
density_field = FourierSeries( jnp.expand_dims( density_map, axis=-1 ), domain=domain_axial )

medium = Medium(domain=domain_axial, sound_speed=sound_speed_field, density=density_field, pml_size=int( pml_size ) ) #, attenuation=attenuation_map )

In [None]:
# show_field( sound_speed_map )
# show_field( density_map )

## Time domain

In [None]:
# input pulse details
time_axis = TimeAxis.from_medium( medium, cfl=0.1 )
time_axis.t_end = simulation_time # this is all we are currently measuring

# calculating time-domain characteristics of pulse
center_wavelength = v_acrylic / center_frequency # pulse wavelength
period = 1. / center_frequency 

logger.info( f'Dominant frequency: {center_frequency:2e} Hz' )
logger.info( f'Dominant_wavelength: {center_wavelength:.2e} m' )
logger.info( f'Dominant time period: {period:.2e} sec' )
logger.info( f'Time step: {time_axis.dt:.2e} sec' )
logger.info( f'Number of time steps: {int(time_axis.Nt)}' )

t = np.arange( 0., float( time_axis.t_end ), float( time_axis.dt ) )
init = pulse.HermiteGauss( t, 1.5*period, period/2., pulse_hg_order )
plt.figure()
plt.plot( t, init, label=f'Input pulse (HG{pulse_hg_order})' )
plt.xlabel( 'Simulation time (sec) $\\longrightarrow$' ) 
plt.legend()
plt.grid()

y_loc = min( np.where( mask[:,:,mask.shape[0]//2].max(axis=1)==1 )[0] )

sources = Sources( 
    positions=( ( z_loc,  ),  ( y_loc, ) ), 
    signals=jnp.stack( [ init ] ),
    dt=time_axis.dt, 
    domain=domain_axial
)

# Run simulation

In [None]:
@jit
def compiled_simulator( sources ):
    return simulate_wave_propagation( medium, time_axis, sources=sources )


In [None]:
pressure = compiled_simulator( sources )
pres = np.squeeze( pressure.on_grid )
logger.info( f'Simulation size: {pres.shape}' )
pmax = min( np.abs( pres.min() ), np.abs( pres.max() ) )
pmin = -pmax

In [None]:
with h5.File( f'{string_handle}Waves_{z_loc}.h5', 'w' ) as fid: 
    fid.create_dataset( 'pressure', data=pres )
    for key, val in {
        'y':grid_ranges[0], 
        'z':grid_ranges[2],
        'time_steps':t.size, 
        'dt':time_axis.dt
    }.items():
        fid[ 'pressure' ].attrs[ key] = val 
