# Example: RCEMIP

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/swirl-jatmos/blob/main/swirl_jatmos/demos/rcemip_demo.ipynb)

The example in this colab shows a simulation of the radiative-convective equilibrium, model intercomparison project (RCEMIP) setup using Jatmos.

Before you run this Colab notebook, make sure that you choose a TPU hardware accelerator checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU**. The default TPU runtime has 8 cores available. In this demo, 4 TPU cores are used.

In [None]:
!git clone https://github.com/google-research/swirl-jatmos.git
%cd swirl-jatmos/
!python3 -m pip install -e .

In [None]:
import functools
import sys
import tempfile

from absl import flags
from etils import epath
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
matplotlib.style.use('ggplot')
ls = 18
fs = 18
matplotlib.rc('xtick', labelsize=ls)
matplotlib.rc('ytick', labelsize=ls)
matplotlib.rc('axes', labelsize=fs)

from swirl_jatmos import config
from swirl_jatmos import convection_config
from swirl_jatmos import driver
from swirl_jatmos import sponge_config
from swirl_jatmos import timestep_control_config
from swirl_jatmos.thermodynamics import water
from swirl_jatmos.microphysics import microphysics_config
from swirl_jatmos.boundary_conditions import boundary_conditions
from swirl_jatmos.boundary_conditions import monin_obukhov
from swirl_jatmos.rrtmgp.config import radiative_transfer
from swirl_jatmos.rrtmgp.optics import lookup_volume_mixing_ratio

from swirl_jatmos.sim_setups import walker_circulation
from swirl_jatmos.sim_setups import walker_circulation_diagnostics
from swirl_jatmos.sim_setups import walker_circulation_parameters

jax.config.update('jax_enable_x64', True)  # Enable 64-bit data types in JAX
jax.config.update('jax_threefry_partitionable', True)
FLAGS = flags.FLAGS

# Parse absl flags
FLAGS(sys.argv[:1])
FLAGS.print_for_colab = True
FLAGS.use_rcemip_ozone_profile = True  # Use the analytically specified ozone profile from Wing et al (2018).

jax.devices()

### Some helper functions and path setup

In [None]:
# Helper functions
def remove_halos(f):
  hw = 1
  return f[..., hw:-hw]

def save_array_to_tempfile_and_return_path(x: np.ndarray) -> str:
  with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
    np.savetxt(tmpfile, x)
    return tmpfile.name

# Paths to RRTMGP data files
_VMR_GLOBAL_MEAN_FILENAME = (
    'rrtmgp/optics/test_data/rcemip_global_mean_vmr.json'
)
_VMR_SOUNDING_FILENAME = 'rrtmgp/optics/test_data/rcemip_vmr_sounding.csv'
_LW_LOOKUP_TABLE_FILENAME = 'rrtmgp/optics/rrtmgp_data/rrtmgp-gas-lw-g128.nc'
_SW_LOOKUP_TABLE_FILENAME = 'rrtmgp/optics/rrtmgp_data/rrtmgp-gas-sw-g112.nc'
_CLD_LW_LOOKUP_TABLE_FILENAME = 'rrtmgp/optics/rrtmgp_data/cloudysky_lw.nc'
_CLD_SW_LOOKUP_TABLE_FILENAME = 'rrtmgp/optics/rrtmgp_data/cloudysky_sw.nc'

root = epath.resource_path('swirl_jatmos')
_VMR_GLOBAL_MEAN_FILEPATH = root / _VMR_GLOBAL_MEAN_FILENAME
_VMR_SOUNDING_FILEPATH = root / _VMR_SOUNDING_FILENAME
_LW_LOOKUP_TABLE_FILEPATH = root / _LW_LOOKUP_TABLE_FILENAME
_SW_LOOKUP_TABLE_FILEPATH = root / _SW_LOOKUP_TABLE_FILENAME
_CLD_LW_LOOKUP_TABLE_FILEPATH = root / _CLD_LW_LOOKUP_TABLE_FILENAME
_CLD_SW_LOOKUP_TABLE_FILEPATH = root / _CLD_SW_LOOKUP_TABLE_FILENAME

### Define the (nonuniform) z grid to be used

