# Obtain NetCDF file from ERA5-Land using the CDSAPI
Information here: https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation

In [None]:
import xarray as xr
import hvplot.xarray
import pandas as pd
import dask
import fsspec
import cdsapi

In [None]:
c = cdsapi.Client()

#### Spin up Dask Cluster

In [None]:
def configure_cluster(resource):
    ''' Helper function to configure cluster
    '''
    if resource == 'denali':
        cluster = LocalCluster(threads_per_worker=1)
        client = Client(cluster)
    
    elif resource == 'tallgrass':
        from dask_jobqueue import SLURMCluster
        cluster = SLURMCluster(queue='cpu', cores=1, interface='ib0',
                               job_extra=['--nodes=1', '--ntasks-per-node=1', '--cpus-per-task=1'],
                               memory='6GB')
        cluster.adapt(maximum_jobs=30)
        client = Client(cluster)
        
    elif resource == 'local':
        import os
        import warnings
        warnings.warn("Running locally can result in costly data transfers!\n")
        n_cores = os.cpu_count() # set to match your machine
        cluster = LocalCluster(threads_per_worker=n_cores)
        client = Client(cluster)
        
    elif resource in ['esip-qhub-gateway-v0.4']:   
        import sys, os
        sys.path.append(os.path.join(os.environ['HOME'],'shared','users','lib'))
        import ebdpy as ebd
        aws_profile = 'esip-qhub'
        ebd.set_credentials(profile=aws_profile)  # sets credentials for notebook
        aws_region = 'us-west-2'
        endpoint = f's3.{aws_region}.amazonaws.com'
        ebd.set_credentials(profile=aws_profile, region=aws_region, endpoint=endpoint)
        worker_max = 3
        client,cluster = ebd.start_dask_cluster(profile=aws_profile, worker_max=worker_max, 
                                              region=aws_region, use_existing_cluster=True,
                                              adaptive_scaling=False, wait_for_cluster=False, 
                                              worker_profile='Small Worker', propagate_env=True)
        
    return client, cluster

In [None]:
resource = 'esip-qhub-gateway-v0.4' #denali, tallgrass, local, esip-qhub-gateway-v0.4

In [None]:
client, cluster = configure_cluster(resource)

In [None]:
cluster.scale(30)

In [None]:
# client.wait_for_workers(n_workers=30)

#### Specify variables, spatial and temporal extents

In [None]:
var_list = ['snow_depth_water_equivalent', 'soil_temperature_level_1']

In [None]:
# CONUS
north = 49.3457868 
west = -124.7844079 
east = -66.9513812 
south =  24.7433195 

In [None]:
fs = fsspec.filesystem('s3', anon=False,  skip_instance_cache=True)

In [None]:
dates = pd.date_range('1980-12-01','2022-01-31', freq='14D')
print(dates)

In [None]:
start_dates = [date.strftime('%Y-%m-%d') for date in dates[:-1]]
stop_dates = [(date+pd.offsets.Day(13)).strftime('%Y-%m-%d') for date in dates[:-1]]

In [None]:
s3_files = [f'esip-qhub/usgs/era5_land/conus_{start_date}.nc' for start_date in start_dates]

In [None]:
s3_files_processed = fs.glob('esip-qhub/usgs/era5_land/conus_*.nc')

In [None]:
s3_files_to_create = list(set(s3_files) - set(s3_files_processed))

In [None]:
print(len(s3_files))
print(len(s3_files_to_create))

#### test generation of start_date and stop_date from s3file

In [None]:
import datetime as dt
s3file = s3_files_to_create[0]
start_date = s3file.split('_')[-1].split('.')[0]
s = start_date.split('-')
stop_date = (dt.datetime(int(s[0]),int(s[1]),int(s[2])) + pd.offsets.Day(13)).strftime('%Y-%m-%d')
print(start_date)
print(stop_date)

