# Rechunk, cap, Q/A, and deliver precip

In [1]:
HISTORY = '''
v1.1 : initial release (version number set to match temperature). 
'''.strip()

OUTPUT_VERSION = 'v1.1'

In [2]:
import os
import json
import fsspec
import requests
import contextlib
import xarray as xr
import pandas as pd
import numpy as np
import itertools
import zarr
import rechunker
import dask
import rhg_compute_tools.kubernetes as rhgk
import rhg_compute_tools.utils as rhgu
import dask.distributed as dd
from tqdm.auto import tqdm
import matplotlib
import matplotlib.pyplot as plt

  from distributed.utils import LoopRunner, format_bytes


In [3]:
! pwd

/home/jovyan/downscaling


In [4]:
DELIVERY_MODELS = [
    'BCC-CSM2-MR',
    'FGOALS-g3',
    'ACCESS-ESM1-5',
    'ACCESS-CM2',
    'INM-CM4-8',
    'INM-CM5-0',
    'MIROC-ES2L',
    'MIROC6',
    'NorESM2-LM',
    'NorESM2-MM',
    'GFDL-ESM4',
    'GFDL-CM4',
    'NESM3',
]

In [5]:
INSTITUTIONS = {    
    'BCC-CSM2-MR': 'BCC',
    'FGOALS-g3': 'CAS',
    'ACCESS-ESM1-5': 'CSIRO',
    'ACCESS-CM2': 'CSIRO-ARCCSS',
    'INM-CM4-8': 'INM',
    'INM-CM5-0': 'INM',
    'MIROC-ES2L': 'MIROC',
    'MIROC6': 'MIROC',
    'NorESM2-LM': 'NCC',
    'NorESM2-MM': 'NCC',
    'GFDL-ESM4': 'NOAA-GFDL',
    'GFDL-CM4': 'NOAA-GFDL',
    'NESM3': 'NUIST',
}

In [6]:
ENSEMBLE_MEMBERS = {
    'BCC-CSM2-MR': 'r1i1p1f1',
    'FGOALS-g3': 'r1i1p1f1',
    'ACCESS-ESM1-5': 'r1i1p1f1',
    'ACCESS-CM2': 'r1i1p1f1',
    'INM-CM4-8': 'r1i1p1f1',
    'INM-CM5-0': 'r1i1p1f1',
    'MIROC-ES2L': 'r1i1p1f2',
    'MIROC6': 'r1i1p1f1',
    'NorESM2-LM': 'r1i1p1f1',
    'NorESM2-MM': 'r1i1p1f1',
    'GFDL-ESM4': 'r1i1p1f1',
    'GFDL-CM4': 'r1i1p1f1',
    'NESM3': 'r1i1p1f1',
}

In [7]:
GRID_SPECS = {
    "ACCESS-CM2": "gn",
    "MRI-ESM2-0": "gn",
    "CanESM5": "gn",
    "ACCESS-ESM1-5": "gn",
    "MIROC6": "gn",
    "EC-Earth3": "gr",
    "EC-Earth3-Veg-LR": "gr",
    "EC-Earth3-Veg": "gr",
    "MPI-ESM1-2-HR": "gn",
    "CMCC-ESM2": "gn",
    "INM-CM5-0": "gr1",
    "INM-CM4-8": "gr1",
    "MIROC-ES2L": "gn",
    "MPI-ESM1-2-LR": "gn",
    "FGOALS-g3": "gn",
    "BCC-CSM2-MR": "gn",
    "AWI-CM-1-1-MR": "gn",
    "NorESM2-LM": "gn",
    "GFDL-ESM4": "gr1",
    "GFDL-CM4": "gr1",
    "CAMS-CSM1-0": "gn",
    "NorESM2-MM": "gn",
    "NESM3": "gn",
}

In [8]:
HIST_EXTENSION_SCENARIO = {
    "ACCESS-CM2": "ssp370",
    "MRI-ESM2-0": "ssp370",
    "CanESM5": "ssp370",
    "ACCESS-ESM1-5": "ssp370",
    "MIROC6": "ssp370",
    "EC-Earth3": "ssp370",
    "EC-Earth3-Veg-LR": "ssp370",
    "EC-Earth3-Veg": "ssp370",
    "MPI-ESM1-2-HR": "ssp370",
    "CMCC-ESM2": "ssp370",
    "INM-CM5-0": "ssp370",
    "INM-CM4-8": "ssp370",
    "MIROC-ES2L": "ssp370",
    "MPI-ESM1-2-LR": "ssp370",
    "FGOALS-g3": "ssp370",
    "BCC-CSM2-MR": "ssp370",
    "AWI-CM-1-1-MR": "ssp370",
    "NorESM2-LM": "ssp370",
    "GFDL-ESM4": "ssp370",
    "GFDL-CM4": "ssp245",
    "CAMS-CSM1-0": "ssp370",
    "NorESM2-MM": "ssp370",
    "NESM3": "ssp245",
}

