In [None]:
def reshape_for_jax(var):
    """Reshape from (ngpblks, nflevg, nproma) to (ngpblks, nproma, nflevg)."""
    return jnp.asarray(np.swapaxes(var, 1, 2))

# Atmospheric state variables
pabs = reshape_for_jax(dataset["PPABSM"].values)  # Pressure (Pa)
exn = reshape_for_jax(dataset["PEXNREF"].values)  # Exner function
exn_ref = reshape_for_jax(dataset["PEXNREF"].values)  # Reference Exner function
rho_dry_ref = reshape_for_jax(dataset["PRHODREF"].values)  # Dry air density (kg/m³)

# Load state from ZRS array (ngpblks, krr1, nflevg, nproma)
# krr1 contains: [theta, rv, rc, rr, ri, rs, rg]
zrs = dataset["ZRS"].values
zrs = np.swapaxes(zrs, 2, 3)  # → (ngpblks, krr1, nproma, nflevg)

th = jnp.asarray(zrs[:, 0, :, :])  # Potential temperature (K)
rv = jnp.asarray(zrs[:, 1, :, :])  # Water vapor mixing ratio (kg/kg)
rc = jnp.asarray(zrs[:, 2, :, :])  # Cloud water mixing ratio (kg/kg)
rr = jnp.asarray(zrs[:, 3, :, :])  # Rain mixing ratio (kg/kg)
ri = jnp.asarray(zrs[:, 4, :, :])  # Ice mixing ratio (kg/kg)
rs = jnp.asarray(zrs[:, 5, :, :])  # Snow mixing ratio (kg/kg)
rg = jnp.asarray(zrs[:, 6, :, :])  # Graupel mixing ratio (kg/kg)

# Load input tendencies from PRS (ngpblks, krr, nflevg, nproma)
prs = dataset["PRS"].values
prs = np.swapaxes(prs, 2, 3)  # → (ngpblks, krr, nproma, nflevg)

rvs = jnp.asarray(prs[:, 0, :, :])  # Vapor tendency
rcs = jnp.asarray(prs[:, 1, :, :])  # Cloud water tendency
ris = jnp.asarray(prs[:, 3, :, :])  # Ice tendency

# Temperature tendency
ths = reshape_for_jax(dataset["PTHS"].values)

# Mass flux variables (from convection scheme)
cf_mf = reshape_for_jax(dataset["PCF_MF"].values)  # Cloud fraction from mass flux
rc_mf = reshape_for_jax(dataset["PRC_MF"].values)  # Cloud water from mass flux
ri_mf = reshape_for_jax(dataset["PRI_MF"].values)  # Ice from mass flux

# Subgrid variability parameters
zsigqsat = dataset["ZSIGQSAT"].values  # (ngpblks, nproma)
sigqsat = jnp.asarray(zsigqsat[:, :, np.newaxis])  # Expand to 3D
sigs = reshape_for_jax(dataset["PSIGS"].values)  # Subgrid variance

print("✓ Input data prepared")
print(f"\nInitial state:")
print(f"  Temperature: {float(th.min()):.1f} - {float(th.max()):.1f} K")
print(f"  Pressure: {float(pabs.min())/100:.1f} - {float(pabs.max())/100:.1f} hPa")
print(f"  Water vapor: {float(rv.min())*1000:.3f} - {float(rv.max())*1000:.3f} g/kg")
print(f"  Cloud water: {float(rc.min())*1000:.3f} - {float(rc.max())*1000:.3f} g/kg")
print(f"  Ice: {float(ri.min())*1000:.3f} - {float(ri.max())*1000:.3f} g/kg")

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/maurinl26/dwarf-p-ice3/blob/dev/examples/run_ice3.ipynb)

# ICE_ADJUST JAX Example

This notebook demonstrates the use of `IceAdjustJAX` for mixed-phase cloud saturation adjustment using real atmospheric test data.

## Overview

The ICE_ADJUST scheme performs saturation adjustment for mixed-phase clouds, computing:
- Cloud fraction based on subgrid variability
- Condensation/evaporation of water vapor
- Deposition/sublimation of ice
- Latent heating from phase changes
- Updated mixing ratios and temperature tendencies

## 1. Setup and Imports

In [None]:
import jax.numpy as jnp
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from pathlib import Path

