In [None]:
import enum
import matplotlib.pyplot as plt
import numpy as np

# Boundary Condition

In [None]:
from numpy import iterable


class BoundaryCondition(enum.Enum):
  """Defines the type of boundary conditions."""

  PERIODIC = 0
  DIRICHLET = 1
  NEUMANN = 2
  REFLECTIVE = 3


class DomainType(enum.Enum):
  """Defines the type of a volume inside the simulation domain."""

  FLUID = 0
  SOLID = 1
  INLET = 2
  OUTLET = 3
  IMPENETRABLE_WALL = 4
  REFLECTIVE_WALL = 5


def periodic(v, dim, halo_width):
  """Applies the periodic BC."""
  if dim == 'x':
    v[:halo_width, ...] = v[-2 * halo_width : -halo_width, ...]
    v[-halo_width:, ...] = v[halo_width : 2 * halo_width, ...]
  elif dim == 'y':
    v[:, :halo_width, :] = v[:, -2 * halo_width : -halo_width, :]
    v[:, -halo_width:, :] = v[:, halo_width : 2 * halo_width, :]
  else:  # dim == 'z':
    v[..., :halo_width] = v[..., -2 * halo_width : -halo_width]
    v[..., -halo_width:] = v[..., halo_width : 2 * halo_width]

  return v


def dirichlet(v, dim, face, val, halo_width):
  """Applies the Dirichlet BC."""
  if dim == 'x':
    if face == 0:
      if isinstance(val, list) or isinstance(val, tuple):
        for i in range(len(val)):
          v[i, ...] = val[i]
      else:
        v[:halo_width, ...] = val
    else:  # face == 1:
      if isinstance(val, list) or isinstance(val, tuple):
        for i in range(len(val)):
          v[-halo_width + i, ...] = val[i]
      else:
        v[-halo_width:, ...] = val
  elif dim == 'y':
    if face == 0:
      if isinstance(val, list) or isinstance(val, tuple):
        for i in range(len(val)):
          v[:, i, :] = val[i]
      else:
        v[:, :halo_width, :] = val
    else:  # face == 1:
      if isinstance(val, list) or isinstance(val, tuple):
        for i in range(len(val)):
          v[:, -halo_width + i, :] = val[i]
      else:
        v[:, -halo_width:, :] = val
  else:  # dim == 'z':
    if face == 0:
      if isinstance(val, list) or isinstance(val, tuple):
        for i in range(len(val)):
          v[..., i] = val[i]
      else:
        v[..., :halo_width] = val
    else:  # face == 1:
      if isinstance(val, list) or isinstance(val, tuple):
        for i in range(len(val)):
          v[..., -halo_width + i] = val[i]
      else:
        v[..., -halo_width:] = val

  return v


def neumann(v, dim, face, val, halo_width):
  """Applies the Neumann BC."""
  if dim == 'x':
    if face == 0:
      for i in range(halo_width):
        v[halo_width - 1 - i, ...] = v[halo_width - i, ...] - val
    else:  # face == 1:
      for i in range(halo_width):
        v[-halo_width + i, ...] = v[-halo_width - 1 + i, ...] + val
  elif dim == 'y':
    if face == 0:
      for i in range(halo_width):
        v[:, halo_width - 1 - i, :] = v[:, halo_width - i, :] - val
    else:  # face == 1:
      for i in range(halo_width):
        v[:, -halo_width + i, :] = v[:, -halo_width - 1 + i, :] + val
  else:  # dim == 'z':
    if face == 0:
      for i in range(halo_width):
        v[..., halo_width - 1 - i] = v[..., halo_width - i] - val
    else:  # face == 1:
      for i in range(halo_width):
        v[..., -halo_width + i] = v[..., -halo_width - 1 + i] + val

  return v


def reflective(v, dim, face, halo_width):
  """Applies the Dirichlet BC."""
  if dim == 'x':
    if face == 0:
      for i in range(halo_width):
        v[halo_width - 1 - i, ...] = v[halo_width + i, ...]
    else:  # face == 1:
      for i in range(halo_width):
        v[-halo_width + i, ...] = v[-halo_width - 1 - i, ...]
  elif dim == 'y':
    if face == 0:
      for i in range(halo_width):
        v[:, halo_width - 1 - i, :] = v[:, halo_width + i, :]
    else:  # face == 1:
      for i in range(halo_width):
        v[:, -halo_width + i, :] = v[:, -halo_width - 1 - i, :]
  else:  # dim == 'z':
    if face == 0:
      for i in range(halo_width):
        v[..., halo_width - 1 - i] = v[..., halo_width + i]
    else:  # face == 1:
      for i in range(halo_width):
        v[..., -halo_width + i] = v[..., -halo_width - 1 - i]

  return v


def update_boundary(states, params):
  """Updates the boundary conditions (BC)."""
  dims = ('x', 'y', 'z')
  faces = (0, 1)

  halo_width = params['halo_width']

  # Enforce boundary conditions on the nodes on the boundary of the simulation
  # domain.
  def update_bc(val, bc):
    """Updates the boundary condition for `varname`."""
    for dim in dims:
      if dim not in bc:
        continue

      for face in faces:
        if bc[dim][face][0] == BoundaryCondition.PERIODIC:
          val = periodic(val, dim, halo_width)
        elif bc[dim][face][0] == BoundaryCondition.DIRICHLET:
          val = dirichlet(val, dim, face, bc[dim][face][1], halo_width)
        elif bc[dim][face][0] == BoundaryCondition.NEUMANN:
          val = neumann(val, dim, face, bc[dim][face][1], halo_width)
        elif bc[dim][face][0] == BoundaryCondition.REFLECTIVE:
          val = reflective(val, dim, face, halo_width)

    return val

  states = {
      key: update_bc(val, params['bc'][key]) for key, val in states.items()
  }

  return states


def reflective_flux(flux, helper_vars):
  """Applies reflective boundary condition to fluxes."""
  dims = ('x', 'y', 'z')

  for dim in dims:
    mask_name = f'mask_{dim}'
    if mask_name not in helper_vars:
      continue

    mask = np.where(
        np.abs(helper_vars[mask_name]) == DomainType.REFLECTIVE_WALL.value
    )
    for key, val in flux[dim].items():
      s = helper_vars[mask_name][mask] / DomainType.REFLECTIVE_WALL.value
      val[mask] = 0.5 * (s + np.abs(s)) * -np.minimum(val[mask], 0.0) + 0.5 * (
          s - np.abs(s)
      ) * -np.maximum(val[mask], 0.0)

  return flux


def impenetrable_flux(flux, helper_vars):
  """Applies impenetrable boundary condition to fluxes."""
  dims = ('x', 'y', 'z')

  for dim in dims:
    mask_name = f'mask_{dim}'
    if mask_name not in helper_vars:
      continue

    mask = np.where(
        np.abs(helper_vars[mask_name]) == DomainType.IMPENETRABLE_WALL.value
    )
    for key, val in flux[dim].items():
      val[mask] = 0.0

  return flux


def update_flux_boundary_condition(flux, helper_vars, params):
  """Updates the boundary conditions for fluxes."""

  flux = reflective_flux(flux, helper_vars)

  flux = impenetrable_flux(flux, helper_vars)

  return flux

# Numerics

In [None]:
def central2(f, h, dim):
  """Computes the 1st order derivative with central difference in dim."""
  dfdh = np.zeros_like(f)
  if dim == 'x':
    dfdh[1:-1, ...] = (f[2:, ...] - f[:-2, ...]) / (2.0 * h)
  elif dim == 'y':
    dfdh[:, 1:-1, :] = (f[:, 2:, :] - f[:, :-2, :]) / (2.0 * h)
  else:  # dim == 'z':
    dfdh[..., 1:-1] = (f[..., 2:] - f[..., :-2]) / (2.0 * h)

  return dfdh


