In [1]:
import os
from glob import glob

import dask
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe

from utils import city_list

## Preliminaries

In [2]:
################
#### Paths #####
################
# Update these for reproduction

project_data_path = "/storage/group/pches/default/users/dcl5300/conus_comparison_lafferty-etal-2024/"
project_code_path = "/storage/home/dcl5300/work/current_projects/conus_comparison_lafferty-etal-2024/"
star_path = "/storage/group/pches/default/users/dcl5300/STAR-ESDM/" # raw loca outputs

In [3]:
##############
### Models ###
##############
ssp245_models = np.unique([file.split('/')[-1].split('.')[1] for file in glob(f"{star_path}/ssp245/*.nc")])
ssp585_models = np.unique([file.split('/')[-1].split('.')[1] for file in glob(f"{star_path}/ssp585/*.nc")])

if (ssp245_models == ssp585_models).all():
    models = ssp245_models
else:
    print('Model mismatch')

In [4]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    account="pches",
    # account="open",
    cores=1,
    memory="15GiB",
    walltime="00:10:00"
)

cluster.scale(jobs=20)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.158:40159,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Calculate metrics

In [19]:
###############################
# Metric calulcation function #
###############################
def calculate_metric(model, ssp, var_id, metric, star_path, out_path):
    """
    Inputs: selected model, ssp, variable, and metric to calculate (from STAR)
    Outputs: calculated (annual) metric (max, avg, sum)
    """
    def read_star(file_path):
        ds = xr.open_mfdataset(file_path, decode_times=False, chunks={'time':365, 'latitude':-1, 'longitude':-1})
        ds.time.attrs['calendar'] = '365_day'
        return xr.decode_cf(ds, decode_times=True)
    
    try:
        ## First check if already exists
        file_info = glob(f"{star_path}/{ssp}/downscaled.{model}.r1i1p1f1.{var_id}*")[0]
        file_info = file_info.split(f".{ssp}.")[-1].replace('1950', '2015')
        out_str = f"{metric}_{var_id}.downscaled.{model}.r1i1p1f1.{var_id}.{ssp}.{file_info}"
        if os.path.isfile(f"{out_path}/{out_str}"):
            print(f"{ssp} {model} already done.")
            return None

        # Read
        ds_tmp = read_star(f"{star_path}/{ssp}/downscaled.{model}.r1i1p1f1.{var_id}*")
        # Calculate tas if needed
        if var_id == "tas":
            ds_tmp['tas'] = (ds_tmp['tasmin'] + ds_tmp['tasmax']) / 2.0
            ds_tmp = ds_tmp[['tas']]

        # Time slice
        if ssp != "historical":
            ds_tmp = ds_tmp.sel(time=slice("2015-01-01","2100-12-31"))

        # Select only var_id
        ds_tmp = ds_tmp[[var_id]]
    
        # Calculate metric
        if metric == "avg":
            ds_out = ds_tmp.resample(time="1Y").mean()
        elif metric == "max":
            ds_out = ds_tmp.resample(time="1Y").max()
        elif metric == "sum":
            ds_out = ds_tmp.resample(time="1Y").sum()
            if var_id == "pr":
                ds_out.pr.attrs["units"] = "mm"

        # Store
        ds_out.to_netcdf(f"{out_path}/{out_str}")
        print(f"{ssp} {model}")
            
    # Log if error
    except Exception as e:
        except_path = f"{project_code_path}/code/logs"
        with open(f"{except_path}/{model}_{ssp}_{var_id}_STAR.txt", "w") as f:
            f.write(str(e))

In [9]:
%%time
#########################
## Average Temperature ##
#########################
var_id = "tas"
metric = "avg"

out_path = f"{project_data_path}/metrics/STAR-ESDM/"

# Loop through all
for ssp in ["ssp245"]:
    for model in models:
        # Calculate metric
        calculate_metric(model = model,
                         ssp = ssp,
                         var_id = var_id,
                         metric = metric,
                         star_path = star_path,
                         out_path = out_path)

ssp245 ACCESS-CM2 already done.
ssp245 ACCESS-ESM1-5 already done.
ssp245 BCC-CSM2-MR already done.
ssp245 CMCC-ESM2 already done.
ssp245 CanESM5 already done.
ssp245 EC-Earth3 already done.
ssp245 EC-Earth3-Veg already done.
ssp245 EC-Earth3-Veg-LR already done.
ssp245 FGOALS-g3 already done.
ssp245 GFDL-CM4 already done.
ssp245 GFDL-ESM4 already done.
ssp245 INM-CM4-8 already done.
ssp245 INM-CM5-0 already done.
ssp245 IPSL-CM6A-LR already done.
ssp245 KACE-1-0-G already done.
ssp245 KIOST-ESM already done.
ssp245 MIROC6 already done.
ssp245 MPI-ESM1-2-HR already done.
ssp245 MPI-ESM1-2-LR already done.
ssp245 MRI-ESM2-0 already done.
ssp245 NESM3 already done.
ssp245 NorESM2-LM already done.
ssp245 NorESM2-MM already done.
ssp245 TaiESM1 already done.
CPU times: user 2.36 ms, sys: 1.18 ms, total: 3.54 ms
Wall time: 12.1 ms


