In [1]:
import os
import numpy as np
import pandas as pd # type: ignore
import xarray as xr # type: ignore
from pathlib import Path
import pickle
from tqdm import tqdm
import multiprocessing
import dask
from dask.distributed import Client, LocalCluster
from dask.diagnostics import ProgressBar
import scipy

# Get number of CPU cores
n_cores = multiprocessing.cpu_count()
print(f"Detected {n_cores} CPU cores")

# put this on your fastest SSD/NVMe with plenty of free space
SPILL_DIR = r"C:\dask-spill"
os.makedirs(SPILL_DIR, exist_ok=True)

dask.config.set({
    "distributed.worker.local-directory": SPILL_DIR,
    "distributed.worker.memory.target": 0.70,    # start spilling earlier
    "distributed.worker.memory.spill": 0.80,
    "distributed.worker.memory.pause": 0.88,
    "distributed.worker.memory.terminate": 0.98,
    "distributed.comm.compression": "auto",      # or None if codec issues reappear
})

cluster = LocalCluster(
    n_workers=6,
    threads_per_worker=4,       # 14900K has plenty; adjust if you want
    processes=False,             # threads-only on Windows
    memory_limit="auto",         # leave headroom for OS/Jupyter
)
client = Client(cluster)

print("\nDask cluster ready for parallel processing!")

# HDF5 locking issues (Rockfish/HPC): set once per process if needed
# os.environ.setdefault("HDF5_USE_FILE_LOCKING", "FALSE")
# os.environ["NETCDF_HDF5_FILE_LOCKING"] = "FALSE"
#xr.set_options(display_style="text") # Display xarray in plain text

# --path--
# era5land_quarterly_data = "/vast/bzaitch1/trp_climate_model_data/era5land_1970_2024_qtrmean" # Location on Rockfish
# era5land_quarterly_data_directory = Path("/Users/kris/Local Job Backup/test set/") # Location on local machine
#output_directory  = "/vast/bzaitch1/trp_climate_model_data/out"
#os.makedirs(output_directory, exist_ok=True)
QUARTER_DIR = Path(rf"E:\backup\trp_climate_model_data\era5land_1970_2024_qtrmean")

# Hourly filename pattern:
# One file per month:    YYYY_MM.nc
PATTERN_PER_MONTH = "{y}_{m}.nc"   # for one file per month

# cvar_aliases = {
#     'tp' : 'total precipitation', 
#     'e' : 'evaporation',
# }

evar_aliases = {
    'y' : 'Log of real GDP', 
    'Dp' : 'Inflation rate (Based on CPI)',
    'eq' : 'Log of real equity prices',
    'ep' : 'Log of real exchange rate',
    'r' : 'Short-term interest rate',
    'lr' : 'Long-term interest rate',
}

Detected 32 CPU cores

Dask cluster ready for parallel processing!


## Load Climate Data for Drought Analysis

We'll use precipitation and evaporation data to calculate drought indices.

In [2]:
# --- fix longitude from 0..360 to -180..180 and sort ---
def fix_long(da):
    if float(da.longitude.max()) > 180:
        lon_new = ((da.longitude + 180) % 360) - 180
        da = da.assign_coords(longitude=lon_new).sortby("longitude")
    
    return da

# --- make latitude ascending (optional but helpful for weights) ---
def sort_lat(da):
    if da.latitude[0] > da.latitude[-1]:
        da = da.sortby("latitude")
    
    return da

In [None]:
# Load quarterly climate data
test_data_set = QUARTER_DIR / '*.nc'
ds = xr.open_mfdataset(str(test_data_set), engine='netcdf4')

# Extract precipitation and evaporation
da_precip = ds['tp']  # Total precipitation
da_evap = ds['e']     # Evaporation

# da_precip = fix_long(da_precip)
# da_evap = fix_long(da_evap) 

print("Loaded climate data:")
print(f"  Precipitation shape: {da_precip.shape}")
print(f"  Evaporation shape: {da_evap.shape}")
print(f"  Time range: {da_precip.valid_time.values[0]} to {da_precip.valid_time.values[-1]}")

ds.close()

