In [7]:
import os
from glob import glob

import dask
import numpy as np
import pandas as pd
import xarray as xr

from utils import city_list

## Preliminaries

In [2]:
################
#### Paths #####
################
# Update these for reproduction

project_data_path = "/storage/group/pches/default/users/dcl5300/conus_comparison_lafferty-etal-2024/"
project_code_path = "/storage/home/dcl5300/work/current_projects/conus_comparison_lafferty-etal-2024/"
loca_path = "/storage/group/pches/default/public/LOCA2" # raw loca outputs

In [3]:
##############
### Models ###
##############

models = os.listdir(f"{loca_path}/")
models.remove('training_data')
models.remove('scripts')

loca_all = {}

# Loop through models
for model in models:
    loca_all[model] = {}
    # Loop through members
    members = os.listdir(f"{loca_path}/{model}/0p0625deg/")
    for member in members:
        # Append SSPs
        ssps = os.listdir(f"{loca_path}/{model}/0p0625deg/{member}/")
        loca_all[model][member] = ssps

In [4]:
##############
### Models ###
##############

# Matches website (https://loca.ucsd.edu/loca-version-2-for-north-america-ca-jan-2023/) as of Jan 2023
print(f"# models: {len(models)}")
print(f"# model/expts: {np.sum([len(np.unique([item for row in [loca_all[model][member] for member in loca_all[model].keys()] for item in row])) for model in models])}")
print(f"# model/expts/ens: {np.sum([len(loca_all[model][ssp]) for model in models for ssp in loca_all[model]])}")

# models: 27
# model/expts: 99
# model/expts/ens: 329


In [5]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    account="pches",
    # account="open",
    cores=1,
    memory="5GiB",
    walltime="00:30:00"
)

cluster.scale(jobs=20)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.157:45441,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Calculate metrics

In [6]:
## File path function
def make_loca_file_path(loca_path, model, member, ssp, var):
    """
    Returns list of file paths for a given downscaled LOCA output.
    """
    out_path = f"{loca_path}/{model}/0p0625deg/{member}/{ssp}/{var}"

    if os.path.isdir(out_path):
        files = os.listdir(out_path)
        files = [file for file in files if file[-7:] != 'ORIG.nc'] # Skip ORIGs (had to fix tasmin naming errors)
        return files
    else:
        return []

In [7]:
###############################
# Metric calulcation function #
###############################
def calculate_metric(model, member, ssp, var, metric, loca_path, out_path):
    """
    Inputs: selected model, member, ssp, variable, and metric to calculate (from LOCA)
    Outputs: calculated (annual) metric (max, avg, sum)
    """ 
    # Get files
    if var == "tas":
        files = make_loca_file_path(loca_path, model, member, ssp, "tasmin")
    else: 
        files = make_loca_file_path(loca_path, model, member, ssp, var)
            
    ## Loop through files
    for file in files:
        try:
            ## First check if already exists
            if var == "tas":
                out_str = file.replace("tasmin", "tas")
            else:
                out_str = file
            if os.path.isfile(f"{out_path}/{metric}_{out_str}"):
                continue

            # Read & calculate
            if var == "tas":
                file_path = f"{loca_path}/{model}/0p0625deg/{member}/{ssp}/tasmin/{file}"
                
                tasmin_file_path = file_path
                tasmax_file_path = file_path.replace("tasmin", "tasmax")
                
                ds_tasmin_tmp = xr.open_dataset(tasmin_file_path)
                ds_tasmax_tmp = xr.open_dataset(tasmax_file_path)
                
                ds_tmp = xr.merge([ds_tasmin_tmp, ds_tasmax_tmp],
                                  combine_attrs = "drop_conflicts")
                        
                ds_tmp["tas"] = (ds_tmp["tasmin"] + ds_tmp["tasmax"]) / 2.0
                
                ds_tmp = ds_tmp.drop_vars(["tasmin","tasmax"])
            else:
                file_path = f"{loca_path}/{model}/0p0625deg/{member}/{ssp}/{var}/{file}"
                ds_tmp = xr.open_dataset(file_path)
            
            # Convert units
            if var in ["tas", "tasmax", "tasmin"]:
                ds_tmp[var] = ds_tmp[var] - 273.15    
                ds_tmp[var].attrs["units"] = "C"
            elif var == "pr":
                ds_tmp[var] = ds_tmp[var] * 86400
                ds_tmp[var].attrs["units"] = "mm/day"
    
            # Calculate metric
            if metric == "avg":
                ds_out = ds_tmp.resample(time="1Y").mean()
            elif metric == "max":
                ds_out = ds_tmp.resample(time="1Y").max()
            elif metric == "sum":
                ds_out = ds_tmp.resample(time="1Y").sum()
                if var == "pr":
                    ds_out.pr.attrs["units"] = "mm"
            
            del ds_tmp
                        
            # Store
            if var == "tas":
                out_str = file.replace("tasmin", "tas")
            else:
                out_str = file
                
            ds_out.to_netcdf(f"{out_path}/{metric}_{out_str}")
            
        # Log if error
        except Exception as e:
            except_path = f"{project_code_path}/code/logs"
            with open(f"{except_path}/{model}_{member}_{ssp}_{var}_LOCA.txt", "w") as f:
                f.write(str(e))