def backward1(f, h, dim):
  """Computes the 1st order derivative with the 1st order backward diff."""
  dfdh = np.zeros_like(f)
  if dim == 'x':
    dfdh[1:, ...] = (f[1:, ...] - f[:-1, ...]) / h
  elif dim == 'y':
    dfdh[:, 1:, :] = (f[:, 1:, ...] - f[:, :-1, :]) / h
  else:  # dim == 'z':
    dfdh[..., 1:] = (f[..., 1:] - f[..., :-1]) / h

  return dfdh


def forward1(f, h, dim):
  """Computes the 1st order derivative with the 1st order forward diff."""
  dfdh = np.zeros_like(f)
  if dim == 'x':
    dfdh[:-1, ...] = (f[1:, ...] - f[:-1, ...]) / h
  elif dim == 'y':
    dfdh[:, :-1, :] = (f[:, 1:, ...] - f[:, :-1, :]) / h
  else:  # dim == 'z':
    dfdh[..., :-1] = (f[..., 1:] - f[..., :-1]) / h

  return dfdh


def dot(vec1, vec2):
  """Compputes the dot product of 2 vectors."""
  return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2]


def linear_interp(v, dim):
  """Computes the face interpolation of f with a linear scheme.

  Args:
    v: A 3D tensor to which the interpolation is performed.
    dim: The dimension of interpolation.

  Returns:
    A 3D tensor that has the same size as `v`. Each node stores the value of
    the right face, e.g. node i contains the value of the i + 1/2 face. The
    value of f is extrapolated to the extreme value of f
  """
  f = np.zeros_like(v)
  if dim == 'x':
    f[:-1, :, :] = v[:-1, :, :] + 0.5 * np.diff(v, axis=0)
    f[-1, :, :] = v[-1, :, :]
  elif dim == 'y':
    f[:, :-1, :] = v[:, :-1, :] + 0.5 * np.diff(v, axis=1)
    f[:, -1, :] = v[:, -1, :]
  else:  # dim == 'z':
    f[:, :, :-1] = v[:, :, :-1] + 0.5 * np.diff(v, axis=2)
    f[:, :, -1] = v[:, :, -1]
  return f


def weno5(v, dim):
  """Computes the face interpolation of f with the 5rd order WENO scheme.

  Args:
    v: A 3D tensor to which the WENO5 interpolation is performed.
    dim: The dimension of interpolation.

  Returns:
    A 3D tensor that has the same size as `v`. Each node stores the value of
    the right face, e.g. node i contains the value of the i + 1/2 face. This
    is true for both the negative and positive interpolations. Invalid stencils
    are marked with zeros.
  """
  k = 3
  eps = 1e-6
  c = {
      3: {
          -1: [
              11.0 / 6.0,
              -7.0 / 6.0,
              1.0 / 3.0,
          ],
          0: [1.0 / 3.0, 5.0 / 6.0, -1.0 / 6.0],
          1: [-1.0 / 6.0, 5.0 / 6.0, 1.0 / 3.0],
          2: [1.0 / 3.0, -7.0 / 6.0, 11.0 / 6.0],
      }
  }
  d = {
      3: {0: 0.3, 1: 0.6, 2: 0.1},
  }

  nx, ny, nz = v.shape

  offset = k - 1
  l = {'x': nx, 'y': ny, 'z': nz}[dim] - 2 * offset

  # Compute the reconstructed values on faces.
  vr_neg = {}
  vr_pos = {}
  for r in range(k):
    if dim == 'x':
      buf_neg = np.zeros((l, ny, nz), dtype=v.dtype)
      buf_pos = np.zeros((l, ny, nz), dtype=v.dtype)
    elif dim == 'y':
      buf_neg = np.zeros((nx, l, nz), dtype=v.dtype)
      buf_pos = np.zeros((nx, l, nz), dtype=v.dtype)
    else:  # dim == 'z':
      buf_neg = np.zeros((nx, ny, l), dtype=v.dtype)
      buf_pos = np.zeros((nx, ny, l), dtype=v.dtype)

    for j in range(k):
      idx_0 = offset - r + j
      idx_1 = idx_0 + l
      if dim == 'x':
        buf_neg += c[k][r][j] * v[idx_0:idx_1, ...]
        buf_pos += c[k][r - 1][j] * v[idx_0:idx_1, ...]
      elif dim == 'y':
        buf_neg += c[k][r][j] * v[:, idx_0:idx_1, :]
        buf_pos += c[k][r - 1][j] * v[:, idx_0:idx_1, :]
      else:  # dim == 'z':
        buf_neg += c[k][r][j] * v[..., idx_0:idx_1]
        buf_pos += c[k][r - 1][j] * v[..., idx_0:idx_1]

    vr_neg.update({r: buf_neg})
    vr_pos.update({r: buf_pos})

  # Copmute the smoothness measurement.
  beta_0 = (
      lambda v1, v2, v3: 13.0 / 12.0 * (v1 - 2.0 * v2 + v3) ** 2
      + 0.25 * (3.0 * v1 - 4.0 * v2 + v3) ** 2
  )
  beta_1 = (
      lambda v1, v2, v3: 13.0 / 12.0 * (v1 - 2.0 * v2 + v3) ** 2
      + 0.25 * (v1 - v3) ** 2
  )
  beta_2 = (
      lambda v1, v2, v3: 13.0 / 12.0 * (v1 - 2.0 * v2 + v3) ** 2
      + 0.25 * (v1 - 4.0 * v2 + 3.0 * v3) ** 2
  )
  if dim == 'x':
    beta = {
        0: beta_0(v[2 : 2 + l, ...], v[3 : 3 + l, ...], v[4 : 4 + l, ...]),
        1: beta_1(v[1 : 1 + l, ...], v[2 : 2 + l, ...], v[3 : 3 + l, ...]),
        2: beta_2(v[:l, ...], v[1 : 1 + l, ...], v[2 : 2 + l, ...]),
    }
  elif dim == 'y':
    beta = {
        0: beta_0(v[:, 2 : 2 + l, :], v[:, 3 : 3 + l, :], v[:, 4 : 4 + l, :]),
        1: beta_1(v[:, 1 : 1 + l, :], v[:, 2 : 2 + l, :], v[:, 3 : 3 + l, :]),
        2: beta_2(v[:, :l, :], v[:, 1 : 1 + l, :], v[:, 2 : 2 + l, :]),
    }
  else:  # dim == 'z':
    beta = {
        0: beta_0(v[..., 2 : 2 + l], v[..., 3 : 3 + l], v[..., 4 : 4 + l]),
        1: beta_1(v[..., 1 : 1 + l], v[..., 2 : 2 + l], v[..., 3 : 3 + l]),
        2: beta_2(v[..., :l], v[..., 1 : 1 + l], v[..., 2 : 2 + l]),
    }

  # Compute the weights.
  alpha_neg = {}
  alpha_pos = {}
  alpha_s_neg = np.zeros_like(beta[0])
  alpha_s_pos = np.zeros_like(beta[0])
  for r in range(k):
    alpha_neg.update({r: d[k][r] / (eps + beta[r]) ** 2})
    alpha_pos.update({r: d[k][k - 1 - r] / (eps + beta[r]) ** 2})
    alpha_s_neg += alpha_neg[r]
    alpha_s_pos += alpha_pos[r]

  omega_neg = {}
  omega_pos = {}
  for r in range(k):
    omega_neg.update({r: alpha_neg[r] / alpha_s_neg})
    omega_pos.update({r: alpha_pos[r] / alpha_s_pos})

  # Compute the weighted interpolated face values.
  def clear_interior(u, dim, sign):
    """Sets the interior of a variable to 0."""
    idx_0 = offset if sign == '-' else offset - 1
    idx_1 = -offset if sign == '-' else -offset - 1
    if dim == 'x':
      u[idx_0:idx_1, ...] = 0.0
    elif dim == 'y':
      u[:, idx_0:idx_1, :] = 0.0
    else:  # dim == 'z':
      u[..., idx_0:idx_1] = 0.0
    return u

  v_neg = clear_interior(np.copy(v), dim, '-')
  v_pos = clear_interior(np.copy(v), dim, '+')
  for r in range(k):
    if dim == 'x':
      v_neg[offset:-offset, ...] += omega_neg[r] * vr_neg[r]
      v_pos[offset - 1 : -offset - 1, ...] += omega_pos[r] * vr_pos[r]
    elif dim == 'y':
      v_neg[:, offset:-offset, :] += omega_neg[r] * vr_neg[r]
      v_pos[:, offset - 1 : -offset - 1, :] += omega_pos[r] * vr_pos[r]
    else:  # dim == 'z':
      v_neg[..., offset:-offset] += omega_neg[r] * vr_neg[r]
      v_pos[..., offset - 1 : -offset - 1] += omega_pos[r] * vr_pos[r]

  return v_neg, v_pos


