<a href="https://colab.research.google.com/github/mampisarkar111/wisonet-colab-demo/blob/main/Wisonet_kernel_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Interactive TES vs. Model Comparison

This interactive analysis is part of the upcoming **WISO-SAT: A global database of stable hydrogen isotopes in water vapor from eight satellite missions** (*Sarkar et al., to be submitted*).  
WISO-SAT unifies multiple satellite retrievals into a consistent framework for isotope–climate analysis, including TES, AIRS, CrIS, and others.

> **Sarkar, M.**, Dee, S., Frazer, M., Bailey, A., Good, S., Mandal, S., Noёl, S., Schneider, M., Bong, H., Bowman, K., Steen-Larsen, H.C., Haagsma, M., Worden, J., *WISO-SAT: A global database of stable hydrogen isotopes in water vapor from eight satellite missions*, in preparation for submission.

The model data used here is from the **WisoMIP Ensemble** (*Bong et al., 2025, submitted*):  

> Bong, H., LeGrande, A.N., Dee, S., Zhu, J., Cauquoin, A., Fiorella, R.P., Ding, Q., Dutrievoz, N., Tanoue, M., Frazer, M., Sarkar, M., Agosta, C., Yoshimura, K., Werner, M., Okazaki, A., Risi, C., Steen-Larsen, H.C., Casado, M., Wahl, S., Nusbaumer, J., Worden, J., Good, S., Bailey, A., Schneider, M., Noël, S., Mandal, S., Bowman, K., Li, Y., and Schmidt, G.A. (2025) *Water Isotope Model Intercomparison Project (WisoMIP): Present-day Climate*, *Journal of Advances in Modeling Earth Systems*, manuscript submitted July 2025.

---

### Why AK correction is essential

Directly comparing raw model output to satellite retrievals can be misleading because satellites have **limited and variable vertical sensitivity**. The **averaging kernel (AK)** describes how much each vertical level of the atmosphere contributes to the retrieved signal. Applying the AK ensures that model output is filtered in the same way that the satellite “sees” the atmosphere, enabling a fair comparison.

---

### What this tool does

We build an **interactive visualization** that lets you compare:
1. **TES δD retrievals** – from the monthly TES gridded product (2006 in this example, but other years are supported).  
2. **Interpolated model δD** – raw WisoMIP model output interpolated onto the TES grid.  
3. **AK-corrected model δD** – model output smoothed with the TES averaging kernel and adjusted using the TES a priori profile.

---

### How it works

#### 1. **Month and pressure level selection**
Widgets allow you to choose:
- **Month** (1–12)
- **Pressure level** (hPa)

#### 2. **Longitude alignment**
Model longitudes are converted from `0–360°` to `–180–180°` for consistency with TES data and map plotting.

#### 3. **Spatial interpolation**
Model δD values are bilinearly interpolated to the TES 5°×5° grid using `scipy.RegularGridInterpolator`.

#### 4. **Averaging kernel application**
TES retrievals follow the **Rodgers (2000)** formulation:

\[
\mathbf{x}_{\text{retrieved}} = \mathbf{x}_a + \mathbf{A} \left( \mathbf{x}_{\text{true}} - \mathbf{x}_a \right)
\]

Where:
- \(\mathbf{x}_{\text{retrieved}}\): TES-like smoothed profile (our AK-corrected model)  
- \(\mathbf{x}_{\text{true}}\): True model profile  
- \(\mathbf{x}_a\): TES a priori profile  
- \(\mathbf{A}\): Averaging kernel matrix  

For isotopes, smoothing is done **in ratio space** (HDO/H₂O), then converted back to δD in per mil (‰).

#### 5. **Visualization**
- **Three side-by-side maps** (TES, interpolated model, AK-corrected model)
- All share the **same color scale**
- **Geographic context:** coastlines, borders, lat/lon gridlines
- Bottom panel includes a single shared colorbar for δD (‰)

---

