# Rechunk, flip, Q/A, and deliver tasmin & tasmax

In [45]:
HISTORY = '''
v1.1 : switch to additive QDM tasmin with swapping where tasmin > tasmax; also include
       regridding with nearest neighbor patch in QPLAD.
v1.0 : initial release; QDM tasmax (additive) and DTR (multiplicative)
'''.strip()

OUTPUT_VERSION = 'v1.1'

In [1]:
import os
import json
import fsspec
import requests
import contextlib
import xarray as xr
import pandas as pd
import numpy as np
import itertools
import zarr
import rechunker
import dask
import rhg_compute_tools.kubernetes as rhgk
import rhg_compute_tools.utils as rhgu
import dask.distributed as dd
from tqdm.auto import tqdm
import matplotlib
import matplotlib.pyplot as plt

  from distributed.utils import LoopRunner, format_bytes


In [27]:
! pwd

/home/jovyan/repositories/downscaleCMIP6/notebooks/downscaling_pipeline


In [28]:
DELIVERY_MODELS = [
    # 'BCC-CSM2-MR',
    # 'FGOALS-g3',
    # 'ACCESS-ESM1-5',
    # 'ACCESS-CM2',
    # 'INM-CM4-8',
    # 'INM-CM5-0',
    # 'MIROC-ES2L',
    # 'MIROC6',
    # 'NorESM2-LM',
    # 'NorESM2-MM',
    # 'GFDL-ESM4',
    # 'GFDL-CM4',
    'NESM3',
    #'CMCC-CM2-SR5',
]

In [29]:
INSTITUTIONS = {    
    'BCC-CSM2-MR': 'BCC',
    'FGOALS-g3': 'CAS',
    'ACCESS-ESM1-5': 'CSIRO',
    'ACCESS-CM2': 'CSIRO-ARCCSS',
    'INM-CM4-8': 'INM',
    'INM-CM5-0': 'INM',
    'MIROC-ES2L': 'MIROC',
    'MIROC6': 'MIROC',
    'NorESM2-LM': 'NCC',
    'NorESM2-MM': 'NCC',
    'GFDL-ESM4': 'NOAA-GFDL',
    'GFDL-CM4': 'NOAA-GFDL',
    'NESM3': 'NUIST',
    'CMCC-CM2-SR5':'CMCC',
}

In [31]:
ENSEMBLE_MEMBERS = {
    'BCC-CSM2-MR': 'r1i1p1f1',
    'FGOALS-g3': 'r1i1p1f1',
    'ACCESS-ESM1-5': 'r1i1p1f1',
    'ACCESS-CM2': 'r1i1p1f1',
    'INM-CM4-8': 'r1i1p1f1',
    'INM-CM5-0': 'r1i1p1f1',
    'MIROC-ES2L': 'r1i1p1f2',
    'MIROC6': 'r1i1p1f1',
    'NorESM2-LM': 'r1i1p1f1',
    'NorESM2-MM': 'r1i1p1f1',
    'GFDL-ESM4': 'r1i1p1f1',
    'GFDL-CM4': 'r1i1p1f1',
    'NESM3': 'r1i1p1f1',
    'CMCC-CM2-SR5':'r1i1p1f1',
}

In [32]:
GRID_SPECS = {
    "ACCESS-CM2": "gn",
    "MRI-ESM2-0": "gn",
    "CanESM5": "gn",
    "ACCESS-ESM1-5": "gn",
    "MIROC6": "gn",
    "EC-Earth3": "gr",
    "EC-Earth3-Veg-LR": "gr",
    "EC-Earth3-Veg": "gr",
    "MPI-ESM1-2-HR": "gn",
    "CMCC-ESM2": "gn",
    "INM-CM5-0": "gr1",
    "INM-CM4-8": "gr1",
    "MIROC-ES2L": "gn",
    "MPI-ESM1-2-LR": "gn",
    "FGOALS-g3": "gn",
    "BCC-CSM2-MR": "gn",
    "AWI-CM-1-1-MR": "gn",
    "NorESM2-LM": "gn",
    "GFDL-ESM4": "gr1",
    "GFDL-CM4": "gr1",
    "CAMS-CSM1-0": "gn",
    "NorESM2-MM": "gn",
    "NESM3": "gn",
    'CMCC-CM2-SR5':'gn',
}

In [33]:
downscaled_filepatt = (
    'gs://downscaled-288ec5ac/stage/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{run_version}.zarr'
)

rechunked_temp_store_pattern = (
    'gs://scratch-170cd6ec/stage/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{run_version}-rechunked-temp-store.zarr'
)