In [None]:
stretched_grid_z_str = """37
111
194
288
395
520
667
843
1062
1331
1664
2055
2505
3000
3500
4000
4500
5000
5500
6000
6500
7000
7500
8000
8500
9000
9500
10000
10500
11000
11500
12000
12500
13000
13500
14000
14500
15000
15500
16000
16500
17000
17500
18000
18500
19000
19500
20000
20500
21000
21500
22000
22500
23000
23500
24000
24500
25000
25500
26000
26500
27000
27500
28000
28750
29750
31000
32500
34250
36250
38500
41000
43500
46000
48500
51000
53500
56000
"""

z_data = np.fromstring(stretched_grid_z_str, sep='\n')
z = z_data
lz = z[-1] + (z[-1] - z[-2]) / 2
lz = float(lz)
print(f'First z level: {z[0]} m.  Last z level: {z[-1]} m')
print(f'{lz=}')
print(f'# of z levels: {len(z)}')

stretched_grid_path_z = save_array_to_tempfile_and_return_path(z)

### Simulation configuration

In [None]:
# @title The config
_P0 = 1.0148e5  # Pressure at the surface [Pa].

cfg_ext = config.ConfigExternal(
    cx=2,
    cy=2,
    cz=1,
    nx=56,
    ny=56,
    nz=80,
    domain_x=(0, 100e3),
    domain_y=(0, 100e3),
    domain_z=(0, lz),
    dt=20.0,
    timestep_control_cfg=timestep_control_config.TimestepControlConfig(
        desired_cfl=0.8,
        max_dt=23.0,
        min_dt=3.0,
        max_change_factor=1.4,
        update_interval_steps=4,
    ),
    wp=water.WaterParams(exner_reference_pressure=_P0),
    convection_cfg=convection_config.ConvectionConfig(
        momentum_scheme='weno5_z',
        theta_li_scheme='weno5_z',
        q_t_scheme='weno5_z',
        q_r_scheme='upwind1',
        q_s_scheme='upwind1',
    ),
    microphysics_cfg=microphysics_config.MicrophysicsConfig(
        autoconversion_params=microphysics_config.AutoconversionParams(tau_is=1.5e3),
        terminal_velocity_method=microphysics_config.TerminalVelocityMethod.CHEN_2022,
        sedimentation_method=1,
    ),
    use_sgs=False,
    stretched_grid_path_z=stretched_grid_path_z,
    z_bcs=boundary_conditions.ZBoundaryConditions(
        bottom=boundary_conditions.ZBC(
            bc_type='monin_obukhov', mop=monin_obukhov.MoninObukhovParameters()
        ),
        top=boundary_conditions.ZBC(bc_type='no_flux'),
    ),
    include_qt_sedimentation=True,
    sponge_cfg=sponge_config.SpongeConfig(
          coeff=6.0, sponge_fraction=0.55, c2=0.5
    ),
    poisson_solver_type=config.PoissonSolverType.FAST_DIAGONALIZATION,
    aux_output_fields=('q_liq', 'q_ice', 'T'),
    diagnostic_fields=('T_1d_z',),
    viscosity=1e-3,
    diffusivity=1e-3,
    radiative_transfer_cfg=radiative_transfer.RadiativeTransfer(
        optics=radiative_transfer.OpticsParameters(
            optics=radiative_transfer.RRTMOptics(
                longwave_nc_filepath=_LW_LOOKUP_TABLE_FILEPATH,
                shortwave_nc_filepath=_SW_LOOKUP_TABLE_FILEPATH,
                cloud_longwave_nc_filepath=_CLD_LW_LOOKUP_TABLE_FILEPATH,
                cloud_shortwave_nc_filepath=_CLD_SW_LOOKUP_TABLE_FILEPATH,
            )
        ),
        atmospheric_state_cfg=radiative_transfer.AtmosphericStateCfg(
            sfc_emis=0.98,
            sfc_alb=0.07,
            zenith=0.733911,
            irrad=551.58,
            toa_flux_lw=0.0,
            vmr_sounding_filepath=_VMR_SOUNDING_FILEPATH,
            vmr_global_mean_filepath=_VMR_GLOBAL_MEAN_FILEPATH,
        ),
        update_cycle_seconds=1800.0,  # 30 minutes.
        apply_cadence_seconds=100.0,
        use_scan=True,
    ),
    disable_checkpointing=True,
)
cfg = config.config_from_config_external(cfg_ext)
print(f'dx={cfg.grid_spacings[0]}, dy={cfg.grid_spacings[1]}')