**Note:**  
This example uses **monthly 2006** TES and WisoMIP data, but the workflow can be applied to:
- Other years in the TES record
- Any isotope-enabled climate model output
- Any in-situ observations


In [None]:
# Download SWING3 2006 subset data (NetCDF format) from Box
!curl -L "https://rice.box.com/shared/static/bcoy3ob0dme3umpurqmf0p6o48bznkj1" -o SWING3_2006_subset.nc > /dev/null 2>&1

# Download TES monthly 5° x 5° gridded isotope data (filtered) for 2006 from Box
!curl -L "https://rice.box.com/shared/static/uuy9m15qc1p7s4wm1yrzfzxc6knx7hzw" -o TES_monthly_5deg_strict.nc > /dev/null 2>&1

# Install cartopy (no output)
!pip install -q cartopy

import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import calendar
from datetime import datetime
import ipywidgets as widgets
from ipywidgets import interactive_output, VBox, HBox
import xarray as xr
from scipy.interpolate import RegularGridInterpolator
import matplotlib.ticker as mticker

# --- Load datasets ---
ds = xr.open_dataset("SWING3_2006_subset.nc")       # Model output
sat = xr.open_dataset("TES_monthly_5deg_strict.nc") # Satellite retrieval

# --- Constants ---
Rvsmow = 3.1152e-4

# --- Common variables ---
# Convert model longitudes to -180…180 right away and sort
lon_m = ((ds['lon'].values + 180) % 360) - 180
sort_m = np.argsort(lon_m)
lon_m = lon_m[sort_m]
lat_m = ds['lat'].values
p_m   = ds['p'].values

# Also reorder model data arrays along longitude dimension
ds['dD'] = ds['dD'][:, :, :, sort_m]

lon_t = sat['lon'].values
lat_t = sat['lat'].values
p_t   = sat['level'].values
time_t = sat['time'].values

# --- Widgets ---
plevel_input = widgets.BoundedFloatText(value=900, min=200, max=1000, step=10, description='Level [hPa]:')
month_slider = widgets.IntSlider(min=1, max=12, step=1, value=1, description='Month')