def muscl(v, dim):
  """Computes the face value with the MUSCL scheme."""

  def minmod(a):
    """Computes the minmod function."""
    s = np.sum([np.sign(a_i) for a_i in a], axis=0) / len(a)
    a_min = np.abs(a[0])
    for i in range(1, len(a)):
      a_min = np.minimum(a_min, np.abs(a[i]))

    return np.where(np.abs(s) == 1, s * a_min, np.zeros_like(a_min))

  grad = (central2(v, 2.0, dim), forward1(v, 1.0, dim), backward1(v, 1.0, dim))

  m = minmod(grad)

  v_neg = v + 0.5 * m
  v_pos = np.copy(v)
  if dim == 'x':
    v_pos[:-1, ...] = v[1:, ...] - 0.5 * m[1:, ...]
  elif dim == 'y':
    v_pos[:, :-1, :] = v[:, 1:, :] - 0.5 * m[:, 1:, :]
  else:  # dim == 'z':
    v_pos[..., :-1] = v[..., 1:] - 0.5 * m[..., 1:]

  return v_neg, v_pos

In [None]:
class TimeIntegrationScheme(enum.Enum):
  """Defines the time-integration schemes."""

  RK3 = 'rk3'


def rk3(rhs_fn, states_0, helper_vars, params):
  """Integrates and ODE with the 3rd order Runge-Kutta scheme."""
  l = rhs_fn(states_0, helper_vars, params)
  states_1 = {key: val + params['dt'] * l[key] for key, val in states_0.items()}
  states_1 = update_boundary(states_1, params)

  l = rhs_fn(states_1, helper_vars, params)
  states_2 = {
      key: 0.75 * val + 0.25 * (states_1[key] + params['dt'] * l[key])
      for key, val in states_0.items()
  }
  states_2 = update_boundary(states_2, params)

  l = rhs_fn(states_2, helper_vars, params)
  states_n = {
      key: 1.0 / 3.0 * val + 2.0 / 3.0 * (states_2[key] + params['dt'] * l[key])
      for key, val in states_0.items()
  }
  return update_boundary(states_n, params)


def time_integration(
    rhs_fn, states_0, helper_vars, params, scheme=TimeIntegrationScheme.RK3
):
  """Performs the time integration."""
  if scheme == TimeIntegrationScheme.RK3:
    return rk3(rhs_fn, states_0, helper_vars, params)
  else:
    raise NotImplementedError(f'{scheme.name} is not supported.')

# Physics

In [None]:
def strain_rate(states_n, states_fx, states_fy, states_fz, params):
  """Computes the strain rate on faces."""
  div_u_x = (
      forward1(states_n['u'], params['dx'], 'x')
      + central2(states_fx['v'], params['dy'], 'y')
      + central2(states_fx['w'], params['dz'], 'z')
  )
  div_u_y = (
      central2(states_fy['u'], params['dx'], 'x')
      + forward1(states_n['v'], params['dy'], 'y')
      + central2(states_fy['w'], params['dz'], 'z')
  )
  div_u_z = (
      central2(states_fz['u'], params['dx'], 'x')
      + central2(states_fz['v'], params['dy'], 'y')
      + forward1(states_n['w'], params['dz'], 'z')
  )
  return [
      [
          forward1(states_n['u'], params['dx'], 'x') - 1.0 / 3.0 * div_u_x,
          0.5
          * (
              forward1(states_n['u'], params['dy'], 'y')
              + central2(states_fy['v'], params['dx'], 'x')
          ),
          0.5
          * (
              forward1(states_n['u'], params['dz'], 'z')
              + central2(states_fz['w'], params['dx'], 'x')
          ),
      ],
      [
          0.5
          * (
              forward1(states_n['v'], params['dx'], 'x')
              + central2(states_fx['u'], params['dy'], 'y')
          ),
          forward1(states_n['v'], params['dy'], 'y') - 1.0 / 3.0 * div_u_y,
          0.5
          * (
              forward1(states_n['v'], params['dz'], 'z')
              + central2(states_fz['w'], params['dy'], 'y')
          ),
      ],
      [
          0.5
          * (
              forward1(states_n['w'], params['dx'], 'x')
              + central2(states_fx['u'], params['dz'], 'z')
          ),
          0.5
          * (
              forward1(states_n['w'], params['dy'], 'y')
              + central2(states_fy['v'], params['dz'], 'z')
          ),
          forward1(states_n['w'], params['dz'], 'z') - 1.0 / 3.0 * div_u_z,
      ],
  ]


def shear_stress(states_n, states_fx, states_fy, states_fz, params):
  """Computes the shear stress tensor on faces."""
  s = strain_rate(states_n, states_fx, states_fy, states_fz, params)
  rho = (states_fx['rho'], states_fy['rho'], states_fz['rho'])
  return [[rho[j] * params['nu'] * s[i][j] for j in range(3)] for i in range(3)]

In [None]:
# Physical constants.
G = 9.81
GAMMA = 1.4
R = 8.314 / 0.029
CV = R / (GAMMA - 1.0)
CP = CV + R
KAPPA = R / CP


def expanded_coordinates(params, dim):
  """Expands a 1D coordinates into 3D."""
  if dim == 'x':
    l = params[dim][:, np.newaxis, np.newaxis]
  elif dim == 'y':
    l = params[dim][np.newaxis, :, np.newaxis]
  else:  # dim == 'z':
    l = params[dim][np.newaxis, np.newaxis, :]

  return l


def gravity_direction(params):
  """Finds the direction of the gravity."""
  dims = ('x', 'y', 'z')
  eps = np.finfo(np.float32).resolution
  g_dim = -1
  for i in range(len(dims)):
    if np.abs(np.abs(params['g'][dims[i]]) - 1.0) < eps:
      g_dim = i
      break

  return g_dim


def pressure(states, params):
  """Computes the pressure."""
  e_int = states['e'] - 0.5 * (
      states['u'] ** 2 + states['v'] ** 2 + states['w'] ** 2
  )

  # Subtract the potential energy if params is provided.
  if 'x' in params and 'y' in params and 'z' in params and 'g' in params:
    e_int -= (
        G * np.abs(params['g']['x']) * expanded_coordinates(params, 'x')
        + G * np.abs(params['g']['y']) * expanded_coordinates(params, 'y')
        + G * np.abs(params['g']['z']) * expanded_coordinates(params, 'z')
    )

  return (GAMMA - 1.0) * states['rho'] * e_int


def total_enthalpy(states):
  """Computes the total enthalpy."""
  return states['e'] + states['p'] / states['rho']


def sound_speed(states):
  """Computes the speed of sound."""
  return np.sqrt(GAMMA * states['p'] / states['rho'])


def primitive_variables(states, params):
  """Converts conservative variables to primitive variables."""
  primitive = {
      'u': states['rhou'] / states['rho'],
      'v': states['rhov'] / states['rho'],
      'w': states['rhow'] / states['rho'],
      'e': states['rhoe'] / states['rho'],
  }
  primitive.update({'rho': states['rho']})
  primitive.update({'p': pressure(primitive, params)})
  primitive.update({'h': total_enthalpy(primitive)})
  return primitive