rechunked_pattern = (
    'gs://scratch-170cd6ec/stage/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{run_version}-rechunked.zarr'
)

fipped_pattern = (
    'gs://scratch-170cd6ec/stage/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{grid_spec}/'
    '{run_version}-tasminmax-flipped.zarr'
)

OUTPUT_PATTERN = (
    'gs://downscaled-288ec5ac/outputs/{activity_id}/{institution_id}/{source_id}/'
    '{experiment_id}/{member_id}/{table_id}/{variable_id}/{delivery_version}.zarr'
)

In [34]:
fs = fsspec.filesystem('gs', timeout=360, cache_timeout=360, requests_timeout=360, read_timeout=360, conn_timeout=360, token='/opt/gcsfuse_tokens/impactlab-data.json')

In [35]:
spec_file = f'version_specs/{OUTPUT_VERSION}.json'

if os.path.isfile(spec_file):
    with open(spec_file, 'r') as f:
        INPUT_FILE_VERSIONS = json.load(f)

else:
    tasmax_fps = {m: {} for m in DELIVERY_MODELS}

    for m in tqdm(DELIVERY_MODELS, desc='tasmax'):

        inst = INSTITUTIONS[m]

        for act, scen in [
            ('CMIP', 'historical'),
            ('ScenarioMIP', 'ssp245'),
            ('ScenarioMIP', 'ssp370'),
        ]:
            tasmax_fps[m][scen] = list(
                fs.glob(
                    downscaled_filepatt.format(
                        activity_id=act,
                        institution_id=inst,
                        source_id=m,
                        experiment_id=scen,
                        member_id=ENSEMBLE_MEMBERS[m],
                        table_id='day',
                        variable_id='tasmax',
                        grid_spec=GRID_SPECS[m],
                        run_version='*',
                    )
                )
            )

    tasmax_max_versions = {
        m: {s: max(vs) for s, vs in mspec.items() if len(vs) > 0}
        for m, mspec in tasmax_fps.items()
    }

    tasmin_fps = {m: {} for m in DELIVERY_MODELS}

    for m in tqdm(DELIVERY_MODELS, desc='tasmin'):

        inst = INSTITUTIONS[m]

        for act, scen in [
            ('CMIP', 'historical'),
            ('ScenarioMIP', 'ssp245'),
            ('ScenarioMIP', 'ssp370'),
        ]:
            tasmin_fps[m][scen] = list(
                fs.glob(
                    downscaled_filepatt.format(
                        activity_id=act,
                        institution_id=inst,
                        source_id=m,
                        experiment_id=scen,
                        member_id=ENSEMBLE_MEMBERS[m],
                        table_id='day',
                        variable_id='tasmin',
                        grid_spec=GRID_SPECS[m],
                        run_version='*',
                    )
                )
            )

    tasmin_max_versions = {
        m: {s: max(vs) for s, vs in mspec.items() if len(vs) > 0}
        for m, mspec in tasmin_fps.items()
    }

    INPUT_FILE_VERSIONS = {
        'version': OUTPUT_VERSION,
        'created': pd.Timestamp.now(tz='US/Pacific').strftime('%c'),
        'history': HISTORY,
        'file_paths': {
            'tasmin': tasmin_max_versions,
            'tasmax': tasmax_max_versions,
        },
    }

    ! mkdir -p version_specs

    with open(spec_file, 'w') as f:
        f.write(json.dumps(INPUT_FILE_VERSIONS))

In [36]:
for m in DELIVERY_MODELS:
    for v in ['tasmin', 'tasmax']:
        if m not in INPUT_FILE_VERSIONS['file_paths'][v]:
            raise ValueError(f"model {m} not found for {v}")

        for s, fp in INPUT_FILE_VERSIONS['file_paths'][v][m].items():
            assert m in fp, f"model name '{m}' not found in filepath '{fp}'"
            assert s in fp, f"scenario '{s}' not found in filepath '{fp}'"

In [37]:
CC0_LICENSE_MODELS = ['FGOALS-g3', 'INM-CM4-8', 'INM-CM5-0']

CC_BY_LICENSE_MODELS = [
    'BCC-CSM2-MR',
    'ACCESS-ESM1-5',
    'ACCESS-CM2',
    'MIROC-ES2L',
    'MIROC6',
    'NorESM2-LM',
    'NorESM2-MM',
    'GFDL-CM4',
    'GFDL-ESM4',
    'NESM3',
    'CMCC-CM2-SR5',
]

