<a href="https://colab.research.google.com/github/mampisarkar111/wisonet-colab-demo/blob/main/Wisonet_kernel_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Comparing WisoMIP Model Ensemble and Satellite (TES) Water Vapor Isotope Profiles Before- and After- Using Averaging Kernels

This notebook demonstrates how to compare modeled HDO (deuterated water vapor) profiles with satellite retrievals (e.g., TES) in a physically consistent way.  

It shows:
- How to read, clean, and subset model and satellite data
- How to convert specific humidity to HDO volume mixing ratio (VMR)
- How to apply satellite averaging kernels (AK) and a priori profiles to the model for fair comparison
- How to visualize and interpret the results, both spatially and vertically

**Key concepts:**
- **Averaging kernel (AK):** Represents how the satellite's retrieval "sees" the vertical structure i.e., its true vertical sensitivity.  
- **A priori profile:** The climatological/initial guess profile used in the retrieval; the AK smooths the difference between model and a priori.
- **AK smoothing:** Applying the AK to model output produces a "retrieved-like" version, directly comparable to the satellite product.

This workflow shows an example case using monthly WisoSAT TES and WisoMIP Ensemble global 5x5 datasets from 2006 for the model-observation comparison (see Rodgers, 2000; TES User Guide).


In [None]:
# Download SWING3 2006 subset data (NetCDF format) from Box (no output)
!curl -L "https://rice.box.com/shared/static/bcoy3ob0dme3umpurqmf0p6o48bznkj1" -o SWING3_2006_subset.nc > /dev/null 2>&1

# Download TES monthly 5° x 5° gridded isotope data (filtered) for 2006 from Box (no output)
!curl -L "https://rice.box.com/shared/static/uuy9m15qc1p7s4wm1yrzfzxc6knx7hzw" -o TES_monthly_5deg_strict.nc > /dev/null 2>&1

