In [1]:
import os
from glob import glob

import dask
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe

from functools import partial

from utils import city_list

In [2]:
################
#### Paths #####
################
# Update these for reproduction

project_data_path = "/storage/group/pches/default/users/dcl5300/conus_comparison_lafferty-etal-2024/"
project_code_path = "/storage/home/dcl5300/work/current_projects/conus_comparison_lafferty-etal-2024/"

In [3]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="4GiB",
    walltime="00:10:00"
)

cluster.scale(jobs=10)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.160:40655,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Regridding

In [5]:
# Get models
models = os.listdir(f"{project_data_path}/metrics/CIL-GDPCIR/conus-comparison/")

### Conservative

In [6]:
# Get CIL grid
ds_in = xr.open_zarr(f"{project_data_path}/metrics/CIL-GDPCIR/conus-comparison/{models[0]}")

In [7]:
# We use LOCA grid as target
loca_lat_grid = np.linspace(23.90625, 53.46875, 474)
loca_lon_grid = np.linspace(234.53125, 293.46875, 944)

ds_out = xr.Dataset({"lat": (["lat"], loca_lat_grid,
                             {"standard_name": "latitude", "units": "degrees_north"}),
                     "lon": (["lon"], loca_lon_grid,
                             {"standard_name": "longitude", "units": "degrees_east"})
                    })

# Add mask from LOCA output
loca_nans = np.load(f'{project_code_path}/code/utils/LOCA2_NaNs.npy')
ds_out["mask"] = xr.DataArray(~loca_nans, dims=['lat','lon'])

# Conservative
regridder = xe.Regridder(ds_in, ds_out, "conservative")

In [8]:
%%time
# Out path
out_path = f"{project_data_path}/metrics_regridded/CIL-GDPCIR/conservative"

# Loop through all
for model in models:
    if not os.path.isfile(f"{out_path}/{model}.nc"):
        # Read
        ds_cil_in = xr.open_zarr(f"{project_data_path}/metrics/CIL-GDPCIR/conus-comparison/{model}")
        
        # Regrid lazy
        # NOTE: use high NaN threshold to try to not introduce NaNs
        # not already present in the LOCA2 grid
        ds_cil_out = regridder(ds_cil_in, skipna=True, na_thres=0.99)
        
        # Store
        ds_cil_out.to_netcdf(f"{out_path}/{model}.nc")
        print(model)

This may cause some slowdown.
Consider scattering data ahead of time and using futures.


GFDL-ESM4


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


HadGEM3-GC31-LL


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


INM-CM4-8


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


INM-CM5-0


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


MIROC-ES2L


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


MIROC6


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


MPI-ESM1-2-HR


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


MPI-ESM1-2-LR


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


NESM3


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


NorESM2-LM


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


NorESM2-MM


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


UKESM1-0-LL
CPU times: user 8min 55s, sys: 33.5 s, total: 9min 29s
Wall time: 26min 15s


# Summaries

## Indices

In [None]:
# Simple preprocessing function to add model and year coordinates
def _preprocess(ds, ssp):
    # Add model coordinate
    model = ds.encoding['source'].split('/')[-1].split('_')[0][:-3]
    ds = ds.assign_coords(model = model)

    # Select SSP
    if ssp in ds.ssp:
        ds = ds.sel(ssp=ssp)
    else:
        return None

    # for some reason CanESM5 has 2 of each ssp (but identical)
    if model == 'CanESM5':
        ds = ds.isel(ssp=0)

    # Time -> year
    ds['time'] = ds['time'].dt.year

    return ds

