In [1]:
import os
from glob import glob

import dask
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe

from utils import city_list

## Preliminaries

In [2]:
################
#### Paths #####
################
# NOTE: this is run on a different system from other datasets
# Update these for reproduction

project_data_path = "/home/fs01/dcl257/data/conus_comparison_lafferty-etal-2024/"
project_code_path = "/home/fs01/dcl257/code/conus_comparison_lafferty-etal-2024/"
gard_path = "/home/shared/vs498_0001/GARD-LENS" # GARD-LENS raw
gard_gcms = ['canesm5', 'cesm2', 'ecearth3']

In [3]:
# Check all same
gardlens_info = {}

for gcm in gard_gcms: 
    t_mean_files = glob(f"{gard_path}/t_mean/GARDLENS_{gcm}_*.nc")
    t_range_files = glob(f"{gard_path}/t_range/GARDLENS_{gcm}_*.nc")
    pcp_files = glob(f"{gard_path}/pcp/GARDLENS_{gcm}_*.nc")
    assert len(t_mean_files) == len(t_range_files)
    assert len(t_mean_files) == len(pcp_files)

In [4]:
# Get all model members
models_members = glob(f"{gard_path}/t_mean/GARDLENS_*.nc")
models_members = [file.split('GARDLENS')[1].split('t_')[0][1:-1] for file in models_members]

In [5]:
############
### Dask ###
############
from dask.distributed import LocalCluster
cluster = LocalCluster(n_workers = 10)
client = cluster.get_client()
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 10
Total threads: 50,Total memory: 92.79 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:36845,Workers: 10
Dashboard: /proxy/8787/status,Total threads: 50
Started: Just now,Total memory: 92.79 GiB

0,1
Comm: tcp://127.0.0.1:34493,Total threads: 5
Dashboard: /proxy/36803/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:33087,
Local directory: /tmp/dask-scratch-space/worker-g7tuur48,Local directory: /tmp/dask-scratch-space/worker-g7tuur48

0,1
Comm: tcp://127.0.0.1:35881,Total threads: 5
Dashboard: /proxy/38233/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:42899,
Local directory: /tmp/dask-scratch-space/worker-__g4otal,Local directory: /tmp/dask-scratch-space/worker-__g4otal

0,1
Comm: tcp://127.0.0.1:42973,Total threads: 5
Dashboard: /proxy/38215/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:34597,
Local directory: /tmp/dask-scratch-space/worker-fk12qstj,Local directory: /tmp/dask-scratch-space/worker-fk12qstj

0,1
Comm: tcp://127.0.0.1:40105,Total threads: 5
Dashboard: /proxy/38681/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:34559,
Local directory: /tmp/dask-scratch-space/worker-q4yz2ejc,Local directory: /tmp/dask-scratch-space/worker-q4yz2ejc

0,1
Comm: tcp://127.0.0.1:34125,Total threads: 5
Dashboard: /proxy/38359/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:37955,
Local directory: /tmp/dask-scratch-space/worker-_e7xr_1f,Local directory: /tmp/dask-scratch-space/worker-_e7xr_1f

0,1
Comm: tcp://127.0.0.1:45685,Total threads: 5
Dashboard: /proxy/37317/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:40575,
Local directory: /tmp/dask-scratch-space/worker-q7t2yk28,Local directory: /tmp/dask-scratch-space/worker-q7t2yk28

0,1
Comm: tcp://127.0.0.1:35529,Total threads: 5
Dashboard: /proxy/42669/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:33731,
Local directory: /tmp/dask-scratch-space/worker-4gxa11_f,Local directory: /tmp/dask-scratch-space/worker-4gxa11_f

0,1
Comm: tcp://127.0.0.1:34257,Total threads: 5
Dashboard: /proxy/34573/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:42361,
Local directory: /tmp/dask-scratch-space/worker-5_zxnzog,Local directory: /tmp/dask-scratch-space/worker-5_zxnzog