In [9]:
CLEANED_REF_0p25deg_FP = 'gs://support-c23ff1a3/qplad-fine-reference/pr/v20220201000555.zarr'

cleaned_gcm_pattern = (
    'gs://clean-b1dbca25/cmip6/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{source_version}.zarr'
)

downscaled_filepatt = (
    'gs://downscaled-288ec5ac/stage/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{run_version}.zarr'
)

rechunked_temp_store_pattern = (
    'gs://scratch-170cd6ec/stage/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{run_version}-rechunked-temp-store.zarr'
)

rechunked_pattern = (
    'gs://scratch-170cd6ec/stage/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{run_version}-rechunked.zarr'
)

capped_pattern = (
    'gs://scratch-170cd6ec/stage/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{run_version}-pr-capped.zarr'
)

OUTPUT_PATTERN = (
    'gs://downscaled-288ec5ac/outputs/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{delivery_version}.zarr'
)

In [10]:
fs = fsspec.filesystem('gs')

In [11]:
precip_spec_file = f'version_specs/precip_{OUTPUT_VERSION}.json'

if os.path.isfile(precip_spec_file):
    with open(precip_spec_file, 'r') as f:
        INPUT_FILE_VERSIONS = json.load(f)

else:
    pr_fps = {m: {} for m in DELIVERY_MODELS}

    for m in tqdm(DELIVERY_MODELS, desc='pr'):

        inst = INSTITUTIONS[m]

        for act, scen in [
            ('CMIP', 'historical'),
            ('ScenarioMIP', 'ssp245'),
            ('ScenarioMIP', 'ssp370'),
        ]:
            pr_fps[m][scen] = list(
                fs.glob(
                    downscaled_filepatt.format(
                        activity_id=act,
                        institution_id=inst,
                        source_id=m,
                        experiment_id=scen,
                        member_id=ENSEMBLE_MEMBERS[m],
                        table_id='day',
                        variable_id='pr',
                        grid_spec=GRID_SPECS[m],
                        run_version='*',
                    )
                )
            )

    pr_max_versions = {
        m: {s: max(vs) for s, vs in mspec.items() if len(vs) > 0}
        for m, mspec in pr_fps.items()
    }

    INPUT_FILE_VERSIONS = {
        'version': OUTPUT_VERSION,
        'created': pd.Timestamp.now(tz='US/Pacific').strftime('%c'),
        'history': HISTORY,
        'file_paths': {
            'pr': pr_max_versions,
        },
    }

    ! mkdir -p version_specs

    with open(precip_spec_file, 'w') as f:
        f.write(json.dumps(INPUT_FILE_VERSIONS))

pr:   0%|          | 0/13 [00:00<?, ?it/s]

In [12]:
for m in DELIVERY_MODELS:
    for v in ['pr']:
        if m not in INPUT_FILE_VERSIONS['file_paths'][v]:
            raise ValueError(f"model {m} not found for {v}")

        for s, fp in INPUT_FILE_VERSIONS['file_paths'][v][m].items():
            assert m in fp, f"model name '{m}' not found in filepath '{fp}'"
            assert s in fp, f"scenario '{s}' not found in filepath '{fp}'"

In [13]:
os.environ['CRS_SUPPORT_BUCKET']

'support-data-cc7330d0'

In [14]:
CC0_LICENSE_MODELS = ['FGOALS-g3', 'INM-CM4-8', 'INM-CM5-0']

CC_BY_LICENSE_MODELS = [
    'BCC-CSM2-MR',
    'ACCESS-ESM1-5',
    'ACCESS-CM2',
    'MIROC-ES2L',
    'MIROC6',
    'NorESM2-LM',
    'NorESM2-MM',
    'GFDL-CM4',
    'GFDL-ESM4',
    'NESM3',
]

for m in (CC0_LICENSE_MODELS + CC_BY_LICENSE_MODELS):
    for v in ['pr']:
        if m not in INPUT_FILE_VERSIONS['file_paths'][v]:
            print(m)

for m in DELIVERY_MODELS:
    assert m in (CC0_LICENSE_MODELS + CC_BY_LICENSE_MODELS)