In [10]:
# Calculates summary indices for CIL-GDPCIR ensemble for given SSP
def get_summary_indices(ssp, years, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.nc"):
        # Partial preprocessing
        _partial_preprocess = partial(_preprocess, ssp=ssp)

        # Read all files
        files = glob(f"{project_data_path}/metrics_regridded/CIL-GDPCIR/conservative/*.nc")
        ds_list = [_partial_preprocess(xr.open_dataset(file, chunks='auto')) for file in files]
        # Remove Nones (mising SSP)
        ds_list = [ds for ds in ds_list if ds != None]
        # Combine
        ds = xr.combine_nested(ds_list, concat_dim='model')

        # Time slice
        ds_sel = ds.sel(time=slice(years[0],years[1]))
    
        ## Summary indices
        # Mean
        ds_mean = ds_sel.mean(dim=['model', 'time']).assign_coords(indice = 'mean')
        # Quantiles
        ds_qlow = ds_sel.chunk(dict(model=-1, time=-1)).quantile(0.005, dim=['model', 'time'])
        ds_qhigh = ds_sel.chunk(dict(model=-1, time=-1)).quantile(0.995, dim=['model', 'time'])
        ds_qrange = (ds_qhigh - ds_qlow).assign_coords(indice = '99range')
    
        ds_q99 = ds_sel.chunk(dict(model=-1, time=-1)).quantile(0.99, dim=['model', 'time']).assign_coords(indice = 'q99')

        # Store
        ds_out = xr.concat([ds_mean, ds_qrange, ds_q99], dim='indice')
        ds_out.to_netcdf(f"{out_path}/{out_str}.nc")

In [None]:
%%time
for years in [[2020,2040], [2050,2070], [2080,2100]]:
    for ssp in ['ssp245', 'ssp370', 'ssp585']:
        get_summary_indices(ssp=ssp, years=years,
                            out_path=f"{project_data_path}/summary_indices",
                            out_str=f"CIL-GDPCIR_{ssp}_{str(years[0])}-{str(years[1])}")

## Raw data

In [4]:
# Calculates summary indices for CIL-GDPCIR ensemble for given SSP
def get_raw_data(ssp, years, lat, lon, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    def read_and_process(model, ssp, years, lat, lon):
        # Read
        ds_tmp = xr.open_dataset(f"{project_data_path}/metrics_regridded/CIL-GDPCIR/conservative/{model}.nc")
        ds_tmp['time'] = ds_tmp["time"].dt.year

        # select SSP if exists
        if ssp in ds_tmp.ssp:
            ds_tmp = ds_tmp.sel(ssp=ssp)
        else:
            return None

        # for some reason CanESM5 has 2 of each ssp (but identical)
        if model == 'CanESM5':
            ds_tmp = ds_tmp.isel(ssp=0)

        # Time slice
        ds_sel = ds_tmp.sel(time=slice(years[0],years[1]))
    
        # Location selection
        if lon < 0:
            lon = 360 + lon
        ds_sel = ds_sel.sel(lat=lat, lon=lon, method='nearest')
        
        # Construct dataframe
        metrics_correct = ['avg_tas', 'sum_pr', 'max_tasmax', 'max_pr']
        metrics_dumb = ['tas_avg', 'pr_sum', 'tasmax_max', 'pr_max'] # Incorrect names in CIL
        df_tmp = ds_sel.to_dataframe()
        if 'pr_sum' not in df_tmp.columns:
            df_tmp['pr_sum'] = np.nan
            df_tmp['pr_max'] = np.nan
        df_tmp = df_tmp.dropna(subset=metrics_dumb, how='all').drop(columns=["lat", "lon"]).reset_index()
        df_tmp = df_tmp.rename(columns = {dumb:correct for (dumb,correct) in zip(metrics_dumb, metrics_correct)})
        df_tmp["ssp"] = ssp
        df_tmp["model"] = model

        # Return 
        return df_tmp
        
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.csv"):
        df_delayed = []
        # Loop through models
        models = glob(f"{project_data_path}/metrics_regridded/CIL-GDPCIR/conservative/*.nc")
        models = [model.split('/')[-1].replace('.nc','') for model in models]
        for model in models:
            df_tmp = dask.delayed(read_and_process)(model, ssp, years, lat, lon)
            df_delayed.append(df_tmp)

        # Compute and store
        df_out = dask.compute(*df_delayed)
        df_out = [df for df in df_out if df is not None]
        pd.concat(df_out).to_csv(f"{out_path}/{out_str}.csv", index=False)

In [5]:
%%time
for city in city_list.keys():
    lat, lon = city_list[city]
    for years in [[2020,2040], [2050,2070], [2080,2100]]:
        for ssp in ['ssp245', 'ssp370', 'ssp585']:
            get_raw_data(ssp=ssp, years=years,
                         lat=lat, lon=lon,
                         out_path=f"{project_data_path}/summary_raw",
                         out_str=f"{city}_CIL-GDPCIR_{ssp}_{str(years[0])}-{str(years[1])}")

CPU times: user 6.4 s, sys: 408 ms, total: 6.8 s
Wall time: 38.2 s
