In [38]:
import xarray as xr
import utils.agcd_agg_functions as agg
from importlib import reload
from shapely.geometry import mapping, box
import geopandas as gpd
import matplotlib.pyplot as plt
import rioxarray
import numpy as np
from rasterio.enums import Resampling
from rasterio.features import geometry_mask
import dill
import pandas as pd
from pathlib import Path
from collections import defaultdict
import seaborn as sns

pd.set_option('display.max_columns',  None)

reload(agg)

<module 'utils.agcd_agg_functions' from 'c:\\Users\\jake.allen.ALLUVIUMQLD\\Documents\\Repos\\climate_health_reference_manual\\utils\\agcd_agg_functions.py'>

### Paths to NetCDF files
We load everything into xarray dataset definitions for easy iterations later. Note that no data is being held in memory yet thanks to dask. This only happens once computations are triggered later.

In [13]:
# build path dictionary
clim_fold = 'E:/Jake_ClimateRasters'

def nested_dict():
    return defaultdict(lambda: defaultdict(list))

path_dict = defaultdict(lambda: defaultdict(nested_dict))

models = ['ACCESS-ESM1-5']
ssps = ['ssp245', 'ssp370']
epochs = ['mid', 'late']
vars = ['hurs', 'tasmax', 'tasmin']

# assign model paths
for model in models:
    for ssp in ssps:
        for epoch in epochs:
            thisepoch = '2035-2064' if epoch == 'mid' else '2070-2099'
            for thisvar in vars:
                thispath = Path(f'{clim_fold}/{model}/{ssp}/{thisvar}/AUS-11/{thisepoch}')
                ncdfs = []
                for path in thispath.rglob('*.nc'):
                    ncdfs.append(path)
                path_dict[model][ssp][epoch][thisvar] = ncdfs

# assign historical observation paths
for thisvar in vars:
    thispath = Path(f'{clim_fold}/Historical/{thisvar}')
    ncdfs = []
    for path in thispath.rglob('*.nc'):
        ncdfs.append(path)
    path_dict['Historical'][thisvar] = ncdfs



In [3]:
# load districts

districts = gpd.read_file('Inputs/health_district_merged.json')

# subset districts if necessary
sub = districts[districts['health_district_name'].isin(['Northern NSW', 'Western Sydney', 'Eyre and Far North'])].copy().to_crs('EPSG:4326')

# get bounding box
minx, miny, maxx, maxy = sub.total_bounds

bbox_geom = box(minx, miny, maxx, maxy)
bbox = [mapping(bbox_geom)]


In [14]:
# Load ncdfs into dictionary for easy wrangling, subsetting to bounding box and chunking by time = 365
# use load_var helper function

# initialise empty dictionary
data_dict = defaultdict(lambda: defaultdict(nested_dict))

# load model data
for model in models:
    for ssp in ssps:
        for epoch in epochs:
            ds = xr.Dataset({
                thisvar: agg.load_var(path_dict[model][ssp][epoch][thisvar], thisvar, bbox)
                for thisvar in vars
            })
            data_dict[model][ssp][epoch] = ds


# Load all historical variables into a dict
# historical data has misaligned time coordinates, causing issues when combining into a single dataset
hist_vars = {}
for thisvar in vars:
    ds = agg.load_var(path_dict['Historical'][thisvar], thisvar, bbox, chunk=False)

    # Force matching time coordinates
    if 'time' in hist_vars:
        ds['time'] = hist_vars['time']
    else:
        hist_vars['time'] = ds['time']

    hist_vars[thisvar] = ds

# Drop the saved 'time' array from the dict
hist_vars.pop('time')

# Now build the dataset
data_dict['Historical'] = xr.Dataset(hist_vars)


## Calculate metrics

#### Define mean temperature and heat index
These metrics need to be calculated for each day throughout the time series. Here we are just defining the metrics, not triggering the dask computation yet. 
Dill is used as a checkpointing mechanism - saving interim results after each model iteration.

In [15]:
# iterate through each dataset to define mean temp and heat index. Heat index is using a function pulled from our utils functions file.

for model in models:
    for ssp in ssps:
        for epoch in epochs:
                ds = data_dict[model][ssp][epoch]
                # mean temp
                ds['tas'] = (ds['tasmax'] + ds['tasmin']) / 2
                # heat index
                ds['hi'] = agg.calculate_heat_index(ds['tasmax'], ds['hurs'])

hist = data_dict['Historical']
hist['tas'] = (hist['tasmax'] + hist['tasmin']) / 2
hist['hi'] = agg.calculate_heat_index(hist['tasmax'], hist['hurs'])

##### Average hot days (thresholds) per year

In [None]:
hot_days_dict = defaultdict(lambda: defaultdict(nested_dict))

for model in models:
    for ssp in ssps:
        for epoch in epochs:
            ds = data_dict[model][ssp][epoch]
            hot_days_dict[model][ssp][epoch] = {
                thresh: agg.calculate_avg_hot_days(ds, threshold=thresh).compute()
                for thresh in [35,40]
            }
    agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'hot_days': hot_days_dict
    })