Loaded climate data:
  Precipitation shape: (160, 1801, 3600)
  Evaporation shape: (160, 1801, 3600)
  Time range: 1980-03-31T00:00:00.000000000 to 2019-12-31T00:00:00.000000000


## Drought Detection Functions

We'll implement multiple drought indices:
1. **Precipitation Anomaly** - Detrended precipitation deviations from normal
2. **Water Balance Anomaly** - Precipitation minus evaporation anomalies
3. **Drought Intensity** - Normalized drought severity index

In [6]:
def detrend_dim(da, dim, deg=1):
    """Remove linear trend from data array."""
    p = da.polyfit(dim=dim, deg=deg)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)
    return da - fit

def calculate_drought_anomaly(da, var_name='precipitation'):
    """
    Calculate drought anomaly from climate variable.
    
    Parameters:
    -----------
    da : xr.DataArray
        Climate variable (e.g., precipitation, evaporation)
    var_name : str
        Name of the variable for metadata
    
    Returns:
    --------
    xr.DataArray : Standardized anomalies (negative = drought conditions)
    """
    print(f"Calculating drought anomaly for {var_name}...")
    
    # Save metadata
    original_attrs = da.attrs.copy()
    
    # Calculate seasonal climatology (by quarter)
    climatology = da.groupby("valid_time.quarter").mean(dim='valid_time')
    
    # Remove seasonal cycle to get anomalies
    anomalies = da.groupby("valid_time.quarter") - climatology
    
    # Detrend the anomalies
    detrended = detrend_dim(anomalies, 'valid_time')
    
    # Standardize (convert to z-scores)
    # Calculate standard deviation by quarter
    std = da.groupby("valid_time.quarter").std(dim='valid_time')
    standardized = detrended.groupby("valid_time.quarter") / std
    
    # Restore metadata
    standardized.attrs = original_attrs
    standardized.attrs['long_name'] = f'{var_name} drought anomaly (standardized)'
    standardized.attrs['interpretation'] = 'Negative values indicate drought conditions'
    standardized.name = f'{var_name}_drought_anomaly'
    
    print(f"  ✓ Calculated standardized anomalies")
    return standardized

def calculate_water_balance_drought(precip, evap):
    """
    Calculate drought index based on water balance (P - E).
    
    Parameters:
    -----------
    precip : xr.DataArray
        Precipitation data
    evap : xr.DataArray
        Evaporation data
    
    Returns:
    --------
    xr.DataArray : Water balance drought index (negative = drought)
    """
    print("Calculating water balance drought index...")
    
    # Convert evaporation to positive values (it's stored as negative)
    #evap_positive = -evap if evap.mean() < 0 and print('Convert evaporation to positive values (if it is stored as negative)') else evap
    
    # Calculate water balance
    water_balance = precip - evap
    
    # Calculate anomaly
    climatology = water_balance.groupby("valid_time.quarter").mean(dim='valid_time')
    anomalies = water_balance.groupby("valid_time.quarter") - climatology
    
    # Detrend
    detrended = detrend_dim(anomalies, 'valid_time')
    
    # Standardize
    std = water_balance.groupby("valid_time.quarter").std(dim='valid_time')
    standardized = detrended.groupby("valid_time.quarter") / std
    
    # Set metadata
    standardized.attrs['long_name'] = 'Water balance drought index (P-E)'
    standardized.attrs['interpretation'] = 'Negative values indicate drought (water deficit)'
    standardized.name = 'water_balance_drought'
    
    print("  ✓ Calculated water balance drought index")
    return standardized

def classify_drought_severity(drought_index):
    """
    Classify drought severity based on standardized index.
    
    Classification:
    - No drought: > -0.5
    - Mild drought: -0.5 to -1.0
    - Moderate drought: -1.0 to -1.5
    - Severe drought: -1.5 to -2.0
    - Extreme drought: < -2.0
    
    Parameters:
    -----------
    drought_index : xr.DataArray
        Standardized drought index
    
    Returns:
    --------
    xr.DataArray : Drought severity classification (0-4)
    """
    severity = xr.zeros_like(drought_index)
    severity = xr.where(drought_index <= -0.5, 1, severity)  # Mild
    severity = xr.where(drought_index <= -1.0, 2, severity)  # Moderate
    severity = xr.where(drought_index <= -1.5, 3, severity)  # Severe
    severity = xr.where(drought_index <= -2.0, 4, severity)  # Extreme
    
    severity.attrs['long_name'] = 'Drought severity classification'
    severity.attrs['classes'] = '0=None, 1=Mild, 2=Moderate, 3=Severe, 4=Extreme'
    severity.name = 'drought_severity'
    
    return severity

