In [None]:
# -*- coding: utf-8 -*-
import os
import math
import string
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from statsmodels.tsa.seasonal import STL

# --- Mann–Kendall (Hamed–Rao) with fallback ---
try:
    import pymannkendall as mk
    def mk_hr_test(y):
        res = mk.hamed_rao_modification_test(y)
        return res.slope, res.p
    MK_BACKEND = "pymannkendall (Hamed–Rao)"
except Exception:
    from scipy.stats import kendalltau

    def theil_sen_slope(y):
        y = np.asarray(y, dtype=float)
        n = len(y)
        slopes = []
        for i in range(n - 1):
            dy = y[i+1:] - y[i]
            dx = np.arange(i+1, n) - i
            slopes.extend(list(dy / dx))
        return np.nanmedian(slopes) if slopes else np.nan

    def mk_hr_test(y):
        tau, p = kendalltau(np.arange(len(y)), y, nan_policy="omit")
        slope = theil_sen_slope(y)
        return slope, p

    MK_BACKEND = "fallback (kendalltau + Theil–Sen)"

# ---------------- CONFIG ----------------
DATA_FILE = "Transfomed indices.xlsx"
SEASONAL_WINDOW = 19
OUT_DIR = "plots_rain"

# Ordered list: (column name, unit for ylabel)
RAIN_INDICS = [
    ("R10mm",   "days"),   # a)
    ("R20mm",   "days"),   # b)
    ("R30mm",   "days"),   # c)
    ("Rx1day",  "mm"),     # d)
    ("Rx5day",  "mm"),     # e)
    ("PRCPTOT", "mm"),     # f)
    ("CDD",     "days"),   # g)
    ("CWD",     "days"),   # h)
]

os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUT_DIR, "stations"), exist_ok=True)
os.makedirs(os.path.join(OUT_DIR, "bangladesh"), exist_ok=True)

# --------------- HELPERS ----------------
def read_station_sheet(xls_path, sheet_name):
    df = pd.read_excel(xls_path, sheet_name=sheet_name)
    ycol = next((c for c in df.columns if str(c).strip().lower() == "year"), None)
    mcol = next((c for c in df.columns if str(c).strip().lower() == "month"), None)
    if ycol is None or mcol is None:
        raise ValueError(f"{sheet_name}: could not find Year/Month columns")

    dt = pd.to_datetime(dict(year=df[ycol], month=df[mcol], day=1), errors="coerce")
    df = df.set_index(dt).sort_index()
    df.index.name = "date"

    # Coerce to numeric
    for name, _unit in RAIN_INDICS:
        if name in df.columns:
            df[name] = pd.to_numeric(df[name], errors="coerce")
    return df

def stl_trend(series, seasonal=SEASONAL_WINDOW):
    ser = pd.Series(series).dropna()
    if ser.empty or ser.shape[0] < max(7, seasonal*2 + 1):
        return pd.Series(index=series.index, dtype=float)
    stl = STL(ser, seasonal=seasonal, robust=True)
    res = stl.fit()
    return res.trend.reindex(series.index)

def fit_line_from_slope(trend, slope):
    t = trend.dropna()
    if t.empty or np.isnan(slope):
        return pd.Series(index=trend.index, dtype=float)
    x = np.arange(len(t), dtype=float)
    intercept = t.iloc[0]
    yhat = x * slope + intercept
    return pd.Series(yhat, index=t.index).reindex(trend.index)

def format_p(p):
    return "< 0.05" if (p is not None and p < 0.05) else f"= {p:.3f}" if p==p else "= NA"

def panel_label(i: int) -> str:
    """Return a), b), ..., z), aa), ab)... for 0-based index i."""
    letters = string.ascii_lowercase
    i += 1
    s = ""
    while i > 0:
        i, rem = divmod(i - 1, 26)
        s = letters[rem] + s
    return f"{s})"

def make_8panel_figure(title_name, df, save_base=None):
    # keep only the 8 requested indicators in order
    avail = [(ind, unit) for ind, unit in RAIN_INDICS if ind in df.columns]
    if not avail:
        return None

    n = len(avail)
    n_cols = 4
    n_rows = 2
    fig_w = 4.8 * n_cols
    fig_h = 3.6 * n_rows

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h), sharex=False)
    axes = np.atleast_1d(axes).ravel()

    for i, (ind, unit) in enumerate(avail):
        ax = axes[i]
        trend = stl_trend(df[ind])
        y = trend.dropna().values
        slope, p = (np.nan, np.nan) if len(y) < 5 else mk_hr_test(y)
        yhat = fit_line_from_slope(trend, slope)

        ax.plot(trend.index, trend, color='blue', linewidth=1.8)
        ax.plot(yhat.index, yhat, color='red', linewidth=1.6)

        label = f"slope = {slope:.4g}, p {format_p(p)}"
        ax.text(0.98, 0.92, label, transform=ax.transAxes, ha="right", va="top",
                fontsize=9, bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

        ax.set_title(f"{panel_label(i)} {ind} ({unit})", loc="left", fontsize=10, pad=6)
        ax.set_ylabel(f"{ind} ({unit})")

    # turn off any unused panels (in case some columns missing)
    for j in range(i+1, len(axes)):
        axes[j].axis('off')

    fig.suptitle(f"{title_name} — STL trend & Mann–Kendall ({MK_BACKEND})", fontsize=14, y=0.995)
    fig.tight_layout(rect=[0, 0.00, 1, 0.97])

    if save_base:
        png = f"{save_base}.png"
        pdf = f"{save_base}.pdf"
        fig.savefig(png, dpi=200, bbox_inches="tight")
        with PdfPages(pdf) as pdfw:
            pdfw.savefig(fig, bbox_inches="tight")
        print(f"Saved: {png}\nSaved: {pdf}")
    return fig

def aggregate_bangladesh(station_frames):
    # union of all dates
    all_idx = None
    for df in station_frames.values():
        all_idx = df.index if all_idx is None else all_idx.union(df.index)
    all_idx = all_idx.sort_values()

    nat = pd.DataFrame(index=all_idx)
    for ind, _unit in RAIN_INDICS:
        cols = []
        for st, df in station_frames.items():
            if ind in df.columns:
                cols.append(df[ind].reindex(all_idx))
        if cols:
            nat[ind] = pd.concat(cols, axis=1).mean(axis=1, skipna=True)
    return nat

# ------------------ MAIN ------------------
def main():
    print(f"Reading: {DATA_FILE}")
    xls = pd.ExcelFile(DATA_FILE)
    sheets = xls.sheet_names
    print(f"Stations: {', '.join(sheets)}")

    stations = {}
    for sh in sheets:
        try:
            stations[sh] = read_station_sheet(DATA_FILE, sh)
        except Exception as e:
            print(f"[WARN] Skipping '{sh}': {e}")

    # Per-station 8-panel figures
    for st, df in stations.items():
        base = os.path.join(OUT_DIR, "stations", f"{st}_rain_8panels")
        make_8panel_figure(st, df, save_base=base)

    # Bangladesh composite (mean across stations)
    bd = aggregate_bangladesh(stations)
    bd_base = os.path.join(OUT_DIR, "bangladesh", "Bangladesh_rain_8panels")
    make_8panel_figure("Bangladesh (mean across stations)", bd, save_base=bd_base)

    # Save composite time series too
    out_csv = os.path.join(OUT_DIR, "bangladesh", "Bangladesh_rain_timeseries.csv")
    bd.to_csv(out_csv)
    print(f"Saved national composite data: {out_csv}")

if __name__ == "__main__":
    main()