0,1
Comm: tcp://127.0.0.1:42535,Total threads: 5
Dashboard: /proxy/35371/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:44403,
Local directory: /tmp/dask-scratch-space/worker-995zk1du,Local directory: /tmp/dask-scratch-space/worker-995zk1du

0,1
Comm: tcp://127.0.0.1:35599,Total threads: 5
Dashboard: /proxy/44375/status,Memory: 9.28 GiB
Nanny: tcp://127.0.0.1:39083,
Local directory: /tmp/dask-scratch-space/worker-9gx7v6_l,Local directory: /tmp/dask-scratch-space/worker-9gx7v6_l


# Calculate metrics 

In [6]:
###############################
# Metric calulcation function #
###############################
def calculate_metric(model_member, var_id, metric, gard_path, out_path):
    """
    Inputs: selected model, ssp, variable, and metric to calculate (from STAR)
    Outputs: calculated (annual) metric (max, avg, sum)
    """
    try:
        ## First check if already exists
        metric_id = f"{metric}_{var_id}"
        out_str = f"GARDLENS_{model_member}_{metric_id}_1950_2100_CONUS.nc"
        if os.path.isfile(f"{out_path}/{out_str}"):
            print(f"{model_member} already done.")
            return None

        # Read
        if model_member.split('_')[0] == 'ecearth3':
            time_range = '1970_2100'
        else:
            time_range = '1950_2100'
            
        if var_id in ['t_max', 't_min']:
            ds_t_mean = xr.open_dataset(f"{gard_path}/t_mean/GARDLENS_{model_member}_t_mean_{time_range}_CONUS.nc", chunks='auto')
            ds_t_range = xr.open_dataset(f"{gard_path}/t_range/GARDLENS_{model_member}_t_range_{time_range}_CONUS.nc", chunks='auto')
        else:
            ds_tmp = xr.open_dataset(f"{gard_path}/{var_id}/GARDLENS_{model_member}_{var_id}_1950_2100_CONUS.nc")
        # Calculate tmax, tmin if needed
        if var_id == "t_max":
            ds_tmp = ds_t_mean['t_mean'] + ds_t_range['t_range'] / 2.0
            # del ds_t_mean, ds_t_range
            ds_tmp = xr.Dataset({var_id: ds_tmp})
        elif var_id == "t_min":
            ds_tmp = ds_t_mean['t_mean'] - ds_t_range['t_range'] / 2.0
            # del ds_t_mean, ds_t_range
            ds_tmp = xr.Dataset({var_id: ds_tmp})
            
        # Select only var_id
        ds_tmp = ds_tmp[[var_id]]
    
        # Calculate metric
        if metric == "avg":
            ds_out = ds_tmp.resample(time="YE").mean()
        elif metric == "max":
            ds_out = ds_tmp.resample(time="YE").max()
        elif metric == "min":
            ds_out = ds_tmp.resample(time="YE").min()
        elif metric == "sum":
            ds_out = ds_tmp.resample(time="YE").sum()
            if var_id == "pcp":
                ds_out.pcp.attrs["units"] = "mm"

        # Store
        ds_out.to_netcdf(f"{out_path}/{out_str}")
        print(f"{model_member}")
            
    # Log if error
    except Exception as e:
        except_path = f"{project_code_path}/code/logs"
        with open(f"{except_path}/{model_member}_{var_id}_GARDLENS.txt", "w") as f:
            f.write(str(e))

In [7]:
%%time
##############################
# Maximum temperature
##############################
var_id = 't_max'
metric = 'max'

out_path = f'{project_data_path}/metrics/GARD-LENS/'

for model_member in models_members:
    calculate_metric(model_member, var_id, metric, gard_path, out_path)