print("Drought detection functions loaded!")

Drought detection functions loaded!


## Calculate Drought Indices

Compute multiple drought indicators for correlation analysis.

In [7]:
# Calculate drought indices
print("="*60)
print("CALCULATING DROUGHT INDICES")
print("="*60)

# 1. Precipitation drought anomaly
drought_precip = calculate_drought_anomaly(da_precip, var_name='precipitation')

# 2. Water balance drought (P - E)
drought_water_balance = calculate_water_balance_drought(da_precip, da_evap)

print("\n" + "="*60)
print("DROUGHT INDICES CALCULATED")
print("="*60)
print(f"Precipitation drought: {drought_precip.shape}")
print(f"Water balance drought: {drought_water_balance.shape}")

# Store in a dictionary for easy access
drought_indices = {
    #'precipitation': drought_precip,
    'water_balance': drought_water_balance
}

drought_index_names = {
    'precipitation': 'Precipitation Deficit Index',
    'water_balance': 'Water Balance Deficit Index (P-E)'
}

CALCULATING DROUGHT INDICES
Calculating drought anomaly for precipitation...
  ✓ Calculated standardized anomalies
Calculating water balance drought index...
  ✓ Calculated water balance drought index

DROUGHT INDICES CALCULATED
Precipitation drought: (160, 1801, 3600)
Water balance drought: (160, 1801, 3600)


In [None]:
# Quick check: plot detrended anomalies (spatial average time series)
# Use subset for faster plotting
import matplotlib.pyplot as plt

# Downsample spatially for quick check (every 20th point)
det_da_sample = drought_water_balance.isel(latitude=slice(None, None, 20), longitude=slice(None, None, 20))

evar_label = drought_index_names.get('water_balance')

# Compute and plot
det_da_sample.mean(dim=['latitude','longitude']).plot(figsize=(12, 4))
plt.title(f'Detrended {evar_label} Anomalies (Global Average - Downsampled)')
plt.xlabel(f'Time') #(°C)
plt.ylabel(f'{evar_label} Anomaly (m)') #(°C)
plt.grid(True)
plt.show()



## Correlation with Economic Data

Correlate drought indices with economic indicators for all countries.

In [33]:
def correlate_drought_with_econ(drought_index, econ_var_series):
    """
    Correlate drought index at every grid point with an economic time series.
    
    Parameters:
    -----------
    drought_index : xr.DataArray
        Drought index with dimensions (valid_time, latitude, longitude)
    econ_var_series : pd.Series
        Economic indicator time series with datetime index
    
    Returns:
    --------
    xr.DataArray : Correlation coefficients at each grid point
    """
    
    # Find overlapping time period
    climate_times = pd.DatetimeIndex(drought_index.valid_time.values)
    econ_times = econ_var_series.index
    
    # Get intersection of times
    common_times = climate_times.intersection(econ_times)
    
    if len(common_times) == 0:
        raise ValueError("No overlapping time periods!")
    
    print(f"    Found {len(common_times)} overlapping time periods")
    
    # Subset both datasets to common times
    drought_subset = drought_index.sel(valid_time=common_times)
    econ_subset = econ_var_series.loc[common_times]
    
    # Remove any NaN values from economic series
    valid_mask = econ_subset.notna()
    if not valid_mask.all():
        print(f"    Removing {(~valid_mask).sum()} NaN values from economic series")
        drought_subset = drought_subset.sel(valid_time=valid_mask.values)
        econ_subset = econ_subset[valid_mask]
    
    # Convert to xarray DataArray
    econ_da = xr.DataArray(
        econ_subset.values,
        coords={'valid_time': drought_subset.valid_time},
        dims=['valid_time']
    )
    
    # Calculate correlation
    correlation_map = xr.corr(drought_subset, econ_da, dim='valid_time')
    
    return correlation_map

