In [1]:
import os
from glob import glob

import dask
import numpy as np
import xarray as xr

## Preliminaries

In [2]:
################
#### Paths #####
################
# Update these for reproduction

project_data_path = "/storage/group/pches/default/users/dcl5300/conus_comparison_lafferty-etal-2024/"
project_code_path = "/storage/home/dcl5300/work/current_projects/conus_comparison_lafferty-etal-2024/"
nex_path = "/storage/group/pches/default/public/NEX-GDDP-CMIP6/models/"  # location of NEX-GDDP models

In [3]:
##############
### Models ###
##############

model_info = {}
for model in os.listdir(nex_path):
    try:
        tmp = glob(f"{nex_path}/{model}/ssp126/tasmax/*_2015.nc")
        tmp = (
            tmp[0]
            .replace(f"{nex_path}/{model}", "")
            .replace("/ssp126/tasmax/tasmax_day_" + model + "_ssp126", "")
            .replace("2015.nc", "")
        )
        model_info.update({model: tmp})
    except:
        continue

print(f"# models: {len(model_info)}")

# models: 29


In [4]:
###############################
# Metric calulcation function #
###############################
def calculate_metric(model, ssp, year, var, metric, model_info, nex_path, out_path):
    """
    Inputs: selected model, ssp, variable, and metric to calculate (from NEX-GDDP-CMIP6)
    Outputs: calculated (annual) metric (max, avg, sum)
    """ 
    ## First check if done
    out_str = f"{var}_day_{model}_{ssp}{model_info}{str(year)}.nc"
    if os.path.isfile(f"{out_path}/{metric}_{out_str}"):
        return None

    try:
        ## Read correct file (use v1.1 if available)
        file_path = f"{nex_path}/{model}/{ssp}/{var}/{var}_day_{model}_{ssp}{model_info}{str(year)}_v1.1.nc"
        if os.path.isfile(file_path):
            pass
        elif os.path.isfile(file_path.replace("_v1.1","")):
            file_path = file_path.replace("_v1.1","")
        else:
            return None
        ds = xr.open_dataset(file_path)
    
        ## Convert units
        # Temperature: K -> C
        if var == "tas" and ds.tas.attrs["units"] == "K":
            ds["tas"] = ds["tas"] - 273.15
        if var == "tasmax" and ds.tasmax.attrs["units"] == "K":
            ds["tasmax"] = ds["tasmax"] - 273.15
        if var == "tasmin" and ds.tasmin.attrs["units"] == "K":
            ds["tasmin"] = ds["tasmin"] - 273.1
            
        # Precip: kg m-2 s-1 -> mm day-1
        if var == "pr" and ds.pr.attrs["units"] == "kg m-2 s-1":
            ds["pr"] = ds["pr"] * 86400
            ds.pr.attrs["units"] = "mm/day"
        
        ## Calculate metric
        if metric == "avg":
            ds_out = ds.resample(time="1Y").mean()
        elif metric == "max":
            ds_out = ds.resample(time="1Y").max()
        elif metric == "sum":
            ds_out = ds.resample(time="1Y").sum()
    
        ## Store
        ds_out.to_netcdf(f"{out_path}/{metric}_{out_str}")
            
    # Log if error
    except Exception as e:
        except_path = f"{project_code_path}/code/logs"
        with open(f"{except_path}/{model}_{ssp}_{var}_NEX.txt", "w") as f:
            f.write(str(e))

In [5]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    account="pches",
    # account="open",
    cores=1,
    memory="12GiB",
    walltime="00:40:00"
)

cluster.scale(jobs=30)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.159:37047,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Calculate metrics

In [6]:
%%time
#########################
## Average Temperature ##
#########################
var = "tas"
metric = "avg"

out_path = f"{project_data_path}/metrics/NEX-GDDP-CMIP6/"

# Parallelize over dask delayed
delayed = []