cesm2_1231_11 already done.
cesm2_1251_07 already done.
cesm2_1301_03 already done.
ecearth3_r134i1p1f1
canesm5_r18i1p2f1 already done.
cesm2_1281_18 already done.
cesm2_1191_10 already done.
ecearth3_r121i1p1f1
canesm5_r12i1p1f1 already done.
cesm2_1231_04 already done.
cesm2_1251_12 already done.
ecearth3_r148i1p1f1
canesm5_r7i1p2f1 already done.
cesm2_1301_16 already done.
cesm2_1061_04 already done.
ecearth3_r149i1p1f1
canesm5_r6i1p2f1 already done.
canesm5_r13i1p1f1 already done.
cesm2_1231_12 already done.
cesm2_1251_04 already done.
ecearth3_r120i1p1f1
cesm2_1231_07 already done.
cesm2_1251_11 already done.
cesm2_1301_15 already done.
canesm5_r19i1p2f1 already done.
ecearth3_r135i1p1f1
ecearth3_r109i1p1f1
ecearth3_r137i1p1f1
cesm2_1251_14 already done.
cesm2_1231_02 already done.
cesm2_1081_05 already done.
cesm2_1301_10 already done.
canesm5_r25i1p2f1 already done.
cesm2_1251_01 already done.
cesm2_1231_17 already done.
ecearth3_r122i1p1f1
cesm2_1301_05 already done.
canesm5_r4

In [8]:
%%time
##############################
# Minimum temperature
##############################
var_id = 't_min'
metric = 'min'

out_path = f'{project_data_path}/metrics/GARD-LENS/'

for model_member in models_members:
    calculate_metric(model_member, var_id, metric, gard_path, out_path)

cesm2_1231_11 already done.
cesm2_1251_07 already done.
cesm2_1301_03 already done.
ecearth3_r134i1p1f1
canesm5_r18i1p2f1 already done.
cesm2_1281_18 already done.
cesm2_1191_10 already done.
ecearth3_r121i1p1f1
canesm5_r12i1p1f1 already done.
cesm2_1231_04 already done.
cesm2_1251_12 already done.
ecearth3_r148i1p1f1
canesm5_r7i1p2f1 already done.
cesm2_1301_16 already done.
cesm2_1061_04 already done.
ecearth3_r149i1p1f1
canesm5_r6i1p2f1 already done.
canesm5_r13i1p1f1 already done.
cesm2_1231_12 already done.
cesm2_1251_04 already done.
ecearth3_r120i1p1f1
cesm2_1231_07 already done.
cesm2_1251_11 already done.
cesm2_1301_15 already done.
canesm5_r19i1p2f1 already done.
ecearth3_r135i1p1f1
ecearth3_r109i1p1f1
ecearth3_r137i1p1f1
cesm2_1251_14 already done.
cesm2_1231_02 already done.
cesm2_1081_05 already done.
cesm2_1301_10 already done.
canesm5_r25i1p2f1 already done.
cesm2_1251_01 already done.
cesm2_1231_17 already done.
ecearth3_r122i1p1f1
cesm2_1301_05 already done.
canesm5_r4

# Regrid

In [8]:
# We use LOCA grid as target
loca_lat_grid = np.linspace(23.90625, 53.46875, 474)
loca_lon_grid = np.linspace(234.53125, 293.46875, 944)
    
ds_out = xr.Dataset({"lat": (["lat"], loca_lat_grid,
                             {"standard_name": "latitude", "units": "degrees_north"}),
                     "lon": (["lon"], loca_lon_grid,
                             {"standard_name": "longitude", "units": "degrees_east"})
                    })
    
# Add mask from LOCA output
loca_nans = np.load(f'{project_code_path}/code/utils/LOCA2_NaNs.npy')
ds_out["mask"] = xr.DataArray(~loca_nans, dims=['lat','lon'])
    
# GARD-LENS grid to construct regridder
example_file = f'{gard_path}/GARDLENS_t_mean_stats_CONUS.nc'
ds_in = xr.open_dataset(example_file).isel(year=0, n_ens=0)
    
# Regridder
conservative_regridder = xe.Regridder(ds_in, ds_out, "conservative")
# nn_s2d_regridder = xe.Regridder(ds_in, ds_out, "nearest_s2d")

