In [5]:
import pandas as pd
import numpy as np
from scipy.stats import gamma, norm
import matplotlib.pyplot as plt
import os

# ----------------------------------------
# Paths
# ----------------------------------------
file = r"D:\climate change\monthly_averages.xlsx"
outpath = r"D:\climate change\SPI_12_two_basePeriods.xlsx"
plot_dir = r"D:\climate change\SPI_Plots"

os.makedirs(plot_dir, exist_ok=True)

# ----------------------------------------
# Load Data (skip Lon/Lat rows)
# ----------------------------------------
df = pd.read_excel(file, skiprows=2)

df.rename(columns={df.columns[0]: "Date"}, inplace=True)
df["Date"] = pd.to_datetime(df["Date"])

# Convert from wide to long
df_long = df.melt(id_vars="Date", var_name="Station", value_name="Rainfall")
df_long = df_long.replace(0, np.nan)

df_long["Year"] = df_long["Date"].dt.year
df_long["Month"] = df_long["Date"].dt.month

# ----------------------------------------
# Base Period Split
# ----------------------------------------
df_b1 = df_long[(df_long['Year'] >= 1975) & (df_long['Year'] <= 1997)]
df_b2 = df_long[(df_long['Year'] >= 1998) & (df_long['Year'] <= 2020)]

# ----------------------------------------
# SPI Function
# ----------------------------------------
def compute_spi(data, window):
    out = []
    for st, d in data.groupby("Station"):
        d = d.sort_values("Date")
        d[f"Rain_{window}"] = d["Rainfall"].rolling(window).sum()
        d = d.dropna(subset=[f"Rain_{window}"])

        spi_vals = []
        for m in range(1, 13):
            dm = d[d["Month"] == m]
            if len(dm) > 8:
                vals = dm[f"Rain_{window}"].values
                shape, loc, scale = gamma.fit(vals, floc=0)
                cdf = gamma.cdf(dm[f"Rain_{window}"], shape, loc, scale)
                spi_vals.extend(norm.ppf(cdf))
            else:
                spi_vals.extend([np.nan]*len(dm))
        d[f"SPI_{window}"] = spi_vals
        out.append(d)
    return pd.concat(out)

# Calculate SPI-12
spi12_1 = compute_spi(df_b1, 12); spi12_1["Base"] = "1975-1997"
spi12_2 = compute_spi(df_b2, 12); spi12_2["Base"] = "1998-2020"

spi = pd.concat([spi12_1, spi12_2])
spi = spi[["Date","Year","Month","Station","Rainfall","SPI_12","Base"]]

spi.to_excel(outpath, index=False)
print(f"‚úÖ SPI saved at: {outpath}")

# ----------------------------------------------------------
# 1Ô∏è‚É£ SPI Time-Series Plot (per station)
# ----------------------------------------------------------
stations = spi["Station"].unique()

for st in stations:
    sub = spi[spi["Station"] == st]
    plt.figure(figsize=(12,4))
    plt.plot(sub["Date"], sub["SPI_12"], label=f"{st} SPI-12")
    plt.axhline(0, linestyle="--")
    plt.title(f"SPI-12 Trend - {st}")
    plt.xlabel("Year")
    plt.ylabel("SPI-12")
    plt.tight_layout()
    plt.savefig(f"{plot_dir}/{st}_SPI_12.png")
    plt.close()

print("üìà SPI time-series plots saved!")

# ----------------------------------------------------------
# 2Ô∏è‚É£ Annual Drought Heatmap
# ----------------------------------------------------------
pivot_spi = spi.pivot_table(index="Station", columns="Year", values="SPI_12")

plt.figure(figsize=(18,6))
plt.imshow(pivot_spi, aspect="auto")
plt.colorbar(label="SPI-12")
plt.xticks(range(len(pivot_spi.columns)), pivot_spi.columns, rotation=90)
plt.yticks(range(len(pivot_spi.index)), pivot_spi.index)
plt.title("Annual Drought Map (SPI-12)")
plt.xlabel("Year"); plt.ylabel("Station")
plt.tight_layout()
plt.savefig(r"D:\climate change\SPI_12_Drought_Heatmap.png")
plt.close()

print("üó∫Ô∏è Annual drought heatmap saved!")

# ----------------------------------------------------------
# 5Ô∏è‚É£ SPI vs Rainfall Anomaly Cross-Validation
# ----------------------------------------------------------
spi["Rain_Anomaly"] = spi.groupby("Station")["Rainfall"].transform(lambda x: x - x.mean())

plt.figure(figsize=(6,5))
plt.scatter(spi["Rain_Anomaly"], spi["SPI_12"])
plt.axhline(0, ls="--")
plt.axvline(0, ls="--")
plt.title("SPI-12 vs Rainfall Anomaly")
plt.xlabel("Rainfall Anomaly")
plt.ylabel("SPI-12")
plt.tight_layout()
plt.savefig(r"D:\climate change\SPI_vs_Rainfall_Anomaly.png")
plt.close()

print("‚úÖ SPI vs Rainfall Anomaly plot saved!")
print("\nüéØ All requested tasks completed successfully!")


‚úÖ SPI saved at: D:\climate change\SPI_12_two_basePeriods.xlsx
üìà SPI time-series plots saved!
üó∫Ô∏è Annual drought heatmap saved!
‚úÖ SPI vs Rainfall Anomaly plot saved!

üéØ All requested tasks completed successfully!