In [15]:
for m in DELIVERY_MODELS:
    hist_extension = HIST_EXTENSION_SCENARIO[m]
    if len(INPUT_FILE_VERSIONS['file_paths']['pr'][m]) == 0:
        continue

    assert hist_extension in INPUT_FILE_VERSIONS['file_paths']['pr'][m].keys(), (
        f"{hist_extension} not in {INPUT_FILE_VERSIONS['file_paths']['pr'][m]} for model {m}"
    )

    if hist_extension != 'ssp370':
        assert 'ssp370' not in INPUT_FILE_VERSIONS['file_paths']['pr'][m].keys()    

In [16]:
pr_files = [fp for m, v in INPUT_FILE_VERSIONS['file_paths']['pr'].items() for s, fp in v.items()]

# Function Definitions

## Support functions

In [17]:
@rhgu.block_globals(whitelist=[
    'downscaled_filepatt',
    'rechunked_temp_store_pattern',
    'rechunked_pattern',
    'capped_pattern',
    'OUTPUT_PATTERN',
])
def get_spec_from_input_fp(fp, output_version=OUTPUT_VERSION):
    (
        bucket,
        stage,
        activity,
        institution,
        model,
        scenario,
        ensemble,
        table,
        variable,
        grid,
        run_version,
    ) = os.path.splitext(fp)[0].replace('gs://', '').split('/')

    spec = dict(
        bucket=bucket,
        stage=stage,
        activity=activity,
        institution=institution,
        model=model,
        scenario=scenario,
        ensemble=ensemble,
        table=table,
        variable=variable,
        grid=grid,
        run_version=run_version,
    )

    for (name, fpatt) in [
        ('downscaled_fp', downscaled_filepatt),
        ('rechunk_temp_store_fp', rechunked_temp_store_pattern),
        ('rechunked_fp', rechunked_pattern),
        ('capped_fp', capped_pattern),
        ('output_fp', OUTPUT_PATTERN),
    ]:
        spec[name] = fpatt.format(
            activity_id=activity,
            institution_id=institution,
            source_id=model,
            experiment_id=scenario,
            member_id=ensemble,
            variable_id=variable,
            table_id=table,
            grid_spec=grid,
            run_version=run_version,
            delivery_version=output_version,
        )

    return spec

@rhgu.block_globals
def get_spec_from_output_fp(fp, output_pattern=OUTPUT_PATTERN):
    
    (
        bucket,
        stage,
        activity,
        institution,
        model,
        scenario,
        ensemble,
        table,
        variable,
        output_version,
    ) = os.path.splitext(fp)[0].replace('gs://', '').split('/')

    output_fp = output_pattern.format(
        activity_id=activity,
        institution_id=institution,
        source_id=model,
        experiment_id=scenario,
        member_id=ensemble,
        variable_id=variable,
        table_id=table,
        delivery_version=output_version,
    )

    return dict(
        activity=activity,
        institution=institution,
        model=model,
        scenario=scenario,
        ensemble=ensemble,
        table=table,
        variable=variable,
        output_version=output_version,
    )


## Stage 1: Rechunk

In [18]:
@rhgu.block_globals(whitelist=[
    'INPUT_FILE_VERSIONS',
])
def rechunk_data(varname, model, scenario, worker_memory_limit):

    fs = fsspec.filesystem('gs', timeout=120, cache_timeout=120, requests_timeout=120, read_timeout=120, conn_timeout=120)

    target_chunks = {
        varname: {'time': 365, 'lat': 360, 'lon': 360},
        'time': {'time': 365},
        'lat': {'lat': 360},
        'lon': {'lon': 360},
    }

    input_fp = INPUT_FILE_VERSIONS['file_paths'][varname][model][scenario]
    input_spec = get_spec_from_input_fp(input_fp)

    rechunked_temp_store_fp = input_spec['rechunk_temp_store_fp']
    rechunked_fp = input_spec['rechunked_fp']

    if fs.isdir(rechunked_fp):
        return

    mapper = fs.get_mapper(input_fp)
    with xr.open_zarr(mapper) as ds:

        rechunked_mapper = fs.get_mapper(rechunked_fp)

        chunk_job = rechunker.rechunk(
            source=ds,
            target_chunks=target_chunks,
            max_mem=worker_memory_limit,
            target_store=rechunked_mapper,
            temp_store=fs.get_mapper(rechunked_temp_store_fp),
        )

        chunk_job_persist = chunk_job._plan.persist()
        dd.wait(chunk_job_persist)

    zarr.convenience.consolidate_metadata(rechunked_mapper)

## Stage 2: Cap precip at the max(max)*max(max)/max(max)