In [8]:
%%time
#########################
## Average Temperature ##
#########################
var = "tas"
metric = "avg"

out_path = f"{project_data_path}/metrics/LOCA2/"

# Parallelize over dask delayed
delayed = []

# Loop through models
for model in models:
    # Loop through members
    for member in loca_all[model].keys():
        # Loop through SSPs
        for ssp in loca_all[model][member]:
            if ssp == "historical":
                continue
            # Some vars are missing for some outputs: skip
            file_paths = make_loca_file_path(loca_path, model, member, ssp, "tasmin")
            if len(file_paths) == 0:
                print(f"{model} {ssp} {member}")
                    
            # Calculate metric
            delayed.append(dask.delayed(calculate_metric)(model = model,
                                                          member = member,
                                                          ssp = ssp,
                                                          var = var,
                                                          metric = metric,
                                                          loca_path = loca_path,
                                                          out_path = out_path))
                
# Compute
print(f"# computations: {len(delayed)} \n")
out = dask.compute(*delayed)

MPI-ESM1-2-LR ssp585 r10i1p1f1
MPI-ESM1-2-LR ssp585 r5i1p1f1
MPI-ESM1-2-LR ssp585 r6i1p1f1
MPI-ESM1-2-LR ssp585 r7i1p1f1
MPI-ESM1-2-LR ssp585 r8i1p1f1
# computations: 221 

CPU times: user 1.17 s, sys: 60.5 ms, total: 1.23 s
Wall time: 11.6 s


In [9]:
%%time
#########################
## Total Precipitation ##
#########################
var = "pr"
metric = "sum"

out_path = f"{project_data_path}/metrics/LOCA2/"

# Parallelize over dask delayed
delayed = []

# Loop through models
for model in models:
    # Loop through members
    for member in loca_all[model].keys():
        # Loop through SSPs
        for ssp in loca_all[model][member]:
            if ssp == "historical":
                continue
            # Some vars are missing for some outputs: skip
            file_paths = make_loca_file_path(loca_path, model, member, ssp, var)
            if len(file_paths) == 0:
                print(f"{model} {ssp} {member}")
                    
            # Calculate metric
            delayed.append(dask.delayed(calculate_metric)(model = model,
                                                          member = member,
                                                          ssp = ssp,
                                                          var = var,
                                                          metric = metric,
                                                          loca_path = loca_path,
                                                          out_path = out_path))
                
# Compute
print(f"# computations: {len(delayed)} \n")
_ = dask.compute(*delayed)

# computations: 221 

CPU times: user 582 ms, sys: 43.1 ms, total: 625 ms
Wall time: 3.34 s


In [10]:
%%time
#########################
## Maximum Temperature ##
#########################
var = "tasmax"
metric = "max"

out_path = f"{project_data_path}/metrics/LOCA2/"

# Parallelize over dask delayed
delayed = []

