<a href="https://colab.research.google.com/github/david-levin11/alaska_verification/blob/main/Cluster_Sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Testing ECMWF Cluster Techniques**
<br/>
Description--More to come later.

- David Levin, Arctic Testbed & Proving Ground, Anchorage Alaska

##**1 - Install and Import Packages**
This will take about a minute to run.

In [None]:
# @title
!pip install ecmwf-opendata eccodes==2.38.3 cfgrib xarray scikit-learn cartopy
from datetime import datetime, timezone
from pathlib import Path
from ecmwf.opendata import Client
import xarray as xr
import numpy as np
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

⏬ Downloading https://github.com/conda-forge/miniforge/releases/download/23.1.0-1/Mambaforge-23.1.0-1-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:20
🔁 Restarting kernel...


# **2 - Set Options & Download Data**

In [None]:

def download_ecmwf_ens(
    param: str,
    init_time: datetime,
    steps,
    *,
    level_type: str = "sfc",          # "sfc" or "pl"
    levels=None,                      # e.g., [500] or [1000,850,700] for levtype="pl"
    members="all",                    # "all", "control", "mean", "stdev", or list of ints (e.g., [1,3,5])
    target_dir=".",
    source: str = "ecmwf"             # "ecmwf", "aws", or "azure"
) -> Path:
    """
    Download ECMWF ENS data (Open Data) for a given init, variable, and steps.

    Examples:
      - Surface:   param="tp", level_type="sfc"
      - Pressure:  param="z",  level_type="pl", levels=[500]
      - Members:   members="all" (pf), "control" (cf), "mean" (em), "stdev" (es), or [1,2,3]
      - Steps:     integer, list[int], or "0/6/240" (MARS-style)
    """
    # Normalize inputs
    date_str = init_time.strftime("%Y-%m-%d")
    hour = int(init_time.strftime("%H"))

    # steps can be int, list, or MARS-style string
    if isinstance(steps, int):
        step_val = steps
    elif isinstance(steps, (list, tuple)):
        step_val = "/".join(str(s) for s in steps)
    else:
        # assume caller passed a MARS-style range like "0/6/240" or "0/to/240/by/6"
        step_val = str(steps)

    # Map members choice to ECMWF type/number
    mtype = "pf"  # perturbed members by default
    number_kw = {}
    if isinstance(members, str):
        m = members.lower()
        if m == "control":
            mtype = "cf"
        elif m == "mean":
            mtype = "em"
        elif m == "stdev":
            mtype = "es"
        elif m == "all":
            pass  # pf + no "number" downloads all perturbed members
        else:
            raise ValueError("members must be 'all', 'control', 'mean', 'stdev', or a list of ints.")
    else:
        # explicit list of member numbers
        mtype = "pf"
        nums = list(members)
        if not nums:
            raise ValueError("members list is empty.")
        number_kw = {"number": nums}

    # Level keywords
    level_kw = {}
    levtype = level_type.lower()
    if levtype == "pl":
        if not levels:
            raise ValueError("Pressure-level request requires 'levels' (e.g., [500] or [1000,850,700]).")
        level_kw = {"levtype": "pl", "levelist": "/".join(str(l) for l in levels)}
        level_label = f"pl_{'-'.join(str(l) for l in levels)}"
    elif levtype == "sfc":
        level_kw = {"levtype": "sfc"}
        level_label = "sfc"
    else:
        raise ValueError("level_type must be 'sfc' or 'pl'.")

    # Build output filename
    if isinstance(param, list):
      param_key = "_".join(param)
    else:
      param_key = param
    steps_label = str(steps) if isinstance(steps, str) else step_val.replace("/", "-")
    member_label = members if isinstance(members, str) else f"m{','.join(map(str, members))}"
    target_dir = Path(target_dir)
    target_dir.mkdir(parents=True, exist_ok=True)
    outfile = target_dir / f"ecmwf_ens_{param_key}_{level_label}_{date_str.replace('-','')}{hour:02d}.grib2"

    # Create client and request
    client = Client(source=source)
    req = {
        "date": date_str,
        "time": hour,
        "stream": "enfo",
        "type": mtype,
        "param": param,
        "step": step_val,
        **level_kw,
        **number_kw,
        "target": str(outfile),
    }

    # Retrieve
    client.retrieve(**req)
    return outfile