for m in (CC0_LICENSE_MODELS + CC_BY_LICENSE_MODELS):
    for v in ['tasmin', 'tasmax']:
        if m not in INPUT_FILE_VERSIONS['file_paths'][v]:
            print(m)

for m in DELIVERY_MODELS:
    assert m in (CC0_LICENSE_MODELS + CC_BY_LICENSE_MODELS)

FGOALS-g3
FGOALS-g3
INM-CM4-8
INM-CM4-8
INM-CM5-0
INM-CM5-0
BCC-CSM2-MR
BCC-CSM2-MR
ACCESS-ESM1-5
ACCESS-ESM1-5
ACCESS-CM2
ACCESS-CM2
MIROC-ES2L
MIROC-ES2L
MIROC6
MIROC6
NorESM2-LM
NorESM2-LM
NorESM2-MM
NorESM2-MM
GFDL-CM4
GFDL-CM4
GFDL-ESM4
GFDL-ESM4
NESM3
NESM3


# Function Definitions

## Support functions

In [38]:
@rhgu.block_globals(whitelist=[
    'downscaled_filepatt',
    'rechunked_temp_store_pattern',
    'rechunked_pattern',
    'fipped_pattern',
    'OUTPUT_PATTERN',
])
def get_spec_from_input_fp(fp, output_version=OUTPUT_VERSION):
    (
        bucket,
        stage,
        activity,
        institution,
        model,
        scenario,
        ensemble,
        table,
        variable,
        grid,
        run_version,
    ) = os.path.splitext(fp)[0].replace('gs://', '').split('/')

    spec = dict(
        bucket=bucket,
        stage=stage,
        activity=activity,
        institution=institution,
        model=model,
        scenario=scenario,
        ensemble=ensemble,
        table=table,
        variable=variable,
        grid=grid,
        run_version=run_version,
    )

    for (name, fpatt) in [
        ('downscaled_fp', downscaled_filepatt),
        ('rechunk_temp_store_fp', rechunked_temp_store_pattern),
        ('rechunked_fp', rechunked_pattern),
        ('flipped_fp', fipped_pattern),
        ('output_fp', OUTPUT_PATTERN),
    ]:
        spec[name] = fpatt.format(
            activity_id=activity,
            institution_id=institution,
            source_id=model,
            experiment_id=scenario,
            member_id=ensemble,
            variable_id=variable,
            table_id=table,
            grid_spec=grid,
            run_version=run_version,
            delivery_version=output_version,
        )

    return spec

@rhgu.block_globals
def get_spec_from_output_fp(fp, output_pattern=OUTPUT_PATTERN):
    
    (
        bucket,
        stage,
        activity,
        institution,
        model,
        scenario,
        ensemble,
        table,
        variable,
        output_version,
    ) = os.path.splitext(fp)[0].replace('gs://', '').split('/')

    output_fp = output_pattern.format(
        activity_id=activity,
        institution_id=institution,
        source_id=model,
        experiment_id=scenario,
        member_id=ensemble,
        variable_id=variable,
        table_id=table,
        delivery_version=output_version,
    )

    return dict(
        activity=activity,
        institution=institution,
        model=model,
        scenario=scenario,
        ensemble=ensemble,
        table=table,
        variable=variable,
        output_version=output_version,
    )


## Stage 1: Rechunk

In [39]:
@rhgu.block_globals(whitelist=[
    'INPUT_FILE_VERSIONS',
])
def rechunk_data(varname, model, scenario, worker_memory_limit):

    fs = fsspec.filesystem('gs', timeout=360, cache_timeout=360, requests_timeout=360, read_timeout=360, conn_timeout=360, token='/opt/gcsfuse_tokens/impactlab-data.json')

    target_chunks = {
        varname: {'time': 365, 'lat': 360, 'lon': 360},
        'time': {'time': 365},
        'lat': {'lat': 360},
        'lon': {'lon': 360},
    }

    input_fp = INPUT_FILE_VERSIONS['file_paths'][varname][model][scenario]
    input_spec = get_spec_from_input_fp(input_fp)

    rechunked_temp_store_fp = input_spec['rechunk_temp_store_fp']
    rechunked_fp = input_spec['rechunked_fp']

    if fs.isdir(rechunked_fp):
        return

    mapper = fs.get_mapper(input_fp)
    with xr.open_zarr(mapper) as ds:

        rechunked_mapper = fs.get_mapper(rechunked_fp)

        chunk_job = rechunker.rechunk(
            source=ds,
            target_chunks=target_chunks,
            max_mem=worker_memory_limit,
            target_store=rechunked_mapper,
            temp_store=fs.get_mapper(rechunked_temp_store_fp),
        )

        chunk_job_persist = chunk_job._plan.persist()
        dd.wait(chunk_job_persist)

    zarr.convenience.consolidate_metadata(rechunked_mapper)