def pressure_hydrostatic(states, params):
  """Computes the pressure under hydrostatic condition."""
  dims = ('x', 'y', 'z')
  # Find the dimension of gravity. Supports gravity along a single axis only.
  g_dir = gravity_direction(params)

  if g_dir == -1:
    return params['p0'] * np.ones_like(states['theta'])

  # Integrates the pressure ODE from 0 to z.
  g_dim = dims[g_dir]
  buf = []
  for i in range(params['halo_width'] + 1, len(params[g_dim]) + 1):
    l = expanded_coordinates(
        {g_dim: params[g_dim][params['halo_width'] : i]}, g_dim
    )
    if g_dim == 'x':
      theta = states['theta'][params['halo_width'] : i, ...]
    elif g_dim == 'y':
      theta = states['theta'][:, params['halo_width'] : i, :]
    else:  # g_dim == 'z':
      theta = states['theta'][..., params['halo_width'] : i]

    buf.append(np.trapz(-G * KAPPA / R / theta, l, axis=g_dir))

  # Integrates the pressure ODE in ghost cells below 0.
  for i in range(params['halo_width']):
    idx_0 = params['halo_width'] - 1 - i
    l = expanded_coordinates(
        {g_dim: params[g_dim][idx_0 : params['halo_width'] + 1]}, g_dim
    )
    if g_dim == 'x':
      theta = states['theta'][idx_0 : params['halo_width'] + 1, ...]
    elif g_dim == 'y':
      theta = states['theta'][:, idx_0 : params['halo_width'] + 1, :]
    else:  # g_dim == 'z':
      theta = states['theta'][..., idx_0 : params['halo_width'] + 1]

    buf.insert(0, -np.trapz(-G * KAPPA / R / theta, l, axis=g_dir))

  buf = np.array(buf).transpose(np.insert((1, 2), g_dir, 0))

  return params['p0'] * (1.0 + buf) ** (1.0 / KAPPA)


def potential_temperature(states, params):
  """Computes the potential temperature."""
  return states['T'] * (states['p'] / params['p0']) ** -KAPPA


def temperature(states, params, opt):
  """Computes temperature from potential temperature."""
  if opt == 'theta':
    return states['theta'] * (states['p'] / params['p0']) ** KAPPA
  elif opt == 'eos':
    return states['p'] / (R * states['rho'])
  else:
    raise ValueError(
        f'{opt} is not a valid option for temperature. Available options are '
        '"eos:, "theta".'
    )


def density(states, params):
  """Computes density from the ideal-gas law."""
  return states['p'] / (R * states['T'])


def total_energy(states, params):
  """Computes the total energy from the definition."""
  e_t = CV * states['T'] + 0.5 * (
      states['u'] ** 2 + states['v'] ** 2 + states['w'] ** 2
  )

  g_dim = gravity_direction(params)
  if g_dim != -1:
    dims = ('x', 'y', 'z')
    z = expanded_coordinates(params, dims[g_dim])
    e_t += G * z

  return e_t


def thermal_ref_states(states, params, opt='const_theta'):
  """Computes the thermodynamic reference states."""
  if opt == 'const_theta':
    theta = 300.0 * np.ones_like(states['rho'])
    p = pressure_hydrostatic({'theta': theta}, params)
    t = temperature({'p': p, 'theta': theta}, params, 'theta')
    rho = density({'p': p, 'T': t}, params)
  else:
    raise ValueError(f'{opt} is not a valid option.')

  return {'p': p, 'T': t, 'rho': rho}

In [None]:
def convective_flux(states_fx, states_fy, states_fz, params):
  """Computes the convective flux of the mass, momentum, and energy on face."""
  # Compute the primitive variables.
  primitive_fx = primitive_variables(states_fx, params)
  primitive_fy = primitive_variables(states_fy, params)
  primitive_fz = primitive_variables(states_fz, params)

  fluxes = {}

  # Computes the fluxes of the continuity equation.
  fluxes.update(
      {'rho': (states_fx['rhou'], states_fy['rhov'], states_fz['rhow'])}
  )

  # Compute the fluxes of the momentum equation.
  fluxes.update({
      'rhou': (
          states_fx['rhou'] * primitive_fx['u'] + primitive_fx['p'],
          states_fy['rhou'] * primitive_fy['v'],
          states_fz['rhou'] * primitive_fz['w'],
      ),
      'rhov': (
          states_fx['rhov'] * primitive_fx['u'],
          states_fy['rhov'] * primitive_fy['v'] + primitive_fy['p'],
          states_fz['rhov'] * primitive_fz['w'],
      ),
      'rhow': (
          states_fx['rhow'] * primitive_fx['u'],
          states_fy['rhow'] * primitive_fy['v'],
          states_fz['rhow'] * primitive_fz['w'] + primitive_fz['p'],
      ),
  })

  # Computes the flux of the energy equation.
  f_rhoe_x = states_fx['rhou'] * primitive_fx['h']

  f_rhoe_y = states_fy['rhov'] * primitive_fy['h']

  f_rhoe_z = states_fz['rhow'] * primitive_fz['h']

  fluxes.update({'rhoe': (f_rhoe_x, f_rhoe_y, f_rhoe_z)})

  return fluxes


def diffusive_fluxes(states_n, params):
  """Computes the diffusive fluxes of the momentum and energy on the face."""
  # Interpolate states to faces using continous function.
  states_fx = {key: linear_interp(val, 'x') for key, val in states_n.items()}
  states_fy = {key: linear_interp(val, 'y') for key, val in states_n.items()}
  states_fz = {key: linear_interp(val, 'z') for key, val in states_n.items()}
  # Compute the primitive variables.
  primitive_n = primitive_variables(states_n, params)
  primitive_fx = primitive_variables(states_fx, params)
  primitive_fy = primitive_variables(states_fy, params)
  primitive_fz = primitive_variables(states_fz, params)

  # Compute the shear stress.
  tau = shear_stress(
      primitive_n, primitive_fx, primitive_fy, primitive_fz, params
  )

  fluxes = {}
  # No diffusive flux of mass, so fill zeros.
  fluxes.update(
      {
          'rho': (
              np.zeros_like(states_fx['rhou']),
              np.zeros_like(states_fy['rhov']),
              np.zeros_like(states_fz['rhow']),
          )
      }
  )
  # Compute the diffusive fluxes of the momentum equation.
  fluxes.update({
      'rhou': (
          -tau[0][0],
          -tau[0][1],
          -tau[0][2],
      ),
      'rhov': (
          -tau[1][0],
          -tau[1][1],
          -tau[1][2],
      ),
      'rhow': (
          -tau[2][0],
          -tau[2][1],
          -tau[2][2],
      ),
  })

  # Computes the flux of the energy equation.
  velocity = (primitive_fx['u'], primitive_fx['v'], primitive_fx['w'])
  tau_buf = (tau[0][0], tau[1][0], tau[2][0])
  f_rhoe_x = -states_fx['rho'] * params['nu'] * forward1(
      primitive_n['h'], params['dx'], 'x'
  ) - dot(velocity, tau_buf)

  velocity = (primitive_fy['u'], primitive_fy['v'], primitive_fy['w'])
  tau_buf = (tau[0][1], tau[1][1], tau[2][1])
  f_rhoe_y = -states_fy['rho'] * params['nu'] * forward1(
      primitive_n['h'], params['dy'], 'y'
  ) - dot(velocity, tau_buf)

  velocity = (primitive_fz['u'], primitive_fz['v'], primitive_fz['w'])
  tau_buf = (tau[0][2], tau[1][2], tau[2][2])
  f_rhoe_z = -states_fz['rho'] * params['nu'] * forward1(
      primitive_n['h'], params['dz'], 'z'
  ) - dot(velocity, tau_buf)

  fluxes.update({'rhoe': (f_rhoe_x, f_rhoe_y, f_rhoe_z)})

  return fluxes