#@markdown What is your model run initialization date?
init_date = "2025-08-11" #@param {type:"date"}
#@markdown What is your model run time
init_time = "00" #@param ["00",  "12"] {type:"raw"}
param = "500mb Height"
param_dict = {"Total Precipitation": "tp", "500mb Height": "gh", "Surface Pressure": "sp"}
#@markdown What time frame do you want to cluster on?
timeframe = "Day7" #@param ["Day5", "Day6", "Day7", "Day8", "Day9", "Day5-7", "Day6-8", "Day7-9", "Day5-8", "Day6-9"]
timeranges = {
    "Day5": (120, 144),
    "Day6": (144, 168),
    "Day7": (168, 192),
    "Day8": (192, 216),
    "Day9": (216, 240),
    "Day5-7": (120, 192),
    "Day6-8": (144, 216),
    "Day7-9": (168, 240),
    "Day5-8": (120, 216),
    "Day6-9": (144, 240),
}
tr = timeranges[timeframe]

path = download_ecmwf_ens(
    param=param_dict[param],
    init_time=datetime.strptime(f"{init_date} {init_time}:00", "%Y-%m-%d %H:%M"),
    steps=list(range(0, 241, 6)),
    level_type="pl",
    levels=[500],
    members="all",
    target_dir=".",
    source="ecmwf"
)


Preparing transaction: ...working... done
Verifying transaction: ...working... done
Executing transaction: ...working... done
 ╭─────────────────────────────────────────────────╮
 │ I'm building Herbie's default config file.      │
 ╰╥────────────────────────────────────────────────╯
 👷🏻‍♂️
 ╭─────────────────────────────────────────────────╮
 │ You're ready to go.                             │
 │ You may edit the config file here:              │
 │ /root/.config/herbie/config.toml                │
 ╰╥────────────────────────────────────────────────╯
 👷🏻‍♂️


  _pyproj_global_context_initialize()


# **3. Perform Cluster Analysis & Plot EOF/Phase Space & Pick Representative Members**

In [None]:
# @title


def plot_mean_height_with_eof_shade_cartopy(
    z_meters,            # xarray.DataArray [number, step, latitude, longitude] in meters
    pca,                 # fitted PCA on stacked, √cosφ-weighted anomalies
    window=(120,168),    # (start_hr, end_hr)
    eof_indices=(0,1),   # which EOFs to draw
    projection=None,     # a cartopy CRS; default is Lambert Conformal for Alaska
    add_states=True,
    add_borders=True,
    coast_res="50m",
    n_contours=20,
    cmap="RdBu_r",
    symmetric_shade=True,
    scale="1sigma"       # "none" or "1sigma"
):
    # coords
    steps = z_meters.step.values
    lat   = z_meters.latitude.values
    lon   = z_meters.longitude.values
    n_steps, n_lat, n_lon = len(steps), len(lat), len(lon)

    # normalize lon to [-180, 180] if needed
    if np.nanmax(lon) > 180:
        lon = ((lon + 180) % 360) - 180
        z_meters = z_meters.assign_coords(longitude=lon).sortby("longitude")

    steps_arr = np.asarray(steps)
    if np.issubdtype(steps_arr.dtype, np.timedelta64):
        step_hours = (steps_arr / np.timedelta64(1, "h")).astype(int)
    else:
        step_hours = steps_arr.astype(int)

    s0, s1 = window
    sel = (step_hours >= s0) & (step_hours <= s1)
    if not np.any(sel):
        raise ValueError(f"No steps in {window}. Available: {step_hours.tolist()}")

    # mean height contours over window
    z_mean = z_meters.sel(step=steps_arr[sel]).mean(dim=("number","step"))  # [lat, lon]

    # weights for unweighting EOFs
    wlat = np.sqrt(np.clip(np.cos(np.deg2rad(lat)), 1e-8, None))   # avoid 0 at poles
    w2d  = wlat[:, None]

    # lon/lat mesh
    Lon, Lat = np.meshgrid(z_meters.longitude.values, z_meters.latitude.values)

    # default projection for Alaska
    if projection is None:
        projection = ccrs.LambertConformal(central_longitude=-150, standard_parallels=(55, 65))

    figs = []
    for k in eof_indices:
        # component -> 3D -> window mean
        comp = pca.components_[k].reshape(n_steps, n_lat, n_lon)
        eof_mean_w = comp[sel].mean(axis=0)                      # weighted space

        # 1-sigma scaling (still weighted)
        if scale == "1sigma":
            eof_mean_w = eof_mean_w * np.sqrt(pca.explained_variance_[k])

        # unweight back to meters
        eof_map = eof_mean_w / w2d

        # symmetric color limits if requested
        vlim = np.nanmax(np.abs(eof_map)) if symmetric_shade else None

        # explained variance (%)
        var_pct = 100.0 * float(pca.explained_variance_ratio_[k])

        fig = plt.figure(figsize=(9, 5))
        ax = plt.axes(projection=projection)

        # set extent to your data bbox
        ax.set_extent([Lon.min(), Lon.max(), Lat.min(), Lat.max()], crs=ccrs.PlateCarree())

        # shaded EOF anomaly
        pm = ax.pcolormesh(
            Lon, Lat, eof_map, shading="auto", cmap=cmap,
            transform=ccrs.PlateCarree(),
            **({} if not vlim else {"vmin": -vlim, "vmax": vlim})
        )
        cb = plt.colorbar(pm, ax=ax, orientation="vertical", pad=0.02, label="EOF anomaly (m)")

        # mean height contours
        cs = ax.contour(
            Lon, Lat, z_mean, levels=n_contours, colors="k", linewidths=0.8,
            transform=ccrs.PlateCarree()
        )
        ax.clabel(cs, fmt="%.0f", fontsize=8)

        # cartographic layers
        ax.coastlines(resolution=coast_res, linewidth=0.8)
        if add_borders:
            ax.add_feature(cfeature.BORDERS.with_scale(coast_res), linewidth=0.6)
        if add_states:
            ax.add_feature(cfeature.STATES.with_scale(coast_res), linewidth=0.4)

        # gridlines with labels
        gl = ax.gridlines(draw_labels=True, linewidth=0.4, color="gray", alpha=0.4, linestyle="--")
        gl.top_labels = False
        gl.right_labels = False

        ax.set_title(f"Mean 500 mb (contours) + EOF{k+1} (shaded), {s0}–{s1} h — Var: {var_pct:.1f}%")
        plt.tight_layout()
        figs.append(fig)

    return tuple(figs)