hot_days_dict['Historical']['Historical']['Historical'] = {
        thresh: agg.calculate_avg_hot_days(data_dict['Historical'], threshold=thresh).compute()
        for thresh in [35, 40]
}
agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'hot_days': hot_days_dict
    })


#### Percentile temps maximum 95%

In [None]:
# create dict to store ncdf results to
pc_max_95_dict = defaultdict(lambda: defaultdict(nested_dict))

for model in models:
    for ssp in ssps:
        for epoch in epochs:
            ds = data_dict[model][ssp][epoch]
            pc_95_max = agg.calculate_percentile(ds, 'tasmax').compute()
            pc_max_95_dict[model][ssp][epoch] = pc_95_max
    agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'pc_max_95': pc_max_95_dict
    })

# historical
ds = data_dict['Historical']
pc_95_max = agg.calculate_percentile(ds, 'tasmax').compute()
pc_max_95_dict['Historical']['Historical']['Historical'] = pc_95_max
agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'pc_max_95': pc_max_95_dict
    })

#### Percentile temps minimum, 95%
can probably turn this into a loop with max, but sometimes useful to run seperately

In [None]:
# create dict to store ncdf results to
pc_min_95_dict = defaultdict(lambda: defaultdict(nested_dict))

for model in models:
    for ssp in ssps:
        for epoch in epochs:
            ds = data_dict[model][ssp][epoch]
            pc_95_min = agg.calculate_percentile(ds, 'tasmin').compute()
            pc_min_95_dict[model][ssp][epoch] = pc_95_min
    agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'pc_min_95': pc_min_95_dict
    })

# historical
ds = data_dict['Historical']
pc_95_min = agg.calculate_percentile(ds, 'tasmin').compute()
pc_min_95_dict['Historical']['Historical']['Historical'] = pc_95_min
agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'pc_min_95': pc_min_95_dict
    })

#### Heat Index 95% - Not needed currently

In [None]:
# create dict to store ncdf results to
pc_hi_95_dict = defaultdict(lambda: defaultdict(nested_dict))

for model in models:
    for ssp in ssps:
        for epoch in epochs:
            ds = data_dict[model][ssp][epoch]
            pc_95_hi = agg.calculate_percentile(ds, 'hi').compute()
            pc_hi_95_dict[model][ssp][epoch] = pc_95_hi

#### Heat Index mean

In [None]:
mean_hi_dict = defaultdict(lambda: defaultdict(nested_dict))

# mean heat index for whole period
for model in models:
    for ssp in ssps:
        for epoch in epochs:
            ds = data_dict[model][ssp][epoch]
            hi_mean = ds['hi'].mean(dim='time').compute()
            mean_hi_dict[model][ssp][epoch] = hi_mean
    agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'mean_hi': mean_hi_dict
    })

# historical
ds = data_dict['Historical']
hi_mean = ds['hi'].mean(dim='time').compute()
mean_hi_dict['Historical']['Historical']['Historical']  = hi_mean
agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'mean_hi': mean_hi_dict
    })

##### Mean temp mean

In [None]:
mean_mean_dict = defaultdict(lambda: defaultdict(nested_dict))

# mean temperature for the whole period
for model in models:
    for ssp in ssps:
        for epoch in epochs:
            ds = data_dict[model][ssp][epoch]
            mean_mean = ds['tas'].mean(dim='time').compute()
            mean_mean_dict[model][ssp][epoch] = mean_mean
    agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'mean_mean': mean_mean_dict
    })

# historical
ds = data_dict['Historical']
mean_mean = ds['tas'].mean(dim='time').compute()
mean_mean_dict['Historical']['Historical']['Historical'] = mean_mean
agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'mean_mean': mean_mean_dict
    })

#### 95th percentile of historical mean temp for EHF

#### Excess Heat Factor (EHF) (Heatwaves)

In [None]:
# compute 95th percentile of historical mean temperature
mean_95_hist = agg.calculate_percentile(data_dict['Historical'], 'tas', 95)

hw_dict = defaultdict(lambda: defaultdict(nested_dict))

# mean temperature for the whole period
for model in models:
    for ssp in ssps:
        for epoch in epochs:
            ds = data_dict[model][ssp][epoch]
            ehf = agg.calculate_ehf(ds['tas'], mean_95_hist)
            ehf = ehf.chunk({'lat': 50, 'lon': 50, 'time': -1})
            hw = agg.summarise_heatwaves(ehf)
            hw = {k: v.compute() for k, v in hw.items()}
            hw_dict[model][ssp][epoch] = hw
    agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'heatwaves': hw_dict
    })