In [19]:
@rhgu.block_globals(whitelist=['cleaned_gcm_pattern', 'HIST_EXTENSION_SCENARIO', 'INPUT_FILE_VERSIONS', 'CLEANED_REF_0p25deg_FP'])
def cap_precip(pr_input_fp):

    fs = fsspec.filesystem('gs')

    pr_spec = get_spec_from_input_fp(pr_input_fp)

    dest_fp = pr_spec['capped_fp']
    if fs.isdir(dest_fp):
        return

    gcm_rechunked = xr.open_zarr(fs.get_mapper(pr_spec['rechunked_fp']))
    
    source_file_versions = INPUT_FILE_VERSIONS['file_paths']['pr'][pr_spec['model']]

    if pr_spec['scenario'] == 'historical':
        proj_scen = HIST_EXTENSION_SCENARIO[pr_spec['model']]
        proj_fp = get_spec_from_input_fp(source_file_versions[proj_scen])['rechunked_fp']
        with xr.open_zarr(fs.get_mapper(proj_fp)) as proj:
            source_version_proj = proj.attrs['version_id']

        source_version_hist = gcm_rechunked.attrs['version_id']
        
    else:
        proj_scen = pr_spec['scenario']
        hist_fp = get_spec_from_input_fp(source_file_versions['historical'])['rechunked_fp']
        with xr.open_zarr(fs.get_mapper(hist_fp)) as hist:
            source_version_hist = hist.attrs['version_id']

        source_version_proj = gcm_rechunked.attrs['version_id']

    source_version_id = gcm_rechunked.attrs['version_id']

    clean_fp_hist = cleaned_gcm_pattern.format(
        activity_id='CMIP',
        institution_id=pr_spec['institution'],
        source_id=pr_spec['model'],
        experiment_id='historical',
        member_id=pr_spec['ensemble'],
        table_id=pr_spec['table'],
        variable_id=pr_spec['variable'],
        grid_spec=pr_spec['grid'],
        source_version=source_version_hist,
    )

    clean_fp_proj = cleaned_gcm_pattern.format(
        activity_id='ScenarioMIP',
        institution_id=pr_spec['institution'],
        source_id=pr_spec['model'],
        experiment_id=proj_scen,
        member_id=pr_spec['ensemble'],
        table_id=pr_spec['table'],
        variable_id=pr_spec['variable'],
        grid_spec=pr_spec['grid'],
        source_version=source_version_proj,
    )

    ref_fp = CLEANED_REF_0p25deg_FP

    try:
        clean_hist = xr.open_zarr(fs.get_mapper(clean_fp_hist))
    except zarr.errors.GroupNotFoundError:
        raise FileNotFoundError(clean_fp_hist)

    try:
        clean_proj = xr.open_zarr(fs.get_mapper(clean_fp_proj))
    except zarr.errors.GroupNotFoundError:
        raise FileNotFoundError(clean_fp_proj)

    ref = xr.open_zarr(fs.get_mapper(ref_fp))

    ref_maxpr = ref.sel(time=slice('1994-12-16', '2015-01-15')).pr.max(dim='time').compute()

    gcm_hist_maxpr = clean_hist.sel(time=slice('1994-12-16', '2015-01-15')).pr.max(dim='time').compute()

    gcm_proj_maxpr = (
        xr.concat([clean_hist, clean_proj], dim='time')
        .pr
        .groupby('time.year')
        .max(dim='time')
        .compute()
    )

    # convert lons to [-180, 180]

    gcm_hist_maxpr = (
        gcm_hist_maxpr
        .assign_coords(lon=((gcm_hist_maxpr.lon % 360 + 180) % 360 - 180))
        .sortby('lon')
    )

    gcm_proj_maxpr = (
        gcm_proj_maxpr
        .assign_coords(lon=((gcm_proj_maxpr.lon % 360 + 180) % 360 - 180))
        .sortby('lon')
    )
    
    gcm_proj_maxpr_rolled = (
        gcm_proj_maxpr
        .rolling(year=21, center=True, min_periods=21).max()
    )

    gcm_factor = (
        (gcm_proj_maxpr_rolled.dropna(dim='year', how='all') / gcm_hist_maxpr)
        .rename({'lat': 'lat_coarse', 'lon': 'lon_coarse'})
        .sel(lat_coarse=ref_maxpr.lat, lon_coarse=ref_maxpr.lon, method='nearest')
        .drop(['lat_coarse', 'lon_coarse'])
    )

    upper_bound = np.maximum(1, gcm_factor) * ref_maxpr
    
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        upper_bound_full = (
            upper_bound
            .reindex(year=np.unique(gcm_rechunked.time.dt.year), method='nearest')
            .chunk({'year': 1})
            .sel(year=gcm_rechunked.time.dt.year)
            .drop('year')
        )

        gcm_capped = gcm_rechunked.copy(deep=False)

        gcm_capped['pr'] = np.minimum(upper_bound_full, gcm_rechunked['pr'])
        gcm_capped['pr'].attrs = gcm_rechunked['pr'].attrs
        gcm_capped.attrs = gcm_rechunked.attrs

        out_mapper = fs.get_mapper(dest_fp)
        gcm_capped.to_zarr(out_mapper)