def update_plot(month, plevel):
    # --- Date handling ---
    m = month - 1
    last_day = calendar.monthrange(2006, month)[1]
    user_date_str = f'2006-{month:02d}-{last_day}'
    target_date = datetime.strptime(user_date_str, '%Y-%m-%d')

    if np.issubdtype(time_t.dtype, np.datetime64):
        t_tes = np.argmin(np.abs(time_t - np.datetime64(user_date_str)))
    else:
        base_date = datetime(2000, 1, 1)
        t_tes = np.argmin(np.abs(time_t - (target_date - base_date).days))

    idx_tes = int(np.argmin(np.abs(p_t - plevel)))
    idx_mod = int(np.argmin(np.abs(p_m - plevel)))

    # --- Extract model δD ---
    dD_model = ds['dD'].values[m, idx_mod, :, :]  # (lat, lon) already -180..180
    dD_model[dD_model == -999] = np.nan

    # --- Extract TES δD ---
    dD_tes = sat['dD'].values[t_tes, idx_tes, :, :]
    dD_tes[dD_tes == -999] = np.nan

    # --- Interpolate model to TES lat/lon grid ---
    interp_func = RegularGridInterpolator((lat_m, lon_m), dD_model, bounds_error=False, fill_value=np.nan)
    LAT_T, LON_T = np.meshgrid(lat_t, lon_t, indexing='ij')
    dD_model_interp = interp_func((LAT_T, LON_T))

    # --- Apply AK correction in ratio space ---
    prior_HDO = sat['HDO_ConstraintVector'].values[t_tes, :, :, :]
    prior_H2O = sat['H2O_ConstraintVector'].values[t_tes, :, :, :]
    AK = sat['AK_HDO'].values[t_tes, :, :, :, :]

    dD_model_corr = np.full_like(dD_model_interp, np.nan)
    for i in range(len(lat_t)):
        for j in range(len(lon_t)):
            x_mod = dD_model_interp[i, j]
            if np.isnan(x_mod):
                continue
            R_mod = Rvsmow * (x_mod / 1000.0 + 1.0)
            R_apri = prior_HDO[:, i, j] / prior_H2O[:, i, j]
            A = AK[:, :, i, j]
            if np.isfinite(R_apri[idx_tes]) and np.all(np.isfinite(A[idx_tes, :])):
                R_adj = R_apri + A @ (R_mod - R_apri)
                dD_model_corr[i, j] = (R_adj[idx_tes] / Rvsmow - 1.0) * 1000.0

    # --- Shift TES lons for plotting ---
    lon_t_shift = ((lon_t + 180) % 360) - 180
    sort_t = np.argsort(lon_t_shift)

    # --- Color scale ---
    combined = np.concatenate([
        dD_tes[np.isfinite(dD_tes)],
        dD_model_interp[np.isfinite(dD_model_interp)],
        dD_model_corr[np.isfinite(dD_model_corr)]
    ])
    vmin, vmax = np.nanpercentile(combined, 2), np.nanpercentile(combined, 98)
    levels = np.linspace(vmin, vmax, 21)

    # --- Plot ---
    fig = plt.figure(figsize=(22, 8))
    pro = ccrs.PlateCarree()
    xticks = np.arange(-180, 181, 60)
    yticks = np.arange(-90, 91, 30)

    datasets = [
        (f'TES δD at ~{plevel} hPa', dD_tes),
        (f'Model δD interp at ~{plevel} hPa', dD_model_interp),
        (f'AK-corrected Model δD at ~{plevel} hPa', dD_model_corr)
    ]

    for k, (title, data) in enumerate(datasets):
        ax = fig.add_subplot(1, 3, k+1, projection=pro)
        cs = ax.contourf(lon_t_shift[sort_t], lat_t, data[:, sort_t], levels=levels,
                         vmin=vmin, vmax=vmax, cmap='coolwarm', transform=pro)
        ax.coastlines()
        ax.add_feature(cfeature.BORDERS, linewidth=0.5)
        gl = ax.gridlines(draw_labels=True, linewidth=0.3, color='gray', alpha=0.5)
        gl.xlocator = mticker.FixedLocator(xticks)
        gl.ylocator = mticker.FixedLocator(yticks)
        ax.set_title(title)
        fig.colorbar(cs, ax=ax, orientation='horizontal', pad=0.05).set_label('δD (‰)')

    plt.tight_layout()
    plt.show()

# --- Display widgets ---
ui = VBox([HBox([plevel_input, month_slider])])
out = interactive_output(update_plot, {'plevel': plevel_input, 'month': month_slider})
display(ui, out)


---

## Adding Vertical Sensitivity (Averaging Kernel) Context

If the AK values at a given level are small, TES is less sensitive to that altitude. Biases between model and observation may partly reflect this vertical sensitivity rather than true atmospheric differences. The previous tool only compared maps, but the **vertical sensitivity** of TES can vary across regions and affect the interpretation of biases.

In this section, we expand the interactive tool to:
- Select a **lat/lon box** and average the TES averaging kernel (AK) over that area.
- Plot the **mean AK matrix** (retrieved vs. true levels) alongside the maps.
- Highlight the selected lat/lon box on each map.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import calendar
from datetime import datetime
import ipywidgets as widgets
from ipywidgets import interactive_output, VBox, HBox
import xarray as xr
from scipy.interpolate import RegularGridInterpolator
import matplotlib.ticker as mticker
from matplotlib.patches import Rectangle
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*tight_layout.*")

# --- Load datasets ---
ds = xr.open_dataset("SWING3_2006_subset.nc")
sat = xr.open_dataset("TES_monthly_5deg_strict.nc")

# --- Constants ---
Rvsmow = 3.1152e-4