# historical
ds = data_dict['Historical']
ehf = agg.calculate_ehf(ds['tas'], mean_95_hist)
ehf = ehf.chunk({'lat': 50, 'lon': 50, 'time': -1})
print(ehf.chunks)
hw = agg.summarise_heatwaves(ehf)
hw = {k: v.compute() for k, v in hw.items()}
hw_dict['Historical']['Historical']['Historical'] = hw
agg.checkpoint_update('netcdf-wrangle-dicts.pkl', {
        'heatwaves': hw_dict
    })

  result = blockwise(
  result = blockwise(
  result = blockwise(
  result = blockwise(


((10957,), (50, 32), (50, 50, 50, 50, 24))


#### Combine metric dictionaries

In [27]:
metric_dicts = {
    'heatwaves': hw_dict,
    'pc_max_95': pc_max_95_dict,
    'pc_min_95': pc_min_95_dict,
    'mean_heat_index': mean_hi_dict,
    'mean_temp': mean_mean_dict,
    'avg_hot_days': hot_days_dict
}

#### Population weighting - resample and align population raster to netcdfs.

In [32]:
# open population raster
pop = rioxarray.open_rasterio("Inputs/australian_pop_grid_2024/apg24e_1_0_0.tif").squeeze() # remove band dimension

# ensure correct projections are assinged (GDA94 for pop grid, WGS 84 for netcdfs)
pop.rio.write_crs("EPSG:3577", inplace=True)
# convert -1 to NaN and set as no data
pop = pop.where(pop != -1, np.nan)
pop = pop.rio.write_nodata(np.nan, inplace=False)

ref = data_dict['Historical']['hurs'].isel(time=0)
ref.rio.set_spatial_dims(x_dim='lon', y_dim='lat', inplace=True)

# reproject and resample, summing values as the resample method
pop_aligned = pop.rio.reproject_match(
    ref,
    resampling=Resampling.sum)

#### Build metrics dataframe - applying population weighting to health districts

In [33]:
# retrieve health districts and reproject to match.

all_districts = gpd.read_file('Inputs/health_district_merged.json')

# subset
districts = all_districts[all_districts['health_district_name'].isin(['Northern NSW', 'Western Sydney', 'Eyre and Far North'])].copy()

districts.to_crs(pop_aligned.rio.crs)

Unnamed: 0,OBJECTID,state,health_district_name,Shape_Length,Shape_Area,district_id,geometry
6,7,New South Wales,Northern NSW,11.494546,1.989182,NSW_7,"MULTIPOLYGON (((153.63873 -28.6361, 153.63869 ..."
8,9,New South Wales,Western Sydney,7.340251,0.686059,NSW_9,"POLYGON ((150.98407 -33.38804, 150.98402 -33.3..."
38,39,South Australia,Eyre and Far North,57.676729,46.122717,SA_5,"MULTIPOLYGON (((135.95444 -35.00627, 135.9545 ..."


In [None]:

# Prepare geometries and affine transformation for zonal statistics
shapes = [mapping(geom) for geom in districts.geometry]
affine = pop_aligned.rio.transform()

rows = []

# Loop through each metric dictionary
for metric_name, metric_dict in metric_dicts.items():
    for model in metric_dict:
        for ssp in metric_dict[model]:
            for epoch in metric_dict[model][ssp]:
                data = metric_dict[model][ssp][epoch]

                # Handle metrics that include subcategories (e.g., hot days by threshold)
                if isinstance(data, dict):
                    for submetric, subdata in data.items():
                        # Compute zonal weighted means for each district
                        weighted_means = agg.zonal_weighted_mean(subdata, pop_aligned, shapes, affine)
                        for i, v in enumerate(weighted_means):
                            rows.append({
                                'district_id': districts.iloc[i]['district_id'],  # Unique identifier
                                'district_name': districts.iloc[i]['health_district_name'],        # name
                                'model': model,
                                'ssp': ssp,
                                'epoch': epoch,
                                'metric': f'{metric_name}_{submetric}',            # e.g. hot_days_35
                                'value': v
                            })
                else:
                    # Handle single-layer metric values (no subcategories)
                    weighted_means = agg.zonal_weighted_mean(data, pop_aligned, shapes, affine)
                    for i, v in enumerate(weighted_means):
                        rows.append({
                            'district_id': districts.iloc[i]['district_id'],
                            'district_name': districts.iloc[i]['health_district_name'],
                            'model': model,
                            'ssp': ssp,
                            'epoch': epoch,
                            'metric': metric_name,
                            'value': v
                        })

# Final combined long-format DataFrame
results_df = pd.DataFrame(rows)
results_df.to_csv('heat_humidity_weighted_results.csv')

## Export

In [None]:
results_df.to_csv('heat_humidity_weighted_results.csv')

### Save and reload workspace dictionaries

In [28]:
# Save key processed outputs to workspace

with open('netcdf-wrangle-workspace.pkl', 'wb') as f:
    dill.dump({
        'metrics': metric_dicts
    }, f)

In [None]:
# reload

with open('netcdf-wrangle-workspace.pkl', 'rb') as f:
    saved = dill.load(f)
    
metric_dicts = saved['metrics']