## Stage 3: copy to destination directory & validate

In [20]:
@rhgu.block_globals(whitelist=['CC0_LICENSE_MODELS', 'CC_BY_LICENSE_MODELS'])
def quick_check_file(fp, ds, spec):
    """
    """

    # check that metadata matches file spec

    assert ds.attrs['institution_id'] == spec['institution'], (
        f"invalid attrs in {fp}: {ds.attrs['institution_id']} ≠ {spec['institution']}"
    )

    assert ds.attrs['source_id'] == spec['model'], (
        f"invalid attrs in {fp}: {ds.attrs['source_id']} ≠ {spec['model']}"
    )

    assert spec['activity'] in ds.attrs['activity_id'], (
        f"invalid attrs in {fp}: {spec['activity']} not in {ds.attrs['activity_id']}"
    )

    assert ds.attrs['experiment_id'] == spec['scenario'], (
        f"invalid attrs in {fp}: {ds.attrs['experiment_id']} ≠ {spec['scenario']}"
    )

    assert ds.attrs['variant_label'] == spec['ensemble'], (
        f"invalid attrs in {fp}: {ds.attrs['variant_label']} ≠ {spec['ensemble']}"
    )

    if spec['variable'] == 'tasmax':
        assert ds['tasmax'].attrs['long_name'] == 'Daily Maximum Near-Surface Air Temperature'
        assert ds['tasmax'].attrs['units'] == 'K'
    elif spec['variable'] == 'tasmin':
        assert ds['tasmin'].attrs['long_name'] == 'Daily Minimum Near-Surface Air Temperature'
        assert ds['tasmin'].attrs['units'] == 'K'
    elif spec['variable'] == 'pr':
#         assert ds['pr'].attrs['long_name'] == 
        assert ds['pr'].attrs['units'] == 'mm day-1'
    else:
        raise ValueError(f'variable not recognized: {spec["variable"]}')

    # Check licensing fields & endpoint URL

    # check that license URL points to a real location and it exists
    license_url = ds.attrs['license']
    assert ds.attrs['source_id'] in license_url, (
        f'model "{ds.attrs["source_id"]}" not found in license url: {license_url}'
    )
    r = requests.get(license_url)
    r.raise_for_status()

    # check that "Creaive Commons" and the model name show up on the page
    assert ds.attrs['source_id'] in r.text, (
        f'model "{ds.attrs["source_id"]}" not found on license page: {license_url}'
    )

    assert "Creative Commons" in r.text, (
        f'"Creative Commons" not found on license page: {license_url}'
    )

    # check that "Creative Commons" appears in the raw license text

    raw_license_url = (
        ds.attrs['license']
        .replace('github.com', 'raw.githubusercontent.com')
        .replace('/blob/', '/')
        .replace('/tree/', '/')
    )

    assert ds.attrs['source_id'] in raw_license_url, (
        f'model "{ds.attrs["source_id"]}" not found in license url: {raw_license_url}'
    )
    r = requests.get(raw_license_url)
    r.raise_for_status()
    assert 'Creative Commons' in r.text, (
        f'"Creative Commons" not found in license text: {raw_license_url}'
    )

    if spec['model'] in CC0_LICENSE_MODELS:
        assert 'CC0 1.0 Universal' in r.text, (
            f"expected CC0 license for {spec['model']} at {fp}"
        )
    elif spec['model'] in CC_BY_LICENSE_MODELS:
        assert 'Attribution 4.0 International' in r.text, (
            f"expected CC-BY 4.0 license for {spec['model']} at {fp}"
        )
    else:
        raise ValueError(
            f"deploying model with unknown license: {spec['model']} at {fp}"
        )

    # Check dimension size & membership

    for c in ds.coords.keys():
        assert ds.coords[c].notnull().all().item() is True, f"NaNs found in coordinate '{c}' in {fp}"

    if spec['activity'] == 'ScenarioMIP':
        date_range = xr.cftime_range("2015-01-01", "2099-12-31", freq="D", calendar="noleap")
        if len(ds.time) > len(date_range):
            date_range = xr.cftime_range("2015-01-01", "2100-12-31", freq="D", calendar="noleap")
    else:
        date_range = xr.cftime_range("1950-01-01", "2014-12-31", freq="D", calendar="noleap")

    assert ds.sizes['time'] == len(date_range), (
        f"unexpected length of dimension 'time': length {len(ds.time)}; "
        f"expected {len(date_range)} in {fp}"
    )

    assert date_range.isin(ds.time.dt.floor('D').values).all(), f"invalid coords in {fp}"

    assert pd.Series(np.arange(-179.875, 180, 0.25)).isin(ds.lon.values).all(), (
        f"invalid coords in {fp}"
    )
    assert pd.Series(np.arange(-89.875, 90, 0.25)).isin(ds.lat.values).all(), (
        f"invalid coords in {fp}"
    )

    varnames = list(ds.data_vars.keys())
    assert len(varnames) == 1
    varname = varnames[0]

    assert ds[varname].sizes['lat'] == 720, f"lat not length 720 in {fp}:\n{ds}"
    assert ds[varname].sizes['lon'] == 1440, f"lon not length 1440 in {fp}:\n{ds}"

