In [2]:
## imports
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn
import jax_cfd.base as cfd
import xarray

In [3]:
## setting up the model parameters and shape
#original
size = (512, 128)
domain = ((0, 8), (0, 2))

size = (256,64)
domain = ((0, 8), (0, 2))

density = 1. ### original 1.
viscosity = 1e-4 # kinematic viscosity ### original 1e-3

pressure_gradient = 2e-3  # uniform dP/dx ### original 2e-3

# Define the physical dimensions of the simulation.


In [4]:
## boundary conditions
# Specify periodic BC on x-boundaries, and no-slip walls on the y-boundaries. <- channel flow settings
velocity_bc = (cfd.boundaries.channel_flow_boundary_conditions(ndim=2),
               cfd.boundaries.channel_flow_boundary_conditions(ndim=2))

In [5]:
grid = cfd.grids.Grid(size, domain=domain)

masks = cfd.grids.domain_interior_masks(grid)



In [6]:
## pressure solver
pressure_solve = cfd.pressure.solve_fast_diag_channel_flow ### solve_fast_diag_channel_flow OR solve_cg

In [7]:
## initial velocity
def x_velocity_fn(x, y):
  return jnp.zeros_like(x + y) + 0.2 * np.random.normal(
      size=grid.shape) * masks[0]

def y_velocity_fn(x, y):
  return jnp.zeros_like(x + y) #+ 0.2 * np.random.normal(size=grid.shape) * masks[1]

vx_fn = lambda x, y: jnp.zeros_like(x)
vy_fn = lambda x, y: jnp.zeros_like(x)

v0 = cfd.initial_conditions.initial_velocity_field(
           velocity_fns = (vx_fn, vy_fn),
           grid = grid,
           velocity_bc = velocity_bc,
           pressure_solve = pressure_solve,
           iterations=5)

In [8]:
## divergence check
div = cfd.finite_differences.divergence(v0)
np.testing.assert_allclose(div.data, 0)

In [9]:
## time step
max_velocity = 1
cfl_safety_factor = 0.5

dt = cfd.equations.stable_time_step(
    max_velocity, cfl_safety_factor, viscosity, grid)

In [10]:
## forcing

## Here are different types of forcing functions/conditions

# Linear force due to uniform pressure gradient
def pressure_gradient_forcing(pressure_gradient: float):
  
  def forcing(v):
    force_vector = (pressure_gradient, 0)
    return tuple(cfd.grids.GridArray(f * jnp.ones_like(u.data), u.offset, u.grid)
                 for f, u in zip(force_vector, v))
  
  return forcing

# Turbulent forcing (kolmogorov forcing)
def turbulentForcing(grid):
    return cfd.forcings.simple_turbulence_forcing(grid, constant_magnitude = 0.5, linear_coefficient = -0.8) # params specified in jax-cfd for 2D turbulence

In [11]:
forcings = cfd.forcings.sum_forcings(
    pressure_gradient_forcing(pressure_gradient),
    #turbulentForcing(grid),
    cfd.forcings.taylor_green_forcing(grid, scale = 0.05)
    #cfd.forcings.kolmogorov_forcing(grid)
    #cfd.forcings.linear_forcing(grid,1.01)
    #cfd.forcings.filtered_forcing(grid = grid, spectral_density = 1)
    #cfd.forcings.filtered_linear_forcing(grid=grid, upper_wavenumber = 2, coefficient = 1, lower_wavenumber = 1)
)

In [12]:
forcing_combinations = {
    cfd.forcings.sum_forcings(
    pressure_gradient_forcing(pressure_gradient),
    #turbulentForcing(grid),
    cfd.forcings.taylor_green_forcing(grid, scale = 0.05)
    #cfd.forcings.kolmogorov_forcing(grid)
    #cfd.forcings.linear_forcing(grid,1.01)
    #cfd.forcings.filtered_forcing(grid = grid, spectral_density = 1)
    #cfd.forcings.filtered_linear_forcing(grid=grid, upper_wavenumber = 2, coefficient = 1, lower_wavenumber = 1)
),
    cfd.forcings.sum_forcings(
    pressure_gradient_forcing(pressure_gradient),
    #turbulentForcing(grid),
    #cfd.forcings.taylor_green_forcing(grid, scale = 0.05)
    #cfd.forcings.kolmogorov_forcing(grid)
    #cfd.forcings.linear_forcing(grid,1.01)
    #cfd.forcings.filtered_forcing(grid = grid, spectral_density = 1)
    #cfd.forcings.filtered_linear_forcing(grid=grid, upper_wavenumber = 2, coefficient = 1, lower_wavenumber = 1)
),
    cfd.forcings.sum_forcings(
    pressure_gradient_forcing(pressure_gradient),
    #turbulentForcing(grid),
    cfd.forcings.taylor_green_forcing(grid, scale = 0.05)
    #cfd.forcings.kolmogorov_forcing(grid)
    #cfd.forcings.linear_forcing(grid,1.01)
    #cfd.forcings.filtered_forcing(grid = grid, spectral_density = 1)
    #cfd.forcings.filtered_linear_forcing(grid=grid, upper_wavenumber = 2, coefficient = 1, lower_wavenumber = 1)
),
    cfd.forcings.sum_forcings(
    pressure_gradient_forcing(pressure_gradient),
    #turbulentForcing(grid),
    cfd.forcings.taylor_green_forcing(grid, scale = 0.05)
    #cfd.forcings.kolmogorov_forcing(grid)
    #cfd.forcings.linear_forcing(grid,1.01)
    #cfd.forcings.filtered_forcing(grid = grid, spectral_density = 1)
    #cfd.forcings.filtered_linear_forcing(grid=grid, upper_wavenumber = 2, coefficient = 1, lower_wavenumber = 1)
)
}

In [13]:
## define convection function
def convect(v):
  return tuple(
      cfd.advection.advect_van_leer(u, v, dt) for u in v)

In [15]:
def the_iteration_over_samples(forcing):

    ## step function

    # time steps per output
    inner_steps = 1_000

    # number of outputs
    outer_steps = 20

    # Define a step function and use it to compute a trajectory.
    step_fn = cfd.funcutils.repeated(
        cfd.equations.semi_implicit_navier_stokes(
            density=density,
            viscosity=viscosity,
            dt=dt,
            grid=grid,
            convect=convect,
            pressure_solve = pressure_solve, # defined above for setting v0
            forcing= forcing
            ),
        steps=inner_steps)
    rollout_fn = jax.jit(cfd.funcutils.trajectory(
        step_fn, outer_steps, start_with_input=True))

    ## compute trajectory
    %time _, trajectory = jax.device_get(rollout_fn(v0))

    ## load into xarray for visualization and analysis
    ds = xarray.Dataset(
        {
            'u': (('time', 'x', 'y'), trajectory[0].data),
            'v': (('time', 'x', 'y'), trajectory[1].data),
        },
        coords={
            'x': grid.axes()[0],
            'y': grid.axes()[1],
            'time': dt * inner_steps * np.arange(1, outer_steps + 1)
        }
    )
    
    return ds

In [23]:
i = 0
for combination in forcing_combinations:
    print(i)
    the_iteration_over_samples(combination) #returns xarray dataset
    i+=1
    # now we need to merge the datasets by a new index
    # NOTE: we need to change the function defined above to make sure it includes the simulation specs as attributes and the velocities in a dask.array format

0
CPU times: user 33.8 s, sys: 1.59 s, total: 35.4 s
Wall time: 19 s


NameError: name 'trajectory' is not defined