In [1]:
import os
from glob import glob

import dask
import numpy as np
import pandas as pd
import xarray as xr

from utils import city_list, gev_metric_ids
import metric_funcs as mf

## Preliminaries

In [2]:
################
#### Paths #####
################
# Update these for reproduction
from utils import roar_code_path as project_code_path
from utils import roar_data_path as project_data_path
loca_path = "/storage/group/pches/default/public/LOCA2" # raw loca outputs

In [3]:
##############
### Models ###
##############

gcms = os.listdir(f"{loca_path}/")
gcms.remove('training_data')
gcms.remove('scripts')

loca_all = {}

# Loop through gcms
for gcm in gcms:
    loca_all[gcm] = {}
    # Loop through members
    members = os.listdir(f"{loca_path}/{gcm}/0p0625deg/")
    for member in members:
        # Append SSPs
        ssps = os.listdir(f"{loca_path}/{gcm}/0p0625deg/{member}/")
        loca_all[gcm][member] = ssps

In [4]:
##############
### Models ###
##############
# Matches website (https://loca.ucsd.edu/loca-version-2-for-north-america-ca-jan-2023/) as of Jan 2023
print(f"# gcm: {len(gcms)}")
print(f"# gcm/expts: {np.sum([len(np.unique([item for row in [loca_all[gcm][member] for member in loca_all[gcm].keys()] for item in row])) for gcm in gcms])}")
print(f"# gcm/expts/ens: {np.sum([len(loca_all[gcm][ssp]) for gcm in gcms for ssp in loca_all[gcm]])}")

# gcm: 27
# gcm/expts: 99
# gcm/expts/ens: 329


In [5]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="10GiB",
    walltime="6:00:00"
)

cluster.scale(jobs=20)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.8.28:35225,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


2025-03-01 10:57:39,915 - distributed.scheduler - ERROR - Couldn't gather keys: {('store-map-937886fe56751b585e21986b76d745a8', 15, 0, 0): 'waiting', ('store-map-937886fe56751b585e21986b76d745a8', 18, 0, 0): 'waiting', ('store-map-937886fe56751b585e21986b76d745a8', 1, 0, 0): 'waiting', ('store-map-937886fe56751b585e21986b76d745a8', 10, 0, 0): 'waiting'}


# Calculate metrics

In [6]:
## File path function
def make_loca_file_path(loca_path, gcm, member, ssp, var):
    """
    Returns list of file paths for a given downscaled LOCA output.
    """
    out_path = f"{loca_path}/{gcm}/0p0625deg/{member}/{ssp}/{var}"

    if os.path.isdir(out_path):
         # Take latest version if possible
        files = glob(f"{out_path}/*_v2024*")
        # Check earlier version if empty
        if len(files) == 0:
            files = glob(f"{out_path}/*_v2022*")
            
        return files
    else:
        return []
    
## Unit conversion
def convert_units(ds):
    # Convert units
    for var in ds.keys():
        if ds[var].attrs['units'] == 'K':
            ds[var] = ds[var] - 273.15    
            ds[var].attrs["units"] = 'C'
        elif ds[var].attrs['units'] == 'kg m-2 s-1':
            ds[var] = ds[var] * 86400
            ds[var].attrs["units"] = 'mm/day'
    
    return ds

In [7]:
###############################
# Metric calulcation function #
###############################
def calculate_metric(metric_func, var_id, needed_vars, gcm, member, ssp, loca_path, out_path):
    """
    Inputs: selected gcm, member, ssp, variable, and metric to calculate (from LOCA)
    Outputs: calculated (annual) metric
    """ 
    # Get all file paths
    files = {}
    for var in needed_vars:
        files[var] = make_loca_file_path(loca_path, gcm, member, ssp, var)
        
    # Loop through LOCA2 time slices
    if ssp == "historical":
        time_slices = ["1950-2014"]
    else:
        time_slices = ["2015-2044", "2045-2074", "2075-2100"]
        
    for time_slice in time_slices:
        try:
            # Check if done
            save_path = out_path.replace('.nc', f'_{time_slice}.nc')
            if os.path.isfile(save_path):
                continue
            # Load
            files_to_load = [xr.open_dataset(file, chunks='auto') for var in needed_vars for file in files[var] if time_slice in file]
            ds_in = xr.merge(files_to_load, combine_attrs='drop_conflicts')
            # Convert units
            ds_in = convert_units(ds_in)
            
            # Calculate metric
            ds_out = metric_func(ds_in, var_id)
    
            # Store
            ds_out.to_netcdf(save_path)
            
        # Log if error
        except Exception as e:
            except_path = f"{project_code_path}/scripts/logs"
            with open(f"{except_path}/metric_calcs/LOCA2_{gcm}_{member}_{ssp}_{var_id}_{time_slice}.txt", "w") as f:
                f.write(str(e))