def plot_pc_phase_space(
    pcs,
    labels=None,
    member_ids=None,
    pca=None,                   # REQUIRED if standardize=True
    cluster_centers=None,       # kmeans.cluster_centers_, optional
    title="Ensemble Phase Space (PC1 vs PC2)",
    annotate=True,
    standardize=True,           # <-- NEW: show scores in σ units
    figsize=(7,7)
):
    pcs = np.asarray(pcs)
    if pcs.shape[1] < 2:
        raise ValueError("pcs must have at least 2 components.")

    # Standardize: divide each PC column by sqrt(eigenvalue)
    if standardize:
        if pca is None or not hasattr(pca, "explained_variance_"):
            raise ValueError("pca with explained_variance_ is required when standardize=True.")
        sd = np.sqrt(pca.explained_variance_)[:pcs.shape[1]]
        S = pcs / sd
        if cluster_centers is not None:
            centers_plot = np.asarray(cluster_centers) / sd
        xlab = f"PC1 (σ, {100*pca.explained_variance_ratio_[0]:.1f}%)"
        ylab = f"PC2 (σ, {100*pca.explained_variance_ratio_[1]:.1f}%)"
    else:
        S = pcs
        centers_plot = np.asarray(cluster_centers) if cluster_centers is not None else None
        if pca is not None and hasattr(pca, "explained_variance_ratio_"):
            xlab = f"PC1 ({100*pca.explained_variance_ratio_[0]:.1f}%)"
            ylab = f"PC2 ({100*pca.explained_variance_ratio_[1]:.1f}%)"
        else:
            xlab, ylab = "PC1", "PC2"

    pc1, pc2 = S[:, 0], S[:, 1]
    if labels is None:
        labels = np.zeros(len(pc1), dtype=int)
    labels = np.asarray(labels)

    if member_ids is None:
        member_ids = np.arange(len(pc1))
    member_ids = np.asarray(member_ids)

    uniq = np.unique(labels)
    colors = plt.cm.tab10(np.linspace(0, 1, max(10, len(uniq))))

    fig, ax = plt.subplots(figsize=figsize)
    for i, lab in enumerate(uniq):
        sel = labels == lab
        ax.scatter(pc1[sel], pc2[sel], s=40, alpha=0.8, label=f"Cluster {lab} (n={sel.sum()})", color=colors[i])

    if annotate:
        for x, y, mid in zip(pc1, pc2, member_ids):
            ax.annotate(str(mid), (x, y), fontsize=8, xytext=(3, 3), textcoords="offset points")

    if cluster_centers is not None:
        ax.scatter(centers_plot[:, 0], centers_plot[:, 1],
                   marker="*", s=200, edgecolor="k", facecolor="none", label="Centroids")

    ax.set_xlabel(xlab)
    ax.set_ylabel(ylab)
    ax.set_title(title)
    ax.grid(True, linestyle="--", alpha=0.4)
    ax.legend(loc="best", frameon=True)
    plt.tight_layout()
    return fig