# Loop through models
models = list(model_info.keys())
for model in models:
    # Loop through SSPs
    ssps = os.listdir(f"{nex_path}/{model}")
    for ssp in ssps:
        if ssp == "historical":
            continue
        # Loop through years
        for year in range(2015,2101):
            # Calculate metric
            delayed.append(dask.delayed(calculate_metric)(model = model,
                                                          ssp = ssp,
                                                          year = year,
                                                          var = var,
                                                          metric = metric,
                                                          model_info = model_info[model],
                                                          nex_path = nex_path,
                                                          out_path = out_path))
                
# Compute
print(f"# computations: {len(delayed)} \n")
_ = dask.compute(*delayed)

# computations: 9546 

CPU times: user 13.4 s, sys: 569 ms, total: 14 s
Wall time: 29.4 s


In [44]:
%%time
#########################
## Maximum Temperature ##
#########################
var = "tasmax"
metric = "max"

out_path = f"{project_data_path}/metrics/NEX-GDDP-CMIP6/"

# Parallelize over dask delayed
delayed = []

# Loop through models
models = list(model_info.keys())
for model in models:
    # Loop through SSPs
    ssps = os.listdir(f"{nex_path}/{model}")
    for ssp in ssps:
        if ssp == "historical":
            continue
        # Loop through years
        for year in range(2015,2101):
            # Calculate metric
            delayed.append(dask.delayed(calculate_metric)(model = model,
                                                          ssp = ssp,
                                                          year = year,
                                                          var = var,
                                                          metric = metric,
                                                          model_info = model_info[model],
                                                          nex_path = nex_path,
                                                          out_path = out_path))
                
# Compute
print(f"# computations: {len(delayed)} \n")
_ = dask.compute(*delayed)

# computations: 9546 

CPU times: user 2min 16s, sys: 10.9 s, total: 2min 27s
Wall time: 43min


In [7]:
%%time
#########################
## Total Precipitation ##
#########################
var = "pr"
metric = "sum"

out_path = f"{project_data_path}/metrics/NEX-GDDP-CMIP6/"

# Parallelize over dask delayed
delayed = []

# Loop through models
models = list(model_info.keys())
for model in models:
    # Loop through SSPs
    ssps = os.listdir(f"{nex_path}/{model}")
    for ssp in ssps:
        if ssp == "historical":
            continue
        # Loop through years
        for year in range(2015,2101):
            # Calculate metric
            delayed.append(dask.delayed(calculate_metric)(model = model,
                                                          ssp = ssp,
                                                          year = year,
                                                          var = var,
                                                          metric = metric,
                                                          model_info = model_info[model],
                                                          nex_path = nex_path,
                                                          out_path = out_path))
                
# Compute
print(f"# computations: {len(delayed)} \n")
_ = dask.compute(*delayed)

# computations: 9546 

CPU times: user 12min 45s, sys: 45.9 s, total: 13min 31s
Wall time: 1h 10min 12s


In [6]:
%%time
#########################
## Max. Precipitation ##
#########################
var = "pr"
metric = "max"

out_path = f"{project_data_path}/metrics/NEX-GDDP-CMIP6/"

# Parallelize over dask delayed
delayed = []

# Loop through models
models = list(model_info.keys())
for model in models:
    # Loop through SSPs
    ssps = os.listdir(f"{nex_path}/{model}")
    for ssp in ssps:
        if ssp == "historical":
            continue
        # Loop through years
        for year in range(2015,2101):
            # Calculate metric
            delayed.append(dask.delayed(calculate_metric)(model = model,
                                                          ssp = ssp,
                                                          year = year,
                                                          var = var,
                                                          metric = metric,
                                                          model_info = model_info[model],
                                                          nex_path = nex_path,
                                                          out_path = out_path))
                
# Compute
print(f"# computations: {len(delayed)} \n")
_ = dask.compute(*delayed)

# computations: 9546 

CPU times: user 1min 15s, sys: 3.2 s, total: 1min 19s
Wall time: 4min 28s
