In [1]:
'''script to regrid CMIP6 datatsets to target grid and store them'''

import numpy as np
import xarray as xr
import dask
import intake
import pandas as pd
import os
from collections import defaultdict
from tqdm.autonotebook import tqdm
from xmip.utils import google_cmip_col
from xmip.postprocessing import combine_datasets, _match_datasets,_concat_sorted_time
from cmip_catalogue_operations import reduce_cat_to_max_num_realizations, drop_vars_from_cat, drop_older_versions
from cmip_ds_dict_operations import select_period, pr_flux_to_m, drop_duplicate_timesteps, drop_coords, drop_incomplete
import xesmf as xe
import gcsfs
fs = gcsfs.GCSFileSystem() #list stores, stripp zarr from filename, load 

  from tqdm.autonotebook import tqdm


In [2]:
from typing import Dict
# instead of creating a regridder for every dataset, lets only do this per source_id (the grid does not change)
# this creates ~tens regridders vs hundreds!
def create_regridder_dict(dataset_dict: Dict[str, xr.Dataset], target_grid_ds: xr.Dataset) -> Dict[str, xe.Regridder]:
    regridders = {}
    source_ids = np.unique([ds.attrs['source_id'] for ds in dataset_dict.values()])
    for si in tqdm(source_ids):
        matching_keys = [k for k in dataset_dict.keys() if si in k]
        # take the first one (we don't really care here which one we use)
        ds = dataset_dict[matching_keys[0]]
        # reusing your line from below
        regridder = xe.Regridder(ds,target_grid_ds,'bilinear',ignore_degenerate=True,periodic=True) #create regridder for this dataset
        regridders[si] = regridder
    return regridders

In [3]:
#configure settings
# output_path = 'gs://leap-persistent/timh37/CMIP6/timeseries_eu_1p5/'
output_path = 'gs://leap-scratch/jbusecke/CMIPcex/timeseries_eu_1p5/'
overwrite_existing = False #whether or not to process files for which output already exists (to-do: implement)

target_grid = xr.Dataset( #grid to interpolate CMIP6 simulations to
        {   "longitude": (["longitude"], np.arange(-30,22.5,1.5), {"units": "degrees_east"}),
            "latitude": (["latitude"], np.arange(70,30,-1.5), {"units": "degrees_north"}),})

query_vars = ['sfcWind','pr','psl'] #variables to process
required_vars = ['sfcWind','pr','psl'] #variables that includes models should provide

ssps = ['ssp245','ssp585']

In [4]:
#query simulations & manipulate data catalogue:
col = google_cmip_col() #google cloud catalogue
lcol = intake.open_esm_datastore("https://storage.googleapis.com/leap-persistent-ro/data-library/catalogs/cmip6-test/leap-pangeo-cmip6-test.json") #temporary pangeo-leap-forge catalogue
col.esmcat._df = pd.concat([col.df,lcol.df],ignore_index=True) #merge these catalogues

ssp_cats = defaultdict(dict)

#search catalogue per ssp (need to do this for each SSP separately as availability may differ between them)
for s,ssp in enumerate(ssps):
    ssp_cat = col.search( #find instances providing all required query_vars for both historical & ssp experiments
    experiment_id=['historical',ssp],
    table_id='day',
    variable_id=required_vars,
    require_all_on=['source_id', 'member_id','grid_label'])
    ssp_cats[ssp] = ssp_cat
    
ssp_cats_merged = ssp_cats[ssp] #merge catalogues for all ssps, and drop duplicate historical simulations
ssp_cats_merged.esmcat._df = pd.concat([v.df for k,v in ssp_cats.items()],ignore_index=True).drop_duplicates(ignore_index=True)

ssp_cats_merged = drop_older_versions(ssp_cats_merged) #if google cloud and leap-pangeo catalogues provide duplicate datasets, keep the newest version, and if the versions are identical, keep the leap-pangeo dataset
ssp_cats_merged = reduce_cat_to_max_num_realizations(ssp_cats_merged) #per model, select grid and 'ipf' combination providing most realizations (needs to be applied to both SSPs together to ensure the same variants are used under both scenarios)

In [5]:
ssp_cats_merged

Unnamed: 0,unique
activity_id,2
institution_id,20
source_id,28
experiment_id,3
member_id,115
table_id,1
variable_id,3
grid_label,4
zstore,2274
dcpp_init_year,0


## @jbusecke: I propose to regrid every store separately here, since the concatenation can lead to chunking issues.

Once the regridding is done, the datasets are tiny (few 100 MB)! So my proposed strategy is as follows: 
- I have slightly optimized the regridding step, and applied it over all of the datasets we have in `ssp_cats_merged`. This is much less error prone, due to problems with concatenation between two experiments.
- Currently this step pretty much only regrids, and rechunks in time, but then applies that a bit more efficiently across many datasets