In [21]:
spec = get_spec_from_input_fp(INPUT_FILE_VERSIONS['file_paths']['pr']['FGOALS-g3']['historical'])

In [22]:
spec['output_fp']

'gs://downscaled-288ec5ac/outputs/CMIP/CAS/FGOALS-g3/historical/r1i1p1f1/day/pr/v1.1.zarr'

In [23]:
@rhgu.block_globals(whitelist=['INPUT_FILE_VERSIONS'])
def validate_outputs(fp, quick=False):
    spec = get_spec_from_output_fp(fp)

    fs = fsspec.filesystem('gs', timeout=60, cache_timeout=60, requests_timeout=60, read_timeout=60, conn_timeout=60)

    mapper = fs.get_mapper(fp)

    with xr.open_zarr(mapper) as ds:

        quick_check_file(fp, ds, spec)

        if quick:
            return

        # check variable contents

        varnames = list(ds.data_vars.keys())
        assert len(varnames) == 1
        varname = varnames[0]

        to_check = ds[varname].sel(lat=slice(-80, 80))

        nans = to_check.isnull().any()
        min_val = to_check.min()
        max_val = to_check.max()

        nans, vmin, vmax = dd.get_client().compute(
            [nans, min_val, max_val],
            optimize_graph=True,
            sync=True,
            retries=3,
        )

        assert nans.item() is False, f"NaNs found in {fp}"

        if varname == 'tasmax':
            allowed_min = 150
            allowed_max = 360
        elif varname == 'tasmin':
            allowed_min = 150
            allowed_max = 360
        elif varname == 'pr':
            allowed_min = 0
            allowed_max = 3000
        else:
            raise ValueError(f'Variable name not recognized: {varname}\nin file: {fp}')

        assert (vmin >= allowed_min).item() is True, (
            f"min value {vmin} outside allowed range [{allowed_min}, {allowed_max}] "
            f"for {varname} in {fp}"
        )
        assert (vmax <= allowed_max).item() is True, (
            f"max value {vmax} outside allowed range [{allowed_min}, {allowed_max}] "
            f"for {varname} in {fp}"
        )