In [10]:
%%time
#########################
## Total Precipitation ##
#########################
var_id = "pr"
metric = "sum"

out_path = f"{project_data_path}/metrics/STAR-ESDM/"

# Loop through all
for ssp in ["ssp245"]:
    for model in models:
        # Calculate metric
        calculate_metric(model = model,
                         ssp = ssp,
                         var_id = var_id,
                         metric = metric,
                         star_path = star_path,
                         out_path = out_path)

ssp245 ACCESS-CM2 already done.
ssp245 ACCESS-ESM1-5 already done.
ssp245 BCC-CSM2-MR already done.
ssp245 CMCC-ESM2 already done.
ssp245 CanESM5 already done.
ssp245 EC-Earth3 already done.
ssp245 EC-Earth3-Veg already done.
ssp245 EC-Earth3-Veg-LR already done.
ssp245 FGOALS-g3 already done.
ssp245 GFDL-CM4 already done.
ssp245 GFDL-ESM4 already done.
ssp245 INM-CM4-8 already done.
ssp245 INM-CM5-0 already done.
ssp245 IPSL-CM6A-LR already done.
ssp245 KACE-1-0-G already done.
ssp245 KIOST-ESM already done.
ssp245 MIROC6 already done.
ssp245 MPI-ESM1-2-HR already done.
ssp245 MPI-ESM1-2-LR already done.
ssp245 MRI-ESM2-0 already done.
ssp245 NESM3 already done.
ssp245 NorESM2-LM already done.
ssp245 NorESM2-MM already done.
ssp245 TaiESM1 already done.
CPU times: user 5.52 ms, sys: 3.76 ms, total: 9.28 ms
Wall time: 23.7 ms


In [11]:
%%time
#########################
## Maximum Temperature ##
#########################
var_id = "tasmax"
metric = "max"

out_path = f"{project_data_path}/metrics/STAR-ESDM/"

# Loop through all
for ssp in ["ssp245"]:
    for model in models:
        # Calculate metric
        calculate_metric(model = model,
                         ssp = ssp,
                         var_id = var_id,
                         metric = metric,
                         star_path = star_path,
                         out_path = out_path)

ssp245 ACCESS-CM2
ssp245 ACCESS-ESM1-5
ssp245 BCC-CSM2-MR
ssp245 CMCC-ESM2
ssp245 CanESM5
ssp245 EC-Earth3
ssp245 EC-Earth3-Veg
ssp245 EC-Earth3-Veg-LR
ssp245 FGOALS-g3
ssp245 GFDL-CM4
ssp245 GFDL-ESM4
ssp245 INM-CM4-8
ssp245 INM-CM5-0
ssp245 IPSL-CM6A-LR
ssp245 KACE-1-0-G
ssp245 KIOST-ESM
ssp245 MIROC6
ssp245 MPI-ESM1-2-HR
ssp245 MPI-ESM1-2-LR
ssp245 MRI-ESM2-0
ssp245 NESM3
ssp245 NorESM2-LM
ssp245 NorESM2-MM
ssp245 TaiESM1
CPU times: user 3min 4s, sys: 11.8 s, total: 3min 16s
Wall time: 15min 29s


In [12]:
%%time
#########################
# Maximum Precipitation #
#########################
var_id = "pr"
metric = "max"

out_path = f"{project_data_path}/metrics/STAR-ESDM/"

# Loop through all
for ssp in ["ssp245"]:
    for model in models:
        # Calculate metric
        calculate_metric(model = model,
                         ssp = ssp,
                         var_id = var_id,
                         metric = metric,
                         star_path = star_path,
                         out_path = out_path)

ssp245 ACCESS-CM2
ssp245 ACCESS-ESM1-5
ssp245 BCC-CSM2-MR
ssp245 CMCC-ESM2
ssp245 CanESM5
ssp245 EC-Earth3
ssp245 EC-Earth3-Veg
ssp245 EC-Earth3-Veg-LR
ssp245 FGOALS-g3
ssp245 GFDL-CM4
ssp245 GFDL-ESM4
ssp245 INM-CM4-8
ssp245 INM-CM5-0
ssp245 IPSL-CM6A-LR
ssp245 KACE-1-0-G
ssp245 KIOST-ESM
ssp245 MIROC6
ssp245 MPI-ESM1-2-HR
ssp245 MPI-ESM1-2-LR
ssp245 MRI-ESM2-0
ssp245 NESM3
ssp245 NorESM2-LM
ssp245 NorESM2-MM
ssp245 TaiESM1
CPU times: user 2min 25s, sys: 10.2 s, total: 2min 35s
Wall time: 13min 19s