from ice3.jax.ice_adjust import IceAdjustJAX
from ice3.phyex_common.phyex import Phyex

## 2. Load Real Atmospheric Test Data

We'll use the `ice_adjust.nc` dataset which contains real atmospheric profiles from PHYEX simulations.

In [None]:
# Load the test dataset
data_path = Path("..") / "data" / "ice_adjust.nc"
dataset = xr.open_dataset(data_path)

print("Dataset dimensions:")
print(f"  ngpblks (number of blocks): {dataset.sizes['ngpblks']}")
print(f"  nproma (points per block): {dataset.sizes['nproma']}")
print(f"  nflevg (vertical levels): {dataset.sizes['nflevg']}")
print(f"\nTotal grid points: {dataset.sizes['ngpblks'] * dataset.sizes['nproma']}")
print(f"Total domain size: {dataset.sizes['ngpblks'] * dataset.sizes['nproma'] * dataset.sizes['nflevg']} points")

## 3. Initialize PHYEX Configuration and JAX Component

In [None]:
# Initialize physics configuration (AROME operational setup)
phyex = Phyex("AROME", TSTEP=50.0)

# Create the JAX component with JIT compilation enabled for performance
ice_adjust = IceAdjustJAX(phyex=phyex, jit=True)

print("✓ IceAdjustJAX initialized with JIT compilation")
print(f"  Physics package: {phyex.phyex_name}")
print(f"  Timestep: {phyex.phyex['TSTEP']} seconds")

## 4. Prepare Input Data

The dataset uses Fortran-style dimension ordering `(ngpblks, nflevg, nproma)`, but JAX expects `(ngpblks, nproma, nflevg)`. We need to swap axes.

## 5. Run ICE_ADJUST

Now we call the saturation adjustment scheme. This will adjust the temperature and water content to maintain thermodynamic equilibrium.

In [None]:
timestep = 50.0  # seconds

# Run the saturation adjustment
result = ice_adjust(
    sigqsat=sigqsat,
    pabs=pabs,
    sigs=sigs,
    th=th,
    exn=exn,
    exn_ref=exn_ref,
    rho_dry_ref=rho_dry_ref,
    rv=rv,
    rc=rc,
    ri=ri,
    rr=rr,
    rs=rs,
    rg=rg,
    cf_mf=cf_mf,
    rc_mf=rc_mf,
    ri_mf=ri_mf,
    rvs=rvs,
    rcs=rcs,
    ris=ris,
    ths=ths,
    timestep=timestep,
)

print("✓ ICE_ADJUST computation completed")

## 6. Extract and Analyze Results

The scheme returns 16 output arrays including adjusted state variables, cloud properties, and tendencies.

In [None]:
# Extract output variables
t_out = result[0]      # Temperature (K)
rv_out = result[1]     # Water vapor mixing ratio (kg/kg)
rc_out = result[2]     # Cloud water mixing ratio (kg/kg)
ri_out = result[3]     # Ice mixing ratio (kg/kg)
cldfr = result[4]      # Cloud fraction (0-1)
hlc_hrc = result[5]    # Heating from cloud water condensation (K/s)
hlc_hcf = result[6]    # Heating from cloud fraction (K/s)
hli_hri = result[7]    # Heating from ice deposition (K/s)
hli_hcf = result[8]    # Heating from ice cloud fraction (K/s)
cph = result[9]        # Cloud phase indicator
lv = result[10]        # Latent heat of vaporization (J/kg)
ls = result[11]        # Latent heat of sublimation (J/kg)
rvs_out = result[12]   # Vapor tendency (kg/kg/s)
rcs_out = result[13]   # Cloud water tendency (kg/kg/s)
ris_out = result[14]   # Ice tendency (kg/kg/s)
ths_out = result[15]   # Theta tendency (K/s)

print("Output summary:")
print(f"\nAdjusted state:")
print(f"  Temperature: {float(t_out.min()):.1f} - {float(t_out.max()):.1f} K")
print(f"  Water vapor: {float(rv_out.min())*1000:.3f} - {float(rv_out.max())*1000:.3f} g/kg")
print(f"  Cloud water: {float(rc_out.min())*1000:.3f} - {float(rc_out.max())*1000:.3f} g/kg")
print(f"  Ice: {float(ri_out.min())*1000:.3f} - {float(ri_out.max())*1000:.3f} g/kg")
print(f"  Cloud fraction: {float(cldfr.min()):.3f} - {float(cldfr.max()):.3f}")