def source_fn(states, helper_vars, params):
  """Defines source terms for conservative variables."""
  src = {key: np.zeros_like(val) for key, val in states.items()}

  # Add the gravitational force to the momentum equation if defined.
  if 'g' in params:
    dims = ('x', 'y', 'z')
    vars = ('rhou', 'rhov', 'rhow')

    for i in range(len(dims)):
      src[vars[i]] += states['rho'] * G * params['g'][dims[i]]

  return src

# Approximate Riemann Solvers and Numerical Flux

In [None]:
class FluxType(enum.Enum):
  """Defines the type of the numerical flux."""

  GODUNOV = 'godunov'
  HLL = 'hll'


def godunov(states_neg, states_pos, flux_neg, flux_pos, dim):
  """Computes the Godunov flux."""
  del dim  # unused

  return {
      key: np.where(
          states_neg[key] > states_pos[key],
          np.maximum(flux_neg[key], flux_pos[key]),
          np.minimum(flux_neg[key], flux_pos[key]),
      )
      for key in states_neg.keys()
  }


def hll(states_neg, states_pos, flux_neg, flux_pos, dim):
  """Computes the Harten-Lax-van Leer surface fluxes."""
  eps = 1e-6

  primitive_neg = primitive_variables(states_neg, params)
  primitive_pos = primitive_variables(states_pos, params)

  c_neg = sound_speed(primitive_neg)
  c_pos = sound_speed(primitive_pos)

  # Computes the Roe average variable.
  rho_neg_s = np.sqrt(states_neg['rho'])
  rho_pos_s = np.sqrt(states_pos['rho'])

  rho = rho_neg_s * rho_pos_s
  rho_sum = rho_neg_s + rho_pos_s
  u = np.divide(
      rho_neg_s * primitive_neg['u'] + rho_pos_s * primitive_pos['u'],
      rho_sum,
      where=(np.abs(rho_sum) > eps),
  )
  v = np.divide(
      rho_neg_s * primitive_neg['v'] + rho_pos_s * primitive_pos['v'],
      rho_sum,
      where=(np.abs(rho_sum) > eps),
  )
  w = np.divide(
      rho_neg_s * primitive_neg['w'] + rho_pos_s * primitive_pos['w'],
      rho_sum,
      where=(np.abs(rho_sum) > eps),
  )
  h = np.divide(
      rho_neg_s * primitive_neg['h'] + rho_pos_s * primitive_pos['h'],
      rho_sum,
      where=(np.abs(rho_sum) > eps),
  )

  c = np.sqrt((GAMMA - 1.0) * (h - 0.5 * (u**2 + v**2 + w**2)))

  # Compute estimate of waves speeds.
  key = ('u', 'v', 'w')[dim]
  vel = (u, v, w)[dim]
  s_l = np.minimum(primitive_neg[key] - c_neg, vel - c)
  s_r = np.maximum(primitive_pos[key] + c_pos, vel + c)

  # Compute the HLL flux.
  ds = s_r - s_l
  t1 = np.divide(
      np.minimum(s_r, 0.0) - np.minimum(s_l, 0.0), ds, where=(np.abs(ds) > eps)
  )
  t2 = 1.0 - t1
  t3 = np.divide(
      s_r * np.abs(s_l) - s_l * np.abs(s_r), 2.0 * ds, where=(np.abs(ds) > eps)
  )

  flux_hll = lambda q_m, q_p, f_m, f_p: t1 * f_p + t2 * f_m - t3 * (q_p - q_m)

  return {
      key: flux_hll(
          states_neg[key], states_pos[key], flux_neg[key], flux_pos[key]
      )
      for key in states_neg.keys()
  }


def monotonic_flux(states, helper_vars, params, flux_type=FluxType.HLL):
  """Computes the monotonic flux."""
  del helper_vars

  dims = ('x', 'y', 'z')

  # Compute the reconstructed values on faces.
  if 'interp' not in params:
    raise ValueError('Interpolation function not defined.')

  interp_fn = None
  interp_opt = params['interp']
  if interp_opt == 'weno5':
    interp_fn = weno5
  elif interp_opt == 'muscl':
    interp_fn = muscl
  else:
    raise NotImplementedError(f'{interp_opt} is not supported.')

  states_f = {
      key: {dim: interp_fn(val, dim) for dim in dims}
      for key, val in states.items()
  }

  states_f_neg = [
      {key: val[dim][0] for key, val in states_f.items()} for dim in dims
  ]
  states_f_pos = [
      {key: val[dim][1] for key, val in states_f.items()} for dim in dims
  ]

  # TODO: Apply boundary conditions to states on faces.

  # Compute the convective fluxes.
  flux_neg = convective_flux(
      states_f_neg[0], states_f_neg[1], states_f_neg[2], params
  )
  flux_pos = convective_flux(
      states_f_pos[0], states_f_pos[1], states_f_pos[2], params
  )

  # Compute the numerical fluxes.
  if flux_type == FluxType.HLL:
    flux_fn = hll
  elif flux_type == FluxType.GODUNOV:
    flux_fn = godunov
  else:
    raise NotImplementedError(f'{flux_type.name} is not supported.')

  flux = {}
  for i in range(3):
    f_neg = {key: val[i] for key, val in flux_neg.items()}
    f_pos = {key: val[i] for key, val in flux_pos.items()}
    flux.update(
        {dims[i]: flux_fn(states_f_neg[i], states_f_pos[i], f_neg, f_pos, i)}
    )
  if params['include_diffusion']:
    flux_diff = diffusive_fluxes(states, params)
    for i in range(3):
      for key, val in flux_diff.items():
        flux[dims[i]][key] += val[i]

  return flux

# Lax-Friedrichs Split Flux

In [None]:
"""Helper functions for Lax-Friedrichs split flux."""


def strain_rate_node(states, params):
  """Computes the strain rate at nodes."""
  dudx = {
      var: {
          dim: central2(states[var], params[f'd{dim}'], dim)
          for dim in ('x', 'y', 'z')
      }
      for var in ('u', 'v', 'w')
  }
  div_u = dudx['u']['x'] + dudx['v']['y'] + dudx['w']['z']
  return [
      [
          dudx['u']['x'] - 1.0 / 3.0 * div_u,
          0.5 * (dudx['u']['y'] + dudx['v']['x']),
          0.5 * (dudx['u']['z'] + dudx['w']['x']),
      ],
      [
          0.5 * (dudx['v']['x'] + dudx['u']['y']),
          dudx['v']['y'] - 1.0 / 3.0 * div_u,
          0.5 * (dudx['v']['z'] + dudx['w']['y']),
      ],
      [
          0.5 * (dudx['w']['x'] + dudx['u']['z']),
          0.5 * (dudx['w']['z'] + dudx['w']['z']),
          dudx['w']['z'] - 1.0 / 3.0 * div_u,
      ],
  ]


def shear_stress_node(states, params):
  """Computes the shear stress tensor on nodes."""
  s = strain_rate_node(states, params)
  return [
      [states['rho'] * params['nu'] * s[i][j] for j in range(3)]
      for i in range(3)
  ]