In [9]:
# Some small preprocessing for GARD-LENS
def _preprocess(ds, gard_stat_id, metric_id):
    # Re-index
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        ds = ds.set_index(n_ens=['gcm', 'scen', 'ens']).unstack('n_ens')
        ds = ds.rename({'gcm':'model', 'scen':'ssp', 'ens':'member'})
    # Rename
    ds = ds.rename({gard_stat_id: metric_id})[[metric_id]]

    return ds

# Regridding function
def regrid_gard(ds_in, gard_var_id, gard_stat_id, model, member, metric_id, regridder, regridder_name, out_path):
    # Check if done
    if not os.path.exists(f"{out_path}/{regridder_name}/{metric_id}_{model}_{member}_ssp370.nc"):
        # Select GCM and member
        ds_in = ds_in.where((ds_in.gcm == model) & (ds_in.ens == member), drop=True)
    
        # Tidy
        ds_in = _preprocess(ds_in, gard_stat_id, metric_id)
    
        # Regrid
        # NOTE: use high NaN threshold to try to not introduce NaNs
        # not already present in the LOCA2 grid
        ds_out = regridder(ds_in, skipna=True, na_thres=0.99)
        
        # Store
        # comp = dict(zlib=True, complevel=5)
        # encoding = {var: comp for var in ds_out.data_vars}
        ds_out.to_netcdf(f"{out_path}/{regridder_name}/{metric_id}_{model}_{member}_ssp370.nc")

In [10]:
# Run it
out_path = f"{project_data_path}/metrics_regridded/GARD-LENS/"

#####################
# avg tas
gard_var_id = 't_mean'
gard_stat_id = 'mean'
metric_id = 'avg_tas'

ds_in = xr.open_dataset(f'{gard_path}/GARDLENS_{gard_var_id}_stats_CONUS.nc')

for model in gard_info.keys():
    for member in gard_info[model]:
        regrid_gard(ds_in = ds_in,
                    gard_var_id = gard_var_id, 
                    gard_stat_id = gard_stat_id,
                    model = model,
                    member = member,
                    metric_id = metric_id,
                    regridder = conservative_regridder,
                    regridder_name = 'conservative',
                    out_path = out_path)

#####################
# max precip
gard_var_id = 'pcp'
gard_stat_id = 'max'
metric_id = 'max_pr'

ds_in = xr.open_dataset(f'{gard_path}/GARDLENS_{gard_var_id}_stats_CONUS.nc')

for model in gard_info.keys():
    for member in gard_info[model]:
        regrid_gard(ds_in = ds_in,
                    gard_var_id = gard_var_id, 
                    gard_stat_id = gard_stat_id,
                    model = model,
                    member = member,
                    metric_id = metric_id,
                    regridder = conservative_regridder,
                    regridder_name = 'conservative',
                    out_path = out_path)

#####################
# sum precip
gard_var_id = 'pcp'
gard_stat_id = 'sum'
metric_id = 'sum_pr'

ds_in = xr.open_dataset(f'{gard_path}/GARDLENS_{gard_var_id}_stats_CONUS.nc')

for model in gard_info.keys():
    for member in gard_info[model]:
        regrid_gard(ds_in = ds_in,
                    gard_var_id = gard_var_id, 
                    gard_stat_id = gard_stat_id,
                    model = model,
                    member = member,
                    metric_id = metric_id,
                    regridder = conservative_regridder,
                    regridder_name = 'conservative',
                    out_path = out_path)

# Summaries

## Indices

