In [None]:
"""
Heatwave Detection and Indicators Calculation
---------------------------------------------

This script detects and quantifies changes in heatwaves between a historical period 
(2015–2030) and a future period (2080–2100) using gridded daily climate data. 

Methodology (fully consistent with accompanying documentation):
1. Regions: 
   - Northern Hemisphere (lat > 23.5°N): May–September.
   - Southern Hemisphere (lat < -23.5°S): November–March.
   - Tropics/Monsoon (−23.5° ≤ lat ≤ 23.5°): all months.
2. Baseline thresholds: For each location (lon, lat), the 95th percentile of 
   daily maximum (tasmax) and minimum (tasmin) temperatures is computed from the 
   historical period data only.
3. Heatwave definition: A heatwave is ≥3 consecutive days where both tasmax and 
   tasmin exceed their respective thresholds.
4. Indicators:
   - Percentage change in heatwave days between future and historical periods.
   - Change in the highest daily maximum temperature recorded during heatwaves (HWTmax).
5. Outputs: CSV files per region containing lon, lat, % change in heatwave days, 
   and change in HWTmax.

Inputs are the 75th-percentile daily climatology CSV files derived from 
multi-model ensembles.
"""
import pandas as pd

# ---- Load ----
dtype_dict = {
    'lon': float, 'lat': float,
    'tas': float, 'tasmax': float, 'tasmin': float, 'pr': float,
    'day_of_year': int
}
hist = pd.read_csv('75th_percentile_values-daily-2015-2030_HW_drought.csv', dtype=dtype_dict)
fut  = pd.read_csv('75th_percentile_values-daily-2080-2100_HW_drought.csv', dtype=dtype_dict)

# ---- Preprocess: normalize longitudes to [-180, 180] (as stated) ----
for df in (hist, fut):
    df['lon'] = ((df['lon'] + 180) % 360) - 180

# ---- Map DOY -> calendar month using a non-leap year (no extra assumptions) ----
def doy_to_month_nonleap(series: pd.Series) -> pd.Series:
    base = pd.Timestamp('2021-01-01')  # non-leap year reference
    s = series.astype(int)
    months = (base + pd.to_timedelta((s.clip(upper=365) - 1), unit='D')).month
    months = months.where(s != 366, other=12)  # if 366 appears, treat as Dec
    return months

hist['month'] = doy_to_month_nonleap(hist['day_of_year'])
fut['month']  = doy_to_month_nonleap(fut['day_of_year'])

# ---- Region masks (exactly as specified) ----
def is_nh(lat): return lat > 23.5
def is_sh(lat): return lat < -23.5
def is_trop(lat): return (lat >= -23.5) & (lat <= 23.5)

NH_MONTHS = {5,6,7,8,9}                 # May–Sep
SH_MONTHS = {11,12,1,2,3}               # Nov–Mar
TROP_MONTHS = set(range(1,13))          # all year

def subset_region(df, region):
    if region == 'NH':
        return df[is_nh(df['lat']) & df['month'].isin(NH_MONTHS)].copy()
    if region == 'SH':
        return df.isin({'month': SH_MONTHS})  # placeholder; fixed below

# Proper SH filtering without assumptions:
def subset_region(df, region):
    if region == 'NH':
        return df[is_nh(df['lat']) & df['month'].isin(NH_MONTHS)].copy()
    elif region == 'SH':
        return df[is_sh(df['lat']) & df['month'].isin(SH_MONTHS)].copy()
    elif region == 'Tropics':
        return df[is_trop(df['lat']) & df['month'].isin(TROP_MONTHS)].copy()
    else:
        raise ValueError("Region must be one of: 'NH', 'SH', 'Tropics'")

# ---- Thresholds (from historical, per region) ----
def calc_thresholds(hist_region_df):
    return (
        hist_region_df.groupby(['lon','lat'])
        .agg(tasmax_95=('tasmax', lambda x: x.quantile(0.95)),
             tasmin_95=('tasmin', lambda x: x.quantile(0.95)))
        .reset_index()
    )

# ---- Heatwave detection (≥3 consecutive days, both tasmax & tasmin > thresholds) ----
def detect_heatwaves(df_region, thresholds):
    df = df_region.merge(thresholds, on=['lon','lat'], how='inner')
    df = df.sort_values(['lon','lat','day_of_year'])
    df['exceeds_tasmax'] = df['tasmax'] > df['tasmax_95']
    df['exceeds_tasmin'] = df['tasmin'] > df['tasmin_95']
    df['heatwave_day'] = False
    for (lon, lat), g in df.groupby(['lon','lat'], sort=False):
        rmax = g['exceeds_tasmax'].rolling(3, min_periods=3).sum()
        rmin = g['exceeds_tasmin'].rolling(3, min_periods=3).sum()
        hw = (rmax >= 3) & (rmin >= 3)
        df.loc[g.index, 'heatwave_day'] = hw
    return df

# ---- Indicators ----
def indicators(hist_hw, fut_hw):
    # % heatwave days
    hist_tot = hist_hw.groupby(['lon','lat'])['day_of_year'].count()
    hist_hw_days = hist_hw.groupby(['lon','lat'])['heatwave_day'].sum()
    hist_pct = (hist_hw_days / hist_tot).fillna(0)

    fut_tot = fut_hw.groupby(['lon','lat'])['day_of_year'].count()
    fut_hw_days = fut_hw.groupby(['lon','lat'])['heatwave_day'].sum()
    fut_pct = (fut_hw_days / fut_tot).fillna(0)

    pct_change = (fut_pct - hist_pct).rename('heatwave_percentage_change')

    # HWTmax change: max tasmax on heatwave days (future - historical)
    hist_max = hist_hw.loc[hist_hw['heatwave_day']].groupby(['lon','lat'])['tasmax'].max().rename('hist_HWTmax')
    fut_max  = fut_hw.loc[fut_hw['heatwave_day']].groupby(['lon','lat'])['tasmax'].max().rename('fut_HWTmax')
    hwt_change = (fut_max - hist_max).rename('HWTmax_change').fillna(0)

    out = pd.concat([pct_change, hwt_change], axis=1).reset_index()
    out['heatwave_percentage_change'] = out['heatwave_percentage_change'].fillna(0)
    return out

# ---- Run exactly for the three regions and save separate CSVs ----
for region in ['NH', 'SH', 'Tropics']:
    print(f"\n=== {region} ===")
    hist_r = subset_region(hist, region)
    fut_r  = subset_region(fut, region)
    if hist_r.empty or fut_r.empty:
        print(f"No data for {region}; skipping.")
        continue
    thr = calc_thresholds(hist_r)
    hist_hw = detect_heatwaves(hist_r, thr)
    fut_hw  = detect_heatwaves(fut_r, thr)
    res = indicators(hist_hw, fut_hw)
    res.to_csv(f'heatwave_analysis_results_{region}.csv', index=False)
    print(f"Saved heatwave_analysis_results_{region}.csv')