# RCEMIP-specific configuration for the Mock-Walker Circulation setup.
wcp = walker_circulation_parameters.WalkerCirculationParameters(
    sst_0=300.0,
    delta_sst=0.0,  # Sea surface temperature is uniform
    theta_li_pert_scaling=1.0,
)

## Run the simulation

In [None]:
# This takes about 7 minutes on the Cloud TPU v2-8.

p_0 = cfg.wp.exner_reference_pressure
_, _, _, rho_ref_xxc = walker_circulation.analytic_profiles_from_paper(
        jnp.array(cfg.z_c, dtype=jnp.float32), wcp, p_0
)
init_fn = functools.partial(walker_circulation.init_fn, wcp=wcp)

output_dir = '/tmp/abcd'  # Not used.
t_final = 24 * 3600  # Time to simulate, in seconds (1 day).
sec_per_cycle = 6 * 3600

states, aux_output, diagnostics = driver.run_driver(
    init_fn,
    np.array(rho_ref_xxc, dtype=np.float64),
    output_dir,
    t_final,
    sec_per_cycle,
    cfg,
    preprocess_update_fn=walker_circulation.preprocess_update_fn,
    diagnostics_update_fn=walker_circulation_diagnostics.diagnostics_update_fn,
)

## Visualize the results
To ensure the demo completes quickly, the simulation above ran only a single day, which is not nearly enough time for equilibration to occur.  However, we can still see the trend toward which the atmospheric state is trending.  Below, we plot horizontal averages of select quantities.

In [None]:
z = remove_halos(cfg.z_c) / 1e3
T = remove_halos(aux_output['T'])
theta_li_0 = remove_halos(states['theta_li_0'])
dtheta_li = remove_halos(states['dtheta_li'])
w = remove_halos(states['w'])
u = remove_halos(states['u'])
q_t = remove_halos(states['q_t'])
q_r = remove_halos(states['q_r'])
q_s = remove_halos(states['q_s'])
theta_li = theta_li_0 + dtheta_li
q_liq = remove_halos(aux_output['q_liq'])
q_ice = remove_halos(aux_output['q_ice'])
rad_heat_src = remove_halos(states['rad_heat_src'])

T = np.mean(T, axis=(0, 1))
dtheta_li = np.mean(dtheta_li, axis=(0, 1))
q_t = np.mean(q_t, axis=(0, 1))
theta_li = np.mean(theta_li, axis=(0, 1))
theta_li_0 = np.squeeze(theta_li_0)
w_rms = np.sqrt(np.mean(w**2, axis=(0, 1)))
w_max = np.max(np.abs(w), axis=(0, 1))
u_rms = np.sqrt(np.mean(u**2, axis=(0, 1)))
q_liq = np.mean(q_liq, axis=(0, 1))
q_ice = np.mean(q_ice, axis=(0, 1))
q_r = np.mean(q_r, axis=(0, 1), dtype=np.float64)
q_s = np.mean(q_s, axis=(0, 1))
rad_heat_src = np.mean(rad_heat_src, axis=(0, 1))

fig, ax = plt.subplots(1, 5, figsize=(15, 4), sharey=True)
ax0, ax1, ax2, ax3, ax4 = ax[0], ax[1], ax[2], ax[3], ax[4]

ax0.plot(T, z)
ax0.set_title('T [K]')
ax0.set_ylabel('z (km)')

ax1.plot(theta_li, z)
ax1.set_title('theta_li [K]')

ax2.plot(86400 * rad_heat_src, z)
ax2.set_title('rad_heat_src [K/day]')

ax3.plot(q_liq, z)
ax3.set_title('q_liq [kg/kg]')

ax4.plot(q_ice, z)
ax4.set_title('q_ice [kg/kg]')