# Regrid

In [4]:
# We use LOCA grid as target
loca_lat_grid = np.linspace(23.90625, 53.46875, 474)
loca_lon_grid = np.linspace(234.53125, 293.46875, 944)
    
ds_out = xr.Dataset({"lat": (["lat"], loca_lat_grid,
                             {"standard_name": "latitude", "units": "degrees_north"}),
                     "lon": (["lon"], loca_lon_grid,
                             {"standard_name": "longitude", "units": "degrees_east"})
                    })
    
# Add mask from LOCA output
loca_nans = np.load(f'{project_code_path}/code/utils/LOCA2_NaNs.npy')
ds_out["mask"] = xr.DataArray(~loca_nans, dims=['lat','lon'])
    
# STAR grid to construct regridder
example_file = glob(f"{project_data_path}/metrics/STAR-ESDM/*.nc")[0]
ds_in = xr.open_dataset(example_file)
    
# Regridder
conservative_regridder = xe.Regridder(ds_in, ds_out, "conservative")
# nn_s2d_regridder = xe.Regridder(ds_in, ds_out, "nearest_s2d")

In [41]:
%%time

# Out path
out_path = f"{project_data_path}/metrics_regridded/STAR-ESDM/"

# Regridder
regridder_names = ["conservative"]
regridders = [conservative_regridder]

# Metrics
metrics_ids = ["avg_tas", "sum_pr", "max_tasmax", "max_pr"]

# Loop through all
for regridder_name, regridder in zip(regridder_names, regridders):
    for ssp in ["ssp245"]:
        for model in models:
            for metric_id in metrics_ids:
                if not os.path.isfile(f"{out_path}/{regridder_name}/{metric_id}_{model}_{ssp}.nc"):
                    # Read
                    metric, var_id = metric_id.split('_')
                    file_path = glob(f"{project_data_path}/metrics/STAR-ESDM/{metric_id}.downscaled.{model}.r1i1p1f1.{var_id}.{ssp}*.nc")[0]
                    ds_star_in = xr.open_dataset(file_path)
                    ds_star_in = ds_star_in.rename({var_id: metric_id})
        
                    # NOTE: use high NaN threshold to try to not introduce NaNs
                    # not already present in the LOCA2 grid
                    ds_star_out = regridder(ds_star_in, skipna=True, na_thres=0.99)
        
                    # Store
                    ds_star_out.to_netcdf(f"{out_path}/{regridder_name}/{metric_id}_{model}_{ssp}.nc")
                    print(f"{metric_id} {model} {ssp}")

sum_pr ACCESS-CM2 ssp245
max_tasmax ACCESS-CM2 ssp245
max_pr ACCESS-CM2 ssp245
sum_pr ACCESS-ESM1-5 ssp245
max_tasmax ACCESS-ESM1-5 ssp245
max_pr ACCESS-ESM1-5 ssp245
sum_pr BCC-CSM2-MR ssp245
max_tasmax BCC-CSM2-MR ssp245
max_pr BCC-CSM2-MR ssp245
sum_pr CMCC-ESM2 ssp245
max_tasmax CMCC-ESM2 ssp245
max_pr CMCC-ESM2 ssp245
sum_pr CanESM5 ssp245
max_tasmax CanESM5 ssp245
max_pr CanESM5 ssp245
sum_pr EC-Earth3 ssp245
max_tasmax EC-Earth3 ssp245
max_pr EC-Earth3 ssp245
sum_pr EC-Earth3-Veg ssp245
max_tasmax EC-Earth3-Veg ssp245
max_pr EC-Earth3-Veg ssp245
sum_pr EC-Earth3-Veg-LR ssp245
max_tasmax EC-Earth3-Veg-LR ssp245
max_pr EC-Earth3-Veg-LR ssp245
sum_pr FGOALS-g3 ssp245
max_tasmax FGOALS-g3 ssp245
max_pr FGOALS-g3 ssp245
sum_pr GFDL-CM4 ssp245
max_tasmax GFDL-CM4 ssp245
max_pr GFDL-CM4 ssp245
sum_pr GFDL-ESM4 ssp245
max_tasmax GFDL-ESM4 ssp245
max_pr GFDL-ESM4 ssp245
sum_pr INM-CM4-8 ssp245
max_tasmax INM-CM4-8 ssp245
max_pr INM-CM4-8 ssp245
sum_pr INM-CM5-0 ssp245
max_tasmax INM-CM5-