In [6]:
# alternative to above
ssp_cats_merged.esmcat.aggregation_control.groupby_attrs = []
ddict_all = ssp_cats_merged.to_dataset_dict(zarr_kwargs={'use_cftime':True},aggregate=True) # single stores (Perhaps we dont need some of them, but at this point we do not really care)


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'


In [12]:
regridder_dict = create_regridder_dict(ddict_all, target_grid)

  0%|          | 0/28 [00:00<?, ?it/s]

In [98]:
# from distributed import Client
# client = Client()
# client

## try distributed big ol cluster
import dask
from dask_gateway import Gateway
gateway = Gateway()

# close existing clusters (be careful if you have multiple clusters/servers open!)
open_clusters = gateway.list_clusters()
print(list(open_clusters))
if len(open_clusters)>0:
    for c in open_clusters:
        cluster = gateway.connect(c.name)
        cluster.shutdown()  


options = gateway.cluster_options()
options.worker_memory = 18
# options.worker_cores = 12

# Create a cluster with those options
cluster = gateway.new_cluster(options)
client = cluster.get_client()
cluster.adapt(20, 100)
client

[ClusterReport<name=prod.9a4a6cd4931943d3908d2797a8f1b2c4, status=RUNNING>]


0,1
Connection method: Cluster object,Cluster type: dask_gateway.GatewayCluster
Dashboard: /services/dask-gateway/clusters/prod.486275becf4e4c1e99b84ce58a36d29a/status,


Exception in callback None()
handle: <Handle cancelled>
Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/iostream.py", line 1367, in _do_ssl_handshake
    self.socket.do_handshake()
  File "/srv/conda/envs/notebook/lib/python3.10/ssl.py", line 1342, in do_handshake
    self._sslobj.do_handshake()
ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate (_ssl.c:1007)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 192, in _handle_events
    handler_func(fileobj, events)
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/iostream.py", line 691, in _handle_events
    self._h

In [76]:
def make_filepath(output_path, ds):
    key = ds.attrs["time_concat_key"]
    variable = ds.attrs['variable_id']
    return os.path.join(output_path,variable,ds.source_id,key+'.hist_'+ssp)

#regrid these datasets to the target grid (@jbusecke: But do not actually store them yet!)
regridded_datasets = []
for key,ds in tqdm(ddict_all.items()):
    ds.attrs["time_concat_key"] = key #add current key information to attributes
    output_fn = make_filepath(output_path, ds)
    
    ds = ds.isel(dcpp_init_year=0,drop=True) #remove this coordinate

    regridder = regridder_dict[ds.attrs['source_id']]
    regridded_ds = regridder(ds, keep_attrs=True)
    
    # regridded_ds = regridder(ds.chunk({'time': 20000}), keep_attrs=True)
    regridded_datasets.append(regridded_ds.unify_chunks().chunk({'time':40000}))

  0%|          | 0/2274 [00:00<?, ?it/s]

In [99]:
# following https://stackoverflow.com/questions/66769922/concurrently-write-xarray-datasets-to-zarr-how-to-efficiently-scale-with-dask
from distributed import worker_client, as_completed
fs = gcsfs.GCSFileSystem()
def write_wrapper(ds, overwrite=False, fs=None):
    target = make_filepath(output_path, ds)
    with worker_client() as client:
        try:
            if overwrite or not fs.exists(target):
                # only write if store doesnt exist or overwrite is true
                ds.to_zarr(store=target, mode='w')
                return target, 'written freshly'
            else:
                return target, 'already written, skipped'
        except Exception as e:
            return target, f"Failed with: {e}"

# There is some more advanced way of doing this with the `as_completed` iterator, to achieve a 'steady' supply of submissions to the client. # (see answers in https://stackoverflow.com/questions/66769922/concurrently-write-xarray-datasets-to-zarr-how-to-efficiently-scale-with-dask), 
# but for our intents and purposes, we can just submit medium sized batches here:
# This seems to scale ok (there is still downtime between the batch submissions). For comparison, just using a big cluster and looping over the datasets achieved ~3x speed up (not bad), 
# but here we are looking at 10+x


interval = 50 # this seems to work fine, except a few warnings about a large graph... You could play with this, but higher numbers seemed to 
# make the scheduler quite unstable...
regridded_datasets_batches = [regridded_datasets[a:a+interval] for a in range(0,len(regridded_datasets), interval)]

written_stores = []