# Loop through models
for model in models:
    # Loop through members
    for member in loca_all[model].keys():
        # Loop through SSPs
        for ssp in loca_all[model][member]:
            if ssp == "historical":
                continue
            # Some vars are missing for some outputs: skip
            file_paths = make_loca_file_path(loca_path, model, member, ssp, var)
            if len(file_paths) == 0:
                print(f"{model} {ssp} {member}")
                    
            # Calculate metric
            delayed.append(dask.delayed(calculate_metric)(model = model,
                                                          member = member,
                                                          ssp = ssp,
                                                          var = var,
                                                          metric = metric,
                                                          loca_path = loca_path,
                                                          out_path = out_path))
                
                
# Compute           
print(f"# computations: {len(delayed)} \n")
out = dask.compute(*delayed)

MPI-ESM1-2-LR ssp585 r10i1p1f1
MPI-ESM1-2-LR ssp585 r5i1p1f1
MPI-ESM1-2-LR ssp585 r6i1p1f1
MPI-ESM1-2-LR ssp585 r7i1p1f1
MPI-ESM1-2-LR ssp585 r8i1p1f1
# computations: 221 

CPU times: user 303 ms, sys: 28.8 ms, total: 331 ms
Wall time: 847 ms


In [11]:
%%time
#########################
# Maximum Precipitation #
#########################
var = "pr"
metric = "max"

out_path = f"{project_data_path}/metrics/LOCA2/"

# Parallelize over dask delayed
delayed = []

# Loop through models
for model in models:
    # Loop through members
    for member in loca_all[model].keys():
        # Loop through SSPs
        for ssp in loca_all[model][member]:
            if ssp == "historical":
                continue
            # Some vars are missing for some outputs: skip
            file_paths = make_loca_file_path(loca_path, model, member, ssp, var)
            if len(file_paths) == 0:
                print(f"{model} {ssp} {member}")
                    
            # Calculate metric
            delayed.append(dask.delayed(calculate_metric)(model = model,
                                                          member = member,
                                                          ssp = ssp,
                                                          var = var,
                                                          metric = metric,
                                                          loca_path = loca_path,
                                                          out_path = out_path))
                
# Compute           
print(f"# computations: {len(delayed)} \n")
out = dask.compute(*delayed)

# computations: 221 

CPU times: user 399 ms, sys: 37 ms, total: 435 ms
Wall time: 512 ms


# Summaries

## Indices

In [6]:
# Simple preprocessing function to add model and year coordinates
def _preprocess(ds):
    # Add model coordinate
    model = ds.encoding['source'].split('/')[-1].split('.')[1]
    ds = ds.assign_coords(model = model)

    # Add member
    member = ds.encoding['source'].split('/')[-1].split('.')[3]
    ds = ds.assign_coords(member = member)

    # Time -> year
    ds['time'] = ds['time'].dt.year

    return ds

