In [None]:
# ============================================================
# DMI code example (Jupyter-friendly, single block)
# Same idea as my original PhD script
# ============================================================

import os
import numpy as np
import xarray as xr
from scipy.ndimage import binary_dilation
import matplotlib.pyplot as plt
from matplotlib.dates import AutoDateLocator, DateFormatter

# ----------------------------
# USER SETTINGS
# ----------------------------
directory_path = "..."   # <-- change
output_dir = "./outputs"
os.makedirs(output_dir, exist_ok=True)

# Variables per scenario
vars_ctrl = ["W", "TQI"]
vars_seed = ["W", "TQI","IN_T1"]

# Thresholds
w_threshold1 = 1.0
particle_threshold1 = 150000.0
level_start_threshold1 = 35
level_end_threshold1 = 79

# Binary dilation neighborhood (rlat/rlon)
structure_threshold1 = np.ones((3, 3), dtype=bool)

# ----------------------------
# 1) LOAD DATA (CTRL + SEED) for PERT1..PERT10
# ----------------------------
ctrl = {}
seed = {}

for i in range(1, 11):
    # CTRL
    variable_file_map_ctrl = {v: f"{v}_DOMAIN_CELL_CTRL_PERT{i}.nc" for v in vars_ctrl}
    ctrl[i] = {v: xr.open_dataset(directory_path + fname) for v, fname in variable_file_map_ctrl.items()}

    # SEED (SEEDING_MED)
    variable_file_map_seed = {v: f"{v}_DOMAIN_CELL_SEEDING_MED_PERT{i}.nc" for v in vars_seed}
    seed[i] = {v: xr.open_dataset(directory_path + fname) for v, fname in variable_file_map_seed.items()}

print("✅ Loaded CTRL + SEED datasets for members 1..10")

# ----------------------------
# 2) COMPUTE w_max over altitude (SEED)
# ----------------------------
w_max_seed = {}
for i in range(1, 11):
    w_seed = seed[i]["W"]["W"]                    # (time, altitude, rlat, rlon)
    w_max_seed[i] = w_seed.max(dim="altitude")    # -> (time, rlat, rlon)

print("✅ Computed w_max_seed for members 1..10")

# ----------------------------
# 3) BUILD CONVECTIVE STORM MASK per member (SEED thresholds + dilation)
#    Mask definition: w_max > w_threshold AND AgI > particle_threshold in a vertical layer
# ----------------------------
expanded_mask = {}

for i in range(1, 11):
    w_field = w_max_seed[i]                              # (time, rlat, rlon)
    in_t1 = seed[i]["IN_T1"]["IN_T1"]                    # (time, altitude, rlat, rlon)

    mask_i = (
        (w_field > w_threshold1) &
        (in_t1[:, level_start_threshold1:level_end_threshold1, :, :] > particle_threshold1).any(dim="altitude")
    )  # (time, rlat, rlon)

    mask_np = mask_i.values

    expanded_np = np.empty_like(mask_np)
    for t in range(mask_np.shape[0]):
        expanded_np[t] = binary_dilation(mask_np[t], structure=structure_threshold1)

    expanded_mask[i] = xr.DataArray(expanded_np, coords=mask_i.coords, dims=mask_i.dims)

print("✅ Created expanded storm masks for members 1..10")

# ----------------------------
# 4) APPLY MASK to TQI (SEED + CTRL)
#    NOTE: For CTRL we apply the SAME storm region mask as in SEED member i
#    so the comparison CTRL vs SEED is done over the same grid points.
# ----------------------------
masked_tqi_seed = {}
masked_tqi_ctrl = {}

for i in range(1, 11):
    tqi_seed = seed[i]["TQI"]["TQI"]   # (time, altitude, rlat, rlon) or similar
    tqi_ctrl = ctrl[i]["TQI"]["TQI"]

    masked_tqi_seed[i] = tqi_seed.where(expanded_mask[i], np.nan)
    masked_tqi_ctrl[i] = tqi_ctrl.where(expanded_mask[i], np.nan)