## Stage 2: Flip Negative DTR

In [40]:
@rhgu.block_globals(whitelist=['INPUT_FILE_VERSIONS', 'rechunked_pattern', 'fipped_pattern'])
def flip_negative_dtr(model, scenario):

    fs = fsspec.filesystem('gs', timeout=360, cache_timeout=360, requests_timeout=360, read_timeout=360, conn_timeout=360, token='/opt/gcsfuse_tokens/impactlab-data.json')
    client = dd.get_client()

    tasmin_input_fp = INPUT_FILE_VERSIONS['file_paths']['tasmin'][model][scenario]
    tasmin_input_spec = get_spec_from_input_fp(tasmin_input_fp)
    tasmin_rechunked_fp = tasmin_input_spec['rechunked_fp']
    tasmin_fipped_fp = tasmin_input_spec['flipped_fp']

    tasmax_input_fp = INPUT_FILE_VERSIONS['file_paths']['tasmax'][model][scenario]
    tasmax_input_spec = get_spec_from_input_fp(tasmax_input_fp)
    tasmax_rechunked_fp = tasmax_input_spec['rechunked_fp']
    tasmax_fipped_fp = tasmax_input_spec['flipped_fp']

    if fs.isdir(tasmin_fipped_fp) and fs.isdir(tasmax_fipped_fp):
        return

    tasmin_in_mapper = fs.get_mapper(tasmin_rechunked_fp)
    tasmax_in_mapper = fs.get_mapper(tasmax_rechunked_fp)

    with xr.open_zarr(tasmin_in_mapper) as tasmin_ds, xr.open_zarr(tasmax_in_mapper) as tasmax_ds:

        tasmin_ds_out = tasmin_ds.copy(deep=False)
        tasmax_ds_out = tasmax_ds.copy(deep=False)

        tasmin_ds_out['tasmin'] = np.minimum(tasmin_ds['tasmin'], tasmax_ds['tasmax'])
        tasmax_ds_out['tasmax'] = np.maximum(tasmin_ds['tasmin'], tasmax_ds['tasmax'])

        tasmin_ds_out['tasmin'].attrs.update(tasmin_ds['tasmin'].attrs)
        tasmax_ds_out['tasmax'].attrs.update(tasmax_ds['tasmax'].attrs)

        dtr = (tasmax_ds_out['tasmax'] - tasmin_ds_out['tasmin'])
        min_dtr = dtr.min()
        dtr_usually_positive = (dtr > 0.1).mean()

        tasmin_out_mapper = fs.get_mapper(tasmin_fipped_fp)
        tasmax_out_mapper = fs.get_mapper(tasmax_fipped_fp)

        write_tasmin = tasmin_ds_out.to_zarr(tasmin_out_mapper, consolidated=True, compute=False)
        write_tasmax = tasmax_ds_out.to_zarr(tasmax_out_mapper, consolidated=True, compute=False)

        min_dtr, dtr_usually_positive, write_tasmin, write_tasmax = client.compute(
            [min_dtr, dtr_usually_positive, write_tasmin, write_tasmax],
            optimize_graph=True,
            retries=3,
            sync=True,
        )

        assert (min_dtr >= 0).item() is True, (
            f"DTR not always positive after flip: min DTR: {min_dtr.item()}"
        )

        assert (dtr_usually_positive > 0.99).item() is True, (
            f"DTR not almost always > 0.1: {dtr_usually_positive.item()}"
        )

## Stage 3: copy to destination directory & validate