def pick_cluster_representatives(
    pcs,
    labels,
    member_ids,
    *,
    pca=None,                      # required if standardize=True
    standardize=True,              # match your phase-space plot (σ units)
    centers=None,                  # e.g., kmeans.cluster_centers_
    n_components=2,                # use PC1..PCn
):
    """
    Returns {cluster_label: representative_member_id} where the representative is
    the member closest to that cluster's centroid in PC space.

    pcs:         (n_members, n_pcs) PCA scores (from pca.transform or fit_transform)
    labels:      (n_members,) cluster labels (ints)
    member_ids:  (n_members,) identifiers (e.g., ENS numbers)
    pca:         fitted PCA (needed if standardize=True)
    centers:     (n_clusters, n_pcs) cluster centers in raw PC units (same as pcs)
    """
    pcs = np.asarray(pcs)
    labels = np.asarray(labels)
    member_ids = np.asarray(member_ids)

    if pcs.shape[1] < n_components:
        raise ValueError(f"pcs has only {pcs.shape[1]} components; need >= {n_components}")

    # Standardize to σ units if requested (divide each PC by sqrt(eigenvalue))
    if standardize:
        if pca is None or not hasattr(pca, "explained_variance_"):
            raise ValueError("Provide fitted pca when standardize=True.")
        sd = np.sqrt(pca.explained_variance_)[:n_components]
        S = pcs[:, :n_components] / sd
        if centers is not None:
            centers_plot = centers[:, :n_components] / sd
        else:
            centers_plot = None
    else:
        S = pcs[:, :n_components]
        centers_plot = centers[:, :n_components] if centers is not None else None

    reps = {}
    uniq = np.unique(labels)
    for c in uniq:
        mask = labels == c
        if not np.any(mask):
            continue
        Xc = S[mask]  # points in cluster c
        mids = member_ids[mask]

        # Use provided centers in the same (standardized) space if available; otherwise mean of cluster
        if centers_plot is not None:
            centroid = centers_plot[c]
        else:
            centroid = Xc.mean(axis=0)

        # Euclidean distance to centroid
        d2 = np.sum((Xc - centroid) ** 2, axis=1)
        i_local = int(np.argmin(d2))
        reps[int(c)] = int(mids[i_local])

    return reps

