In [1]:
###############################################
### TO RUN ON MICROSOFT PLANETARY COMPUTER ####
###############################################

In [1]:
import collections
import getpass
import io

import azure.storage.blob
import fsspec
import numpy as np
import pandas as pd
import planetary_computer
import pystac
import pystac_client
import requests
import xarray as xr
import zarr

# import regionmask

In [4]:
#################
# Data access
#################

# Complete catalog
catalog = pystac_client.Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")


# function to grab variables and all SSPs for singe model
def grab_model(model_id, vars_to_grab, subset_US):
    # Search across all licences in CIL-GDPCIR
    search = catalog.search(
        collections=["cil-gdpcir-cc0", "cil-gdpcir-cc-by", "cil-gdpcir-cc-by-sa"],
        query={"cmip6:source_id": {"eq": model_id}, "cmip6:experiment_id": {"neq": "historical"}},  # omit historical
    )
    ensemble = search.item_collection()

    # Grab all into one dataset
    ds_ssp = []

    for item in ensemble:
        signed = planetary_computer.sign(item)
        ds_vars = []
        for variable_id in vars_to_grab:
            asset = signed.assets[variable_id]
            ds_tmp = xr.open_dataset(asset.href, **asset.extra_fields["xarray:open_kwargs"])
            ds_tmp = ds_tmp.assign_coords(ssp=ds_tmp.attrs["experiment_id"])
            ds_vars.append(ds_tmp)
        ds_ssp.append(xr.merge(ds_vars))

    ds_out = xr.concat(ds_ssp, dim="ssp")
    
    # Subset US if desired
    if subset_US:
        ds_out = ds_out.sel(lon=slice(-130,-50), lat=slice(20,60))

    return ds_out

In [5]:
# Get all models
models = []
for license in ["cil-gdpcir-cc0", "cil-gdpcir-cc-by", "cil-gdpcir-cc-by-sa"]:
    collection = catalog.get_collection(license)
    models_tmp = collection.summaries.to_dict()['cmip6:source_id']
    models.append(models_tmp)
    
models = np.hstack(models)

In [6]:
#########
# Dask
#########
import dask_gateway

gateway = dask_gateway.Gateway()

# cluster options
cluster_options = gateway.cluster_options()
cluster_options["worker_memory"] = 30
cluster_options["worker_cores"] = 1

# start cluster
cluster = gateway.new_cluster(cluster_options)
client = cluster.get_client()
cluster.scale(30)

# dashboard link
print(cluster.dashboard_link)

https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.82f79b92a1d648a788af1d4037732aa0/status


In [7]:
%%time
#########################
### Calculate metrics ###
#########################
# loop through models: RUNTIME IS AROUND 10 MINS PER MODEL WITH 30 DASK WORKERS
for model in models:
    # FGOALS-g3 missing pr
    if model == 'FGOALS-g3':
        # load data (lazy)
        ds = grab_model(model, ["tasmin", "tasmax"], True)
    else:
        ds = grab_model(model, ["tasmin", "tasmax", "pr"], True)
    
    # unit conversions
    ds["tasmax"] = ds["tasmax"] - 273.15  # K -> C
    ds["tasmin"] = ds["tasmin"] - 273.15  # K -> C

    # compute
    ds["tas"] = (ds["tasmax"] + ds["tasmin"]) / 2.0
    
    ds_tas_avg = ds["tas"].resample(time="1Y").mean()
    ds_tasmax_max = ds["tasmax"].resample(time="1Y").max()

    if model != 'FGOALS-g3':
        ds_pr_sum = ds["pr"].resample(time="1Y").sum()
        ds_pr_max = ds["pr"].resample(time="1Y").max()
    
    # merge
    if model == 'FGOALS-g3':
        ds_final = xr.Dataset({"tas_avg": ds_tas_avg,
                               "tasmax_max": ds_tasmax_max})
    else:
        ds_final = xr.Dataset({"tas_avg": ds_tas_avg,
                               "pr_sum": ds_pr_sum,
                               "tasmax_max": ds_tasmax_max,
                               "pr_max": ds_pr_max})

    # storage options
    ds_final = ds_final.chunk({"ssp": 1, "time": 10, "lat": 720, "lon": 1440})

    compressor = zarr.Blosc(cname="zstd", clevel=3)
    encoding = {vname: {"compressor": compressor} for vname in ds_final.data_vars}

    store = zarr.ABSStore(client=container_client, prefix=model)

    # store
    ds_final.to_zarr(store=store, encoding=encoding, consolidated=True, mode="w")
    print(model)