In [41]:
@rhgu.block_globals(whitelist=['CC0_LICENSE_MODELS', 'CC_BY_LICENSE_MODELS'])
def quick_check_file(fp, ds, spec):
    """
    """

    # check that metadata matches file spec

    assert ds.attrs['institution_id'] == spec['institution'], (
        f"invalid attrs in {fp}: {ds.attrs['institution_id']} ≠ {spec['institution']}"
    )

    assert ds.attrs['source_id'] == spec['model'], (
        f"invalid attrs in {fp}: {ds.attrs['source_id']} ≠ {spec['model']}"
    )

    assert spec['activity'] in ds.attrs['activity_id'], (
        f"invalid attrs in {fp}: {spec['activity']} not in {ds.attrs['activity_id']}"
    )

    assert ds.attrs['experiment_id'] == spec['scenario'], (
        f"invalid attrs in {fp}: {ds.attrs['experiment_id']} ≠ {spec['scenario']}"
    )

    assert ds.attrs['variant_label'] == spec['ensemble'], (
        f"invalid attrs in {fp}: {ds.attrs['variant_label']} ≠ {spec['ensemble']}"
    )

    if spec['variable'] == 'tasmax':
        assert ds['tasmax'].attrs['long_name'] == 'Daily Maximum Near-Surface Air Temperature'
        assert ds['tasmax'].attrs['units'] == 'K'
    elif spec['variable'] == 'tasmin':
        assert ds['tasmin'].attrs['long_name'] == 'Daily Minimum Near-Surface Air Temperature'
        assert ds['tasmin'].attrs['units'] == 'K'
    elif spec['variable'] == 'pr':
        raise NotImplementedError()
#         assert ds['tasmax'].attrs['units'] == 'mm/day'
    else:
        raise ValueError(f'variable not recognized: {spec["variable"]}')

    # Check licensing fields & endpoint URL

    # check that license URL points to a real location and it exists
    license_url = ds.attrs['license']
    assert ds.attrs['source_id'] in license_url, (
        f'model "{ds.attrs["source_id"]}" not found in license url: {license_url}'
    )
    r = requests.get(license_url)
    r.raise_for_status()

    # check that "Creaive Commons" and the model name show up on the page
    assert ds.attrs['source_id'] in r.text, (
        f'model "{ds.attrs["source_id"]}" not found on license page: {license_url}'
    )

    assert "Creative Commons" in r.text, (
        f'"Creative Commons" not found on license page: {license_url}'
    )

    # check that "Creative Commons" appears in the raw license text

    raw_license_url = (
        ds.attrs['license']
        .replace('github.com', 'raw.githubusercontent.com')
        .replace('/blob/', '/')
        .replace('/tree/', '/')
    )

    assert ds.attrs['source_id'] in raw_license_url, (
        f'model "{ds.attrs["source_id"]}" not found in license url: {raw_license_url}'
    )
    r = requests.get(raw_license_url)
    r.raise_for_status()
    assert 'Creative Commons' in r.text, (
        f'"Creative Commons" not found in license text: {raw_license_url}'
    )

    if spec['model'] in CC0_LICENSE_MODELS:
        assert 'CC0 1.0 Universal' in r.text, (
            f"expected CC0 license for {spec['model']} at {fp}"
        )
    elif spec['model'] in CC_BY_LICENSE_MODELS:
        assert 'Attribution 4.0 International' in r.text, (
            f"expected CC-BY 4.0 license for {spec['model']} at {fp}"
        )
    else:
        raise ValueError(
            f"deploying model with unknown license: {spec['model']} at {fp}"
        )

    # Check dimension size & membership

    for c in ds.coords.keys():
        assert ds.coords[c].notnull().all().item() is True, f"NaNs found in coordinate '{c}' in {fp}"

    if spec['activity'] == 'ScenarioMIP':
        date_range = xr.cftime_range("2015-01-01", "2099-12-31", freq="D", calendar="noleap")
        if len(ds.time) > len(date_range):
            date_range = xr.cftime_range("2015-01-01", "2100-12-31", freq="D", calendar="noleap")
    else:
        date_range = xr.cftime_range("1950-01-01", "2014-12-31", freq="D", calendar="noleap")

    assert ds.sizes['time'] == len(date_range), (
        f"unexpected length of dimension 'time': length {len(ds.time)}; "
        f"expected {len(date_range)} in {fp}"
    )

    assert date_range.isin(ds.time.dt.floor('D').values).all(), f"invalid coords in {fp}"

    assert pd.Series(np.arange(-179.875, 180, 0.25)).isin(ds.lon.values).all(), (
        f"invalid coords in {fp}"
    )
    assert pd.Series(np.arange(-89.875, 90, 0.25)).isin(ds.lat.values).all(), (
        f"invalid coords in {fp}"
    )

    varnames = list(ds.data_vars.keys())
    assert len(varnames) == 1
    varname = varnames[0]

    assert ds[varname].sizes['lat'] == 720, f"lat not length 720 in {fp}:\n{ds}"
    assert ds[varname].sizes['lon'] == 1440, f"lon not length 1440 in {fp}:\n{ds}"

In [42]:
spec = get_spec_from_input_fp(INPUT_FILE_VERSIONS['file_paths']['tasmax']['CMCC-CM2-SR5']['historical'])