def plot_cluster_composites_500hpa(
    z_meters,                 # xarray.DataArray [number, step, latitude, longitude] in meters
    labels,                   # array-like of cluster labels per member (aligned to member_ids)
    member_ids=None,          # array-like of z_meters.number matching labels; if None, assume same order as z_meters.number
    window=(120, 168),        # (start_hr, end_hr) time window in hours
    clusters_to_show=None,    # list/array of cluster labels to show; if None, first 4 sorted unique
    projection=None,          # cartopy CRS; default Lambert Conformal for AK
    coast_res="50m",
    cmap="RdBu_r",
    units="m",                # "m" or "dam" for contour and anomaly units
    n_contours=20,
    symmetric_shade=True,
    title="Cluster composites: mean 500 mb height (contours) + anomaly vs ensemble mean (shaded)"
):
    """
    Builds a 5-panel figure:
      Panels 1-4: each cluster's mean height (contours) + anomaly vs ensemble mean (shaded)
      Panel 5: total ensemble mean (contours only)
    """

    # --- coords and basic prep
    steps = z_meters.step.values
    lat   = z_meters.latitude.values
    lon   = z_meters.longitude.values

    # normalize lon to [-180,180] if needed
    if np.nanmax(lon) > 180:
        lon = ((lon + 180) % 360) - 180
        z_meters = z_meters.assign_coords(longitude=lon).sortby("longitude")

    # step mask
    steps_arr = np.asarray(steps)
    if np.issubdtype(steps_arr.dtype, np.timedelta64):
        step_hours = (steps_arr / np.timedelta64(1, "h")).astype(int)
    else:
        step_hours = steps_arr.astype(int)
    s0, s1 = window
    sel = (step_hours >= s0) & (step_hours <= s1)
    if not np.any(sel):
        raise ValueError(f"No steps in window {window}; available: {step_hours.tolist()}")

    # member id alignment
    labels = np.asarray(labels)
    if member_ids is None:
        # assume labels align with z_meters.number order
        member_ids = z_meters["number"].values
        if labels.shape[0] != member_ids.shape[0]:
            raise ValueError("labels length does not match z_meters.number; provide member_ids explicitly.")
    else:
        member_ids = np.asarray(member_ids)
        if labels.shape[0] != member_ids.shape[0]:
            raise ValueError("labels and member_ids must be same length.")

    # limit the data to members listed in member_ids (in case you dropped some before)
    z_sub = z_meters.sel(number=member_ids)

    # ensemble mean over provided members + time window
    ens_mean = z_sub.sel(step=steps_arr[sel]).mean(dim=("number", "step"))  # [lat, lon]

    # clusters to show (up to 4)
    uniq = np.unique(labels)
    if clusters_to_show is None:
        clusters_to_show = uniq[:4]
    else:
        clusters_to_show = np.asarray(clusters_to_show)[:4]

    # compute cluster means and anomalies (cluster - ensemble) for color scaling
    cluster_fields = []
    diffs = []
    counts = []
    for c in clusters_to_show:
        mask = labels == c
        counts.append(int(mask.sum()))
        if counts[-1] == 0:
            # Empty cluster; fill with NaNs
            cl_mean = xr.full_like(ens_mean, np.nan)
        else:
            mids = member_ids[mask]
            cl_mean = z_sub.sel(number=mids, step=steps_arr[sel]).mean(dim=("number", "step"))
        cluster_fields.append(cl_mean)
        diffs.append((cl_mean - ens_mean))

    # consistent symmetric color limits across cluster panels
    if symmetric_shade:
        vmax = np.nanmax([np.nanmax(np.abs(d.values)) for d in diffs if d is not None])
        if not np.isfinite(vmax) or vmax == 0:
            vmax = None
    else:
        vmax = None

    # unit conversion
    unit_factor = 1.0 if units == "m" else 0.1  # meters->dam
    ens_mean_plot = ens_mean * unit_factor
    diffs_plot = [d * unit_factor for d in diffs]
    cluster_fields_plot = [c * unit_factor for c in cluster_fields]
    unit_label = "m" if units == "m" else "dam"

    # lon/lat mesh for pcolormesh/contour
    Lon, Lat = np.meshgrid(z_meters.longitude.values, z_meters.latitude.values)

    # projection default
    if projection is None:
        projection = ccrs.LambertConformal(central_longitude=-150, standard_parallels=(55, 65))

    # --- figure layout: 2 rows x 3 cols; last axis is empty
    fig = plt.figure(figsize=(14, 8))
    axes = []
    for i in range(6):
        ax = plt.subplot(2, 3, i+1, projection=projection) if i < 5 else plt.subplot(2, 3, i+1)
        axes.append(ax)

    # plot 4 cluster panels
    for i, (c, cl_mean, diff, n) in enumerate(zip(clusters_to_show, cluster_fields_plot, diffs_plot, counts)):
        ax = axes[i]
        ax.set_extent([Lon.min(), Lon.max(), Lat.min(), Lat.max()], crs=ccrs.PlateCarree())
        # shaded anomaly vs ensemble mean
        if vmax is not None:
            pm = ax.pcolormesh(Lon, Lat, diff, shading="auto", cmap=cmap,
                               vmin=-vmax, vmax=+vmax, transform=ccrs.PlateCarree())
        else:
            pm = ax.pcolormesh(Lon, Lat, diff, shading="auto", cmap=cmap, transform=ccrs.PlateCarree())
        # mean height contours (cluster mean)
        cs = ax.contour(Lon, Lat, cl_mean, levels=n_contours, colors="k", linewidths=0.8,
                        transform=ccrs.PlateCarree())
        ax.clabel(cs, fmt="%.0f" if units == "m" else "%.1f", fontsize=8)

        # cartographic layers
        ax.coastlines(resolution=coast_res, linewidth=0.8)
        ax.add_feature(cfeature.BORDERS.with_scale(coast_res), linewidth=0.6)
        ax.add_feature(cfeature.STATES.with_scale(coast_res), linewidth=0.4)

        ax.set_title(f"Cluster {c} (n={n}); shaded: Δ vs ens mean")

        # add a colorbar only once (right side)
        if i == 0:
            cax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
            cb = fig.colorbar(pm, cax=cax, label=f"Anomaly ({unit_label})")

    # panel 5: ensemble mean contours only
    ax5 = axes[4]
    ax5.set_extent([Lon.min(), Lon.max(), Lat.min(), Lat.max()], crs=ccrs.PlateCarree())
    cs5 = ax5.contour(Lon, Lat, ens_mean_plot, levels=n_contours, colors="k", linewidths=0.9,
                      transform=ccrs.PlateCarree())
    ax5.clabel(cs5, fmt="%.0f" if units == "m" else "%.1f", fontsize=8)
    ax5.coastlines(resolution=coast_res, linewidth=0.8)
    ax5.add_feature(cfeature.BORDERS.with_scale(coast_res), linewidth=0.6)
    ax5.add_feature(cfeature.STATES.with_scale(coast_res), linewidth=0.4)
    ax5.set_title(f"Ensemble mean {unit_label} ({s0}–{s1} h)")

    # panel 6: turn off
    axes[5].axis("off")

    fig.suptitle(title, y=0.98, fontsize=12)
    plt.tight_layout(rect=[0, 0, 0.9, 0.96])
    return fig