print("Drought correlation function loaded!")

Drought correlation function loaded!


In [None]:
# Process all countries and economic variables for ALL drought indices
df = pd.ExcelFile("./static/df_country_data_climate.xlsx")
list_countries = df.sheet_names

# Dictionary structure: {drought_type: {country: {econ_var: correlation_map}}}
all_drought_correlations = {}

print(f"Processing drought correlations for {len(list_countries)} countries...")
print(f"Drought indices: {list(drought_indices.keys())}")
print(f"Economic variables: {list(evar_aliases.keys())}")
print("="*60)

# Process each drought index type
for drought_type, drought_data in drought_indices.items():
    print(f"\n{'='*60}")
    print(f"PROCESSING: {drought_index_names[drought_type]}")
    print(f"{'='*60}")
    
    all_drought_correlations[drought_type] = {}
    
    for country in tqdm(list_countries, desc=f"Countries ({drought_type})"):
        print(f"\n  {country}...")
        
        try:
            # Load economic data
            econ_data = pd.read_excel(df, sheet_name=country)
            econ_data = econ_data[3:].copy()
            econ_data['time'] = pd.to_datetime(econ_data['Unnamed: 0'])
            econ_data = econ_data.set_index('time')
            econ_data = econ_data.drop(columns=['Unnamed: 0'], errors='ignore')
            
            # Initialize correlation maps for this country
            correlation_maps = {}
            
            # Loop through economic variables
            for econ_var in evar_aliases.keys():
                if econ_var not in econ_data.columns:
                    continue
                
                try:
                    econ_series = econ_data[econ_var]
                    
                    # Check data quality
                    nan_ratio = econ_series.isna().sum() / len(econ_series)
                    if nan_ratio > 0.5:
                        continue
                    
                    # Compute correlation map
                    print(f"    Computing: {evar_aliases[econ_var]}...")
                    corr_map = correlate_drought_with_econ(drought_data, econ_series)
                    correlation_maps[econ_var] = corr_map
                    print(f"      ✓ Complete")
                    
                except Exception as e:
                    print(f"      ✗ Error: {e}")
                    continue
            
            # Store results
            all_drought_correlations[drought_type][country] = correlation_maps
            
            # Save individual country results
            out_path = Path(r'.\cache\tmp\drought') / f"{drought_type}_{country}.pkl"
            out_path.parent.mkdir(parents=True, exist_ok=True)
            with out_path.open('wb') as f:
                pickle.dump(correlation_maps, f)
            
            print(f"    ✓ Saved {len(correlation_maps)} correlation maps")
            
        except Exception as e:
            print(f"    ✗ Failed: {e}")
            continue

# Save all results
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
for drought_type, country_data in all_drought_correlations.items():
    print(f"\n{drought_index_names[drought_type]}:")
    print(f"  Processed {len(country_data)} countries")
    total_vars = sum(len(maps) for maps in country_data.values())
    print(f"  Total correlations: {total_vars}")

# Save complete dataset
out_path = Path(r'.\cache\tmp\drought') / "all_drought_correlations.pkl"
with out_path.open('wb') as f:
    pickle.dump(all_drought_correlations, f)
print(f"\n✓ Saved all drought correlations to {out_path}")

## Visualization: Drought-GDP Correlations

Generate plots for all countries showing drought impact on GDP.

In [37]:
# Generate GDP correlation plots for all drought indices and countries
import matplotlib.pyplot as plt
from tqdm import tqdm

gdp_var = 'y'  # Real GDP variable