In [43]:
spec['output_fp']

'gs://downscaled-288ec5ac/outputs/CMIP/CMCC/CMCC-CM2-SR5/historical/r1i1p1f1/day/tasmax/v1.1.zarr'

In [44]:
@rhgu.block_globals(whitelist=['INPUT_FILE_VERSIONS'])
def validate_outputs(fp, quick=False, check_dtr=False):
    spec = get_spec_from_output_fp(fp)

    fs = fsspec.filesystem('gs', timeout=360, cache_timeout=360, requests_timeout=360, read_timeout=360, conn_timeout=360, token='/opt/gcsfuse_tokens/impactlab-data.json')

    mapper = fs.get_mapper(fp)

    if check_dtr and (spec['variable'] != 'tasmin'):
        raise ValueError('check_dtr can only be used with variable == "tasmin"')

    if check_dtr:
        tasmax_spec = get_spec_from_input_fp(
            INPUT_FILE_VERSIONS['file_paths']['tasmax'][spec['model']][spec['scenario']]
        )

        tasmax_fp = tasmax_spec['output_fp']
        tasmax_mapper = fs.get_mapper(tasmax_fp)
        tasmax_opener = xr.open_zarr(tasmax_mapper)

    else:
        tasmax_opener = contextlib.nullcontext()

    with xr.open_zarr(mapper) as ds, tasmax_opener as tasmax_ds:

        quick_check_file(fp, ds, spec)

        if quick:
            return

        # check variable contents

        varnames = list(ds.data_vars.keys())
        assert len(varnames) == 1
        varname = varnames[0]

        to_check = ds[varname].sel(lat=slice(-80, 80))

        nans = to_check.isnull().any()
        min_val = to_check.min()
        max_val = to_check.max()

        if check_dtr:
            assert varname == 'tasmin'
            min_dtr = (tasmax_ds.tasmax.sel(lat=slice(-80, 80)) - to_check).min()
        else:
            min_dtr = dask.delayed(lambda: None)()

        nans, vmin, vmax, min_dtr = dd.get_client().compute(
            [nans, min_val, max_val, min_dtr],
            optimize_graph=True,
            sync=True,
            retries=3,
        )

        assert nans.item() is False, f"NaNs found in {fp}"

        if varname == 'tasmax':
            allowed_min = 150
            allowed_max = 360
        elif varname == 'tasmin':
            allowed_min = 150
            allowed_max = 360
        elif varname == 'pr':
            allowed_min = 0
            allowed_max = 3000
        else:
            raise ValueError(f'Variable name not recognized: {varname}\nin file: {fp}')

        assert (vmin >= allowed_min).item() is True, (
            f"min value {vmin} outside allowed range [{allowed_min}, {allowed_max}] "
            f"for {varname} in {fp}"
        )
        assert (vmax <= allowed_max).item() is True, (
            f"max value {vmax} outside allowed range [{allowed_min}, {allowed_max}] "
            f"for {varname} in {fp}"
        )

        if check_dtr:
            assert (min_dtr >= 0).item() is True, (
                f"DTR not greater than zero - min DTR: {min_dtr.item()} in {fp}"
            )


