In [None]:
# Import required packages
import numpy as np
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import os
import glob
import rioxarray
from rasterstats import zonal_stats
import regionmask

In [None]:
# === FILE PATHS ===
forecast_dir = "/home/u/jamie.towner/jordan_training/data/forecast_data/converted"
shp_path = "/home/u/jamie.towner/jordan_training/data/shapefile/geoBoundaries-JOR-ADM1.shp"
nc_files = glob.glob(os.path.join(forecast_dir, "*.nc"))
nc_files

In [None]:
# === READ SHAPEFILE ONCE ===
gdf = gpd.read_file(shp_path)
gdf

In [None]:
# === LOOP THROUGH FORECAST FILES ===
for file in nc_files:
    print(f"\nReading: {file}")
    ds = xr.open_dataset(file)

    print("Variables:", list(ds.data_vars))
    print("Dimensions:", ds.dims)

    # Choose variable (adjust if needed)
    if "tp" in ds.data_vars:
        var_name = "tp"
    elif "tprate_mm" in ds.data_vars:
        var_name = "tprate_mm"
    else:
        print("Precipitation variable not found.")
        continue

    # Check dimensions and select lead month
    if "forecastMonth" not in ds.dims or "number" not in ds.dims:
        print("Missing expected dimensions.")
        continue
        
    # Select lead month and compute ensemble mean
    lead_month_index = 2 # Remember that Python begins indexing at 0
    ensemble_mean = ds[var_name].isel(forecastMonth=lead_month).mean(dim="number")

    # Plot
    fig, ax = plt.subplots(figsize=(10, 8), subplot_kw={'projection': ccrs.PlateCarree()})
    p = ensemble_mean.plot(
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap='Blues',
        cbar_kwargs={'label': '(mm/month)'},
    )

    gdf.boundary.plot(ax=ax, edgecolor='black', linewidth=1)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.COASTLINE)
    ax.set_title(f"Ensemble Mean Precipitation - Lead Time {lead_month_index + 1}\n{os.path.basename(file)}")
    ax.set_extent([34, 40, 29, 34])  # Jordan extent

    plt.show()

In [None]:
# === USER INPUTS ===
lead_month_index = 3  # Define lead time month, 0 = lead month 1
var_name = "tp"
output_dir = "/home/u/jamie.towner/jordan_training/data/forecast_data/converted/temp"  # To store temporary files
os.makedirs(output_dir, exist_ok=True)

for file in nc_files:
    print(f"Processing {file}")
    ds = xr.open_dataset(file)

    if var_name not in ds:
        print(f"Variable '{var_name}' not found in {file}")
        continue

    if "forecastMonth" not in ds.dims or ds.sizes["forecastMonth"] <= lead_month_index:
        print(f"Unexpected forecastMonth dimension in {file}")
        continue

    # Take ensemble mean and select a lead month
    tp = ds[var_name].isel(forecastMonth=lead_month_index).mean(dim="number")

    # Assign CRS and save as GeoTIFF
    tp.rio.write_crs("EPSG:4326", inplace=True)
    raster_path = os.path.join(output_dir, "tmp_tp.tif")
    tp.rio.to_raster(raster_path)

    # Compute zonal stats: mean per province
    stats = zonal_stats(shp_path, raster_path, stats=["mean"], geojson_out=True, all_touched=True, nodata=np.nan)

    # Convert to GeoDataFrame
    gdf_stats = gpd.GeoDataFrame.from_features(stats)
    gdf_stats["mean_tp"] = gdf_stats["mean"]

    # Plot result
    fig, ax = plt.subplots(figsize=(10, 8))
    gdf_stats.plot(
        column="mean_tp",
        cmap="Blues",
        edgecolor="black",
        legend=True,
        legend_kwds={"label": "(mm/month)"},
        ax=ax
    )

    ax.set_title(f"Mean Precipitation - Lead Month {lead_month_index + 1}\n{os.path.basename(file)}")
    plt.axis("off")
    plt.show()

In [None]:
province_name = "Karak" # Choose a province of your choice. 
gdf_province = gdf[gdf["shapeName"] == province_name]

if gdf_province.empty:
    raise ValueError(f"Province '{province_name}' not found in shapefile.")

