# Example: Supercell simulation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/swirl-jatmos/blob/main/swirl_jatmos/demos/supercell_demo.ipynb)


The example in this colab shows a simulation of a supercell in Jatmos.

Before you run this Colab notebook, make sure that you choose a hardware accelerator (either TPU or GPU) checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator**. The default TPU runtime has 8 cores available but only 1 will be used in this demo.

In [None]:
!git clone https://github.com/google-research/swirl-jatmos.git
%cd swirl-jatmos/
!python3 -m pip install -e .

In [None]:
import sys

from absl import flags
import jax
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
matplotlib.style.use('ggplot')
ls = 18
fs = 18
matplotlib.rc('xtick', labelsize=ls)
matplotlib.rc('ytick', labelsize=ls)
matplotlib.rc('axes', labelsize=fs)

from swirl_jatmos import config
from swirl_jatmos import convection_config
from swirl_jatmos import driver
from swirl_jatmos import timestep_control_config
from swirl_jatmos.sim_setups import supercell
from swirl_jatmos.thermodynamics import water

jax.config.update('jax_enable_x64', True)  # Enable 64-bit data types in JAX, required.
FLAGS = flags.FLAGS

# Parse absl flags
FLAGS(sys.argv[:1])

FLAGS.print_for_colab = True
jax.devices()

## Set the configuration

In [None]:
cfg_ext = config.ConfigExternal(
    cx=1,
    cy=1,
    cz=1,
    nx=128,
    ny=128,
    nz=128,
    domain_x=(0, 100e3),
    domain_y=(0, 100e3),
    domain_z=(0, 20e3),
    dt=1.5,
    timestep_control_cfg=timestep_control_config.TimestepControlConfig(
        desired_cfl=0.8,
        max_dt=10.0,
        min_dt=0.5,
        max_change_factor=1.4,
        update_interval_steps=4,
    ),
    convection_cfg=convection_config.ConvectionConfig(
        momentum_scheme='weno5_z',
        theta_li_scheme='weno5_z',
        q_t_scheme='weno5_z',
    ),
    wp=water.WaterParams(),
    use_sgs=False,
    poisson_solver_type=config.PoissonSolverType.FAST_DIAGONALIZATION,
    aux_output_fields=('q_c',),
    viscosity=1e-3,
    diffusivity=1e-3,
    disable_checkpointing=True,
)
cfg = config.config_from_config_external(cfg_ext)

## Run the simulation

In [None]:
_, _, _, rho_ref_xxc = supercell.thermodynamic_initial_condition(
      cfg.z_c, cfg.wp
  )
output_dir = '/tmp/abcd'  # Not used when checkpointing is disabled.
t_final = 3600.0  # Simulate 60 minutes.
sec_per_cycle = 600  # 10 min per cycle.

states, aux_output, diagnostics = driver.run_driver(
    supercell.init_fn,  # Defines the initial conditions.
    np.array(rho_ref_xxc, dtype=np.float64),
    output_dir,
    t_final,
    sec_per_cycle,
    cfg,
)

## Visualize the results

In [None]:
# Plot the condensate at t=1 hour.

hw = 1 # halo width in z.
# Center x & y at 0, and convert x,y,z to km.
x_c = np.array(cfg.x_c) / 1e3 - 50
y_c = np.array(cfg.y_c) / 1e3 - 50
z_c = np.array(cfg.z_c)[hw:-hw] / 1e3

XX, YY = np.meshgrid(x_c, y_c, indexing='ij')
q_c = np.array(aux_output['q_c'])[:, :, hw:-hw]
rho_xxc = np.array(states['rho_xxc'])[:, :, hw:-hw]
rho_qc = rho_xxc * q_c

rho_q_c_2d = np.mean(rho_qc, axis=2)

plt.figure(figsize=(6, 4))
plt.pcolormesh(XX, YY, rho_q_c_2d, cmap='viridis')
plt.colorbar()
plt.title(r'mean_z($\rho q_c$)')
plt.xlabel('x (km)')
plt.ylabel('y (km)')

rho_q_c_xz = np.mean(rho_qc, axis=1)

plt.figure(figsize=(6, 4))
XX2, ZZ2 = np.meshgrid(x_c, z_c, indexing='ij')
plt.pcolormesh(XX2, ZZ2, rho_q_c_xz, cmap='viridis')
plt.colorbar()
plt.title(r'mean_y($\rho q_c$)')
plt.xlabel('x (km)')
plt.ylabel('z (km)');