@rhgu.block_globals
def copy_and_validate(
    source_fp,
    output_version=OUTPUT_VERSION,
    check=False,
    deep_copy_check=False,
    quick_check_and_retry=True,
    overwrite=False,
    overwrite_on_failure=False,
    check_dtr=False,
    pbar=False,
):

    spec = get_spec_from_input_fp(source_fp, output_version=output_version)
    flipped_fp = spec['flipped_fp']
    output_fp = spec['output_fp']
    model = spec['model']
    scenario = spec['scenario']

    fs = fsspec.filesystem('gs', timeout=360, cache_timeout=360, requests_timeout=360, read_timeout=360, conn_timeout=360, token='/opt/gcsfuse_tokens/impactlab-data.json')

    if fs.exists(output_fp):
        if overwrite:
            fs.remove(output_fp, recursive=True)

        else:
            if deep_copy_check:
                dirs = list([(d, f) for d, dirs, fps in fs.walk(flipped_fp) for f in fps])
                if pbar:
                    dirs = tqdm(dirs)

                for d, f in dirs:
                    src = flipped_fp[:5] + os.path.join(d, f)
                    dst = os.path.join(output_fp, os.path.relpath(src, flipped_fp))
                    assert '..' not in dst
                    src_hash = fs.stat(src)['md5Hash']

                    for i in range(5):
                        try:
                            assert (src_hash == fs.stat(dst)['md5Hash'])
                            break
                        except (FileNotFoundError, AssertionError):
                            if i == 4:
                                raise

                            fs.rm(dst)
                            fs.copy(src, dst)

            if check:
                try:
                    validate_outputs(
                        output_fp,
                        check_dtr=(check_dtr and (spec['variable'] == 'tasmin')),
                    )
                    return
                except (
                    AssertionError,
                    FileNotFoundError,
                    ValueError,
                    IOError,
                    xr.coding.times.OutOfBoundsDatetime,
                    OverflowError,
                ):
                    if overwrite_on_failure:
                        fs.rm(output_fp, recursive=True)
                    else:
                        raise

            elif quick_check_and_retry:
                try:
                    validate_outputs(output_fp, quick=True)
                    return
                except (
                    OverflowError,
                    IOError,
                    zarr.errors.GroupNotFoundError,
                    FileNotFoundError,
                    AssertionError,
                    ValueError,
                ):
                    pass

                fs.rm(output_fp, recursive=True)
            else:
                return

    print(f'copying:\n\tsrc:\t{flipped_fp}\n\tdst:\t{output_fp}')
    fs.copy(flipped_fp, output_fp, recursive=True, batch_size=1000)

    if deep_copy_check:
        for d, f in list([(d, f) for d, dirs, fps in fs.walk(flipped_fp) for f in fps]):
            src = flipped_fp[:5] + os.path.join(d, f)
            dst = os.path.join(output_fp, os.path.relpath(src, flipped_fp))
            assert '..' not in dst
            src_hash = fs.stat(src)['md5Hash']

            for i in range(5):
                try:
                    assert (src_hash == fs.stat(dst)['md5Hash'])
                    break
                except (FileNotFoundError, AssertionError):
                    if i == 4:
                        raise

                    fs.rm(dst)
                    fs.copy(src, dst)

    if check:
        validate_outputs(
            output_fp,
            check_dtr=(check_dtr and (spec['variable'] == 'tasmin')),
        )
    elif quick_check_and_retry:
        validate_outputs(output_fp, quick=True)

# Full workflow

In [None]:
client, cluster = rhgk.get_giant_cluster()
cluster.scale(60)

MAX_MEM = '16GB' # for standard cluster

cluster

In [None]:
print('https://compute.impactlab.org' + cluster.dashboard_link)

# Prepare final outputs

In [34]:
with tqdm(DELIVERY_MODELS) as pbar:
    for model in pbar:
        for scenario in INPUT_FILE_VERSIONS['file_paths']['tasmin'][model].keys():
            
            tasmin_input_fp = INPUT_FILE_VERSIONS['file_paths']['tasmin'][model][scenario]
            tasmin_spec = get_spec_from_input_fp(tasmin_input_fp)

            tasmax_input_fp = INPUT_FILE_VERSIONS['file_paths']['tasmax'][model][scenario]
            tasmax_spec = get_spec_from_input_fp(tasmax_input_fp)

            # comment out this block to reproduce rechunked/flipped data on scratch bucket
            if fs.exists(tasmin_spec['output_fp']) and fs.exists(tasmax_spec['output_fp']):
                print(f'skipping {model} {scenario} - output already exists')
                continue

            pbar.set_postfix({'model': model, 'scen': scenario, 'stage': 'rechunk tasmin'})
            rechunk_data('tasmin', model, scenario, worker_memory_limit=MAX_MEM)
            pbar.set_postfix({'model': model, 'scen': scenario, 'stage': 'rechunk tasmax'})
            rechunk_data('tasmax', model, scenario, worker_memory_limit=MAX_MEM)
            pbar.set_postfix({'model': model, 'scen': scenario, 'stage': 'flip negative DTR'})
            flip_negative_dtr(model, scenario)

  0%|          | 0/13 [00:00<?, ?it/s]

skipping BCC-CSM2-MR historical - output already exists
skipping BCC-CSM2-MR ssp245 - output already exists
skipping BCC-CSM2-MR ssp370 - output already exists
skipping FGOALS-g3 historical - output already exists
skipping FGOALS-g3 ssp245 - output already exists
skipping FGOALS-g3 ssp370 - output already exists
skipping ACCESS-ESM1-5 historical - output already exists
skipping ACCESS-ESM1-5 ssp245 - output already exists
skipping ACCESS-ESM1-5 ssp370 - output already exists
skipping ACCESS-CM2 historical - output already exists
skipping ACCESS-CM2 ssp245 - output already exists
skipping ACCESS-CM2 ssp370 - output already exists
skipping INM-CM4-8 historical - output already exists
skipping INM-CM4-8 ssp245 - output already exists
skipping INM-CM4-8 ssp370 - output already exists
skipping INM-CM5-0 historical - output already exists
skipping INM-CM5-0 ssp245 - output already exists
skipping INM-CM5-0 ssp370 - output already exists
skipping MIROC-ES2L historical - output already exists