def physical_fluxes_node(states, params):
  """Computes the fluxes of the mass, momentum, and energy equations on node."""
  # Compute the primitive variables.
  primitive = primitive_variables(states, params)

  # Compute the shear stress.
  tau = shear_stress_node(primitive, params)

  fluxes = {}

  # Computes the fluxes of the continuity equation.
  fluxes.update({'rho': (states['rhou'], states['rhov'], states['rhow'])})

  # Compute the fluxes of the momentum equation.
  fluxes.update({
      'rhou': (
          states['rhou'] * primitive['u'] - tau[0][0] + primitive['p'],
          states['rhou'] * primitive['v'] - tau[0][1],
          states['rhou'] * primitive['w'] - tau[0][2],
      ),
      'rhov': (
          states['rhov'] * primitive['u'] - tau[1][0],
          states['rhov'] * primitive['v'] - tau[1][1] + primitive['p'],
          states['rhov'] * primitive['w'] - tau[1][2],
      ),
      'rhow': (
          states['rhow'] * primitive['u'] - tau[2][0],
          states['rhow'] * primitive['v'] - tau[2][1],
          states['rhow'] * primitive['w'] - tau[2][2] + primitive['p'],
      ),
  })

  # Computes the flux of the energy equation.
  velocity = (primitive['u'], primitive['v'], primitive['w'])
  f_rhoe_x = (
      states['rhou'] * primitive['h']
      - states['rho']
      * params['nu']
      * central2(primitive['h'], params['dx'], 'x')
      - dot(velocity, tau[0])
  )
  f_rhoe_y = (
      states['rhov'] * primitive['h']
      - states['rho']
      * params['nu']
      * central2(primitive['h'], params['dy'], 'y')
      - dot(velocity, tau[1])
  )
  f_rhoe_z = (
      states['rhow'] * primitive['h']
      - states['rho']
      * params['nu']
      * central2(primitive['h'], params['dz'], 'z')
      - dot(velocity, tau[2])
  )
  fluxes.update({'rhoe': (f_rhoe_x, f_rhoe_y, f_rhoe_z)})

  return fluxes

In [None]:
def lf(states, params):
  """Computes the Lax-Friedrichs split flux."""
  # Compute the maximum eigenvalue of the Jacobian.
  vels = ('u', 'v', 'w')
  dims = ('x', 'y', 'z')

  primitive = primitive_variables(states, params)
  c = sound_speed(primitive)
  a = [
      np.maximum(
          np.maximum(
              np.abs(primitive[vels[i]] - c), np.abs(primitive[vels[i]] + c)
          ),
          np.abs(primitive[vels[i]]),
      )
      for i in range(len(vels))
  ]

  # Compute the split flux.
  f = physical_fluxes_node(states, params)
  f_m = {
      key: {
          dims[i]: 0.5 * (val[i] - a[i] * states[key]) for i in range(len(dims))
      }
      for key, val in f.items()
  }
  f_p = {
      key: {
          dims[i]: 0.5 * (val[i] + a[i] * states[key]) for i in range(len(dims))
      }
      for key, val in f.items()
  }

  # Compute the WENO reconstructed fluxes on faces.
  f_neg = {
      key: {dim: weno5(val[dim], dim)[0] for dim in dims}
      for key, val in f_p.items()
  }
  f_pos = {
      key: {dim: weno5(val[dim], dim)[1] for dim in dims}
      for key, val in f_m.items()
  }

  # Form the numerical flux.
  return {
      dim: {key: f_neg[key][dim] + f_pos[key][dim] for key in f.keys()}
      for dim in dims
  }

# Navier-Stokes Solver

In [None]:
def rhs(states, helper_vars, params):
  """Computes the right-hand side function of the Navier-Stokes equation."""
  dims = ('x', 'y', 'z')

  # Compute the numerical flux.
  flux = monotonic_flux(states, helper_vars, params, flux_type=FluxType.HLL)
  # flux = lf(states, params)

  flux = update_flux_boundary_condition(flux, helper_vars, params)

  # Compute the right-hand side function.
  d_f = {key: np.zeros_like(val) for key, val in states.items()}
  for key in states.keys():
    for dim in dims:
      d_f[key] -= backward1(flux[dim][key], params[f'd{dim}'], dim)

  # Add additional source terms.
  if 'source_fn' in params:
    src = params['source_fn'](states, helper_vars, params)
    d_f = {key: val + src[key] for key, val in d_f.items()}

  # Clear values where mask is non zero (not fluid).
  def update_mask(val):
    """Sets the right hand side to 0 if not fluid."""
    val[helper_vars['mask'] != 0] = 0.0
    return val

  d_f = {key: update_mask(val) for key, val in d_f.items()}

  return d_f


def solve(states, helper_vars, params):
  """Solves the compressible NS equations."""
  print('t [s]\t CFL_x\t CFL_y\t CFL_z')

  for i in range(params['nt']):
    states = time_integration(
        rhs, states, helper_vars, params, TimeIntegrationScheme.RK3
    )

    if i % params['n_step'] == 0:
      var = primitive_variables(states, params)
      t = i * params['dt']
      c = sound_speed(var)
      cfl_x = np.max(np.abs(var['u']) + c) * params['dt'] / params['dx']
      cfl_y = np.max(np.abs(var['v']) + c) * params['dt'] / params['dy']
      cfl_z = np.max(np.abs(var['w']) + c) * params['dt'] / params['dz']
      print(f'{t:5.3e}\t {cfl_x:5.3e}\t {cfl_y:5.3e}\t {cfl_z:5.3e}')

  return states

# Simulations

In [None]:
# @title General Parameters
interp_fn = 'weno5'  # @param['weno5', 'muscl']
if interp_fn == 'weno5':
  hw = 3
elif interp_fn == 'muscl':
  hw = 2
else:
  hw = 2

params = {
    'interp': interp_fn,
    'halo_width': hw,
    'include_diffusion': True,
}


def get_coord(l, n):
  """Generates uniform 1D coordinates."""
  d = l / (n - 2 * hw)
  return d * np.arange(n) - hw * d

In [None]:
# @title 1D Burgers equation


def burgers_1d(states, helper_vars, params):
  """Computes the right-hand side functino of the Burgers equations."""
  del helper_vars

  u = states['u']

  f = 0.5 * u**2
  df = u

  a = np.max(np.abs(df))
  f_m = 0.5 * (f - a * u)
  f_p = 0.5 * (f + a * u)

  f_neg, _ = weno5(f_p, 'x')
  _, f_pos = weno5(f_m, 'x')

  flux = f_neg + f_pos

  dfdx = backward1(flux, params['dx'], 'x')
  dfdx[:3, ...] = 0.0
  dfdx[-3:, ...] = 0.0

  return {'u': -dfdx}


def solve_burgers_1d(states, params):
  """Solves the Burgers equation."""
  for i in range(params['nt']):
    states = time_integration(burgers_1d, states, {}, params)

  return states


nx = 104
dx = 2 * np.pi / (nx - 6)
x = dx * np.arange(nx) - 3 * dx
x = x[:, np.newaxis, np.newaxis]

states = {'u': np.sin(x)}

params = {
    'dx': dx,
    'halo_width': 3,
    'dt': 1e-2,
    'nt': 100,
    'bc': {
        'u': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
        },
    },
}

states_n = solve_burgers_1d(states, params)

plt.plot(np.squeeze(x), np.squeeze(states_n['u']))

In [None]:
# @title Taylor-Green Vortex

nx = 36
ny = 36
nz = 36

x = get_coord(2 * np.pi, nx)
y = get_coord(2 * np.pi, ny)
z = get_coord(2 * np.pi, nz)

params.update({
    'dx': x[1] - x[0],
    'dy': y[1] - y[0],
    'dz': z[1] - z[0],
    'dt': 5e-3,
    'nt': 100,
    'n_step': 10,
    'nu': 1e-3,
    'bc': {
        'rho': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
        },
        'rhou': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
        },
        'rhov': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
        },
        'rhow': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
        },
        'rhoe': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
        },
    },
})

xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')

u = np.sin(xx) * np.cos(yy) * np.cos(zz)
v = -np.cos(xx) * np.sin(yy) * np.cos(zz)
w = np.zeros_like(xx)
p = (
    1.0
    / 16.0
    * (np.cos(2.0 * xx) + np.cos(2.0 * yy))
    * (np.cos(2.0 * zz) + 2.0)
) + 100.0
e = p / 0.4 + 0.5 * (u**2 + v**2 + w**2)

states = {
    'rho': np.ones((nx, ny, nz), dtype=np.float32),
    'rhou': u,
    'rhov': v,
    'rhow': w,
    # p = 100 Pa.
    # 'rhoe': 251.5 * np.ones((nx, ny, nz), dtype=np.float32),
    'rhoe': e,
}