# Summaries

## Indices

In [43]:
# Simple preprocessing function to add model and year coordinates
def _preprocess(ds):
    # Add model and SSP as coordinates
    model = ds.encoding['source'].split('/')[-1].split('_')[2]
    ds = ds.assign_coords(model = model)

    # Time -> year
    ds['time'] = ds['time'].dt.year

    return ds

In [44]:
# Calculates summary indices for NEX-GDDP-CMIP6 ensemble for given SSP
def get_summary_indices(metric_id, ssp, years, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.nc"):
        # Read all
        ds = xr.open_mfdataset(f"{project_data_path}/metrics_regridded/STAR-ESDM/conservative/{metric_id}_*_{ssp}.nc", chunks='auto',
                               preprocess=_preprocess, combine='nested', concat_dim=['model'])

        # Time slice
        ds_sel = ds.sel(time=slice(years[0],years[1]))
    
        ## Summary indices
        # Mean
        ds_mean = ds_sel.mean(dim=['model', 'time']).assign_coords(indice = 'mean')
        # Quantiles
        ds_qlow = ds_sel.chunk(dict(model=-1)).quantile(0.005, dim=['model', 'time'])
        ds_qhigh = ds_sel.chunk(dict(model=-1)).quantile(0.995, dim=['model', 'time'])
        ds_qrange = (ds_qhigh - ds_qlow).assign_coords(indice = '99range')
    
        ds_q99 = ds_sel.chunk(dict(model=-1)).quantile(0.99, dim=['model', 'time']).assign_coords(indice = 'q99')

        # Store
        ds_out = xr.concat([ds_mean, ds_qrange, ds_q99], dim='indice')
        ds_out.to_netcdf(f"{out_path}/{out_str}.nc")

In [45]:
%%time
for years in [[2020,2040], [2050,2070], [2080,2100]]:
    for ssp in ['ssp245']:
        for metric_id in ['avg_tas', 'sum_pr', 'max_tasmax', 'max_pr']:
            get_summary_indices(metric_id = metric_id,
                                ssp = ssp,
                                years = years,
                                out_path=f"{project_data_path}/summary_indices",
                                out_str=f"STAR-ESDM_{ssp}_{str(years[0])}-{str(years[1])}_{metric_id}")

CPU times: user 36.8 s, sys: 4.12 s, total: 40.9 s
Wall time: 3min 59s


## Raw data

In [5]:
# Calculates summary indices for STAR-ESDM ensemble for given SSP
def get_raw_data(metric_id, ssp, years, lat, lon, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    def read_and_process(metric_id, model, ssp, years, lat, lon):
        # Read
        ds_tmp = xr.open_dataset(f"{project_data_path}/metrics_regridded/STAR-ESDM/conservative/{metric_id}_{model}_{ssp}.nc")
        ds_tmp['time'] = ds_tmp["time"].dt.year

        # Time slice
        ds_sel = ds_tmp.sel(time=slice(years[0],years[1]))
    
        # Location selection
        if lon < 0:
            lon = 360 + lon
        ds_sel = ds_sel.sel(lat=lat, lon=lon, method='nearest')
        
        # Construct dataframe
        df_tmp = ds_sel.to_dataframe().drop(columns=["lat", "lon"]).reset_index()
        df_tmp["ssp"] = ssp
        df_tmp["model"] = model

        # Return 
        return df_tmp
        
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.csv"):
        df_delayed = []
        # Read all
        for model in models:
            df_tmp = dask.delayed(read_and_process)(metric_id, model, ssp, years, lat, lon)
            df_delayed.append(df_tmp)
        
        # Compute and store
        df_out = dask.compute(*df_delayed)
        pd.concat(df_out).to_csv(f"{out_path}/{out_str}.csv", index=False)

In [6]:
%%time
for city in city_list.keys():
    lat, lon = city_list[city]
    for years in [[2020,2040], [2050,2070], [2080,2100]]:
        for ssp in ['ssp245']:
            for metric_id in ['avg_tas', 'sum_pr', 'max_tasmax', 'max_pr']:
                get_raw_data(metric_id = metric_id, 
                             ssp=ssp, years=years,
                             lat=lat, lon=lon,
                             out_path=f"{project_data_path}/summary_raw",
                             out_str=f"{city}_STAR-ESDM_{ssp}_{str(years[0])}-{str(years[1])}_{metric_id}")

CPU times: user 3.15 s, sys: 278 ms, total: 3.43 s
Wall time: 13.6 s
