## First corrections to NS equations

In [27]:
## imports
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn
import jax_cfd.base as cfd
import xarray
import pandas as pd #for saving into xarray

In [28]:
## forcing

## Here are different types of forcing functions/conditions

# Linear force due to uniform pressure gradient
def pressure_gradient_forcing(pressure_gradient: float):
  
  def forcing(v):
    force_vector = (pressure_gradient, 0)
    return tuple(cfd.grids.GridArray(f * jnp.ones_like(u.data), u.offset, u.grid)
                 for f, u in zip(force_vector, v))
  
  return forcing

# Turbulent forcing (kolmogorov forcing)
def turbulentForcing(grid):
    return cfd.forcings.simple_turbulence_forcing(grid, constant_magnitude = 0.5, linear_coefficient = -0.8) # params specified in jax-cfd for 2D turbulence

In [84]:
simSpecs = [
    ## run 1
    {
        "description" : "channel, TG and pressure gradient",
        "size" : (256,64),
        "domain" : ((0, 8), (0, 2)),
        
        "density" : 1., ### original 1.
        "viscosity" : 1e-4, # kinematic viscosity ### original 1e-3

        "pressure_gradient" : 2e-3,  # uniform dP/dx ### original 2e-3
        
        "velocity_bc" : (cfd.boundaries.channel_flow_boundary_conditions(ndim=2),
                       cfd.boundaries.channel_flow_boundary_conditions(ndim=2)),
        
        "pressure_solve" : cfd.pressure.solve_fast_diag_channel_flow, ### solve_fast_diag_channel_flow OR solve_cg
        
        "vx_fn" : lambda x, y: jnp.zeros_like(x),
        "vy_fn" : lambda x, y: jnp.zeros_like(x),
        
        "max_velocity" : 1,
        "cfl_safety_factor" : 0.5,
        
        # time steps per output
        "inner_steps" : 1,

        # number of outputs
        "outer_steps" : 50000,
        
        "forcing" : lambda pressure_gradient, grid :  cfd.forcings.sum_forcings(
                pressure_gradient_forcing(pressure_gradient),
                #turbulentForcing(grid),
                cfd.forcings.taylor_green_forcing(grid, scale = 0.05)
                #cfd.forcings.kolmogorov_forcing(grid)
                #cfd.forcings.linear_forcing(grid,1.01)
                #cfd.forcings.filtered_forcing(grid = grid, spectral_density = 1)
                #cfd.forcings.filtered_linear_forcing(grid=grid, upper_wavenumber = 2, coefficient = 1, lower_wavenumber = 1)
            )
        
    }
]


In [85]:
# Evaluate forcing and grid
for i in range(len(simSpecs)):
    simSpecs[i]['grid'] = grid = cfd.grids.Grid(simSpecs[i]["size"], 
                                                              domain=simSpecs[i]["domain"])
    simSpecs[i]['forcing'] = simSpecs[i]['forcing'](pressure_gradient = simSpecs[i]['pressure_gradient'], 
                                         grid = simSpecs[i]['grid']
                                        )

In [86]:
# function that iterates over the different simulation specifications
def theFunction(thisSim,target_sim_time=0.0,toTime=False):
    v0 = cfd.initial_conditions.initial_velocity_field(
           velocity_fns = (thisSim["vx_fn"], thisSim["vy_fn"]),
           grid = thisSim["grid"],
           velocity_bc = thisSim["velocity_bc"],
           pressure_solve = thisSim["pressure_solve"],
           iterations=5)
    
    ## divergence check
    div = cfd.finite_differences.divergence(v0)
    np.testing.assert_allclose(div.data, 0)
    
    # calculate stable timestep (do we need to save this?)
    dt = cfd.equations.stable_time_step(
        thisSim["max_velocity"], thisSim["cfl_safety_factor"], thisSim["viscosity"], thisSim["grid"])
    
    
    ## define convection function
    def convect(v):
      return tuple(
          cfd.advection.advect_van_leer(u, v, dt) for u in v)
    
    
    # time steps per output
    inner_steps = thisSim["inner_steps"]

    # number of outputs
    if toTime:
        outer_steps = target_sim_time//(inner_steps*dt)
    else: 
        outer_steps = thisSim["outer_steps"]
    
    # Define a step function and use it to compute a trajectory.
    step_fn = cfd.funcutils.repeated(
        cfd.equations.semi_implicit_navier_stokes(
            density=thisSim["density"],
            viscosity=thisSim["viscosity"],
            dt=dt,
            grid=thisSim["grid"],
            convect=convect,
            pressure_solve = thisSim["pressure_solve"],
            forcing= thisSim["forcing"]
            ),
        steps=inner_steps)
    rollout_fn = jax.jit(cfd.funcutils.trajectory(
        step_fn, outer_steps, start_with_input=True))

    ## compute trajectory
    #%time _, trajectory = jax.device_get(rollout_fn(v0))
    
    
    return jax.device_get(rollout_fn(v0)),dt,outer_steps

In [87]:
def runAllSims(simSpecs,measureTotalRuntime=True,verbose=False):

    datasets = []
    sample_nums = []
    for i in range(len(simSpecs)):
        print("Simulation number: " + str(i)) # stick to zero-based indexing for clarity
        if verbose:
            print("\tdescription: " + simSpecs[i]["description"])
        
        sample_nums.append(i)

        (_,trajectory), simSpecs[i]["dt"],outer_steps = theFunction(simSpecs[i])
        
#         (_,trajectory), simSpecs[i]["dt"],simSpecs[i]["outer_steps"] = theFunction(simSpecs[i],
#                                                         target_sim_time=150,
#                                                         toTime=True)
        
        
        print("\n")
        

        ## load into xarray for visualization and analysis


        datasets.append(xarray.Dataset(
            {
                'u': (('time', 'x', 'y'), trajectory[0].data),
                'v': (('time', 'x', 'y'), trajectory[1].data),

            },
            coords={
                'x': grid.axes()[0],
                'y': grid.axes()[1],
                'time': simSpecs[i]["dt"] * simSpecs[i]["inner_steps"] * np.arange(1, simSpecs[i]["outer_steps"] + 1),

            }#,
            #attrs = simSpecs[i]
            )

        )
    if measureTotalRuntime:
        print("\nTOTAL runtime: ")
    
    return datasets,sample_nums

In [88]:
%time datasets,sample_nums = runAllSims(simSpecs, measureTotalRuntime=True, verbose=True)

Simulation number: 0
	description: channel, TG and pressure gradient



TOTAL runtime: 
CPU times: user 1min 20s, sys: 6.01 s, total: 1min 26s
Wall time: 46.2 s


In [89]:
# specify save resolutions and coarsen ds as needed (coarsen time too)

In [90]:
final_ds = xarray.concat(datasets, pd.Index(sample_nums, name="sample")) 
## Try and change this: attributes do not get save for each "sample", only one set of attributes stored