# Install cartopy (no output)
!pip install -q cartopy


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.7/11.7 MB[0m [31m92.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Data Loading and Setup

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import calendar
import ipywidgets as widgets
from ipywidgets import interactive_output, VBox, HBox
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import warnings
from cartopy.io import DownloadWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="xarray")
warnings.filterwarnings("ignore", category=DownloadWarning, module="cartopy")
warnings.filterwarnings("ignore", category=RuntimeWarning)

# Load model (SWING3) and satellite (TES) datasets
ds = xr.open_dataset("SWING3_2006_subset.nc")           # Model output
sat = xr.open_dataset("TES_monthly_5deg_strict.nc")     # Satellite retrieval

# Common dimension arrays for grids and levels
lon = ds['lon'].values
lat = ds['lat'].values
p = ds['p'].values                     # Model pressure [hPa]
tes_lat = sat['lat'].values
tes_lon = sat['lon'].values
tes_p = sat['level'].values            # TES pressure [hPa]
tes_time = sat['time'].values


## Scientific Background

### 1. Specific Humidity to Volume Mixing Ratio (VMR)

Specific humidity ($q$) is converted to molar/volume mixing ratio ($q_\mathrm{VMR}$) as:
$$
q_\mathrm{VMR} = \frac{q / M_\mathrm{H_2O}}{(q / M_\mathrm{H_2O}) + ((1-q) / M_\mathrm{air})}
$$
where:
- $M_\mathrm{H_2O}$: molar mass of H$_2$O ($18.01528$ g/mol)
- $M_\mathrm{air}$: molar mass of dry air ($28.9647$ g/mol)




### 2. Converting to HDO VMR

Given the reference HDO/H$_2$O ratio in Vienna Standard Mean Ocean Water (VSMOW), $R_\mathrm{vsmow} = 3.1152 \times 10^{-4}$, and the $\delta D$ value in permil ($\text{dD}$):
$$
\mathrm{HDO_{VMR}} = q_\mathrm{VMR} \times R_\mathrm{vsmow} \times \left(1 + \frac{\text{dD}}{1000}\right)
$$
We convert model-based specific humidity ($q$) to volume mixing ratio ($q_{\mathrm{VMR}}$) in order to match the satellite data units.

- **Model output** is typically in specific humidity ($q$), which is the mass of water vapor per mass of air (kg/kg).
- **Satellite retrievals** (such as TES) provide the volume mixing ratio (VMR), which is the number of water vapor molecules per total air molecules (mol/mol).

**Converting $q$ to $q_{\mathrm{VMR}}$ ensures that both model and satellite data are in the same units, allowing for direct, apples-to-apples comparison and proper application of the satellite averaging kernel.**

### 3. Applying the Averaging Kernel (AK)

Satellite retrievals do not "see" the full vertical structure. They apply a smoothing described by the **averaging kernel** ($\mathbf{A}$), and are anchored to a prior profile ($x_a$).

Directly comparing a model profile to a satellite retrieval is not valid because the satellite retrieval is always "blurred" (smoothed) by its AK and influenced by its a priori. To accurately simulate what the satellite would "see" if the model were reality, you must **apply the AK and a priori to the model** before comparing.

Following the approach described in **Worden et al. (2011)**, we apply the averaging kernel as follows (in log space):

$$
\ln(x_\mathrm{smoothed}) = \ln(x_a) + \mathbf{A} \cdot \left[\ln(x_\mathrm{model}) - \ln(x_a)\right]
$$

where:
- $x_\mathrm{model}$ is the vertical model profile, interpolated to the satellite pressure grid,
- $x_a$ is the satellite retrieval a priori profile,
- $\mathbf{A}$ is the satellite averaging kernel matrix (describing vertical sensitivity),
- $\ln$ denotes the natural logarithm (applied elementwise).

**This process ensures that model and satellite data are compared on the same "vertical sensitivity" basis, providing a fair, physically consistent comparison.**

*Reference: Worden, J. et al. (2011), "Estimate of bias in Aura TES HDO/H₂O profiles from comparison of TES and in situ HDO/H₂O measurements at the Mauna Loa observatory," ACP, 11, 4491–4503, [link](https://www.atmos-chem-phys.net/11/4491/2011/).*



## 🌎 Interactive HDO VMR Maps: Satellite vs Model

This interactive tool lets you compare spatial maps of HDO volume mixing ratio (VMR) from:

- **TES Satellite Retrievals**: Observed values at each grid point.
- **Model "True" Field**: WisoMIP simulated $q$, converted to HDO VMR units.
- **AK-Corrected Model**: The model output processed with the TES averaging kernel and a priori, simulating what the satellite would “see” if the model were reality.

**How to use:**  
- Select a **pressure level** (in hPa) and a **month** (1 = January, 12 = December) to view.
- The three maps update by themselves: left = TES, middle = raw model, right = AK-corrected model.

**Tip:**  
The AK-corrected model panel is the best comparison to TES. It answers:  
*If the satellite were looking at the model world, what would it see?*

---

*See “Scientific Background” above for formulas and details.*


In [None]:
plevel_input = widgets.BoundedFloatText(value=900, min=200, max=1000, step=10, description='Level [hPa]:')
month_slider = widgets.IntSlider(min=1, max=12, step=1, value=1, description='Month')

def update_plot(month, plevel):
    m = month - 1
    last_day = calendar.monthrange(2006, month)[1]
    user_date_str = f'2006-{month:02d}-{last_day}'
    target_date = datetime.strptime(user_date_str, '%Y-%m-%d')

    dD = ds['dD'].values.astype(float)
    q = ds['q'].values.astype(float)
    dD[dD == -999] = np.nan
    q[q == -999] = np.nan

    # --- Use full global region for maps
    lat_mask = np.full_like(lat, True, dtype=bool)
    lon_mask = np.full_like(lon, True, dtype=bool)

    dD_eq = dD[m, :, lat_mask, :][:, :, lon_mask]
    q_eq = q[m, :, lat_mask, :][:, :, lon_mask]

    dD_mean = np.nanmean(np.nanmean(dD_eq, axis=2), axis=0)
    q_mean = np.nanmean(np.nanmean(q_eq, axis=2), axis=0)

    # --- Model HDO calculations (global)
    Rvsmow = 3.1152e-4
    M_air = 28.9647
    M_H2O = 18.01528
    q_vmr = (q_mean / M_H2O) / ((q_mean / M_H2O) + ((1 - q_mean) / M_air))
    HDO_vmr1 = q_vmr * Rvsmow * (1 + dD_mean / 1000)

    q_vmr_3d = (q / M_H2O) / ((q / M_H2O) + ((1 - q) / M_air))
    HDO_vmr1_3d = q_vmr_3d * Rvsmow * (1 + dD / 1000)
    HDO_vmr1_3d = HDO_vmr1_3d[m, :, :, :]

    # --- Find TES time index for this month
    if np.issubdtype(tes_time.dtype, np.datetime64):
        t_tes = np.argmin(np.abs(tes_time - np.datetime64(user_date_str)))
    else:
        base_date = datetime(2000, 1, 1)
        t_tes = np.argmin(np.abs(tes_time - (target_date - base_date).days))

    idx_tes = int(np.argmin(np.abs(tes_p - plevel)))
    idx_mod = int(np.argmin(np.abs(p - plevel)))

    # --- 2D global maps (TES, raw model)
    HDO_globe = sat['HDO_vmr'].values[t_tes, idx_tes, :, :]
    HDO_model = HDO_vmr1_3d[idx_mod, :, :]

    # --- Compute AK-corrected model global map (on TES grid) ---
    nlat_tes, nlon_tes = len(tes_lat), len(tes_lon)
    HDO_model_corr_2d = np.full((nlat_tes, nlon_tes), np.nan)
    for i in range(nlat_tes):
        for j in range(nlon_tes):
            model_ilat = np.argmin(np.abs(lat - tes_lat[i]))
            model_ilon = np.argmin(np.abs(lon - tes_lon[j]))
            x = HDO_vmr1_3d[:, model_ilat, model_ilon]
            a = sat['HDO_ConstraintVector'].values[t_tes, :, i, j]
            A = sat['AK_HDO'].values[t_tes, :, :, i, j]
            if np.all(np.isfinite(x)) and np.all(np.isfinite(a)) and np.all(np.isfinite(A)):
                ln_true = np.log(x)
                ln_apriori = np.log(a)
                ln_smoothed = ln_apriori + A @ (ln_true - ln_apriori)
                HDO_model_corr_2d[i, j] = np.exp(ln_smoothed[idx_tes])

    # --- Mask fill values
    HDO_globe[HDO_globe == -999] = np.nan
    HDO_model[HDO_model == -999] = np.nan
    HDO_model_corr_2d[HDO_model_corr_2d == -999] = np.nan

    # --- Longitude sorting for model grid
    lon_plot = ((lon + 180) % 360) - 180
    sort_idx = np.argsort(lon_plot)
    lon_plot_sorted = lon_plot[sort_idx]
    HDO_model_sorted = HDO_model[:, sort_idx]

    # --- Colorbar limits (include all 3 maps)
    combined = np.concatenate([
        HDO_globe[np.isfinite(HDO_globe)],
        HDO_model_sorted[np.isfinite(HDO_model_sorted)],
        HDO_model_corr_2d[np.isfinite(HDO_model_corr_2d)]
    ])
    vmin, vmax = np.nanpercentile(combined, 1), np.nanpercentile(combined, 99)
    levels = np.linspace(vmin, vmax, 21)

    fig = plt.figure(figsize=(22, 8))

    ax1 = fig.add_subplot(1, 3, 1, projection=ccrs.PlateCarree())
    cs1 = ax1.contourf(tes_lon, tes_lat, HDO_globe, levels=levels, cmap='viridis',
                       vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
    ax1.coastlines()
    ax1.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax1.gridlines(draw_labels=True, linewidth=0.3)
    ax1.set_title(f'TES HDO VMR at ~{plevel} hPa')
    fig.colorbar(cs1, ax=ax1, orientation='horizontal', pad=0.05).set_label('TES HDO VMR')

    ax2 = fig.add_subplot(1, 3, 2, projection=ccrs.PlateCarree())
    cs2 = ax2.contourf(lon_plot_sorted, lat, HDO_model_sorted, levels=levels, cmap='viridis',
                       vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
    ax2.coastlines()
    ax2.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax2.gridlines(draw_labels=True, linewidth=0.3)
    ax2.set_title(f'Model HDO VMR at ~{plevel} hPa')
    fig.colorbar(cs2, ax=ax2, orientation='horizontal', pad=0.05).set_label('Model HDO VMR')

    ax3 = fig.add_subplot(1, 3, 3, projection=ccrs.PlateCarree())
    cs3 = ax3.contourf(tes_lon, tes_lat, HDO_model_corr_2d, levels=levels, cmap='viridis',
                       vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
    ax3.coastlines()
    ax3.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax3.gridlines(draw_labels=True, linewidth=0.3)
    ax3.set_title(f'AK-corrected Model HDO VMR at ~{plevel} hPa')
    fig.colorbar(cs3, ax=ax3, orientation='horizontal', pad=0.05).set_label('AK-corrected Model HDO VMR')

    plt.tight_layout()
    plt.show()

# Display widgets
ui = VBox([
    HBox([plevel_input, month_slider])
])
out = interactive_output(update_plot, {
    'plevel': plevel_input,
    'month': month_slider
})
display(ui, out)


VBox(children=(HBox(children=(BoundedFloatText(value=900.0, description='Level [hPa]:', max=1000.0, min=200.0,…

Output()

## Advanced Interactive Exploration: Regional Profiles and Vertical Sensitivity

In this section, you can:

- **Select a specific latitude/longitude region** (using Lat Min/Max, Lon Min/Max) to compute spatial averages and focus your analysis.
- **Choose the vertical profile view:**  
  - *Model True*: Profile from the model, converted to HDO VMR.
  - *TES Retrieved*: The satellite-retrieved HDO VMR.
  - *TES Apriori*: The a priori (prior) profile used in the satellite retrieval.
  - *AK-applied*: The model profile smoothed by the TES averaging kernel, simulating what the satellite would observe if the model were reality.
- **Toggle log scaling** for easier visualization of the full dynamic range of HDO in the vertical.

The **profile plot** (bottom left) shows how satellite and model vertical structures compare over your chosen region and month.

The **AK matrix plot** (bottom center) visualizes the vertical sensitivity of the TES retrieval at each pressure level—the “blurring” effect applied to the model.

**Tips:**
- Start with a broad latitude/longitude range, then zoom into smaller regions (e.g., a box over the subtropics or tropics) to see regional differences.
- Change the Step View to see the impact of the averaging kernel and a priori.
- Log scale is especially useful to view both upper and lower troposphere on the same axis.

---

*For scientific details, see the “Scientific Background” section above.*


In [None]:

# User controls for interactive exploration
plevel_input = widgets.BoundedFloatText(value=900, min=200, max=1000, step=10, description='Level [hPa]:')
month_slider = widgets.IntSlider(min=1, max=12, step=1, value=1, description='Month')
log_toggle = widgets.Checkbox(value=True, description='Log X-axis')
step_dropdown = widgets.Dropdown(options=['All', 'Model True vs TES', 'Model True vs TES vs Apriori'],
                                value='All', description='Step View:')
lat_min_input = widgets.BoundedFloatText(value=20, min=-90, max=90, step=1, description='Lat Min:')
lat_max_input = widgets.BoundedFloatText(value=40, min=-90, max=90, step=1, description='Lat Max:')
lon_min_input = widgets.BoundedFloatText(value=-70, min=-180, max=180, step=5, description='Lon Min:')
lon_max_input = widgets.BoundedFloatText(value=-40, min=-180, max=180, step=5, description='Lon Max:')

def update_plot(month, plevel, log_x, step_view, lat_min, lat_max, lon_min, lon_max):
    m = month - 1
    last_day = calendar.monthrange(2006, month)[1]
    user_date_str = f'2006-{month:02d}-{last_day}'
    target_date = datetime.strptime(user_date_str, '%Y-%m-%d')

    dD = ds['dD'].values.astype(float)
    q = ds['q'].values.astype(float)
    dD[dD == -999] = np.nan
    q[q == -999] = np.nan

    # --- Lat/lon subsetting for profile
    lat_mask = (lat >= lat_min) & (lat <= lat_max)
    lon_wrapped = (lon + 360) % 360
    lon_min_wrapped = (lon_min + 360) % 360
    lon_max_wrapped = (lon_max + 360) % 360
    if lon_min_wrapped < lon_max_wrapped:
        lon_mask = (lon_wrapped >= lon_min_wrapped) & (lon_wrapped <= lon_max_wrapped)
    else:
        lon_mask = (lon_wrapped >= lon_min_wrapped) | (lon_wrapped <= lon_max_wrapped)

    dD_eq = dD[m, :, lat_mask, :][:, :, lon_mask]
    q_eq = q[m, :, lat_mask, :][:, :, lon_mask]

    dD_mean = np.nanmean(np.nanmean(dD_eq, axis=2), axis=0)
    q_mean = np.nanmean(np.nanmean(q_eq, axis=2), axis=0)

    # --- Model HDO calculations (all-lat/lon for maps, subregion for profile)
    Rvsmow = 3.1152e-4
    M_air = 28.9647
    M_H2O = 18.01528
    q_vmr = (q_mean / M_H2O) / ((q_mean / M_H2O) + ((1 - q_mean) / M_air))
    HDO_vmr1 = q_vmr * Rvsmow * (1 + dD_mean / 1000)

    q_vmr_3d = (q / M_H2O) / ((q / M_H2O) + ((1 - q) / M_air))
    HDO_vmr1_3d = q_vmr_3d * Rvsmow * (1 + dD / 1000)
    HDO_vmr1_3d = HDO_vmr1_3d[m, :, :, :]  # (level, lat, lon)

    # --- Find TES time index for this month
    if np.issubdtype(tes_time.dtype, np.datetime64):
        t_tes = np.argmin(np.abs(tes_time - np.datetime64(user_date_str)))
    else:
        base_date = datetime(2000, 1, 1)
        t_tes = np.argmin(np.abs(tes_time - (target_date - base_date).days))

    # --- TES region mask for vertical profile
    tes_mask = (tes_lat >= lat_min) & (tes_lat <= lat_max)
    tes_lon_wrapped = (tes_lon + 360) % 360
    if lon_min_wrapped < lon_max_wrapped:
        tes_lon_mask = (tes_lon_wrapped >= lon_min_wrapped) & (tes_lon_wrapped <= lon_max_wrapped)
    else:
        tes_lon_mask = (tes_lon_wrapped >= lon_min_wrapped) | (tes_lon_wrapped <= lon_max_wrapped)

    HDO_ret = sat['HDO_vmr'].values[t_tes, :, tes_mask, :][:, :, tes_lon_mask]
    HDO_a = sat['HDO_ConstraintVector'].values[t_tes, :, tes_mask, :][:, :, tes_lon_mask]
    AK = sat['AK_HDO'].values[t_tes, :, :, tes_mask, :][:, :, :, tes_lon_mask]

    idx_tes = int(np.argmin(np.abs(tes_p - plevel)))
    idx_mod = int(np.argmin(np.abs(p - plevel)))

    # --- 2D global maps (TES, raw model)
    HDO_globe = sat['HDO_vmr'].values[t_tes, idx_tes, :, :]
    HDO_model = HDO_vmr1_3d[idx_mod, :, :]

    # --- Compute AK-corrected model global map (on TES grid) ---
    nlat_tes, nlon_tes = len(tes_lat), len(tes_lon)
    HDO_model_corr_2d = np.full((nlat_tes, nlon_tes), np.nan)
    for i in range(nlat_tes):      # TES lat
        for j in range(nlon_tes):  # TES lon
            model_ilat = np.argmin(np.abs(lat - tes_lat[i]))
            model_ilon = np.argmin(np.abs(lon - tes_lon[j]))
            x = HDO_vmr1_3d[:, model_ilat, model_ilon]   # [level]
            a = sat['HDO_ConstraintVector'].values[t_tes, :, i, j]
            A = sat['AK_HDO'].values[t_tes, :, :, i, j]
            if np.all(np.isfinite(x)) and np.all(np.isfinite(a)) and np.all(np.isfinite(A)):
                ln_true = np.log(x)
                ln_apriori = np.log(a)
                ln_smoothed = ln_apriori + A @ (ln_true - ln_apriori)
                HDO_model_corr_2d[i, j] = np.exp(ln_smoothed[idx_tes])

    # --- Mask fill values
    HDO_globe[HDO_globe == -999] = np.nan
    HDO_model[HDO_model == -999] = np.nan
    HDO_model_corr_2d[HDO_model_corr_2d == -999] = np.nan

    # --- Longitude sorting for model grid
    lon_plot = ((lon + 180) % 360) - 180
    sort_idx = np.argsort(lon_plot)
    lon_plot_sorted = lon_plot[sort_idx]
    HDO_model_sorted = HDO_model[:, sort_idx]

    # --- Colorbar limits (include all 3 maps)
    combined = np.concatenate([
        HDO_globe[np.isfinite(HDO_globe)],
        HDO_model_sorted[np.isfinite(HDO_model_sorted)],
        HDO_model_corr_2d[np.isfinite(HDO_model_corr_2d)]
    ])
    vmin, vmax = np.nanpercentile(combined, 1), np.nanpercentile(combined, 99)
    levels = np.linspace(vmin, vmax, 21)

    # --- Map figure: 3 horizontal panels (TES, Model, AK-corrected Model) ---
    fig = plt.figure(figsize=(22, 8))

    ax1 = fig.add_subplot(2, 3, 1, projection=ccrs.PlateCarree())
    cs1 = ax1.contourf(tes_lon, tes_lat, HDO_globe, levels=levels, cmap='viridis',
                       vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
    ax1.coastlines()
    ax1.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax1.gridlines(draw_labels=True, linewidth=0.3)
    ax1.add_patch(plt.Rectangle((lon_min, lat_min), lon_max - lon_min, lat_max - lat_min,
                                linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree()))
    ax1.set_title(f'TES HDO VMR at ~{plevel} hPa')
    fig.colorbar(cs1, ax=ax1, orientation='horizontal', pad=0.05).set_label('TES HDO VMR')

    ax2 = fig.add_subplot(2, 3, 2, projection=ccrs.PlateCarree())
    cs2 = ax2.contourf(lon_plot_sorted, lat, HDO_model_sorted, levels=levels, cmap='viridis',
                       vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
    ax2.coastlines()
    ax2.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax2.gridlines(draw_labels=True, linewidth=0.3)
    ax2.add_patch(plt.Rectangle((lon_min, lat_min), lon_max - lon_min, lat_max - lat_min,
                                linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree()))
    ax2.set_title(f'Model Ensemble HDO VMR at ~{plevel} hPa')
    fig.colorbar(cs2, ax=ax2, orientation='horizontal', pad=0.05).set_label('Model Ensemble HDO VMR')

    ax3 = fig.add_subplot(2, 3, 3, projection=ccrs.PlateCarree())
    cs3 = ax3.contourf(tes_lon, tes_lat, HDO_model_corr_2d, levels=levels, cmap='viridis',
                       vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
    ax3.coastlines()
    ax3.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax3.gridlines(draw_labels=True, linewidth=0.3)
    ax3.add_patch(plt.Rectangle((lon_min, lat_min), lon_max - lon_min, lat_max - lat_min,
                                linewidth=2, edgecolor='red', facecolor='none', transform=ccrs.PlateCarree()))
    ax3.set_title(f'AK-corrected Model HDO VMR at ~{plevel} hPa')
    fig.colorbar(cs3, ax=ax3, orientation='horizontal', pad=0.05).set_label('AK-corrected Model HDO VMR')

    # === Profile and AK plots remain as before, just move to next row ===
    HDO_ret[HDO_ret == -999] = np.nan
    HDO_a[HDO_a == -999] = np.nan
    AK[AK == -999] = np.nan

    HDO_ret_avg = np.nanmean(np.nanmean(HDO_ret, axis=2), axis=0)
    HDO_a_avg = np.nanmean(np.nanmean(HDO_a, axis=2), axis=0)
    AK_avg = np.nanmean(np.nanmean(AK, axis=3), axis=0)

    HDO_vmr1_interp = np.interp(tes_p, p[::-1], HDO_vmr1[::-1])
    ln_true = np.log(HDO_vmr1_interp)
    ln_apriori = np.log(HDO_a_avg)
    ln_smoothed = ln_apriori + AK_avg @ (ln_true - ln_apriori)
    HDO_recon = np.exp(ln_smoothed)

    # --- Profile plot ---
    ax4 = fig.add_subplot(2, 3, 4)
    if step_view == 'Model True vs TES':
        ax4.plot(HDO_vmr1_interp, tes_p, 'k-', label='Model True')
        ax4.plot(HDO_ret_avg, tes_p, 'b-o', label='TES Retrieved')
    elif step_view == 'Model True vs TES vs Apriori':
        ax4.plot(HDO_vmr1_interp, tes_p, 'k-', label='Model True')
        ax4.plot(HDO_ret_avg, tes_p, 'b-o', label='TES Retrieved')
        ax4.plot(HDO_a_avg, tes_p, 'g--', label='TES Apriori')
    else:
        ax4.plot(HDO_ret_avg, tes_p, 'b-o', label='TES Retrieved')
        ax4.plot(HDO_a_avg, tes_p, 'g--', label='TES Apriori')
        ax4.plot(HDO_vmr1_interp, tes_p, 'k-', label='Model True')
        ax4.plot(HDO_recon, tes_p, 'm-.', label='AK-applied')
    ax4.set_xlabel('HDO VMR [mol/mol]')
    ax4.set_ylabel('Pressure [hPa]')
    lat_label = f"{abs(lat_min):.0f}°{'S' if lat_min < 0 else 'N'}–{abs(lat_max):.0f}°{'S' if lat_max < 0 else 'N'}"
    lon_label = f"{abs(lon_min):.0f}°{'W' if lon_min < 0 else 'E'}–{abs(lon_max):.0f}°{'W' if lon_max < 0 else 'E'}"
    ax4.set_title(f'Vertical Profile ({lat_label}, {lon_label}, {user_date_str})')

    all_profiles = [HDO_ret_avg, HDO_a_avg, HDO_vmr1_interp, HDO_recon]
    valid_values = np.concatenate([x[np.isfinite(x)] for x in all_profiles])
    x_max = np.nanmax(valid_values)
    ax4.invert_yaxis()
    ax4.grid(True)
    ax4.set_ylim(1000, 0)
    ax4.legend(loc='lower left')
    if log_x:
        ax4.set_xscale('log')
        ax4.set_xlim(1e-10, x_max * 1.1)
    else:
        ax4.set_xlim(0, x_max * 1.1)

    # --- AK matrix plot ---
    import matplotlib.cm as cm
    import matplotlib.colors as colors
    ax5 = fig.add_subplot(2, 3, 5)
    norm = colors.Normalize(vmin=np.min(tes_p), vmax=np.max(tes_p))
    cmap = cm.viridis
    sm = cm.ScalarMappable(norm=norm, cmap=cmap)
    AK_avg = AK_avg.T
    for i in range(AK_avg.shape[0]):
        color = cmap(norm(tes_p[i]))
        ax5.plot(AK_avg[i, :], tes_p, color=color, label=f'Level {tes_p[i]:.0f} hPa')
    ax5.set_xlabel('AK Value (dx_est / dx_true)')
    ax5.set_ylabel('Pressure [hPa]')
    ax5.set_title('Averaging Kernel Rows (colored by level)')
    ax5.invert_yaxis()
    ax5.grid(True)
    cbar = fig.colorbar(sm, ax=ax5, orientation='vertical', pad=0.02)
    cbar.set_label('Retrieved Level [hPa]')

    plt.tight_layout()
    plt.show()

# === Display Widgets ===
ui = VBox([
    HBox([plevel_input]),
    HBox([month_slider]),
    HBox([lat_min_input, lat_max_input, lon_min_input, lon_max_input]),
    HBox([log_toggle, step_dropdown])
])
out = interactive_output(update_plot, {
    'plevel': plevel_input,
    'month': month_slider,
    'log_x': log_toggle,
    'step_view': step_dropdown,
    'lat_min': lat_min_input,
    'lat_max': lat_max_input,
    'lon_min': lon_min_input,
    'lon_max': lon_max_input
})
display(ui, out)


VBox(children=(HBox(children=(BoundedFloatText(value=900.0, description='Level [hPa]:', max=1000.0, min=200.0,…

Output()