FGOALS-g3
INM-CM4-8
INM-CM5-0
BCC-CSM2-MR
ACCESS-ESM1-5
ACCESS-CM2
MIROC-ES2L
MIROC6
NorESM2-LM
NorESM2-MM
GFDL-CM4
GFDL-ESM4
NESM3
MPI-ESM1-2-HR
HadGEM3-GC31-LL
UKESM1-0-LL
MPI-ESM1-2-LR
EC-Earth3
EC-Earth3-AerChem
EC-Earth3-CC
EC-Earth3-Veg
EC-Earth3-Veg-LR
CMCC-CM2-SR5
CMCC-ESM2
CanESM5
CPU times: user 3min 54s, sys: 14.3 s, total: 4min 8s
Wall time: 1h 41min 38s


2024-01-02 19:15:12,166 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client


In [8]:
import os
from glob import glob

import dask
import numpy as np
import xarray as xr
import xesmf as xe

from functools import partial

In [4]:
################
#### Paths #####
################
# Update these for reproduction

project_data_path = "/storage/group/pches/default/users/dcl5300/conus_comparison_lafferty-etal-2024/"
project_code_path = "/storage/home/dcl5300/work/current_projects/conus_comparison_lafferty-etal-2024/"

In [4]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    account="pches",
    # account="open",
    cores=1,
    memory="8GiB",
    walltime="45:00:00"
)

cluster.scale(jobs=20)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.162:43201,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Regridding

In [5]:
# Get models
models = os.listdir(f"{project_data_path}/metrics/CIL-GDPCIR/conus-comparison/")

### Conservative

In [6]:
# Get CIL grid
ds_in = xr.open_zarr(f"{project_data_path}/metrics/CIL-GDPCIR/conus-comparison/{models[0]}")

In [7]:
# We use LOCA grid as target
loca_lat_grid = np.linspace(23.90625, 53.46875, 474)
loca_lon_grid = np.linspace(234.53125, 293.46875, 944)

ds_out = xr.Dataset({"lat": (["lat"], loca_lat_grid,
                             {"standard_name": "latitude", "units": "degrees_north"}),
                     "lon": (["lon"], loca_lon_grid,
                             {"standard_name": "longitude", "units": "degrees_east"})
                    })

# Add mask from LOCA output
loca_nans = np.load(f'{project_code_path}/code/utils/LOCA2_NaNs.npy')
ds_out["mask"] = xr.DataArray(~loca_nans, dims=['lat','lon'])

# Conservative
regridder = xe.Regridder(ds_in, ds_out, "conservative")

In [8]:
%%time
# Out path
out_path = f"{project_data_path}/metrics_regridded/CIL-GDPCIR/conservative"

# Loop through all
for model in models:
    if not os.path.isfile(f"{out_path}/{model}.nc"):
        # Read
        ds_cil_in = xr.open_zarr(f"{project_data_path}/metrics/CIL-GDPCIR/conus-comparison/{model}")
        
        # Regrid lazy
        # NOTE: use high NaN threshold to try to not introduce NaNs
        # not already present in the LOCA2 grid
        ds_cil_out = regridder(ds_cil_in, skipna=True, na_thres=0.99)
        
        # Store
        ds_cil_out.to_netcdf(f"{out_path}/{model}.nc")
        print(model)

This may cause some slowdown.
Consider scattering data ahead of time and using futures.


GFDL-ESM4


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


HadGEM3-GC31-LL


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


INM-CM4-8


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


INM-CM5-0


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


MIROC-ES2L


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


MIROC6


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


MPI-ESM1-2-HR


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