for drought_type, country_data in all_drought_correlations.items():
    # Create plots directory for this drought type
    plots_dir = Path(r'.\cache\drought_plots') / drought_type / 'gdp'
    plots_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*60}")
    print(f"PLOTTING: {drought_index_names[drought_type]} - GDP Correlations")
    print(f"{'='*60}")
    print(f"Saving to: {plots_dir}")
    
    successful = 0
    skipped = 0
    
    for country, maps in tqdm(country_data.items(), desc=f"{drought_type}"):
        if gdp_var not in maps:
            skipped += 1
            continue
        
        try:
            corr_map_lazy = maps[gdp_var]
            
            # Compute if dask array
            if hasattr(corr_map_lazy, 'compute'):
                corr_map = corr_map_lazy.compute()
            else:
                corr_map = corr_map_lazy
            
            # Create the plot
            fig, ax = plt.subplots(figsize=(14, 7))
            im = corr_map.plot(ax=ax, cmap='RdBu_r', vmin=-0.8, vmax=0.8,
                              cbar_kwargs={'label': 'Correlation Coefficient', 'shrink': 0.8})
            
            ax.set_title(f'{country} - {evar_aliases[gdp_var]} vs {drought_index_names[drought_type]}',
                        fontsize=14, fontweight='bold', pad=20)
            ax.set_xlabel('Longitude', fontsize=11)
            ax.set_ylabel('Latitude', fontsize=11)
            ax.grid(True, alpha=0.3, linestyle='--')
            
            # Add statistics box
            corr_values = corr_map.values
            valid_corr = corr_values[~np.isnan(corr_values)]
            if len(valid_corr) > 0:
                stats_text = (f'Mean corr: {valid_corr.mean():.3f}\n'
                             f'Max corr: {valid_corr.max():.3f}\n'
                             f'Min corr: {valid_corr.min():.3f}\n'
                             f'Interpretation: Negative = drought reduces GDP')
                ax.text(0.02, 0.98, stats_text,
                       transform=ax.transAxes,
                       fontsize=9,
                       verticalalignment='top',
                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
            plt.tight_layout()
            
            # Save the plot
            plot_filename = plots_dir / f"{country}_GDP_vs_{drought_type}.png"
            plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
            plt.close(fig)
            
            successful += 1
            
        except Exception as e:
            print(f"  ✗ Error plotting {country}: {e}")
            plt.close('all')
            skipped += 1
            continue
    
    print(f"  ✓ Successfully plotted: {successful} countries")
    print(f"  ⚠ Skipped: {skipped} countries")

print(f"\n{'='*60}")
print("ALL DROUGHT-GDP PLOTS COMPLETE")
print(f"{'='*60}")


PLOTTING: Water Balance Deficit Index (P-E) - GDP Correlations
Saving to: cache\drought_plots\water_balance\gdp


water_balance:   0%|          | 0/33 [26:40<?, ?it/s]
Task exception was never retrieved
future: <Task finished name='Task-33754460' coro=<Client._gather.<locals>.wait() done, defined at c:\Users\Kris\AppData\Local\anaconda3\Lib\site-packages\distributed\client.py:2384> exception=AllExit()>
Traceback (most recent call last):
  File "c:\Users\Kris\AppData\Local\anaconda3\Lib\site-packages\distributed\client.py", line 2393, in wait
    raise AllExit()
distributed.client.AllExit


KeyboardInterrupt: 

## Summary

This notebook implements a complete drought detection and economic correlation workflow:

### Drought Indices Computed:
1. **Precipitation Deficit Index**: Standardized precipitation anomalies (negative = drought)
2. **Water Balance Deficit Index**: Precipitation minus Evaporation anomalies (negative = water deficit/drought)

### Analysis Pipeline:
1. Load ERA5-Land quarterly climate data (precipitation & evaporation)
2. Calculate drought indices with seasonal detrending and standardization
3. Correlate drought indices with economic variables for all countries
4. Generate correlation maps and save to organized folders

### Interpretation:
- **Negative correlations**: Drought (water deficit) negatively impacts the economic variable
- **Positive correlations**: Drought positively impacts the economic variable (unusual)
- **Near-zero correlations**: Little to no relationship between drought and economy

### Output Structure:
```
cache/
├── drought/                           # Pickle files with correlation data
│   ├── precipitation_Argentina.pkl
│   ├── water_balance_Argentina.pkl
│   └── all_drought_correlations.pkl
└── drought_plots/                     # PNG plots
    ├── precipitation/
    │   └── gdp/
    │       ├── Argentina_GDP_vs_precipitation.png
    │       └── ...
    └── water_balance/
        └── gdp/
            ├── Argentina_GDP_vs_water_balance.png
            └── ...
```