for file_path in nc_files:
    print(f"Processing file: {os.path.basename(file_path)}")
    ds = xr.open_dataset(file_path)
    tp = ds["tp"]  # rainfall variable (mm/month)
    tp.rio.set_spatial_dims(x_dim="longitude", y_dim="latitude", inplace=True)
    tp.rio.write_crs("EPSG:4326", inplace=True)

    lead_times = tp["forecastMonth"].values
    ensemble_members = tp["number"].values

    mean_tp = []

    for member in ensemble_members:
        member_means = []
        for lead in lead_times:
            tp_slice = tp.sel(number=member, forecastMonth=lead).squeeze()

            stats = zonal_stats(
                gdf_province,
                tp_slice.values,
                affine=tp_slice.rio.transform(),
                stats=["mean"],
                nodata=np.nan
            )
            member_means.append(stats[0]["mean"])
        mean_tp.append(member_means)

    # Plotting
    plt.figure(figsize=(10, 6))
    for i, member_means in enumerate(mean_tp):
        plt.plot(lead_times, member_means, label=f"Member {ensemble_members[i]}", alpha=0.7)

    plt.xlabel("Forecast Month (Lead Time)")
    plt.ylabel("Precipitation (mm/month)")
    plt.title(f"Precipitation Forecast for {province_name}\nFile: {os.path.basename(file_path)}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:
province_name = "Karak" # Choose a province of your choice
gdf_province = gdf[gdf["shapeName"] == province_name]

if gdf_province.empty:
    raise ValueError(f"Province '{province_name}' not found in shapefile.")

for file_path in nc_files:
    print(f"Processing file: {os.path.basename(file_path)}")
    ds = xr.open_dataset(file_path)
    tp = ds["tp"]  # rainfall variable (mm/month)
    tp.rio.set_spatial_dims(x_dim="longitude", y_dim="latitude", inplace=True)
    tp.rio.write_crs("EPSG:4326", inplace=True)

    lead_times = tp["forecastMonth"].values
    ensemble_members = tp["number"].values

    # Collect mean rainfall per ensemble member per lead time
    data_for_boxplot = []

    for lead in lead_times:
        # Collect all members' means for this lead time
        lead_values = []
        for member in ensemble_members:
            tp_slice = tp.sel(number=member, forecastMonth=lead).squeeze()

            stats = zonal_stats(
                gdf_province,
                tp_slice.values,
                affine=tp_slice.rio.transform(),
                stats=["mean"],
                nodata=np.nan
            )
            lead_values.append(stats[0]["mean"])
        data_for_boxplot.append(lead_values)

    # Plot boxplot
    plt.figure(figsize=(10, 6))
    plt.boxplot(data_for_boxplot, positions=lead_times, widths=0.6, patch_artist=True)

    plt.xlabel("Forecast Month (Lead Time)")
    plt.ylabel("Precipitation (mm/month)")
    plt.title(f"Precipitation Forecast for {province_name}\nFile: {os.path.basename(file_path)}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
# Threshold (mm/month) - set your desired value here
rainfall_threshold = 10.0  

# === Read province shapefile and select province ===
gdf = gpd.read_file(shp_path)
province_name = "Karak" # Choose a province of your choice
gdf_province = gdf[gdf["shapeName"] == province_name]

if gdf_province.empty:
    raise ValueError(f"Province '{province_name}' not found in shapefile.")

for file_path in nc_files:
    print(f"Processing file: {os.path.basename(file_path)}")
    ds = xr.open_dataset(file_path)
    tp = ds["tp"]  # rainfall variable (mm/month)
    tp.rio.set_spatial_dims(x_dim="longitude", y_dim="latitude", inplace=True)
    tp.rio.write_crs("EPSG:4326", inplace=True)

    lead_times = tp["forecastMonth"].values
    ensemble_members = tp["number"].values

    exceedance_probs = []

    for lead in lead_times:
        exceed_count = 0
        total_members = len(ensemble_members)

        for member in ensemble_members:
            tp_slice = tp.sel(number=member, forecastMonth=lead).squeeze()

            # Calculate spatial mean over the province
            stats = zonal_stats(
                gdf_province,
                tp_slice.values,
                affine=tp_slice.rio.transform(),
                stats=["mean"],
                nodata=np.nan
            )
            mean_val = stats[0]["mean"]

            if mean_val is not None and mean_val < rainfall_threshold:
                exceed_count += 1
        
        prob = exceed_count / total_members
        exceedance_probs.append(prob)

    # Plot probability of exceedance
    plt.figure(figsize=(10, 6))
    plt.plot(lead_times, exceedance_probs, marker='o')
    plt.xlabel("Forecast Month (Lead Time)")
    plt.ylabel(f"Probability of Exceeding {rainfall_threshold} mm/month")
    plt.title(f"Probability of Precipitation Exceedance in {province_name}\n{os.path.basename(file_path)}")
    plt.ylim(0, 1)
    plt.grid(True)
    plt.tight_layout()
    plt.show()