MPI-ESM1-2-LR


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


NESM3


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


NorESM2-LM


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


NorESM2-MM


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


UKESM1-0-LL
CPU times: user 8min 55s, sys: 33.5 s, total: 9min 29s
Wall time: 26min 15s


# Summary indices

In [9]:
# Simple preprocessing function to add model and year coordinates
def _preprocess(ds, ssp):
    # Add model coordinate
    model = ds.encoding['source'].split('/')[-1].split('_')[0]
    ds = ds.assign_coords(model = model)

    # Select SSP
    if ssp in ds.
    ds = ds.sel(ssp=ssp)

    # Time -> year
    ds['time'] = ds['time'].dt.year

    return ds

In [None]:
ds = xr.open_mfdataset(f"{project_data_path}/metrics_regridded/CIL-GDPCIR/conservative/AC.nc"

In [23]:
ds = xr.open_zarr(f"{project_data_path}/metrics/CIL-GDPCIR/conus-comparison/CanESM5")

In [24]:
ds

Unnamed: 0,Array,Chunk
Bytes,268.75 MiB,3.91 MiB
Shape,"(8, 86, 160, 320)","(1, 10, 160, 320)"
Dask graph,72 chunks in 2 graph layers,72 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 268.75 MiB 3.91 MiB Shape (8, 86, 160, 320) (1, 10, 160, 320) Dask graph 72 chunks in 2 graph layers Data type float64 numpy.ndarray",8  1  320  160  86,

Unnamed: 0,Array,Chunk
Bytes,268.75 MiB,3.91 MiB
Shape,"(8, 86, 160, 320)","(1, 10, 160, 320)"
Dask graph,72 chunks in 2 graph layers,72 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,268.75 MiB,3.91 MiB
Shape,"(8, 86, 160, 320)","(1, 10, 160, 320)"
Dask graph,72 chunks in 2 graph layers,72 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 268.75 MiB 3.91 MiB Shape (8, 86, 160, 320) (1, 10, 160, 320) Dask graph 72 chunks in 2 graph layers Data type float64 numpy.ndarray",8  1  320  160  86,

Unnamed: 0,Array,Chunk
Bytes,268.75 MiB,3.91 MiB
Shape,"(8, 86, 160, 320)","(1, 10, 160, 320)"
Dask graph,72 chunks in 2 graph layers,72 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,134.38 MiB,1.95 MiB
Shape,"(8, 86, 160, 320)","(1, 10, 160, 320)"
Dask graph,72 chunks in 2 graph layers,72 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 134.38 MiB 1.95 MiB Shape (8, 86, 160, 320) (1, 10, 160, 320) Dask graph 72 chunks in 2 graph layers Data type float32 numpy.ndarray",8  1  320  160  86,

Unnamed: 0,Array,Chunk
Bytes,134.38 MiB,1.95 MiB
Shape,"(8, 86, 160, 320)","(1, 10, 160, 320)"
Dask graph,72 chunks in 2 graph layers,72 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,134.38 MiB,1.95 MiB
Shape,"(8, 86, 160, 320)","(1, 10, 160, 320)"
Dask graph,72 chunks in 2 graph layers,72 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 134.38 MiB 1.95 MiB Shape (8, 86, 160, 320) (1, 10, 160, 320) Dask graph 72 chunks in 2 graph layers Data type float32 numpy.ndarray",8  1  320  160  86,

Unnamed: 0,Array,Chunk
Bytes,134.38 MiB,1.95 MiB
Shape,"(8, 86, 160, 320)","(1, 10, 160, 320)"
Dask graph,72 chunks in 2 graph layers,72 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [13]:
ssp='ssp585'
_partial_preprocess = partial(_preprocess, ssp=ssp)
ds = xr.open_dataset(f"{project_data_path}/metrics_regridded/CIL-GDPCIR/conservative/CanESM5.nc", chunks='auto')

In [19]:
ds.sel(ssp='ssp585')

Unnamed: 0,Array,Chunk
Bytes,587.18 MiB,84.83 MiB
Shape,"(2, 86, 474, 944)","(2, 43, 253, 511)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 587.18 MiB 84.83 MiB Shape (2, 86, 474, 944) (2, 43, 253, 511) Dask graph 8 chunks in 3 graph layers Data type float64 numpy.ndarray",2  1  944  474  86,

Unnamed: 0,Array,Chunk
Bytes,587.18 MiB,84.83 MiB
Shape,"(2, 86, 474, 944)","(2, 43, 253, 511)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,587.18 MiB,84.83 MiB
Shape,"(2, 86, 474, 944)","(2, 43, 253, 511)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 587.18 MiB 84.83 MiB Shape (2, 86, 474, 944) (2, 43, 253, 511) Dask graph 8 chunks in 3 graph layers Data type float64 numpy.ndarray",2  1  944  474  86,

Unnamed: 0,Array,Chunk
Bytes,587.18 MiB,84.83 MiB
Shape,"(2, 86, 474, 944)","(2, 43, 253, 511)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,587.18 MiB,84.83 MiB
Shape,"(2, 86, 474, 944)","(2, 43, 253, 511)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 587.18 MiB 84.83 MiB Shape (2, 86, 474, 944) (2, 43, 253, 511) Dask graph 8 chunks in 3 graph layers Data type float64 numpy.ndarray",2  1  944  474  86,

Unnamed: 0,Array,Chunk
Bytes,587.18 MiB,84.83 MiB
Shape,"(2, 86, 474, 944)","(2, 43, 253, 511)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,587.18 MiB,84.83 MiB
Shape,"(2, 86, 474, 944)","(2, 43, 253, 511)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 587.18 MiB 84.83 MiB Shape (2, 86, 474, 944) (2, 43, 253, 511) Dask graph 8 chunks in 3 graph layers Data type float64 numpy.ndarray",2  1  944  474  86,

Unnamed: 0,Array,Chunk
Bytes,587.18 MiB,84.83 MiB
Shape,"(2, 86, 474, 944)","(2, 43, 253, 511)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [12]:
ds

NameError: name 'ds' is not defined

In [48]:
# Calculates summary indices for NEX-GDDP-CMIP6 ensemble for given SSP
def get_summary_indices(ssp, years, out_path, out_str):
    """
    Current summary indices calculated: mean, 99th quantile, 99% quantile range
    `years` define the window over which all outputs are pooled. 
    """
    # Check if done
    if not os.path.isfile(f"{out_path}/{out_str}.nc"):
        # Read all
        partial_func = partial(_preprocess, ssp=ssp)
        ds = xr.open_mfdataset(f"{project_data_path}/metrics_regridded/CIL-GDPCIR/conservative/*.nc", chunks='auto',
                               preprocess=_preprocess, combine='nested', concat_dim=['model'])

        # Time slice
        ds_sel = ds.sel(time=slice(years[0],years[1]))
    
        ## Summary indices
        # Mean
        ds_mean = ds_sel.mean(dim=['model', 'time']).assign_coords(indice = 'mean')
        # Quantiles
        ds_qlow = ds_sel.chunk(dict(model=-1)).quantile(0.005, dim=['model', 'time'])
        ds_qhigh = ds_sel.chunk(dict(model=-1)).quantile(0.995, dim=['model', 'time'])
        ds_qrange = (ds_qhigh - ds_qlow).assign_coords(indice = '99range')
    
        ds_q99 = ds_sel.chunk(dict(model=-1)).quantile(0.99, dim=['model', 'time']).assign_coords(indice = 'q99')

        # Store
        ds_out = xr.concat([ds_mean, ds_qrange, ds_q99], dim='indice')
        ds_out.to_netcdf(f"{out_path}/{out_str}.nc")

In [None]:
%%time
for years in [[2020,2040], [2050,2070], [2080,2100]]:
    for ssp in ['ssp245', 'ssp370', 'ssp585']:
        get_summary_indices(ssp=ssp, years=years,
                            out_path=f"{project_data_path}/summary_indices",
                            out_str=f"NEX-GDDP-CMIP6_{ssp}_{str(years[0])}-{str(years[1])}")