########################## Clustering & Plotting ##############################

# 1) Load
ds = xr.open_dataset(
    path,
    engine="cfgrib"
)

# Optional: normalize longitudes to -180..180 if dataset is 0..360
if ds.longitude.max() > 180:
    ds = ds.assign_coords(longitude=((ds.longitude + 180) % 360) - 180).sortby("longitude")

# 2) Subset time + bbox
lat_min, lat_max = 40, 75
lon_min, lon_max = -179, -125

subset = ds.sel(
    step=slice(np.timedelta64(tr[0], "h"), np.timedelta64(tr[1], "h")),
    latitude=slice(lat_max, lat_min),      # lat usually decreasing
    longitude=slice(lon_min, lon_max)
)

# 3) Convert to meters
z_meters = subset["gh"] / 9.80665   # dims: number, step, latitude, longitude

# 4) Latitude weights (sqrt(cos(lat))) over the latitude dimension
lat = z_meters["latitude"]
wlat = np.sqrt(np.clip(np.cos(np.deg2rad(lat)), 0, None))
weights = xr.DataArray(wlat, dims=["latitude"], coords={"latitude": lat})

# Apply weights (broadcast over step/number/longitude automatically)
weighted = z_meters * weights

# 5) Stack features AFTER weighting
features = weighted.stack(features=("step", "latitude", "longitude"))  # dims: number, features
X = features.values  # shape: [n_members, n_features]

# 6) Handle NaNs
# Identify members (rows) with any NaNs
bad_members_mask = np.any(np.isnan(X), axis=1)

if np.any(bad_members_mask):
    bad_ids = z_meters['number'].values[bad_members_mask]
    print(f"Dropping members with missing data: {bad_ids}")

    # Keep only rows without NaNs
    X = X[~bad_members_mask, :]
    member_ids = z_meters['number'].values[~bad_members_mask]
else:
    print("All members have complete data.")
    member_ids = z_meters['number'].values

# Weighted anomalies (subtract mean, don't scale by std)
X_anoms = StandardScaler(with_mean=True, with_std=False).fit_transform(X)

# 7) PCA to 2 components
pca = PCA(n_components=2)
pcs = pca.fit_transform(X_anoms)
print(f"Explained variance (2 PCs): {pca.explained_variance_ratio_.sum():.2%}")

# 8) K-means in PC space
n_clusters = 4
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
labels = kmeans.fit_predict(pcs)

# 9) Members per cluster
clusters = {i: [] for i in range(n_clusters)}
for m_idx, cid in enumerate(labels):
    clusters[cid].append(int(z_meters["number"][m_idx].values))

for cid, members in clusters.items():
    print(f"Cluster {cid}: {members}")

# Plotting EOFs
fig1, fig2 = plot_mean_height_with_eof_shade_cartopy(
    z_meters, pca,
    window=tr,
    eof_indices=(0, 1),
    projection=ccrs.NorthPolarStereo(central_longitude=-150, true_scale_latitude=60),
    scale="1sigma"            # 1-σ amplitude patterns (meters)
)

# Plotting phase space
member_ids = z_meters["number"].values    # or the filtered array if you dropped members

fig = plot_pc_phase_space(
    pcs=pcs,
    labels=labels,
    member_ids=member_ids,
    pca=pca,
    cluster_centers=getattr(kmeans, "cluster_centers_", None),
    title="ECMWF ENS 500 hPa — Phase Space (PC1 vs PC2)"
)

representatives = pick_cluster_representatives(
    pcs=pcs,
    labels=labels,
    member_ids=member_ids,
    pca=pca,
    standardize=True,                       # matches your standardized phase-space
    centers=getattr(kmeans, "cluster_centers_", None),
    n_components=2
)

