In [None]:
from getpass import getuser # Libaray to copy things
from pathlib import Path # Object oriented libary to deal with paths
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory # Creating temporary Files/Dirs
from subprocess import run, PIPE
import sys
 
import dask # Distributed data libary
from dask_jobqueue import SLURMCluster # Setting up distributed memories via slurm
from distributed import Client, progress, wait # Libaray to orchestrate distributed resources
import xarray as xr # Libary to work with labeled n-dimensional data and dask

In [None]:
import warnings
warnings.filterwarnings(action='ignore')

In [None]:
# Set some user specific variables
scratch_dir = Path('/scratch') / getuser()[0] / getuser() # Define the users scratch dir
# Create a temp directory where the output of distributed cluster will be written to, after this notebook
# is closed the temp directory will be closed
dask_tmp_dir = TemporaryDirectory(dir=scratch_dir, prefix='PostProc')
cluster = SLURMCluster(memory='500GiB',
                       cores=72,
                       project='mh0731',
                       walltime='1:00:00',
                       queue='gpu',
                       name='PostProc',
                       scheduler_options={'dashboard_address': ':12435'},
                       local_directory=dask_tmp_dir.name,
                       job_extra=[f'-J PostProc', 
                                  f'-D {dask_tmp_dir.name}',
                                  f'--begin=now',
                                  f'--output={dask_tmp_dir.name}/LOG_cluster.%j.o',
                                  f'--output={dask_tmp_dir.name}/LOG_cluster.%j.o'
                                 ],
                       interface='ib0')
cluster.scale(jobs=2)
dask_client = Client(cluster)
dask_client.wait_for_workers(18)
data_path = Path('/work/mh0287/k203123/GIT/icon-aes-dyw_albW/experiments/dpp0016/')
glob_pattern_2d = 'atm2_2d_ml'
 
# Collect all file names with pathlib's rglob and list compressions 
file_names = sorted([str(f) for f in data_path.rglob(f'*{glob_pattern_2d}*.nc')]) #[1:]
dset = xr.open_mfdataset(file_names, combine='by_coords', parallel=True)
var_names = ['pr']
dset_subset = dset[var_names].persist()
dset_subset

In [None]:
time_mean = dset_subset.mean(dim='time').persist()
field_mean = dset_subset.mean(dim='ncells').persist()

In [None]:
from dask.utils import format_bytes

In [None]:
format_bytes(dset_subset.nbytes)

In [None]:
dset_subset['pr']

In [None]:
def get_griddes(y_res, x_res, x_first=-180, y_first=-90):
    """Create a description for a regular global grid at given x, y resolution."""
 
    xsize = 360 / x_res
    ysize = 180 / y_res
    xfirst = -180 + x_res / 2
    yfirst = -90 + x_res / 2
 
    return f'''
#
# gridID 1
#
gridtype  = lonlat
gridsize  = {int(xsize * ysize)}
xsize     = {int(xsize)}
ysize     = {int(ysize)}
xname     = lon
xlongname = "longitude"
xunits    = "degrees_east"
yname     = lat
ylongname = "latitude"
yunits    = "degrees_north"
xfirst    = {xfirst}
xinc      = {x_res}
yfirst    = {yfirst}
yinc      = {y_res}
 
 
    '''

In [None]:
@dask.delayed
def gen_dis(dataset, xres, yres, gridfile):
    '''Create a distance weights using cdo.'''
    scratch_dir = Path('/scratch') / getuser()[0] / getuser() # Define the users scratch dir
#     with TemporaryDirectory(dir=scratch_dir, prefix='Weights_') as td:
    if True:
        td = '/scratch/m/m300414/Weights_123'
        in_file = Path(td) / 'in_file.nc'
        weightfile = Path(td) / 'weight_file.nc'
        griddes = Path(td) / 'griddes.txt'
        with griddes.open('w') as f:
            f.write(get_griddes(xres, yres))
        dataset.to_netcdf(in_file, mode='w') # Write the file to a temorary netcdf file
        cmd = ('cdo', '-O', f'gendis,{griddes}', f'-setgrid,{gridfile}', str(in_file), str(weightfile))
        run_cmd(cmd)
        df = xr.open_dataset(weightfile).load()
        wait(df)
        return df
 
def run_cmd(cmd, path_extra=Path(sys.exec_prefix)/'bin'):
    '''Run a bash command.'''
    env_extra = os.environ.copy()
    env_extra['PATH'] = str(path_extra) + ':' + env_extra['PATH']
    status = run(cmd, check=False, stderr=PIPE, stdout=PIPE, env=env_extra)
    if status.returncode != 0:
        error = f'''{' '.join(cmd)}: {status.stderr.decode('utf-8')}'''
        raise RuntimeError(f'{error}')
    return status.stdout.decode('utf-8')

In [None]:
@dask.delayed
def remap(dataset, x_res, y_res, weights, gridfile):
    """Perform a weighted remapping.
 
    Parameters
    ==========
 
    dataset : xarray.dataset
        The dataset that will be regridded
    griddes : Path, str
        Path to the grid description file
    weights : xarray.dataset
        Distance weights
 
    Returns
    =======
    xarray.dataset : Remapped dataset
    """
    if isinstance(dataset, xr.DataArray):
        # If a dataArray is given create a dataset
        dataset = xr.Dataset(data_vars={dataset.name: dataset})
    scratch_dir = Path('/scratch') / getuser()[0] / getuser() # Define the users scratch dir
#     with TemporaryDirectory(dir=scratch_dir, prefix='Remap_') as td:
    if True:
        td = '/scratch/m/m300414/Remap_123'
        infile = Path(td) / 'input_file.nc'
        weightfile = Path(td) / 'weight_file.nc'
        griddes = Path(td) / 'griddes.txt'
        outfile = Path(td) / 'remaped_file.nc'
        with griddes.open('w') as f:
            f.write(get_griddes(x_res, y_res))
        dataset.to_netcdf(infile, mode='w') # Write the file to a temorary netcdf file
        weights.to_netcdf(weightfile, mode='w')
        cmd = ('cdo', '-O', f'remap,{griddes},{weightfile}', f'-setgrid,{gridfile}',
               str(infile), str(outfile))
        run_cmd(cmd)
        return xr.open_dataset(outfile).load()

In [None]:
grid_file = '/pool/data/ICON/grids/public/mpim/0015/icon_grid_0015_R02B09_G.nc'
weights_future = gen_dis(time_mean, 0.0225, 0.0225, grid_file)
weights_future

In [None]:
dset_subset['pr'][:2]

In [None]:
remap_futures = []
# Process each variable in parallel.
for snapshot in dset_subset['pr'][:2]:
    remap_futures.append(remap(dset_subset['pr'].sel(time=snapshot.time.values.astype(str)), 0.0225, 0.0225, weights_future, grid_file))
remap_futures

In [None]:
remap_jobs = dask.persist(remap_futures)
progress(remap_jobs, notebook=False)

In [None]:
dset_remap = xr.concat(list(dask.compute(*remap_futures)), dim=dset_subset.time[:3])
dset_remap

In [None]:
# 1 Save the time-series
out_file = Path(scratch_dir) / 'dpp0016_precip.nc'
dset_remap.to_netcdf(out_file, mode='w')