for ds_batch in tqdm(regridded_datasets_batches):
    # futures = [client.submit(write_wrapper, ds) for ds in ds_batch]
    futures = client.map(write_wrapper, ds_batch, overwrite=False, fs=fs)
    for future, result in as_completed(futures, with_results=True):
        written_stores.append(result)
        future.release()
    # do we need to deal with failed futures?
    # explicitly delete futures to ease pressure on client (I do not 100% understand how this works TBH).
    del futures

  0%|          | 0/46 [00:00<?, ?it/s]

This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Co

In [100]:
written_stores[-4:]

[('gs://leap-scratch/jbusecke/CMIPcex/timeseries_eu_1p5/sfcWind/ACCESS-CM2/ScenarioMIP.CSIRO-ARCCSS.ACCESS-CM2.ssp245.r1i1p1f1.day.sfcWind.gn.gs://cmip6/CMIP6/ScenarioMIP/CSIRO-ARCCSS/ACCESS-CM2/ssp245/r1i1p1f1/day/sfcWind/gn/v20191108/.20191108.hist_ssp585',
  'written freshly'),
 ('gs://leap-scratch/jbusecke/CMIPcex/timeseries_eu_1p5/sfcWind/MIROC6/CMIP.MIROC.MIROC6.historical.r23i1p1f1.day.sfcWind.gn.gs://cmip6/CMIP6/CMIP/MIROC/MIROC6/historical/r23i1p1f1/day/sfcWind/gn/v20200519/.20200519.hist_ssp585',
  'written freshly'),
 ('gs://leap-scratch/jbusecke/CMIPcex/timeseries_eu_1p5/psl/EC-Earth3/ScenarioMIP.EC-Earth-Consortium.EC-Earth3.ssp245.r139i1p1f1.day.psl.gr.gs://cmip6/CMIP6/ScenarioMIP/EC-Earth-Consortium/EC-Earth3/ssp245/r139i1p1f1/day/psl/gr/v20210401/.20210401.hist_ssp585',
  'written freshly'),
 ('gs://leap-scratch/jbusecke/CMIPcex/timeseries_eu_1p5/pr/CMCC-ESM2/CMIP.CMCC.CMCC-ESM2.historical.r1i1p1f1.day.pr.gn.gs://cmip6/CMIP6/CMIP/CMCC/CMCC-ESM2/historical/r1i1p1f1/day/p

In [101]:
xr.open_dataset(written_stores[-1][0], engine='zarr', chunks={})

Unnamed: 0,Array,Chunk
Bytes,0.92 MiB,235.26 kiB
Shape,"(60225, 2)","(30113, 1)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 0.92 MiB 235.26 kiB Shape (60225, 2) (30113, 1) Dask graph 4 chunks in 2 graph layers Data type object numpy.ndarray",2  60225,

Unnamed: 0,Array,Chunk
Bytes,0.92 MiB,235.26 kiB
Shape,"(60225, 2)","(30113, 1)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,217.10 MiB,144.20 MiB
Shape,"(1, 60225, 27, 35)","(1, 40000, 27, 35)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 217.10 MiB 144.20 MiB Shape (1, 60225, 27, 35) (1, 40000, 27, 35) Dask graph 2 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  35  27  60225,

Unnamed: 0,Array,Chunk
Bytes,217.10 MiB,144.20 MiB
Shape,"(1, 60225, 27, 35)","(1, 40000, 27, 35)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [102]:
fs.du(output_path)/1e9

262.155190536

What you have now is a regridded/rechunked store of all the datasets. You will need to reload these, and perform your filtering/concatenation.
Doing that afterwards turns out to be critical to avoid a bunch of issues with inconsistent time chunks etc...

If there are any steps that you can apply onto each dataset (e.g. `ddict = drop_coords(ddict,['bnds','nbnd','height']) #remove some unused auxiliary coordinates`, I would consider adding them above and overwriting these stores.)

After that I wonder how much of the computation you could layer on lazily? One of the issues with the old processing was that the historical run was often written out double if I understand that correctly.

> FYI, most single datasets are small enough at this point to load them directly into memory when processing. Maybe keep that in mind for later steps (eofs etc).

## Some random notes
- Check out my refactor of the regridder creation! This should be much faster now, because we only built one per model (~30), not one per dataset (~2000).
- ...

## Back to Tim

This is your old code, which I picked apart quite heavily....

In [8]:
for s,ssp in tqdm(enumerate(['ssp245','ssp585'])): #for each ssp:  
    #select historical and ssp data in merged catalogue for this particular ssp
    cat_to_open = ssp_cats_merged.search(
    experiment_id=['historical',ssp],
    table_id='day',
    variable_id=required_vars,
    require_all_on=['source_id', 'member_id','grid_label'])

    cat_to_open = drop_vars_from_cat(cat_to_open,[k for k in required_vars if k not in query_vars]) #out of required variables only process query variables
    #open datasets into dictionary
    cat_to_open.esmcat.aggregation_control.groupby_attrs = [] #to circumvent aggregate=false bug # @jbusecke: Which bug are you referring to? Issue?

    #to avoid this issue: https://github.com/intake/intake-esm/issues/496
        #doesn't actually aggregate if we set cmip6_cat.esmcat.aggregation_control.groupby_attrs = []
    kwargs = {'zarr_kwargs':{'consolidated':True,'use_cftime':True},'aggregate':True} #keyword arguments for generating dictionary of datasets from cmip6 catalogue
    # @jbusecke: Curious why you are not using xMIP here?
    ddict = cat_to_open.to_dataset_dict(**kwargs) #open datasets into dictionary
    #

    #preprocess datasets in dictionary
    ddict = pr_flux_to_m(ddict) #convert pr flux to accumulated pr
    ddict = drop_duplicate_timesteps(ddict) #remove duplicate timesteps if datasets have them
    #ddict = select_period(ddict,1850,2100) #preselect time periods, do this at later stage in the chain?
    ddict = drop_coords(ddict,['bnds','nbnd','height']) #remove some unused auxiliary coordinates
    
    break

0it [00:00, ?it/s]


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'



KeyboardInterrupt



In [48]:
#concatenate historical and ssp datasets in time
# with dask.config.set(**{'array.slicing.split_large_chunks': True}):
# trying to get to the bottom of this chunking issue...wondering why this is necessary...
hist_ssp = combine_datasets(ddict,_concat_sorted_time,match_attrs =['source_id', 'grid_label','table_id','variant_label','variable_id'],combine_func_kwargs={'join':'inner','coords':'minimal'})

In [49]:
hist_ssp_ = defaultdict(dict) #probably a better way to do this, but there are approx. 1 files for which the time units are inconsistent between historical and ssp
for k,v in hist_ssp.items():
    if v.time[-1].values.dtype != v.time[0].values.dtype:
        print('dropping ' + k +' due to inconsistent timestamps in historical and ssp runs')
        continue
    else:
        hist_ssp_[k] = v

# @jbusecke: something here takes a long time...possibly another opportunity for optimization
hist_ssp_ = drop_duplicate_timesteps(hist_ssp_) #remove overlap between historical and ssp experiments which sometimes exists
hist_ssp_complete = drop_incomplete(hist_ssp_) #remove historical+ssp timeseries which are not montonically increasing or have large timegaps (based on Julius Buseckes rudimentary testing in CMIP6-LEAP-feadstock)

  0%|          | 0/27 [00:00<?, ?it/s]

dropping EC-Earth3.gr.day.r24i1p1f1.psl due to inconsistent timestamps in historical and ssp runs
Dropping duplicate timesteps for:CESM2-WACCM.gn.day.r2i1p1f1.psl
Dropping duplicate timesteps for:CESM2-WACCM.gn.day.r2i1p1f1.pr
Dropping duplicate timesteps for:CESM2-WACCM.gn.day.r1i1p1f1.sfcWind
Dropping duplicate timesteps for:CESM2.gn.day.r4i1p1f1.psl
Dropping duplicate timesteps for:CESM2.gn.day.r11i1p1f1.pr
Dropping duplicate timesteps for:CESM2-WACCM.gn.day.r3i1p1f1.psl
Dropping duplicate timesteps for:CESM2-WACCM.gn.day.r3i1p1f1.pr
Dropping duplicate timesteps for:CESM2-WACCM.gn.day.r2i1p1f1.sfcWind
Dropping duplicate timesteps for:CESM2.gn.day.r11i1p1f1.sfcWind
Dropping duplicate timesteps for:CESM2.gn.day.r4i1p1f1.pr
Dropping duplicate timesteps for:CESM2-WACCM.gn.day.r1i1p1f1.psl
Dropping duplicate timesteps for:CESM2-WACCM.gn.day.r3i1p1f1.sfcWind
Dropping duplicate timesteps for:EC-Earth3-Veg.gr.day.r5i1p1f1.sfcWind
Dropping duplicate timesteps for:CESM2.gn.day.r11i1p1f1.psl
D

Calculate total size of datasets

In [None]:
x=0
for k,v in ddict.items():
    if 'ssp245' in k:
        x += v.nbytes/1000000000
x

List available members per model

In [None]:
models = ssp_cats_merged.df.source_id.unique()
count_members = np.zeros(len(models))

for k,ds in hist_ssp_complete.items():
    count_members[np.where(models==ds.source_id)[0][0]] = count_members[np.where(models==ds.source_id)[0][0]] +1

In [None]:
print(models)
np.floor_divide(count_members,3)