print(representatives)

fig = plot_cluster_composites_500hpa(
    z_meters=z_meters,
    labels=labels,
    member_ids=member_ids,     # if you didn't drop any, you can omit this
    window=tr,
    units="m"                # contours + anomalies in decameters
)

# **4. Plot Representative Member**

In [None]:
import matplotlib as mpl
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.patches import Patch

def make_custom_cmaps(name, colors, bounds: list = None, N: int = None):
    if N is None:
        N = len(colors)
    linear_cmap = mcolors.LinearSegmentedColormap.from_list(name, colors)
    segment_cmap = mcolors.LinearSegmentedColormap.from_list(name + "2", colors, N=N)

    # When data is NaN, set color to transparent
    linear_cmap.set_bad("#ffffff00")
    segment_cmap.set_bad("#ffffff00")

    for cm in [linear_cmap, segment_cmap]:
        mpl.colormaps.register(cmap=cm, force=True)
        mpl.colormaps.register(cmap=cm.reversed(), force=True)

    if bounds is not None:
        return (
            mcolors.Normalize(bounds.min(), bounds.max()),
            mcolors.BoundaryNorm(bounds, linear_cmap.N),
        )


class NWSPrecipitation:
    """National Weather Service precipitation amount colorbar properties.

    Also known as Qualitative Precipitation Forecast/Estimate (QPF/QPE).
    """

    name = "nws.pcp"
    units = "in"
    variable = "Precipitation"
    colors = np.array(
        [
            "#ffffff",
            "#c7e9c0",
            "#a1d99b",
            "#74c476",
            "#31a353",
            "#006d2c",
            "#fffa8a",
            "#ffcc4f",
            "#fe8d3c",
            "#fc4e2a",
            "#d61a1c",
            "#ad0026",
            "#700026",
            "#3b0030",
            "#4c0073",
            "#ffdbff",
        ]
    )
    # NWS bounds in inches
    bounds = np.array(
        [0, 0.01, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 6, 8, 10, 15, 20, 30, 50]
    )
    norm, norm2 = make_custom_cmaps(name, colors, bounds)
    cmap = plt.get_cmap(name)
    cmap2 = plt.get_cmap(name + "2")
    kwargs = dict(cmap=cmap, norm=norm)
    kwargs2 = dict(cmap=cmap, norm=norm2)
    cbar_kwargs = dict(label=f"{variable} ({units})")
    cbar_kwargs2 = cbar_kwargs | dict(spacing="uniform", ticks=bounds)

def mm_to_in(mm):
  return mm * 0.0393701