@rhgu.block_globals
def copy_and_validate(
    source_fp,
    output_version=OUTPUT_VERSION,
    check=False,
    deep_copy_check=False,
    quick_check_and_retry=True,
    overwrite=False,
    overwrite_on_failure=False,
    pbar=False,
):

    spec = get_spec_from_input_fp(source_fp, output_version=output_version)
    capped_fp = spec['capped_fp']
    output_fp = spec['output_fp']
    model = spec['model']
    scenario = spec['scenario']

    fs = fsspec.filesystem(
        'gs',
        timeout=360,
        cache_timeout=360,
        requests_timeout=360, read_timeout=360, conn_timeout=360)

    if fs.exists(output_fp):
        if overwrite:
            fs.remove(output_fp, recursive=True)

        else:
            if deep_copy_check:
                dirs = list([(d, f) for d, dirs, fps in fs.walk(capped_fp) for f in fps])
                if pbar:
                    dirs = tqdm(dirs)

                for d, f in dirs:
                    src = capped_fp[:5] + os.path.join(d, f)
                    dst = os.path.join(output_fp, os.path.relpath(src, capped_fp))
                    assert '..' not in dst
                    src_hash = fs.stat(src)['md5Hash']

                    for i in range(5):
                        try:
                            assert (src_hash == fs.stat(dst)['md5Hash'])
                            break
                        except (FileNotFoundError, AssertionError):
                            if i == 4:
                                raise

                            fs.rm(dst)
                            fs.copy(src, dst)

            if check:
                try:
                    validate_outputs(
                        output_fp,
                    )
                    return
                except (
                    AssertionError,
                    FileNotFoundError,
                    ValueError,
                    IOError,
                    xr.coding.times.OutOfBoundsDatetime,
                    OverflowError,
                ):
                    if overwrite_on_failure:
                        fs.rm(output_fp, recursive=True)
                    else:
                        raise

            elif quick_check_and_retry:
                try:
                    validate_outputs(output_fp, quick=True)
                    return
                except (
                    OverflowError,
                    IOError,
                    zarr.errors.GroupNotFoundError,
                    FileNotFoundError,
                    AssertionError,
                    ValueError,
                ):
                    pass

                fs.rm(output_fp, recursive=True)
            else:
                return

    print(f'copying:\n\tsrc:\t{capped_fp}\n\tdst:\t{output_fp}')
    fs.copy(capped_fp, output_fp, recursive=True, batch_size=1000)

    if deep_copy_check:
        for d, f in list([(d, f) for d, dirs, fps in fs.walk(capped_fp) for f in fps]):
            src = capped_fp[:5] + os.path.join(d, f)
            dst = os.path.join(output_fp, os.path.relpath(src, capped_fp))
            assert '..' not in dst
            src_hash = fs.stat(src)['md5Hash']

            for i in range(5):
                try:
                    assert (src_hash == fs.stat(dst)['md5Hash'])
                    break
                except (FileNotFoundError, AssertionError):
                    if i == 4:
                        raise

                    fs.rm(dst)
                    fs.copy(src, dst)

    if check:
        validate_outputs(
            output_fp,
        )
    elif quick_check_and_retry:
        validate_outputs(output_fp, quick=True)

In [24]:
import contextlib

@contextlib.contextmanager
def kill_cluster_on_error():
    try:
        yield
    except Exception:
        # kill the cluster if something unexpected happens during a long-running job
        client.restart()
        cluster.scale(0)
        client.close()
        cluster.close()
        raise

In [25]:
def blocking_pbar(futures):
    status = {'error': 0, 'killed': 0, 'lost': 0}
    with tqdm(dd.as_completed(futures), total=len(futures)) as pbar:
        for f in pbar:
            if f.status in status.keys():
                status[f.status] += 1
                pbar.set_postfix(status)

    dd.wait(futures)
    for bad_status in status.keys():
        if status[bad_status] > 0:
            [f for f in futures if f.status == bad_status][0].result()

# Full workflow

In [26]:
client, cluster = rhgk.get_giant_cluster()
cluster.scale(60)

MAX_MEM = '12GB' # for standard cluster

cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

# Prepare final outputs

In [27]:
with kill_cluster_on_error():
    with tqdm(DELIVERY_MODELS) as pbar:
        for model in pbar:
            for scenario in INPUT_FILE_VERSIONS['file_paths']['pr'][model].keys():

                pr_input_fp = INPUT_FILE_VERSIONS['file_paths']['pr'][model][scenario]
                pr_spec = get_spec_from_input_fp(pr_input_fp)

                # comment out this block to reproduce rechunked/capped data on scratch bucket
                if fs.exists(pr_spec['output_fp']):
                    print(f'skipping {model} {scenario} - output already exists')
                    continue

                pbar.set_postfix({'model': model, 'scen': scenario, 'stage': 'rechunk pr'})
                rechunk_data('pr', model, scenario, worker_memory_limit=MAX_MEM)

            for scenario in INPUT_FILE_VERSIONS['file_paths']['pr'][model].keys():

                pr_input_fp = INPUT_FILE_VERSIONS['file_paths']['pr'][model][scenario]
                pr_spec = get_spec_from_input_fp(pr_input_fp)

                # comment out this block to reproduce rechunked/capped data on scratch bucket
                if fs.exists(pr_spec['output_fp']):
                    print(f'skipping {model} {scenario} - output already exists')
                    continue

                pbar.set_postfix({'model': model, 'scen': scenario, 'stage': 'cap precip'})
                with dask.config.set(**{'array.slicing.split_large_chunks': False}):
                    cap_precip(pr_input_fp)

  0%|          | 0/13 [00:00<?, ?it/s]

