In [1]:
import os
from glob import glob

import dask
import numpy as np
import pandas as pd
import xarray as xr

from utils import city_list
import metric_funcs as mf

## Preliminaries

In [2]:
################
#### Paths #####
################
# Update these for reproduction

project_data_path = "/storage/group/pches/default/users/dcl5300/conus_comparison_lafferty-etal-2024/"
project_code_path = "/storage/home/dcl5300/work/current_projects/conus_comparison_lafferty-etal-2024/"
loca_path = "/storage/group/pches/default/public/LOCA2" # raw loca outputs

In [3]:
##############
### Models ###
##############

gcms = os.listdir(f"{loca_path}/")
gcms.remove('training_data')
gcms.remove('scripts')

loca_all = {}

# Loop through gcms
for gcm in gcms:
    loca_all[gcm] = {}
    # Loop through members
    members = os.listdir(f"{loca_path}/{gcm}/0p0625deg/")
    for member in members:
        # Append SSPs
        ssps = os.listdir(f"{loca_path}/{gcm}/0p0625deg/{member}/")
        loca_all[gcm][member] = ssps

In [4]:
##############
### Models ###
##############

# Matches website (https://loca.ucsd.edu/loca-version-2-for-north-america-ca-jan-2023/) as of Jan 2023
print(f"# gcm: {len(gcms)}")
print(f"# gcm/expts: {np.sum([len(np.unique([item for row in [loca_all[gcm][member] for member in loca_all[gcm].keys()] for item in row])) for gcm in gcms])}")
print(f"# gcm/expts/ens: {np.sum([len(loca_all[gcm][ssp]) for gcm in gcms for ssp in loca_all[gcm]])}")

# gcm: 27
# gcm/expts: 99
# gcm/expts/ens: 329


In [5]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="20GiB",
    walltime="06:00:00"
)

cluster.scale(jobs=30)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.8.19:37907,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


2024-10-28 17:55:12,758 - distributed.batched - INFO - Batched Comm Closed <TCP (closed)  local=tcp://10.6.8.19:37907 remote=tcp://10.6.8.19:33174>
Traceback (most recent call last):
  File "/storage/home/dcl5300/miniforge3/envs/climate-stack-2024-10/lib/python3.12/site-packages/distributed/batched.py", line 115, in _background_send
    nbytes = yield coro
             ^^^^^^^^^^
  File "/storage/home/dcl5300/miniforge3/envs/climate-stack-2024-10/lib/python3.12/site-packages/tornado/gen.py", line 766, in run
    value = future.result()
            ^^^^^^^^^^^^^^^
  File "/storage/home/dcl5300/miniforge3/envs/climate-stack-2024-10/lib/python3.12/site-packages/distributed/comm/tcp.py", line 262, in write
    raise CommClosedError()
distributed.comm.core.CommClosedError


# Calculate metrics

In [6]:
## File path function
def make_loca_file_path(loca_path, gcm, member, ssp, var):
    """
    Returns list of file paths for a given downscaled LOCA output.
    """
    out_path = f"{loca_path}/{gcm}/0p0625deg/{member}/{ssp}/{var}"

    if os.path.isdir(out_path):
         # Take latest version if possible
        files = glob(f"{out_path}/*_v2024*")
        # Check earlier version if empty
        if len(files) == 0:
            files = glob(f"{out_path}/*_v2022*")
            
        return files
    else:
        return []
    
## Unit conversion
def convert_units(ds):
    # Convert units
    for var in ds.keys():
        if ds[var].attrs['units'] == 'K':
            ds[var] = ds[var] - 273.15    
            ds[var].attrs["units"] = 'C'
        elif ds[var].attrs['units'] == 'kg m-2 s-1':
            ds[var] = ds[var] * 86400
            ds[var].attrs["units"] = 'mm/day'
    
    return ds

In [7]:
###############################
# Metric calulcation function #
###############################
def calculate_metric(metric_func, var_id, needed_vars, gcm, member, ssp, loca_path, out_path):
    """
    Inputs: selected gcm, member, ssp, variable, and metric to calculate (from LOCA)
    Outputs: calculated (annual) metric
    """ 
    # Get all file paths
    files = {}
    for var in needed_vars:
        files[var] = make_loca_file_path(loca_path, gcm, member, ssp, var)
        
    # Loop through LOCA2 time slices
    if ssp == "historical":
        time_slices = ["1950-2014"]
    else:
        time_slices = ["2015-2044", "2045-2074", "2075-2100"]
        
    for time_slice in time_slices:
        try:
            # Check if done
            save_path = out_path.replace('.nc', f'_{time_slice}.nc')
            if os.path.isfile(save_path):
                return None
            # Load
            files_to_load = [xr.open_dataset(file, chunks='auto') for var in needed_vars for file in files[var] if time_slice in file]
            ds_in = xr.merge(files_to_load, combine_attrs='drop_conflicts')
            # Convert units
            ds_in = convert_units(ds_in)
        
            # Calculate metric
            ds_out = metric_func(ds_in, var_id)
    
            # Store
            ds_out.to_netcdf(out_path.replace('.nc', f'_{time_slice}.nc'))
            
        # Log if error
        except Exception as e:
            except_path = f"{project_code_path}/code/logs"
            with open(f"{except_path}/{gcm}_{member}_{ssp}_{var_id}_{time_slice}_LOCA.txt", "w") as f:
                f.write(str(e))