def plot_msl_with_tp6h(
    ds,
    *,
    lat_min=50, lat_max=72,
    lon_min=-180, lon_max=-130,
    projection=None,
    coast_res="50m",
    tp_mm_max=4,            # colorbar cap for 6-h precip (in)
    n_contours=20,
    clip_negative=True       # clip tiny negative diffs to 0
):
    """
    Plots one map per 6-h accumulation step:
      - MSLP (hPa) as contours
      - 6-h precipitation (mm) as shaded

    Assumes ds has cumulative 'tp' (meters) at steps including the previous step.
    The first original step is NOT plotted; tp6 is aligned to the 'upper' steps only.
    """
    # Normalize longitudes to [-180, 180] if needed
    if float(ds.longitude.max()) > 180:
        ds = ds.assign_coords(longitude=((ds.longitude + 180) % 360) - 180).sortby("longitude")

    # Subset bbox (lat usually decreasing in GRIB)
    ds_box = ds.sel(latitude=slice(lat_max, lat_min), longitude=slice(lon_min, lon_max))

    # Convert units
    msl_hpa = ds_box["msl"] / 100.0         # Pa -> hPa
    tp_cum  = ds_box["tp"]                  # meters, cumulative since T0

    # 6-h accumulation (mm), aligned to the upper step (e.g., 168, 174, …)
    tp6 = tp_cum.diff("step", label="upper") * 1000.0
    tp6 = mm_to_in(tp6)  # in
    if clip_negative:
        tp6 = xr.where(tp6 < 0, 0, tp6)

    # Mesh for plotting
    Lon, Lat = np.meshgrid(ds_box.longitude.values, ds_box.latitude.values)

    # Projection
    if projection is None:
        projection = ccrs.LambertConformal(central_longitude=-150, standard_parallels=(55, 65))

    # Optional: check spacing ~6h
    if "timedelta64" in str(ds_box.step.dtype):
        dh = (ds_box.step.diff("step") / np.timedelta64(1, "h")).astype(float)
        if not np.allclose(dh, 6.0):
            print("⚠️  Step spacing is not uniformly 6 h; plotting interval accumulations as-is.")

    # Loop over *tp6* steps (skips the first original step by construction)
    for i, step_val in enumerate(tp6.step.values):
        z = msl_hpa.sel(step=step_val)
        p = tp6.sel(step=step_val)

        fig = plt.figure(figsize=(10, 6))
        ax = plt.axes(projection=projection)
        ax.set_extent([Lon.min(), Lon.max(), Lat.min(), Lat.max()], crs=ccrs.PlateCarree())

        # shaded 6-h precip (mm)
        #pm = ax.pcolormesh(Lon, Lat, p, cmap=cmap, shading="auto",
        #                   vmin=0, vmax=tp_mm_max, transform=ccrs.PlateCarree())
        kwargs = NWSPrecipitation.kwargs2
        cbar_kwargs = NWSPrecipitation.cbar_kwargs2

        pm = ax.pcolormesh(Lon, Lat, p, transform=ccrs.PlateCarree(), **kwargs)
        plt.colorbar(pm, ax=ax, pad=0.02, **cbar_kwargs)

        # MSLP contours (hPa)
        cs = ax.contour(Lon, Lat, z, levels=n_contours, colors="k", linewidths=0.8,
                        transform=ccrs.PlateCarree())
        ax.clabel(cs, fmt="%.0f", fontsize=8)

        # carto layers
        ax.coastlines(resolution=coast_res, linewidth=0.8)
        ax.add_feature(cfeature.BORDERS.with_scale(coast_res), linewidth=0.6)
        ax.add_feature(cfeature.STATES.with_scale(coast_res), linewidth=0.4)
        gl = ax.gridlines(draw_labels=True, linewidth=0.4, color="gray", alpha=0.4, linestyle="--")
        gl.top_labels = False
        gl.right_labels = False

        # Title
        # Title
        if "valid_time" in ds_box:
            vt = ds_box.valid_time.sel(step=step_val).values  # numpy.datetime64
            # ensure numpy datetime64 and format to hour
            tstr = np.datetime_as_string(np.asarray(vt, dtype="datetime64[ns]"), unit="h")
        else:
            # fall back to lead time
            if np.issubdtype(ds_box.step.dtype, np.timedelta64):
                lead_h = int(step_val / np.timedelta64(1, "h"))
                tstr = f"T+{lead_h:02d} h"
            else:
                tstr = str(step_val)
        ax.set_title(f"MSLP (hPa) contours + 6-h precip (mm) shaded — {tstr}")

        plt.tight_layout()
        plt.show()



#@markdown Which cluster do you want to see?
cluster = 1 #@param [0,1,2,3]
#@markdown Which variable do you want to plot?
wxvar = "MSLP_Precip" #@param ["MSLP_Precip"]

#@markdown What time frame do you want to see?
timeframe = "Day7" #@param ["Day5", "Day6", "Day7", "Day8", "Day9", "Day5-7", "Day6-8", "Day7-9", "Day5-8", "Day6-9"]
timeranges = {
    "Day5": (120, 144),
    "Day6": (144, 168),
    "Day7": (168, 192),
    "Day8": (192, 216),
    "Day9": (216, 240),
    "Day5-7": (120, 192),
    "Day6-8": (144, 216),
    "Day7-9": (168, 240),
    "Day5-8": (120, 216),
    "Day6-9": (144, 240),
}
tr = timeranges[timeframe]
vardict = {
    "MSLP_Precip": "MSLP",
}

if wxvar == "MSLP_Precip":
    #getting MSL
    path1 = download_ecmwf_ens(
        param=['msl', 'tp'],
        init_time=datetime.strptime(f"{init_date} {init_time}:00", "%Y-%m-%d %H:%M"),
        steps=list(range(tr[0]-6, tr[1], 6)),
        level_type="sfc",
        members=[representatives[cluster]],
        target_dir=".",
        source="ecmwf"
    )

#path1 = "ecmwf_ens_msl_tp_sfc_2025081100.grib2"
# 1) Load
ds = xr.open_dataset(
    path1,
    engine="cfgrib"
)
# ds is your Dataset with coords and variables: step, latitude, longitude, msl, tp
plot_msl_with_tp6h(
    ds,
    lat_min=50, lat_max=75,
    lon_min=-180, lon_max=-125,
    tp_mm_max=30,              # adjust colorbar cap as you like
    n_contours=20
)