mask = np.ones((nx, ny, nz), dtype=np.int32)
mask[hw:-hw, hw:-hw, hw:-hw] = 0
helper_vars = {
    'mask': mask,
}

states_n = solve(states, helper_vars, params)
print('Done!')

In [None]:
# @title Periodic Channel

nx = 64
ny = 8
nz = 32

params.update({
    'dx': 1.0,
    'dy': 1.0,
    'dz': 1.0,
    'dt': 1e-2,
    'nt': 1000,
    'n_step': 50,
    'nu': 1e-3,
    'bc': {
        'rho': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    1.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    1.0,
                ),
            },
        },
        'rhou': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhov': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhow': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhoe': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
        },
    },
})

states = {
    'rho': np.ones((nx, ny, nz), dtype=np.float32),
    'rhou': np.ones((nx, ny, nz), dtype=np.float32),
    'rhov': np.zeros((nx, ny, nz), dtype=np.float32),
    'rhow': np.zeros((nx, ny, nz), dtype=np.float32),
    # p = 100 Pa.
    'rhoe': 251.5 * np.ones((nx, ny, nz), dtype=np.float32),
}

mask = np.ones((nx, ny, nz), dtype=np.int32)
mask[2:-2, 2:-2, 2:-2] = 0
helper_vars = {
    'mask': mask,
}

states_n = solve(states, helper_vars, params)
print('Done!')

In [None]:
# @title Channel

nx = 64
nz = 9
ny = 32

params.update({
    'dx': 1.0,
    'dy': 1.0,
    'dz': 1.0,
    'dt': 2.5e-2,
    'nt': 1000,
    'n_step': 50,
    'nu': 1e-3,
    'bc': {
        'rho': {
            'x': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    1.0,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    1.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    1.0,
                ),
            },
        },
        'rhou': {
            'x': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    1.0,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhov': {
            'x': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhow': {
            'x': {
                0: (BoundaryCondition.DIRICHLET, 0.0),
                1: (BoundaryCondition.NEUMANN, 0.0),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhoe': {
            'x': {
                0: (BoundaryCondition.DIRICHLET, 251.5),
                1: (BoundaryCondition.NEUMANN, 0.0),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
        },
    },
})

states = {
    'rho': np.ones((nx, ny, nz), dtype=np.float32),
    'rhou': np.ones((nx, ny, nz), dtype=np.float32),
    'rhov': np.zeros((nx, ny, nz), dtype=np.float32),
    'rhow': np.zeros((nx, ny, nz), dtype=np.float32),
    # p = 100 Pa.
    'rhoe': 251.5 * np.ones((nx, ny, nz), dtype=np.float32),
}

mask = np.ones((nx, ny, nz), dtype=np.int32)
mask[hw:-hw, hw:-hw, hw:-hw] = 0
helper_vars = {
    'mask': mask,
}

states_n = solve(states, helper_vars, params)
print('Done!')

In [None]:
# @title Forward step

nx = 66
nz = 9
ny = 26

params.update({
    'dx': 0.1,
    'dy': 0.1,
    'dz': 0.1,
    'dt': 5e-3,
    'nt': 10000,
    'n_step': 50,
    'nu': 1.0e-1,
    'bc': {
        'rho': {
            'x': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    1.4,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.REFLECTIVE,),
                1: (BoundaryCondition.REFLECTIVE,),
            },
        },
        'rhou': {
            'x': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    4.2,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.REFLECTIVE,),
                1: (BoundaryCondition.REFLECTIVE,),
            },
        },
        'rhov': {
            'x': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.REFLECTIVE,),
                1: (BoundaryCondition.REFLECTIVE,),
            },
        },
        'rhow': {
            'x': {
                0: (BoundaryCondition.DIRICHLET, 0.0),
                1: (BoundaryCondition.NEUMANN, 0.0),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.REFLECTIVE,),
                1: (BoundaryCondition.REFLECTIVE,),
            },
        },
        'rhoe': {
            'x': {
                0: (BoundaryCondition.DIRICHLET, 8.8),
                1: (BoundaryCondition.NEUMANN, 0.0),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.REFLECTIVE,),
                1: (BoundaryCondition.REFLECTIVE,),
            },
        },
    },
})

mask = np.ones((nx, ny, nz), dtype=np.int32)
mask[hw : hw + 12, hw:-hw, hw:-hw] = 0
mask[hw + 12 : -2, hw + 3 : -hw, hw:-hw] = 0

mask_x = np.zeros((nx, ny, nz), dtype=np.int32)
mask_x[hw + 11, hw : hw + 4, hw:-hw] = DomainType.REFLECTIVE_WALL.value

mask_y = np.zeros((nx, ny, nz), dtype=np.int32)
mask_y[hw + 12, hw + 4 : -hw, hw:-hw] = DomainType.REFLECTIVE_WALL.value

helper_vars = {
    'mask': mask,
    # 'mask_x': mask_x,
    # 'mask_y': mask_y,
}

rhou = 4.2 * np.ones((nx, ny, nz), dtype=np.float32)
rhou[mask != 0] = 0.0
# rhou = np.zeros((nx, ny, nz), dtype=np.float32)
states = {
    'rho': 1.4 * np.ones((nx, ny, nz), dtype=np.float32),
    'rhou': rhou,
    'rhov': np.zeros((nx, ny, nz), dtype=np.float32),
    'rhow': np.zeros((nx, ny, nz), dtype=np.float32),
    # p = 2.408 Pa.
    'rhoe': 8.8 * np.ones((nx, ny, nz), dtype=np.float32),
}

states_n = solve(states, helper_vars, params)
print('Done!')

In [None]:
# @title Sod's shock tube

nx = 9
ny = 9
nz = 106
hw = 3

params.update({
    'dx': 0.01,
    'dy': 0.01,
    'dz': 0.01,
    'dt': 5e-3,
    'nt': 40,
    'n_step': 5,
    'nu': 1e-4,
    'bc': {
        'rho': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    1.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.125,
                ),
            },
        },
        'rhou': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhov': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhow': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhoe': {
            'x': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    2.5,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.25,
                ),
            },
        },
    },
})


x = params['dx'] * np.arange(nx)
y = params['dy'] * np.arange(ny)
z = params['dz'] * (np.arange(nz) - 2)
_, _, zz = np.meshgrid(x, y, z, indexing='ij')

print(zz.shape)

states = {
    'rho': np.where(zz < 0.5, np.ones_like(zz), 0.125 * np.ones_like(zz)),
    'rhou': np.zeros((nx, ny, nz), dtype=np.float32),
    'rhov': np.zeros((nx, ny, nz), dtype=np.float32),
    'rhow': np.zeros((nx, ny, nz), dtype=np.float32),
    'rhoe': np.where(zz < 0.5, 2.5 * np.ones_like(zz), 0.25 * np.ones_like(zz)),
}

mask = np.ones((nx, ny, nz), dtype=np.int32)
mask[hw:-hw, hw:-hw, hw:-hw] = 0
helper_vars = {
    'mask': mask,
}

states_n = solve(states, helper_vars, params)
print('Done!')

In [None]:
# @title Buoyant bubble

nx = 106
nz = 9
ny = 56

lx = 2e4
ly = 1e4

dx = lx / (nx - 2 * hw)
dy = ly / (ny - 2 * hw)
dz = dx

x = dx * (np.arange(nx) - hw)
y = dy * (np.arange(ny) - hw)
z = dz * (np.arange(nz) - hw)
xx, yy, _ = np.meshgrid(x, y, z, indexing='ij')

theta_p = 2.0
xc = 1e4
yc = 2e3
xr = 2e3
yr = 2e3

params.update({
    'dx': dx,
    'dy': dy,
    'dz': dz,
    'dt': 0.25,
    'nt': 3000,
    'n_step': 50,
    'x': x,
    'y': y,
    'z': z,
    'nu': 1e-5,
    'g': {'x': 0.0, 'y': -1.0, 'z': 0.0},
    'p0': 1e5,
    'source_fn': source_fn,
})