In [17]:
# Calculates summary indices for CIL-GDPCIR ensemble for given SSP
def get_summary_indices(metric, ssp, years, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.nc"):
        # Read
        ds_models = []
        for model in models:
            files = glob(f"{project_data_path}/metrics/LOCA2/{metric}.{model}.{ssp}.r*.nc")
            # Check files exist
            if len(files) > 0:
                ds_members = []
                # Read all members
                for member in loca_all[model].keys():
                    if len(glob(f"{project_data_path}/metrics/LOCA2/{metric}.{model}.{ssp}.{member}.*.nc")) > 0:
                        ds_tmp = xr.open_mfdataset(f"{project_data_path}/metrics/LOCA2/{metric}.{model}.{ssp}.{member}.*.nc",
                                                   preprocess=_preprocess)
                        ds_members.append(ds_tmp)
            # Combine & append members
            ds_models.append(xr.concat(ds_members, dim="member", fill_value=np.nan))
        # Combine models
        ds = xr.concat(ds_models, dim="model", fill_value=np.nan)

        # Rename
        ds = ds.rename({list(ds.data_vars)[0]: metric})

        # Time slice
        ds_sel = ds.sel(time=slice(years[0],years[1])).chunk(dict(model=-1, time=-1, member=-1, lat=50, lon=100))
    
        ## Summary indices
        # Mean
        ds_mean = ds_sel.mean(dim=['model', 'time', 'member']).assign_coords(indice = 'mean')
        # Quantiles
        ds_qlow = ds_sel.quantile(0.005, dim=['model', 'time', 'member'])
        ds_qhigh = ds_sel.quantile(0.995, dim=['model', 'time', 'member'])
        ds_qrange = (ds_qhigh - ds_qlow).assign_coords(indice = '99range')
    
        ds_q99 = ds_sel.quantile(0.99, dim=['model', 'time', 'member']).assign_coords(indice = 'q99')

        # Store
        ds_out = xr.concat([ds_mean, ds_qrange, ds_q99], dim='indice')
        ds_out.to_netcdf(f"{out_path}/{out_str}.nc")

In [16]:
%%time
for years in [[2020,2040], [2050,2070], [2080,2100]]:
    for ssp in ['ssp245', 'ssp370', 'ssp585']:
        for metric in ['avg_tas', 'sum_pr', 'max_tasmax', 'max_pr']:
            get_summary_indices(metric = metric,
                                ssp = ssp,
                                years = years,
                                out_path=f"{project_data_path}/summary_indices",
                                out_str=f"LOCA2_{ssp}_{str(years[0])}-{str(years[1])}_{metric}")

CPU times: user 22.5 s, sys: 1.21 s, total: 23.7 s
Wall time: 40.8 s


## Raw data

In [6]:
# Calculates summary indices for CIL-GDPCIR ensemble for given SSP
def get_raw_data(metric, ssp, years, lat, lon, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    def read_and_process(metric, model, member, ssp, years, lat, lon):
        # Read
        files = glob(f"{project_data_path}/metrics/LOCA2/{metric}.{model}.{ssp}.{member}.*.nc")
        ds_tmp = xr.concat([xr.open_dataset(file) for file in files], dim='time')
        ds_tmp['time'] = ds_tmp["time"].dt.year
        
        # Rename
        ds_tmp = ds_tmp.rename({list(ds_tmp.data_vars)[0]: metric})

        # Time slice
        ds_sel = ds_tmp.sel(time=slice(years[0],years[1]))
    
        # Location selection
        if lon < 0:
            lon = 360 + lon
        ds_sel = ds_sel.sel(lat=lat, lon=lon, method='nearest')
        
        # Construct dataframe
        df_tmp = ds_sel.to_dataframe().dropna().drop(columns=["lat", "lon"]).reset_index()
        df_tmp["ssp"] = ssp
        df_tmp["model"] = model
        df_tmp["member"] = member

        # Return 
        return df_tmp

    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.csv"):
        df_delayed = []
        # Loop through models
        for model in models:
            # Loop through members
            for member in loca_all[model].keys():
                # Some missing combinations as reported above
                check = glob(f"{project_data_path}/metrics/LOCA2/{metric}.{model}.{ssp}.{member}.*.nc")
                if len(check) > 0:
                    df_tmp = dask.delayed(read_and_process)(metric, model, member, ssp, years, lat, lon)
                    df_delayed.append(df_tmp)
        
        # Compute and store
        df_out = dask.compute(*df_delayed)
        pd.concat(df_out).to_csv(f"{out_path}/{out_str}.csv", index=False)

In [16]:
%%time
for city in city_list.keys():
    lat, lon = city_list[city]
    for years in [[2020,2040], [2050,2070], [2080,2100]]:
        for ssp in ['ssp245', 'ssp370', 'ssp585']:
            for metric in ['avg_tas', 'sum_pr', 'max_tasmax', 'max_pr']:
                get_raw_data(metric = metric,
                             ssp = ssp,
                             years = years,
                             lat = lat,
                             lon = lon,
                             out_path=f"{project_data_path}/summary_raw",
                             out_str=f"{city}_LOCA2_{ssp}_{str(years[0])}-{str(years[1])}_{metric}")

CPU times: user 1min 29s, sys: 13.7 s, total: 1min 43s
Wall time: 15min 31s
