In [None]:
###############################################
### FOR USE ON MICROSOFT PLANETARY COMPUTER ###
###############################################

In [1]:
import planetary_computer
import pystac_client
import pystac

import numpy as np
import xarray as xr
import pandas as pd

import collections
import fsspec
import requests

import getpass
import azure.storage.blob
import zarr

### Preliminaries

In [2]:
######################
# Azure blob storage
######################
# connection string (from azure web login, select your storage account, then "Access keys")
connection_string = getpass.getpass()

    
# format storage
container_client = azure.storage.blob.ContainerClient.from_connection_string(
    connection_string, container_name="mpctransfer")

 ········


In [1]:
###################
# Models
###################

# nex models with all SSPs and variables (tas, pr)
complete_nex_models = ['ACCESS-CM2', 'ACCESS-ESM1-5', 'CanESM5', 'CMCC-ESM2', 
                       'CNRM-CM6-1', 'CNRM-ESM2-1', 'EC-Earth3',
                       'EC-Earth3-Veg-LR', 'FGOALS-g3', 'GFDL-CM4', 'GFDL-ESM4', 
                       'GISS-E2-1-G', 'INM-CM4-8', 'INM-CM5-0',
                       'IPSL-CM6A-LR', 'KACE-1-0-G', 'MIROC-ES2L', 'MIROC6',
                       'MPI-ESM1-2-HR', 'MPI-ESM1-2-LR', 'MRI-ESM2-0', 'NorESM2-LM',
                       'NorESM2-MM', 'TaiESM1', 'UKESM1-0-LL']

# cil models with all SSPs and variables
complete_cil_models = ["INM-CM4-8", "INM-CM5-0", "BCC-CSM2-MR", "CMCC-CM2-SR5",
              "CMCC-ESM2", "MIROC-ES2L", "MIROC6", "UKESM1-0-LL", "MPI-ESM1-2-LR",
              "NorESM2-LM", "NorESM2-MM", "GFDL-ESM4", "EC-Earth3", 
              "EC-Earth3-Veg-LR", "EC-Earth3-Veg", "CanESM5"]

# intersection of models
models = np.intersect1d(complete_cil_models, complete_nex_models)

In [2]:
models

array(['CMCC-ESM2', 'CanESM5', 'EC-Earth3', 'EC-Earth3-Veg-LR',
       'GFDL-ESM4', 'INM-CM4-8', 'INM-CM5-0', 'MIROC-ES2L', 'MIROC6',
       'MPI-ESM1-2-LR', 'NorESM2-LM', 'NorESM2-MM', 'UKESM1-0-LL'],
      dtype='<U16')

In [4]:
#################
# Data access
#################

# Complete catalog
catalog = pystac_client.Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")

# function to grab variables and SSPs for singe model
def grab_model(model_id, include_temp, include_prcp):
    # Search across all licences in CIL-GDPCIR
    search = catalog.search(
        collections=["cil-gdpcir-cc0", "cil-gdpcir-cc-by", "cil-gdpcir-cc-by-sa"],
        query={"cmip6:source_id" : {"eq": model_id},
               "cmip6:experiment_id": {"neq": "historical"}} # omit historical
    )
    ensemble = search.get_all_items()
    
    # grab all into one dataset
    ds_ssp = []
    
    # define vars to grab
    vars_to_grab = include_temp * ['tasmin', 'tasmax'] + include_prcp * ['pr']

    for item in ensemble:
        signed = planetary_computer.sign(item)
        ds_vars = []
        for variable_id in vars_to_grab:
            asset = signed.assets[variable_id]
            ds_tmp = xr.open_dataset(asset.href, **asset.extra_fields["xarray:open_kwargs"])
            ds_tmp = ds_tmp.assign_coords(ssp = ds_tmp.attrs['experiment_id'])
            ds_vars.append(ds_tmp)
        ds_ssp.append(xr.merge(ds_vars))

    ds_out = xr.concat(ds_ssp, dim='ssp')
    
    return ds_out

In [5]:
#########
# Dask
#########
import dask_gateway
gateway = dask_gateway.Gateway()

# cluster options
cluster_options = gateway.cluster_options()
cluster_options["worker_memory"] = 16
cluster_options["worker_cores"] = 1

# start cluster
cluster = gateway.new_cluster(cluster_options)
client = cluster.get_client()
cluster.scale(40)

# dashboard link
print(cluster.dashboard_link)

https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.b66642175a9e4725b8a36c5bfdfeeb28/status


## Annual averages

In [None]:
# loop through models: RUNTIME IS AROUND 15 MINS PER MODEL WITH 40 DASK WORKERS
for model in models:
    # load data (lazy)
    ds = grab_model(model, True, True)
    
    # storage options
    compressor = zarr.Blosc(cname='zstd', clevel=3)
    encoding = {vname: {'compressor': compressor} for vname in ds.data_vars} 
    
    azure_prefix = 'cil-gdpcir/annual_avgs/' + model
    store = zarr.ABSStore(client=container_client, prefix=azure_prefix)
    
    # compute and store
    ds['tasavg'] = (ds['tasmax'] + ds['tasmin']) / 2.
    ds_final = ds.resample(time='1Y').mean()

    ds_final.to_zarr(store=store, encoding=encoding, consolidated=True, mode='w')
    print(model)

## Annual maxima