In [17]:
# Calculates summary indices for GARD-LENS model ensemble for given SSP
def get_summary_indices(metric_id, model, years, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.nc"):
        # Read all
        ds = xr.open_mfdataset(f"{project_data_path}/metrics_regridded/GARD-LENS/conservative/{metric_id}_{model}_*.nc", chunks='auto')

        # Time slice
        ds = ds.rename({'year':'time'})
        ds_sel = ds.sel(time=slice(years[0],years[1]))
    
        ## Summary indices
        # Mean
        ds_mean = ds_sel.mean(dim=['member', 'time']).assign_coords(indice = 'mean')
        # Quantiles
        ds_qlow = ds_sel.chunk(dict(member=-1)).quantile(0.005, dim=['member', 'time'])
        ds_qhigh = ds_sel.chunk(dict(member=-1)).quantile(0.995, dim=['member', 'time'])
        ds_qrange = (ds_qhigh - ds_qlow).assign_coords(indice = '99range')
    
        ds_q99 = ds_sel.chunk(dict(member=-1)).quantile(0.99, dim=['member', 'time']).assign_coords(indice = 'q99')

        # Store
        ds_out = xr.concat([ds_mean, ds_qrange, ds_q99], dim='indice')
        ds_out.to_netcdf(f"{out_path}/{out_str}.nc")

In [None]:
%%time
ssp = 'ssp370'

for years in [[2020,2040], [2050,2070], [2080,2100]]:
    for metric_id in ['avg_tas', 'max_pr', 'sum_pr']:
        for model in gard_info.keys():
            get_summary_indices(metric_id = metric_id,
                                model = model,
                                years = years,
                                out_path=f"{project_data_path}/summary_indices",
                                out_str=f"GARD-LENS_{model}_{ssp}_{years[0]}-{years[1]}_{metric_id}")

## Timeseries

### Raw

In [6]:
# Some small preprocessing for GARD-LENS
def _preprocess(ds, gard_stat_id, metric_id):
    # Re-index
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        ds = ds.set_index(n_ens=['gcm', 'scen', 'ens']).unstack('n_ens')
        ds = ds.rename({'gcm':'model', 'scen':'ssp', 'ens':'member'})
    # Rename
    ds = ds.rename({gard_stat_id: metric_id})[[metric_id]]

    return ds

In [7]:
# Calculates summary indices for GARD-LENS ensemble for given GCM
def get_raw_data(ds, gard_stat_id, metric_id, model, years, lat, lon, out_path, out_str):
    """
    """
    # Check if done:
    if not os.path.isfile(f"{out_path}/{out_str}.csv"):
        
        # Select GCM
        ds = ds.where((ds.gcm == model), drop=True)
        
        # Location selection first
        ds_sel = ds.sel(lat=lat, lon=lon, method='nearest')
        
        # Tidy
        ds_sel = _preprocess(ds_sel, gard_stat_id, metric_id)
        
        # Time slice
        if years is not None:
            ds_sel = ds_sel.sel(time=slice(years[0],years[1]))
        
        # Construct dataframe
        df_out = ds_sel.to_dataframe().reset_index().dropna().drop(columns=["lat", "lon"])
        df_out["ssp"] = 'ssp370'
        df_out["model"] = model
            
        # Store
        df_out.to_csv(f"{out_path}/{out_str}.csv", index=False)

In [15]:
# Calculates summary indices for GARD-LENS ensemble
def get_raw_data(gard_metric_id, metric_id, model, years, lat, lon, out_path, out_str):
    """
    """
    def read_and_process(file_path, years, lat, lon):
        try:
            # Read
            ds_tmp = xr.open_dataset(file_path)
            ds_tmp['time'] = ds_tmp['time'].dt.year
    
            # Time slice
            if years is not None:
                ds_sel = ds_tmp.sel(time=slice(years[0],years[1]))
            else:
                ds_sel = ds_tmp.copy()
                
            # Location selection
            ds_sel = ds_sel.sel(lat=lat, lon=lon, method='nearest')
            
            # Construct dataframe
            df_tmp = ds_sel.to_dataframe().drop(columns=["lat", "lon"]).reset_index()
            df_tmp["ssp"] = "ssp370"
            
            # Get model, member
            file_name = file_path.split('GARDLENS')[1].split(metric_id)[0].split('_')
            model = file_name[1]
            member = file_name[2] + f"_{file_name[3]}" if model == 'cesm2' else file_name[2]
            df_tmp["model"] = model
            df_tmp["member"] = member
    
            # Return 
            return df_tmp
        # Log if error
        except Exception as e:
            except_path = f"{project_code_path}/code/logs"
            with open(f"{except_path}/{file_path.split('/')[-1]}", "w") as f:
                f.write(str(e))
            
        
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.csv"):
        df_delayed = []
        # Read all
        file_paths = glob(f"{project_data_path}/metrics/GARD-LENS/GARDLENS_{model}_*_{gard_metric_id}_*.nc")
        for file_path in file_paths:
            # df_tmp = dask.delayed(read_and_process)(file_path, years, lat, lon)
            df_tmp = read_and_process(file_path, years, lat, lon)
            df_delayed.append(df_tmp)
        
        # Compute and store
        # df_out = dask.compute(*df_delayed)
        df_out = pd.concat(df_delayed).rename(columns={gard_metric_id: metric_id})
        df_out.to_csv(f"{out_path}/{out_str}.csv", index=False)

In [16]:
# Run it
out_path = f"{project_data_path}/summary_timeseries/"

#####################
# max tasmax
#####################
gard_metric_id = 'max_t_max'
metric_id = 'max_tasmax'

for model in gard_gcms:
    for city in ['chicago', 'nyc', 'denver']:
        lat, lon = city_list[city]
        get_raw_data(gard_metric_id = gard_metric_id,
                     metric_id = metric_id, 
                     model = model, 
                     years = None, 
                     lat=lat, lon=lon,
                     out_path = out_path,
                     out_str = f"{city}_GARD-LENS_{model}_ssp370_{metric_id}")

#####################
# min tasmin
#####################
gard_metric_id = 'min_t_min'
metric_id = 'min_tasmin'

for model in gard_gcms:
    for city in ['chicago', 'nyc', 'denver']:
        lat, lon = city_list[city]
        get_raw_data(gard_metric_id = gard_metric_id,
                     metric_id = metric_id, 
                     model = model, 
                     years = None, 
                     lat=lat, lon=lon,
                     out_path = out_path,
                     out_str = f"{city}_GARD-LENS_{model}_ssp370_{metric_id}")

### Regridded

In [24]:
# Calculates summary indices for GARD-LENS ensemble for given GCM
def get_raw_data(metric_id, model, years, lat, lon, out_path, out_str):
    """
    """
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.csv"):
        
        # Read all
        ds = xr.open_mfdataset(f"{project_data_path}/metrics_regridded/GARD-LENS/conservative/{metric_id}_{model}_*.nc", chunks='auto')
    
        # Time slice
        ds = ds.rename({'year':'time'})
        ds_sel = ds.sel(time=slice(years[0],years[1]))

        # Location selection
        if lon < 0:
            lon = 360 + lon
        ds_sel = ds_sel.sel(lat=lat, lon=lon, method='nearest')
    
        # Construct dataframe
        df_out = ds_sel.to_dataframe().drop(columns=["lat", "lon"]).reset_index()
        df_out["ssp"] = 'ssp370'
        df_out["model"] = model
        
        # Store
        df_out.to_csv(f"{out_path}/{out_str}.csv", index=False)

In [29]:
%%time
for city in city_list.keys():
    lat, lon = city_list[city]
    for years in [[2020,2040], [2050,2070], [2080,2100]]:
        for metric_id in ['avg_tas', 'sum_pr', 'max_pr']:
            for model in gard_info.keys():
                get_raw_data(metric_id = metric_id, 
                             model = model,
                             years=years,
                             lat=lat, lon=lon,
                             out_path=f"{project_data_path}/summary_raw",
                             out_str=f"{city}_GARD-LENS_{model}_ssp370_{years[0]}-{years[1]}_{metric_id}")

CPU times: user 3min 23s, sys: 31.6 s, total: 3min 55s
Wall time: 16min 26s