In [22]:
tasmin_files = [fp for m, v in INPUT_FILE_VERSIONS['file_paths']['tasmin'].items() for s, fp in v.items()]
tasmax_files = [fp for m, v in INPUT_FILE_VERSIONS['file_paths']['tasmax'].items() for s, fp in v.items()]

In [24]:
def blocking_pbar(futures):
    status = {'error': 0, 'killed': 0, 'lost': 0}
    with tqdm(dd.as_completed(futures), total=len(futures)) as pbar:
        for f in pbar:
            if f.status in status.keys():
                status[f.status] += 1
                pbar.set_postfix(status)

    dd.wait(futures)

# Copy files to final destination

In [25]:
tasmax_futures = client.map(
    copy_and_validate,
    tasmax_files,
    output_version=OUTPUT_VERSION,
    check=False,
    deep_copy_check=False,
    quick_check_and_retry=True,
    overwrite=False,
    overwrite_on_failure=False,
    check_dtr=False,
    pbar=False,
)

blocking_pbar(tasmax_futures)

  0%|          | 0/37 [00:00<?, ?it/s]

In [26]:
tasmin_futures = client.map(
    copy_and_validate,
    tasmin_files,
    output_version=OUTPUT_VERSION,
    check=False,
    deep_copy_check=False,
    quick_check_and_retry=True,
    overwrite=False,
    overwrite_on_failure=False,
    check_dtr=False,
    pbar=False,
)

blocking_pbar(tasmin_futures)

  0%|          | 0/37 [00:00<?, ?it/s]

# Deep copy check
Check every file against source to ensure a complete copy

In [27]:
tasmax_futures = client.map(
    copy_and_validate,
    tasmax_files,
    output_version=OUTPUT_VERSION,
    check=False,
    deep_copy_check=True,
    quick_check_and_retry=True,
    overwrite=False,
    overwrite_on_failure=False,
    check_dtr=False,
    pbar=False,
)

blocking_pbar(tasmax_futures)

  0%|          | 0/37 [00:00<?, ?it/s]

In [28]:
tasmin_futures = client.map(
    copy_and_validate,
    tasmin_files,
    output_version=OUTPUT_VERSION,
    check=False,
    deep_copy_check=True,
    quick_check_and_retry=True,
    overwrite=False,
    overwrite_on_failure=False,
    check_dtr=False,
    pbar=False,
)

blocking_pbar(tasmin_futures)

  0%|          | 0/37 [00:00<?, ?it/s]

### Check tasmax data in final location
Check all tasmax values, including bounds & NAN checks

In [27]:
for f in tqdm(tasmax_files):
    copy_and_validate(
        f,
        output_version=OUTPUT_VERSION,
        check=True,
        deep_copy_check=False,
        quick_check_and_retry=False,
        overwrite=False,
        overwrite_on_failure=False,
        check_dtr=False,
        pbar=False,
    )

  0%|          | 0/37 [00:00<?, ?it/s]

### Check tasmin data & DTR in final location
Check all tasmin values, including bounds & NAN checks, plus check DTR implied by tasmin & tasmax for positivity

In [28]:
for f in tqdm(tasmin_files):
    copy_and_validate(
        f,
        output_version=OUTPUT_VERSION,
        check=True,
        deep_copy_check=False,
        quick_check_and_retry=False,
        overwrite=False,
        overwrite_on_failure=False,
        check_dtr=True,
        pbar=False,
    )

  0%|          | 0/37 [00:00<?, ?it/s]

In [29]:
client.restart()
cluster.scale(0)
client.close()
cluster.close()

In [37]:
outfiles = []
for f in (tasmin_files + tasmax_files):
    outfiles.append(get_spec_from_input_fp(f)['output_fp'])

print(f'outputs are located in the following directory: {os.path.commonpath(outfiles).replace("gs:/", "gs://")}')

outputs are located in the following directory: gs://downscaled-288ec5ac/outputs


To transfer data elsewhere, such as to prep for public delivery or delivery to Catalyst buckets, contact Mike for help with google transfer utility