In [None]:
# loop through models: RUNTIME IS AROUND 12 MINS PER MODEL WITH 40 DASK WORKERS
for model in models:
    # load data (lazy)
    ds = grab_model(model, True, True)

    # storage options
    compressor = zarr.Blosc(cname='zstd', clevel=3)
    encoding = {vname: {'compressor': compressor} for vname in ds.data_vars} 
    
    azure_prefix = 'cil-gdpcir/annual_maxs/' + model
    store = zarr.ABSStore(client=container_client, prefix=azure_prefix)
    
    # compute and store
    ds_final = ds.resample(time='1Y').max()

    ds_final.to_zarr(store=store, encoding=encoding, consolidated=True, mode='w')
    print(model)

## Annual minima (temperature only)

In [None]:
# loop through models: RUNTIME IS AROUND 8 MINS PER MODEL WITH 40 DASK WORKERS
for model in models:
    # load data (lazy)
    ds = grab_model(model, True, False)

    # storage options
    compressor = zarr.Blosc(cname='zstd', clevel=3)
    encoding = {vname: {'compressor': compressor} for vname in ds.data_vars} 
    
    azure_prefix = 'cil-gdpcir/annual_mins/' + model
    store = zarr.ABSStore(client=container_client, prefix=azure_prefix)
    
    # compute and store
    ds_final = ds.resample(time='1Y').min()

    ds_final.to_zarr(store=store, encoding=encoding, consolidated=True, mode='w')
    print(model)

## Precipitation indices

In [None]:
# loop through models: RUNTIME IS AROUND 8 MINS PER MODEL WITH 40 DASK WORKERS
for model in models:
    # load data (lazy)
    ds = grab_model(model, False, True)

    # storage options  
    azure_prefix = 'cil-gdpcir/precip_inds/' + model
    store = zarr.ABSStore(client=container_client, prefix=azure_prefix)
    
    # compute and store
    prcp_sdii = ds.where(ds.pr >= 1.).resample(time='1Y').mean()
    prcp_r20mm = ds.where(ds.pr >= 20.).resample(time='1Y').count()
    
    ds_final = xr.combine_by_coords([prcp_sdii.rename({'pr': 'SDII'}),
                                     prcp_r20mm.rename({'pr': 'R20mm'})])

    # storage options
    compressor = zarr.Blosc(cname='zstd', clevel=3)
    encoding = {vname: {'compressor': compressor} for vname in ds_final.data_vars} 
    
    ds_final.to_zarr(store=store, encoding=encoding, consolidated=True, mode='w')
    print(model)

## re-chunk

In [None]:
# re-chunk for easier access - should have thought about this the first time around!
for model in models:
    print(model)
    for metric in ['annual_avgs', 'annual_maxs', 'annual_mins', 'precip_inds']:
        # read
        azure_prefix = 'cil-gdpcir/' + metric + '/' + model
        store = zarr.ABSStore(client=container_client, prefix=azure_prefix)

        ds_cil = xr.open_zarr(store=store)
        
        # rechunk and write
        azure_prefix = 'cil-gdpcir_rechunked/' + metric + '/' + model
        store = zarr.ABSStore(client=container_client, prefix=azure_prefix)
        
        ds_cil = ds_cil.chunk({'ssp':1, 'time':10, 'lat':720, 'lon':1440})
        
        compressor = zarr.Blosc(cname='zstd', clevel=3)
        encoding = {vname: {'compressor': compressor} for vname in ds_cil.data_vars} 

        ds_cil.to_zarr(store=store, encoding=encoding, consolidated=True, mode='w')

In [18]:
##################################################
# OLD: all SSPs and variables
##################################################

In [19]:
# def grab_ssp_var(ssp_id, variable_id):
#     # Search across all licences in CIL-GDPCIR
#     search = catalog.search(
#         collections=["cil-gdpcir-cc0", "cil-gdpcir-cc-by", "cil-gdpcir-cc-by-sa"],
#         query={"cmip6:experiment_id": {"eq": ssp_id}},
#     )
#     # How many models?
#     ensemble = search.get_all_items()
    
#     # grab all into one dataset
#     datasets_by_model = []

#     for item in tqdm(ensemble[:2]):
#         try:
#             signed = planetary_computer.sign(item)
#             asset = signed.assets[variable_id]
#             datasets_by_model.append(
#                 xr.open_dataset(asset.href, **asset.extra_fields["xarray:open_kwargs"])
#             )
#         except: 
#             print(variable_id + ' error for ' + item.id)

#     all_datasets = xr.concat(
#         datasets_by_model,
#         dim=pd.Index([ds.attrs["source_id"] for ds in datasets_by_model], name="model"),
#         combine_attrs="drop_conflicts",
#     )
    
#     return all_datasets

In [20]:
# ssp_id = 'ssp126'

# # tmin, tmax, prcp
# tmax_ssp126 = grab_ssp_var(ssp_id, 'tasmax')
# tmin_ssp126 = grab_ssp_var(ssp_id, 'tasmin')
# prcp_ssp126 = grab_ssp_var(ssp_id, 'pr')

# # merge and assign ssp coordinate
# ssp126_all = xr.merge([tmax_ssp126, tmin_ssp126, prcp_ssp126])
# ssp126_all = ssp126_all.assign_coords(ssp=ssp_id)