# --- Common variables ---
lon_m = ((ds['lon'].values + 180) % 360) - 180
sort_m = np.argsort(lon_m)
lon_m = lon_m[sort_m]
lat_m = ds['lat'].values
p_m   = ds['p'].values
ds['dD'] = ds['dD'][:, :, :, sort_m]

lon_t = sat['lon'].values
lat_t = sat['lat'].values
p_t   = sat['level'].values
time_t = sat['time'].values
akhdo = sat['AK_HDO'].values
level = sat['level'].values

# --- Widgets ---
plevel_input = widgets.BoundedFloatText(value=900, min=200, max=1000, step=10, description='Level [hPa]:')
month_slider = widgets.IntSlider(min=1, max=12, step=1, value=1, description='Month')
lat_min_input = widgets.BoundedFloatText(value=-20, min=-90, max=90, step=1, description='Lat Min:')
lat_max_input = widgets.BoundedFloatText(value=20, min=-90, max=90, step=1, description='Lat Max:')
lon_min_input = widgets.BoundedFloatText(value=-20, min=-180, max=180, step=1, description='Lon Min:')
lon_max_input = widgets.BoundedFloatText(value=20, min=-180, max=180, step=1, description='Lon Max:')

def update_all(month, plevel, lat_min, lat_max, lon_min, lon_max):
    # --- Date handling ---
    m = month - 1
    last_day = calendar.monthrange(2006, month)[1]
    user_date_str = f'2006-{month:02d}-{last_day}'
    target_date = datetime.strptime(user_date_str, '%Y-%m-%d')
    if np.issubdtype(time_t.dtype, np.datetime64):
        t_tes = np.argmin(np.abs(time_t - np.datetime64(user_date_str)))
    else:
        base_date = datetime(2000, 1, 1)
        t_tes = np.argmin(np.abs(time_t - (target_date - base_date).days))

    idx_tes = int(np.argmin(np.abs(p_t - plevel)))
    idx_mod = int(np.argmin(np.abs(p_m - plevel)))

    # --- Extract model δD ---
    dD_model = ds['dD'].values[m, idx_mod, :, :]
    dD_model[dD_model == -999] = np.nan

    # --- Extract TES δD ---
    dD_tes = sat['dD'].values[t_tes, idx_tes, :, :]
    dD_tes[dD_tes == -999] = np.nan

    # --- Interpolate model to TES grid ---
    interp_func = RegularGridInterpolator((lat_m, lon_m), dD_model, bounds_error=False, fill_value=np.nan)
    LAT_T, LON_T = np.meshgrid(lat_t, lon_t, indexing='ij')
    dD_model_interp = interp_func((LAT_T, LON_T))

    # --- AK correction ---
    prior_HDO = sat['HDO_ConstraintVector'].values[t_tes, :, :, :]
    prior_H2O = sat['H2O_ConstraintVector'].values[t_tes, :, :, :]
    AK = sat['AK_HDO'].values[t_tes, :, :, :, :]
    dD_model_corr = np.full_like(dD_model_interp, np.nan)
    for i in range(len(lat_t)):
        for j in range(len(lon_t)):
            x_mod = dD_model_interp[i, j]
            if np.isnan(x_mod):
                continue
            R_mod = Rvsmow * (x_mod / 1000.0 + 1.0)
            R_apri = prior_HDO[:, i, j] / prior_H2O[:, i, j]
            A = AK[:, :, i, j]
            if np.isfinite(R_apri[idx_tes]) and np.all(np.isfinite(A[idx_tes, :])):
                R_adj = R_apri + A @ (R_mod - R_apri)
                dD_model_corr[i, j] = (R_adj[idx_tes] / Rvsmow - 1.0) * 1000.0

    # --- Shift TES lons ---
    lon_t_shift = ((lon_t + 180) % 360) - 180
    sort_t = np.argsort(lon_t_shift)

    # --- Color scale ---
    combined = np.concatenate([
        dD_tes[np.isfinite(dD_tes)],
        dD_model_interp[np.isfinite(dD_model_interp)],
        dD_model_corr[np.isfinite(dD_model_corr)]
    ])
    vmin, vmax = np.nanpercentile(combined, 2), np.nanpercentile(combined, 98)
    levels = np.linspace(vmin, vmax, 21)

    # === Create figure with reduced height ===
    fig = plt.figure(figsize=(14, 8))

    # --- Left: Averaging Kernel plot ---
    lon_mask = (lon_t >= lon_min) & (lon_t <= lon_max)
    lat_mask = (lat_t >= lat_min) & (lat_t <= lat_max)
    akhdo_subset = akhdo[:, :, :, lat_mask, :][:, :, :, :, lon_mask]
    AK_avg = np.nanmean(akhdo_subset, axis=(0, 3, 4))

    ax_left = fig.add_subplot(1, 2, 1)
    cmap = cm.get_cmap('viridis', AK_avg.shape[0])
    for i in range(AK_avg.shape[0]):
        ax_left.plot(AK_avg[i, :], level, color=cmap(i), label=f'{level[i]:.0f} hPa')
    ax_left.invert_yaxis()
    ax_left.set_xlabel('AK Value', fontsize=9)
    ax_left.set_ylabel('Pressure [hPa]', fontsize=9)
    ax_left.set_title(f'Mean AKs\n({lat_min}°–{lat_max}°, {lon_min}°–{lon_max}°)', fontsize=10)
    ax_left.grid(True)
    ax_left.legend(fontsize=6, loc='best')

    # --- Right: Spatial plots ---
    pro = ccrs.PlateCarree()
    xticks = np.arange(-180, 181, 60)
    yticks = np.arange(-90, 91, 30)

    datasets = [
        (f'TES δD ~{plevel} hPa', dD_tes),
        (f'Model δD interp ~{plevel} hPa', dD_model_interp),
        (f'AK-corrected Model δD ~{plevel} hPa', dD_model_corr)
    ]

    cs_bottom = None
    for idx, (title, data) in enumerate(datasets):
        ax = fig.add_subplot(3, 2, idx * 2 + 2, projection=pro)
        cs = ax.contourf(lon_t_shift[sort_t], lat_t, data[:, sort_t], levels=levels,
                         vmin=vmin, vmax=vmax, cmap='coolwarm', transform=pro)
        if idx == len(datasets) - 1:
            cs_bottom = cs
        ax.coastlines()
        ax.add_feature(cfeature.BORDERS, linewidth=0.5)
        gl = ax.gridlines(draw_labels=True, linewidth=0.3, color='gray', alpha=0.5)
        gl.xlocator = mticker.FixedLocator(xticks)
        gl.ylocator = mticker.FixedLocator(yticks)
        ax.set_title(title, fontsize=9)
        ax.add_patch(Rectangle((lon_min, lat_min), lon_max - lon_min, lat_max - lat_min,
                               linewidth=2, edgecolor='red', facecolor='none', transform=pro))

    # Single bottom colorbar
    cax = fig.add_axes([0.56, 0.05, 0.35, 0.015])
    cb = fig.colorbar(cs_bottom, cax=cax, orientation='horizontal')
    cb.set_label('δD (‰)', fontsize=9)

    plt.tight_layout(rect=[0, 0.06, 1, 1])
    plt.show()

# --- Display widgets ---
ui = VBox([
    HBox([plevel_input, month_slider]),
    HBox([lat_min_input, lat_max_input]),
    HBox([lon_min_input, lon_max_input])
])
out = interactive_output(update_all, {
    'plevel': plevel_input,
    'month': month_slider,
    'lat_min': lat_min_input,
    'lat_max': lat_max_input,
    'lon_min': lon_min_input,
    'lon_max': lon_max_input
})
display(ui, out)


---

## Summary

By combining the **spatial comparison** with the **vertical sensitivity context**, this notebook helps us:
- Identify regions and levels where TES is truly sensitive.
- Understand where AK smoothing significantly alters the model output.
- Interpret model–observation differences more robustly.