print(f"\nPhysical tendencies:")
print(f"  Vapor: {float(rvs_out.min())*1e6:.3f} - {float(rvs_out.max())*1e6:.3f} mg/kg/s")
print(f"  Cloud: {float(rcs_out.min())*1e6:.3f} - {float(rcs_out.max())*1e6:.3f} mg/kg/s")
print(f"  Ice: {float(ris_out.min())*1e6:.3f} - {float(ris_out.max())*1e6:.3f} mg/kg/s")
print(f"  Theta: {float(ths_out.min()):.6f} - {float(ths_out.max()):.6f} K/s")

print(f"\nLatent heating:")
print(f"  Condensation: {float(hlc_hrc.min()):.6f} - {float(hlc_hrc.max()):.6f} K/s")
print(f"  Deposition: {float(hli_hri.min()):.6f} - {float(hli_hri.max()):.6f} K/s")

## 7. Cloud Statistics

In [None]:
# Calculate cloud statistics
shape = cldfr.shape
total_points = shape[0] * shape[1] * shape[2]
cloudy_points = (cldfr > 0.01).sum()

print(f"Cloud statistics:")
print(f"  Total grid points: {total_points}")
print(f"  Cloudy points (CF > 1%): {cloudy_points} ({100*cloudy_points/total_points:.1f}%)")
print(f"  Mean cloud fraction: {float(cldfr.mean()):.3f}")
print(f"  Max cloud fraction: {float(cldfr.max()):.3f}")

## 8. Physical Validation

Let's verify that the scheme conserves total water and respects physical bounds.

In [None]:
# Check total water conservation
total_water_in = rv + rc + ri
total_water_out = rv_out + rc_out + ri_out
water_error = jnp.abs(total_water_out - total_water_in).max()

print("Physical validation:")
print(f"\nConservation:")
print(f"  Total water error: {float(water_error)*1e9:.3f} µg/kg")
print(f"  {'✓' if water_error < 1e-10 else '✗'} Water is conserved (error < 0.1 µg/kg)")

print(f"\nPhysical bounds:")
cf_valid = (cldfr >= 0).all() and (cldfr <= 1).all()
mr_valid = (rv_out >= 0).all() and (rc_out >= 0).all() and (ri_out >= 0).all()
t_valid = (t_out > 0).all()

print(f"  {'✓' if cf_valid else '✗'} Cloud fraction in [0, 1]")
print(f"  {'✓' if mr_valid else '✗'} All mixing ratios >= 0")
print(f"  {'✓' if t_valid else '✗'} Temperature > 0 K")

if cf_valid and mr_valid and t_valid:
    print("\n✓ All physical constraints satisfied!")

## 9. Visualization - Vertical Profile

Let's visualize the results by looking at a vertical profile from a cloudy column.

In [None]:
# Find a column with significant clouds
column_cloud_fraction = cldfr.mean(axis=2)  # Average over vertical levels
cloudy_mask = column_cloud_fraction > 0.05