skipping BCC-CSM2-MR historical - output already exists
skipping BCC-CSM2-MR ssp245 - output already exists
skipping BCC-CSM2-MR ssp370 - output already exists
skipping BCC-CSM2-MR historical - output already exists
skipping BCC-CSM2-MR ssp245 - output already exists
skipping BCC-CSM2-MR ssp370 - output already exists
skipping FGOALS-g3 historical - output already exists
skipping FGOALS-g3 ssp245 - output already exists
skipping FGOALS-g3 ssp370 - output already exists
skipping FGOALS-g3 historical - output already exists
skipping FGOALS-g3 ssp245 - output already exists
skipping FGOALS-g3 ssp370 - output already exists
skipping ACCESS-ESM1-5 historical - output already exists
skipping ACCESS-ESM1-5 ssp245 - output already exists
skipping ACCESS-ESM1-5 ssp370 - output already exists
skipping ACCESS-ESM1-5 historical - output already exists
skipping ACCESS-ESM1-5 ssp245 - output already exists
skipping ACCESS-ESM1-5 ssp370 - output already exists
skipping ACCESS-CM2 historical - output 

# Copy files to final destination

In [28]:
with kill_cluster_on_error():
    pr_futures = client.map(
        copy_and_validate,
        pr_files,
        output_version=OUTPUT_VERSION,
        check=False,
        deep_copy_check=False,
        quick_check_and_retry=True,
        overwrite=False,
        overwrite_on_failure=False,
        pbar=False,
    )

    blocking_pbar(pr_futures)

  0%|          | 0/37 [00:00<?, ?it/s]

# Deep copy check
Check every file against source to ensure a complete copy

In [29]:
with kill_cluster_on_error():
    pr_futures = client.map(
        copy_and_validate,
        pr_files,
        output_version=OUTPUT_VERSION,
        check=False,
        deep_copy_check=True,
        quick_check_and_retry=True,
        overwrite=False,
        overwrite_on_failure=False,
        pbar=False,
    )

    blocking_pbar(pr_futures)

  0%|          | 0/37 [00:00<?, ?it/s]

### Check pr data in final location
Check all pr values, including bounds & NAN checks

In [30]:
with kill_cluster_on_error():
    for f in tqdm(pr_files):
        copy_and_validate(
            f,
            output_version=OUTPUT_VERSION,
            check=True,
            deep_copy_check=False,
            quick_check_and_retry=False,
            overwrite=False,
            overwrite_on_failure=False,
            pbar=False,
        )

  0%|          | 0/37 [00:00<?, ?it/s]

In [31]:
client.restart()
cluster.scale(0)
client.close()
cluster.close()

In [32]:
outfiles = []
for f in (pr_files):
    outfiles.append(get_spec_from_input_fp(f)['output_fp'])

print(f'outputs are located in the following directory: {os.path.commonpath(outfiles).replace("gs:/", "gs://").replace(":///", "://")}')

outputs are located in the following directory: gs://downscaled-288ec5ac/outputs


To transfer data elsewhere, such as to prep for public delivery or delivery to Catalyst buckets, contact Mike for help with google transfer utility

In [33]:
with tqdm(DELIVERY_MODELS) as pbar:
    for model in pbar:
        for scenario in INPUT_FILE_VERSIONS['file_paths']['pr'][model].keys():

            pr_input_fp = INPUT_FILE_VERSIONS['file_paths']['pr'][model][scenario]
            pr_spec = get_spec_from_input_fp(pr_input_fp)

            print(pr_spec['output_fp'])

  0%|          | 0/13 [00:00<?, ?it/s]

gs://downscaled-288ec5ac/outputs/CMIP/BCC/BCC-CSM2-MR/historical/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/ScenarioMIP/BCC/BCC-CSM2-MR/ssp245/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/ScenarioMIP/BCC/BCC-CSM2-MR/ssp370/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/CMIP/CAS/FGOALS-g3/historical/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/ScenarioMIP/CAS/FGOALS-g3/ssp245/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/ScenarioMIP/CAS/FGOALS-g3/ssp370/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/CMIP/CSIRO/ACCESS-ESM1-5/historical/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/ScenarioMIP/CSIRO/ACCESS-ESM1-5/ssp245/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/ScenarioMIP/CSIRO/ACCESS-ESM1-5/ssp370/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/CMIP/CSIRO-ARCCSS/ACCESS-CM2/historical/r1i1p1f1/day/pr/v1.1.zarr
gs://downscaled-288ec5ac/outputs/ScenarioMIP/CSIRO-ARCCSS/