print("✅ Applied storm mask to TQI for SEED + CTRL")

# ----------------------------
# 5) FILTER OUTLIERS: keep only 1–99% range (per member, NaNs ignored)
# ----------------------------
def filter_1_99(da: xr.DataArray) -> xr.DataArray:
    q01 = da.quantile(0.01, skipna=True)
    q99 = da.quantile(0.99, skipna=True)
    return da.where((da >= q01) & (da <= q99), np.nan)

for i in range(1, 11):
    masked_tqi_seed[i] = filter_1_99(masked_tqi_seed[i])
    masked_tqi_ctrl[i] = filter_1_99(masked_tqi_ctrl[i])

print("✅ Applied 1–99% filtering to masked TQI (SEED + CTRL)")

# ----------------------------
# 6) SPATIAL MEAN -> time series per member
# ----------------------------
tqi_seed_mean = {}
tqi_ctrl_mean = {}

for i in range(1, 11):
    tqi_seed_mean[i] = masked_tqi_seed[i].mean(dim=["rlat", "rlon"], skipna=True)
    tqi_ctrl_mean[i] = masked_tqi_ctrl[i].mean(dim=["rlat", "rlon"], skipna=True)

print("✅ Computed domain-mean time series per member (SEED + CTRL)")

# ----------------------------
# 7) PLOT ALL MEMBERS (SEED)
# ----------------------------
fig, ax = plt.subplots(figsize=(12, 6))
for i in range(10, 0, -1):
    da = tqi_seed_mean[i]
    tdim = "time" if "time" in da.dims else da.dims[0]
    ax.plot(da[tdim].values, da.values, linewidth=1.6, alpha=0.9, label=f"pert{i}")

ax.set_xlabel("Time")
ax.set_ylabel("TQI (kg m$^{-2}$)")
ax.set_title("TQI (SEED) — storm-only domain mean per member")
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(DateFormatter("%Y-%m-%d\n%H:%M"))
ax.grid(True, alpha=0.3)
ax.legend(title="Member", frameon=False, bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.show()

# ----------------------------
# 8) PLOT ALL MEMBERS (CTRL)
# ----------------------------
fig, ax = plt.subplots(figsize=(12, 6))
for i in range(10, 0, -1):
    da = tqi_ctrl_mean[i]
    tdim = "time" if "time" in da.dims else da.dims[0]
    ax.plot(da[tdim].values, da.values, linewidth=1.6, alpha=0.9, label=f"pert{i}")

ax.set_xlabel("Time")
ax.set_ylabel("TQI (kg m$^{-2}$)")
ax.set_title("TQI (CTRL) — storm-only domain mean per member")
ax.xaxis.set_major_locator(AutoDateLocator())
ax.xaxis.set_major_formatter(DateFormatter("%Y-%m-%d\n%H:%M"))
ax.grid(True, alpha=0.3)
ax.legend(title="Member", frameon=False, bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.show()

# ----------------------------
# 9) EXPORT: save all member time series to one NetCDF
# ----------------------------
series = {}

for i in range(1, 11):
    series[f"tqi_pert{i}_seed_mean"] = tqi_seed_mean[i]
    series[f"tqi_pert{i}_ctrl_mean"] = tqi_ctrl_mean[i]

# Align on time (outer keeps union of timestamps)
aligned_list = xr.align(*[da.rename(name) for name, da in series.items()], join="outer")
names = list(series.keys())
aligned = {name: da for name, da in zip(names, aligned_list)}

# Sort time + cast
for k in aligned:
    da = aligned[k]
    tdim = "time" if "time" in da.dims else da.dims[0]
    aligned[k] = da.sortby(tdim).astype("float32")

ds_out = xr.Dataset(aligned)

# Light compression
encoding = {v: {"zlib": True, "complevel": 1} for v in ds_out.data_vars}

output_fname = os.path.join(output_dir, "tqi_storm_only_domain_mean_seed_vs_ctrl.nc")
ds_out.to_netcdf(output_fname, encoding=encoding)

print(f"✅ Saved: {output_fname}")