if cloudy_mask.any():
    # Get indices of cloudiest column
    cloudy_indices = jnp.where(cloudy_mask)
    i_block = int(cloudy_indices[0][0])
    i_point = int(cloudy_indices[1][0])
    
    # Extract vertical profile
    p_profile = pabs[i_block, i_point, :] / 100  # Convert to hPa
    t_profile = t_out[i_block, i_point, :]
    rv_profile = rv_out[i_block, i_point, :] * 1000  # g/kg
    rc_profile = rc_out[i_block, i_point, :] * 1000  # g/kg
    ri_profile = ri_out[i_block, i_point, :] * 1000  # g/kg
    cf_profile = cldfr[i_block, i_point, :]
    
    # Create multi-panel plot
    fig, axes = plt.subplots(1, 4, figsize=(16, 6), sharey=True)
    
    # Temperature
    axes[0].plot(t_profile, p_profile, 'r-', linewidth=2)
    axes[0].set_xlabel('Temperature (K)', fontsize=12)
    axes[0].set_ylabel('Pressure (hPa)', fontsize=12)
    axes[0].grid(True, alpha=0.3)
    axes[0].invert_yaxis()
    axes[0].set_title('Temperature Profile', fontsize=13)
    
    # Mixing ratios
    axes[1].plot(rv_profile, p_profile, 'b-', linewidth=2, label='Vapor')
    axes[1].plot(rc_profile, p_profile, 'c-', linewidth=2, label='Cloud water')
    axes[1].plot(ri_profile, p_profile, 'm-', linewidth=2, label='Ice')
    axes[1].set_xlabel('Mixing Ratio (g/kg)', fontsize=12)
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    axes[1].set_title('Water Content', fontsize=13)
    
    # Cloud fraction
    axes[2].plot(cf_profile, p_profile, 'k-', linewidth=2)
    axes[2].set_xlabel('Cloud Fraction', fontsize=12)
    axes[2].set_xlim([0, 1])
    axes[2].grid(True, alpha=0.3)
    axes[2].set_title('Cloud Fraction', fontsize=13)
    
    # Latent heating
    hlc_profile = hlc_hrc[i_block, i_point, :]
    hli_profile = hli_hri[i_block, i_point, :]
    axes[3].plot(hlc_profile, p_profile, 'g-', linewidth=2, label='Condensation')
    axes[3].plot(hli_profile, p_profile, 'orange', linewidth=2, label='Deposition')
    axes[3].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    axes[3].set_xlabel('Heating Rate (K/s)', fontsize=12)
    axes[3].legend(fontsize=10)
    axes[3].grid(True, alpha=0.3)
    axes[3].set_title('Latent Heating', fontsize=13)
    
    plt.suptitle(f'Vertical Profile - Column (block={i_block}, point={i_point})', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print(f"Showing profile from column (block={i_block}, point={i_point})")
    print(f"Column-mean cloud fraction: {float(column_cloud_fraction[i_block, i_point]):.3f}")
else:
    print("No significantly cloudy columns found in this dataset.")

## 10. Horizontal Cloud Distribution

In [None]:
# Compute column-integrated cloud water and ice
# Reshape to 2D (all horizontal points, vertical)
cldfr_2d = cldfr.reshape(-1, cldfr.shape[2])
rc_2d = rc_out.reshape(-1, rc_out.shape[2])
ri_2d = ri_out.reshape(-1, ri_out.shape[2])

# Compute column means
cldfr_mean = cldfr_2d.mean(axis=1)
rc_mean = rc_2d.mean(axis=1) * 1000  # g/kg
ri_mean = ri_2d.mean(axis=1) * 1000  # g/kg

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Cloud fraction histogram
axes[0].hist(np.array(cldfr_mean), bins=50, color='skyblue', edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Column-Mean Cloud Fraction', fontsize=11)
axes[0].set_ylabel('Number of Columns', fontsize=11)
axes[0].set_title('Cloud Fraction Distribution', fontsize=12)
axes[0].grid(True, alpha=0.3)

# Cloud water histogram
axes[1].hist(np.array(rc_mean), bins=50, color='lightblue', edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Column-Mean Cloud Water (g/kg)', fontsize=11)
axes[1].set_ylabel('Number of Columns', fontsize=11)
axes[1].set_title('Cloud Water Distribution', fontsize=12)
axes[1].grid(True, alpha=0.3)

# Ice histogram
axes[2].hist(np.array(ri_mean), bins=50, color='lavender', edgecolor='black', alpha=0.7)
axes[2].set_xlabel('Column-Mean Ice (g/kg)', fontsize=11)
axes[2].set_ylabel('Number of Columns', fontsize=11)
axes[2].set_title('Ice Distribution', fontsize=12)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:

1. **Loading real atmospheric data** from the PHYEX test dataset
2. **Initializing IceAdjustJAX** with AROME physics configuration
3. **Running saturation adjustment** for mixed-phase clouds
4. **Analyzing physical outputs** including cloud fraction, phase partitioning, and latent heating
5. **Validating conservation** of total water and physical bounds
6. **Visualizing results** through vertical profiles and horizontal distributions

The ICE_ADJUST scheme is a critical component of atmospheric models, ensuring thermodynamic consistency while representing subgrid cloud variability and mixed-phase microphysics.