In [8]:
%%time
#############
## CDD sum ##
#############
var_id = "cdd"
metric_func = mf.calculate_dd_sum
needed_vars = ['tasmin', 'tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/sum_{var_id}_{gcm}_{member}_{ssp}.nc"

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 409 ms, sys: 120 ms, total: 529 ms
Wall time: 13.6 s


In [9]:
%%time
#############
## CDD max ##
#############
var_id = "cdd"
metric_func = mf.calculate_dd_max
needed_vars = ['tasmin', 'tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/max_{var_id}_{gcm}_{member}_{ssp}.nc"

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 200 ms, sys: 42.7 ms, total: 243 ms
Wall time: 6.49 s


In [10]:
%%time
#############
## HDD sum ##
#############
var_id = "hdd"
metric_func = mf.calculate_dd_sum
needed_vars = ['tasmin', 'tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/sum_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 203 ms, sys: 42.5 ms, total: 246 ms
Wall time: 6.98 s


In [11]:
%%time
#############
## HDD max ##
#############
var_id = "hdd"
metric_func = mf.calculate_dd_max
needed_vars = ['tasmin', 'tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/max_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 201 ms, sys: 55.7 ms, total: 256 ms
Wall time: 6.85 s


In [12]:
%%time
#########################
## Average Temperature ##
#########################
var_id = "tas"
metric_func = mf.calculate_avg
needed_vars = ['tasmin', 'tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/avg_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 2.14 s, sys: 314 ms, total: 2.46 s
Wall time: 41.8 s


In [13]:
%%time
#########################
## Maximum Temperature ##
#########################
var_id = "tasmax"
metric_func = mf.calculate_max
needed_vars = ['tasmax']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/max_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 139 ms, sys: 38.1 ms, total: 177 ms
Wall time: 4.83 s


In [14]:
%%time
#########################
## Minimum Temperature ##
#########################
var_id = "tasmin"
metric_func = mf.calculate_min
needed_vars = ['tasmin']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/min_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 128 ms, sys: 42.2 ms, total: 171 ms
Wall time: 5.32 s


In [15]:
%%time
#########################
## Maximum Precip ##
#########################
var_id = "pr"
metric_func = mf.calculate_max
needed_vars = ['pr']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/max_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 166 ms, sys: 46.5 ms, total: 213 ms
Wall time: 6.85 s


In [16]:
%%time
################
## Sum Precip ##
################
var_id = "pr"
metric_func = mf.calculate_sum
needed_vars = ['pr']

out_path = lambda gcm, ssp, member: f"{project_data_path}/metrics/LOCA2/sum_{var_id}_{gcm}_{member}_{ssp}.nc"

# Parallelize over dask delayed
delayed = []

# Loop through gcms
for gcm in gcms:
    # Loop through members
    for member in loca_all[gcm].keys():
        # Loop through SSPs
        for ssp in loca_all[gcm][member]:
            # Calculate metric
            calculate_metric(metric_func = metric_func,
                             var_id = var_id,
                             gcm = gcm,
                             ssp = ssp,
                             member=member,
                             needed_vars = needed_vars,
                             loca_path = loca_path,
                             out_path = out_path(gcm, ssp, member))

CPU times: user 11min 29s, sys: 30.3 s, total: 11min 59s
Wall time: 1h 38min 37s


In [17]:
client.shutdown()

# OLD

In [9]:
# # Calculates summary indices for CIL-GDPCIR ensemble for given SSP
# def get_raw_data(metric_id, ssp, years, lat, lon, out_path, out_str):
#     """
#     """
#     def read_and_process(metric, gcm, member, ssp, years, lat, lon):
#         # Read
#         files = glob(f"{project_data_path}/metrics/LOCA2/{metric_id}_{gcm}_{member}_{ssp}_*.nc")
#         ds_tmp = xr.concat([xr.open_dataset(file) for file in files], dim='time')
#         ds_tmp['time'] = ds_tmp["time"].dt.year

#         # Time slice
#         if years is not None:
#             ds_sel = ds_tmp.sel(time=slice(years[0],years[1]))
#         else:
#             ds_sel = ds_tmp.copy()
    
#         # Location selection
#         if lon < 0:
#             lon = 360 + lon
#         ds_sel = ds_sel.sel(lat=lat, lon=lon, method='nearest')
        
#         # Construct dataframe
#         df_tmp = ds_sel.to_dataframe().dropna().drop(columns=["lat", "lon"]).reset_index()
#         df_tmp["ssp"] = ssp
#         df_tmp["gcm"] = gcm
#         df_tmp["member"] = member

#         # Return 
#         return df_tmp

#     # Check if done
#     if not os.path.isfile(f"{out_path}/{out_str}.csv"):
#         df_delayed = []
#         # Loop through gcms
#         for gcm in gcms:
#             # Loop through members
#             for member in loca_all[gcm].keys():
#                 # Some missing combinations as reported above
#                 check = glob(f"{project_data_path}/metrics/LOCA2/{metric_id}_{gcm}_{member}_{ssp}_*.nc")
#                 if len(check) > 0:
#                     df_tmp = dask.delayed(read_and_process)(metric_id, gcm, member, ssp, years, lat, lon)
#                     df_delayed.append(df_tmp)
        
#         # Compute and store
#         df_out = dask.compute(*df_delayed)
#         pd.concat(df_out).to_csv(f"{out_path}/{out_str}.csv", index=False)

In [11]:
# %%time
# # Compute and store timeseries
# ssps = ['historical', 'ssp245', 'ssp370', 'ssp585']

# # Loop through cities
# for city in city_list:
#     lat, lon = city_list[city]
#     # Loop through SSPs
#     for ssp in ssps:
#         # Loop through metrics
#         for metric_id in metric_ids:
#             # Compute
#             get_raw_data(metric_id = metric_id,
#                          ssp = ssp,
#                          years = None,
#                          lat = lat,
#                          lon = lon,
#                          out_path=f"{project_data_path}/timeseries/original_grids/",
#                          out_str=f"{metric_id}_LOCA2_{ssp}_{city}")

CPU times: user 2min 40s, sys: 21.7 s, total: 3min 1s
Wall time: 9min 24s