r = np.sqrt(((xx - xc) / xr) ** 2 + ((yy - yc) / yr) ** 2)
theta = 300.0 * np.ones_like(xx) + np.where(
    np.abs(r) < 1.0, theta_p * np.cos(0.5 * np.pi * r) ** 2, np.zeros_like(r)
)
u = np.zeros((nx, ny, nz), dtype=np.float32)
v = np.zeros((nx, ny, nz), dtype=np.float32)
w = np.zeros((nx, ny, nz), dtype=np.float32)
p = pressure_hydrostatic({'theta': theta}, params)
t = temperature({'p': p, 'theta': theta}, params, 'theta')
rho = density({'p': p, 'T': t}, params)
e = total_energy({'T': t, 'u': u, 'v': v, 'w': w}, params)

states = {
    'rho': rho,
    'rhou': u,
    'rhov': v,
    'rhow': w,
    'rhoe': rho * e,
}

mask = np.ones((nx, ny, nz), dtype=np.int32)
mask[hw:-hw, hw:-hw, hw:-hw] = 0

# mask_x = np.zeros((nx, ny, nz), dtype=np.int32)
# mask_x[hw, hw:-hw, hw:-hw] = (
#     DomainType.REFLECTIVE_WALL.value)
# mask_x[-hw - 1, hw:-hw, hw:-hw] = (
#     DomainType.REFLECTIVE_WALL.value)

# mask_y = np.zeros((nx, ny, nz), dtype=np.int32)
# mask_y[hw:-hw, hw, hw:-hw] = (
#     DomainType.REFLECTIVE_WALL.value)
# mask_y[hw:-hw, -hw - 1, hw:-hw] = (
#     DomainType.REFLECTIVE_WALL.value)

helper_vars = {
    'mask': mask,
    # 'mask_x': mask_x,
    # 'mask_y': mask_y,
}

helper_vars = {
    'mask': mask,
}

params.update({
    'bc': {
        'rho': {
            'x': {
                # 0: (
                #     BoundaryCondition.DIRICHLET,
                #     [
                #      np.copy(states['rho'][0, ...]),
                #      np.copy(states['rho'][1, ...]),
                #      ],
                #     ),
                # 1: (
                #     BoundaryCondition.DIRICHLET,
                #     [
                #      np.copy(states['rho'][-2, ...]),
                #      np.copy(states['rho'][-1, ...]),
                #      ],
                #     ),
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    [np.copy(states['rho'][:, i, :]) for i in range(hw)],
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    [np.copy(states['rho'][:, -hw + i, :]) for i in range(hw)],
                ),
            },
        },
        'rhou': {
            'x': {
                # 0: (BoundaryCondition.DIRICHLET, 0.0,),
                # 1: (BoundaryCondition.DIRICHLET, 0.0,),
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
        },
        'rhov': {
            'x': {
                # 0: (BoundaryCondition.NEUMANN, 0.0,),
                # 1: (BoundaryCondition.NEUMANN, 0.0,),
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    0.0,
                ),
            },
        },
        'rhow': {
            'x': {
                # 0: (BoundaryCondition.NEUMANN, 0.0),
                # 1: (BoundaryCondition.NEUMANN, 0.0),
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
                1: (
                    BoundaryCondition.NEUMANN,
                    0.0,
                ),
            },
        },
        'rhoe': {
            'x': {
                # 0: (
                #     BoundaryCondition.DIRICHLET,
                #     [
                #      np.copy(states['rhoe'][0, ...]),
                #      np.copy(states['rhoe'][1, ...]),
                #      ],
                #     ),
                # 1: (
                #     BoundaryCondition.DIRICHLET,
                #     [
                #      np.copy(states['rhoe'][-2, ...]),
                #      np.copy(states['rhoe'][-1, ...]),
                #      ],
                #     ),
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'z': {
                0: (BoundaryCondition.PERIODIC,),
                1: (BoundaryCondition.PERIODIC,),
            },
            'y': {
                0: (
                    BoundaryCondition.DIRICHLET,
                    [np.copy(states['rhoe'][:, i, :]) for i in range(hw)],
                ),
                1: (
                    BoundaryCondition.DIRICHLET,
                    [np.copy(states['rhoe'][:, -hw + i, :]) for i in range(hw)],
                ),
            },
        },
    },
})

states_n = solve(states, helper_vars, params)
print('Done!')

In [None]:
#@title Visualization

primitive = primitive_variables(states_n, params)

# primitive.update({'T': temperature(primitive, params, 'eos')})
# primitive.update({'theta': potential_temperature(primitive, params)})

varname = 'h'  #@param['rho', 'u', 'v', 'w', 'p', 'e', 'h', 'T', 'theta']
axis = 'z'  #@param['x', 'y', 'z']
include_ghost_cells = False  #@param{type: "boolean"}
transpose = False  #@param{type: "boolean"}

var = primitive[varname]
nx, ny, nz = var.shape
if axis == 'x':
  plane = np.squeeze(var[nx // 2, ...])
elif axis == 'y':
  plane = np.squeeze(var[:, ny // 2, :])
else:
  plane = np.squeeze(var[..., nz // 2])

if not include_ghost_cells:
  hw = params['halo_width']
  plane = plane[hw:-hw, hw:-hw]

fig, ax = plt.subplots(2, 1, figsize=(16, 16))
c = ax[0].contourf(np.transpose(plane), cmap='jet', levels=21)
ax[0].set_aspect('equal', 'box')
plt.colorbar(c, ax=ax)

n0, n1 = plane.shape
x = plane[n0 // 2, :] if transpose else np.arange(n1)
y = np.arange(n1) if transpose else plane[n0 // 2, :]
ax[1].plot(x, y)


In [None]:
# @title Strain rate and shear stress test data generator.
nx = ny = nz = 16
n_tot = [2.0 * n for n in (nx, ny, nz)]
h = [np.pi / (n - 1) for n in (nx, ny, nz)]
x, y, z = [h_i * np.arange(n) for h_i, n in zip(h, n_tot)]

xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
u = np.cos(xx) * np.sin(yy) * np.sin(zz)
v = np.sin(xx) * np.cos(yy) * np.sin(zz)
w = np.sin(xx) * np.sin(yy) * np.cos(zz)
rho = 0.5 * (np.sin(xx) * np.sin(yy) * np.sin(yy) + 2.0)

states = {
    'u': u[::2, ::2, ::2],
    'v': v[::2, ::2, ::2],
    'w': w[::2, ::2, ::2],
    'rho': rho[::2, ::2, ::2],
}

states_fy = {
    'u': u[::2, 1::2, ::2],
    'v': v[::2, 1::2, ::2],
    'w': w[::2, 1::2, ::2],
    'rho': rho[::2, 1::2, ::2],
}

states_fz = {
    'u': u[::2, ::2, 1::2],
    'v': v[::2, ::2, 1::2],
    'w': w[::2, ::2, 1::2],
    'rho': rho[::2, ::2, 1::2],
}

states_fx = {
    'u': u[1::2, ::2, ::2],
    'v': v[1::2, ::2, ::2],
    'w': w[1::2, ::2, ::2],
    'rho': rho[1::2, ::2, ::2],
}

cfg = {f'd{dim}': 2.0 * h_i for dim, h_i in zip(('x', 'y', 'z'), h)}
cfg.update({'nu': 1e-1})

s = strain_rate(states, states_fx, states_fy, states_fz, cfg)
tau = shear_stress(states, states_fx, states_fy, states_fz, cfg)

s_zxy = [[np.transpose(s[i][j], (2, 0, 1)) for j in range(3)] for i in range(3)]
tau_zxy = [
    [np.transpose(tau[i][j], (2, 0, 1)) for j in range(3)] for i in range(3)
]

np.save('strain_rate_tgv.npy', s_zxy)
np.save('shear_stress_tgv.npy', tau_zxy)