In [None]:
def get_chunk(s3file, keepbits):
    
    from numcodecs.bitround import BitRound
    import pandas as pd

    def _np_bitround(data, keepbits):
        """Bitround for Arrays."""
        codec = BitRound(keepbits=keepbits)
        data = data.copy()  # otherwise overwrites the input
        encoded = codec.encode(data)
        return codec.decode(encoded)


    def _keepbits_interface(da, keepbits):
        """Common interface to allowed keepbits types
        Parameters
        ----------
        da : :py:class:`xarray.DataArray`
          Input data to bitround
        keepbits : int, dict of {str: int}, :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset`
          How many bits to keep as int
        Returns
        -------
        keep : int
          Number of keepbits for variable given in ``da``
        """
        assert isinstance(da, xr.DataArray)
        if isinstance(keepbits, int):
            keep = keepbits
        elif isinstance(keepbits, dict):
            v = da.name
            if v in keepbits.keys():
                keep = keepbits[v]
            else:
                raise ValueError(f"name {v} not for in keepbits: {keepbits.keys()}")
        elif isinstance(keepbits, xr.Dataset):
            assert keepbits.coords["inflevel"].shape <= (
                1,
            ), "Information content is only allowed for one 'inflevel' here. Please make a selection."
            if "dim" in keepbits.coords:
                assert keepbits.coords["dim"].shape <= (
                    1,
                ), "Information content is only allowed along one dimension here. Please select one `dim`. To find the maximum keepbits, simply use `keepbits.max(dim='dim')`"
            v = da.name
            if v in keepbits.keys():
                keep = int(keepbits[v])
            else:
                raise ValueError(f"name {v} not for in keepbits: {keepbits.keys()}")
        elif isinstance(keepbits, xr.DataArray):
            assert keepbits.coords["inflevel"].shape <= (
                1,
            ), "Information content is only allowed for one 'inflevel' here. Please make a selection."
            assert keepbits.coords["dim"].shape <= (
                1,
            ), "Information content is only allowed along one dimension here. Please select one `dim`. To find the maximum keepbits, simply use `keepbits.max(dim='dim')`"
            v = da.name
            if v == keepbits.name:
                keep = int(keepbits)
            else:
                raise KeyError(f"no keepbits found for variable {v}")
        else:
            raise TypeError(f"type {type(keepbits)} is not a valid type for keepbits.")
        return keep
    
    def xr_bitround(da, keepbits):
    
        """Apply bitrounding based on keepbits from :py:func:`xbitinfo.xbitinfo.get_keepbits` for :py:class:`xarray.Dataset` or :py:class:`xarray.DataArray` wrapping ``numcodecs.bitround``
        Parameters
        ----------
        da : :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset`
          Input data to bitround
        keepbits : int, dict of {str: int}, :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset`
          How many bits to keep as int. Fails if dict or :py:class:`xarray.Dataset` and key or variable not present.
        Returns
        -------
        da_bitrounded : :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset`
        Example
        -------
        >>> ds = xr.tutorial.load_dataset("air_temperature")
        >>> info_per_bit = xb.get_bitinformation(ds, dim="lon")
        >>> keepbits = xb.get_keepbits(info_per_bit, 0.99)
        >>> ds_bitrounded = xb.xr_bitround(ds, keepbits)
        """
        if isinstance(da, xr.Dataset):
            da_bitrounded = da.copy()
            for v in da.data_vars:
                da_bitrounded[v] = xr_bitround(da[v], keepbits)
            return da_bitrounded

        assert isinstance(da, xr.DataArray)
        keep = _keepbits_interface(da, keepbits)

        da = xr.apply_ufunc(_np_bitround, da, keep, dask="parallelized", keep_attrs=True)
        da.attrs["_QuantizeBitRoundNumberOfSignificantDigits"] = keep
        return da

    import datetime as dt   
    start_date = s3file.split('_')[-1].split('.')[0]
    s = start_date.split('-')
    stop_date = (dt.datetime(int(s[0]),int(s[1]),int(s[2])) + pd.offsets.Day(13)).strftime('%Y-%m-%d')
    local_ncfile = f'era5land_{start_date}.nc'
    local_nc4file = f'era5_land_{start_date}.nc'
    c.retrieve(
        'reanalysis-era5-land',
        {
            'variable': var_list, 
            'area'    : f'{north}/{west}/{south}/{east}', 
            'date'    : f'{start_date}/{stop_date}',
            'time': ['00:00', '01:00', '02:00', '03:00', '04:00', '05:00',
                     '06:00', '07:00', '08:00', '09:00', '10:00', '11:00',
                     '12:00', '13:00', '14:00', '15:00', '16:00', '17:00',
                     '18:00', '19:00', '20:00', '21:00', '22:00', '23:00'],
            'format':'netcdf'
        },
        local_ncfile)
        
    ds = xr.open_dataset(local_ncfile)
    ds_bitrounded = xr_bitround(ds, keepbits)
    encoding = {}
    for data_var in ds.data_vars:
        encoding[data_var]=dict(dtype='float32', zlib=True)

    encoding['latitude'] = {'_FillValue':None}
    encoding['longitude'] = {'_FillValue':None}

    ds_bitrounded.to_netcdf(local_nc4file, engine='netcdf4', encoding=encoding, mode='w')  
    fs.upload(local_nc4file, s3file)
    fs2 = fsspec.filesystem('file')
    fs2.rm([local_ncfile, local_nc4file])

In [None]:
keepbits = xr.open_dataset('keepbits.nc')

In [None]:
keepbits

In [None]:
%%time
_ = dask.compute(*[dask.delayed(get_chunk)(s3file,keepbits) for s3file in s3_files_to_create], retries=10);

In [None]:
flist = fs.glob('esip-qhub/usgs/era5_land/*.nc')

In [None]:
fs.info(flist[-1])

In [None]:
ds = xr.open_dataset(fs.open(flist[-1]), chunks={})

In [None]:
ds

In [None]:
ds.sd.isel(time=0).plot()