In [8]:
%%time
#########################
## Cooling Degree Days ##
#########################
var_id = "cdd"
metric_func = mf.calculate_dd
needed_vars = ['tasmin', 'tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 19.5 s, sys: 1.68 s, total: 21.2 s
Wall time: 4min 7s


In [9]:
%%time
#########################
## Heating Degree Days ##
#########################
var_id = "hdd"
metric_func = mf.calculate_dd
needed_vars = ['tasmin', 'tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 13 s, sys: 978 ms, total: 14 s
Wall time: 3min 30s


In [10]:
%%time
#########################
## Average Temperature ##
#########################
var_id = "tas"
metric_func = mf.calculate_avg
needed_vars = ['tasmin', 'tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/avg_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 12.3 s, sys: 962 ms, total: 13.3 s
Wall time: 2min 31s


In [11]:
%%time
#########################
## Maximum Temperature ##
#########################
var_id = "tasmax"
metric_func = mf.calculate_max
needed_vars = ['tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/max_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 10.1 s, sys: 747 ms, total: 10.9 s
Wall time: 1min 30s


In [12]:
%%time
#########################
## Minimum Temperature ##
#########################
var_id = "tasmin"
metric_func = mf.calculate_min
needed_vars = ['tasmin']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/min_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 10.6 s, sys: 713 ms, total: 11.3 s
Wall time: 1min 28s


In [13]:
%%time
#########################
## Maximum Precip ##
#########################
var_id = "pr"
metric_func = mf.calculate_max
needed_vars = ['pr']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/max_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 68 ms, sys: 61.8 ms, total: 130 ms
Wall time: 1.66 s


In [14]:
%%time
################
## Sum Precip ##
################
var_id = "pr"
metric_func = mf.calculate_sum
needed_vars = ['pr']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/sum_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 25.8 ms, sys: 30.7 ms, total: 56.5 ms
Wall time: 335 ms


In [15]:
client.shutdown()

# Summaries

## Indices

In [5]:
# Simple preprocessing function to add model and year coordinates
def _preprocess(ds):
    # Add model coordinate
    model = ds.encoding['source'].split('/')[-1].split('.')[1]
    ds = ds.assign_coords(model = model)

    # Add member
    member = ds.encoding['source'].split('/')[-1].split('.')[3]
    ds = ds.assign_coords(member = member)

    # Time -> year
    ds['time'] = ds['time'].dt.year

    return ds

In [11]:
# Calculates summary indices for CIL-GDPCIR ensemble for given SSP
def get_summary_indices(metric_id, ssp, years, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.nc"):
        # Read
        ds_models = []
        for model in models:
            files = glob(f"{project_data_path}/metrics/LOCA2/{metric_id}.{model}.{ssp}.r*.nc")
            # Check files exist
            if len(files) > 0:
                ds_members = []
                # Read all members
                for member in loca_all[model].keys():
                    if len(glob(f"{project_data_path}/metrics/LOCA2/{metric_id}.{model}.{ssp}.{member}.*.nc")) > 0:
                        try:
                            ds_tmp = xr.open_mfdataset(f"{project_data_path}/metrics/LOCA2/{metric_id}.{model}.{ssp}.{member}.*.nc",
                                                       preprocess=_preprocess)
                            ds_members.append(ds_tmp)
                        except Exception as e:
                            except_path = f"{project_code_path}/code/logs"
                            with open(f"{except_path}/{model}_{member}_{ssp}_{metric_id}_LOCA.txt", "w") as f:
                                f.write(str(e))
            # Combine & append members
            ds_models.append(xr.concat(ds_members, dim="member", fill_value=np.nan))
        # Combine models
        ds = xr.concat(ds_models, dim="model", fill_value=np.nan)

        # Rename
        ds = ds.rename({list(ds.data_vars)[0]: metric_id})

        # Time slice
        ds_sel = ds.sel(time=slice(years[0],years[1]))#.chunk(dict(model=-1, time=-1, member=-1, lat=50, lon=100))
    
        ## Summary indices
        # Mean
        ds_mean = ds_sel.mean(dim=['model', 'time', 'member']).assign_coords(indice = 'mean')
        # Quantiles
        ds_qlow = ds_sel.quantile(0.005, dim=['model', 'time', 'member'])
        ds_qhigh = ds_sel.quantile(0.995, dim=['model', 'time', 'member'])
        ds_qrange = (ds_qhigh - ds_qlow).assign_coords(indice = '99range')
    
        ds_q99 = ds_sel.quantile(0.99, dim=['model', 'time', 'member']).assign_coords(indice = 'q99')

        # Store
        ds_out = xr.concat([ds_mean, ds_qrange, ds_q99], dim='indice')
        ds_out.to_netcdf(f"{out_path}/{out_str}.nc")

In [None]:
%%time
for years in [[2020,2040], [2050,2070], [2080,2100]]:
    for ssp in ['ssp245', 'ssp370', 'ssp585']:
        for metric_id in ['avg_tas', 'sum_pr', 'max_tasmax', 'max_pr', 'max_tas']:
            get_summary_indices(metric_id = metric_id,
                                ssp = ssp,
                                years = years,
                                out_path=f"{project_data_path}/summary_indices",
                                out_str=f"LOCA2_{ssp}_{str(years[0])}-{str(years[1])}_{metric_id}")

## Timeseries

In [6]:
# Calculates summary indices for CIL-GDPCIR ensemble for given SSP
def get_raw_data(metric, ssp, years, lat, lon, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    def read_and_process(metric, model, member, ssp, years, lat, lon):
        # Read
        files = glob(f"{project_data_path}/metrics/LOCA2/{metric}.{model}.{ssp}.{member}.*.nc")
        ds_tmp = xr.concat([xr.open_dataset(file) for file in files], dim='time')
        ds_tmp['time'] = ds_tmp["time"].dt.year
        
        # Rename
        ds_tmp = ds_tmp.rename({list(ds_tmp.data_vars)[0]: metric})

        # Time slice
        if years is not None:
            ds_sel = ds_tmp.sel(time=slice(years[0],years[1]))
        else:
            ds_sel = ds_tmp.copy()
    
        # Location selection
        if lon < 0:
            lon = 360 + lon
        ds_sel = ds_sel.sel(lat=lat, lon=lon, method='nearest')
        
        # Construct dataframe
        df_tmp = ds_sel.to_dataframe().dropna().drop(columns=["lat", "lon"]).reset_index()
        df_tmp["ssp"] = ssp
        df_tmp["model"] = model
        df_tmp["member"] = member

        # Return 
        return df_tmp

    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.csv"):
        df_delayed = []
        # Loop through models
        for model in models:
            # Loop through members
            for member in loca_all[model].keys():
                # Some missing combinations as reported above
                check = glob(f"{project_data_path}/metrics/LOCA2/{metric}.{model}.{ssp}.{member}.*.nc")
                if len(check) > 0:
                    df_tmp = dask.delayed(read_and_process)(metric, model, member, ssp, years, lat, lon)
                    df_delayed.append(df_tmp)
        
        # Compute and store
        df_out = dask.compute(*df_delayed)
        pd.concat(df_out).to_csv(f"{out_path}/{out_str}.csv", index=False)

### Raw

In [7]:
%%time
for city in ['chicago', 'nyc', 'denver']:
    lat, lon = city_list[city]
    for ssp in ['ssp245', 'ssp370', 'ssp585']:
        for metric in ['avg_tas', 'sum_pr', 'max_tasmax', 'max_pr', 'max_tas']:
            get_raw_data(metric = metric,
                         ssp = ssp,
                         years = None,
                         lat = lat,
                         lon = lon,
                         out_path=f"{project_data_path}/summary_raw_original_grid/",
                         out_str=f"{city}_LOCA2_{ssp}_{metric}")

CPU times: user 1min 2s, sys: 8.31 s, total: 1min 10s
Wall time: 8min 6s


### Regridded

In [None]:
%%time
for city in city_list.keys():
    lat, lon = city_list[city]
    for years in [[2020,2040], [2050,2070], [2080,2100]]:
        for ssp in ['ssp245', 'ssp370', 'ssp585']:
            for metric in ['avg_tas', 'sum_pr', 'max_tasmax', 'max_pr', 'max_tas']:
                get_raw_data(metric = metric,
                             ssp = ssp,
                             years = years,
                             lat = lat,
                             lon = lon,
                             out_path=f"{project_data_path}/summary_raw",
                             out_str=f"{city}_LOCA2_{ssp}_{str(years[0])}-{str(years[1])}_{metric}")