In [None]:
%load_ext autoreload
%autoreload 2

# Prepare dataset using pygrib

The purpose of this notebook is to prepare a statistical downscaling dataset
using pygrib and XArray. Notebook *018-prepare-dataset.ipynb* showed that 
it is not really practical to do so with XArray and CFGrib only. Indeed,
opening a large collection of grib files is very long with CFGrib, and 
the API limitations on filtering the fields to have only consistent
datacubes force us to open the files multiple times.

Our overall strategy here is going to be:
1. Extract relevant data from grib to NetCDF using PyGrib.
2. The NetCDF will not be CF compliant, it will only be an
   intermediate file format.
3. Open the NetCDF files using XArray, and then perform the
   interpolation.
  
Our hope is that XArray will be much happier playing with netCDF files 
than Grib files, and that we will be able to open and interpolate the 
data only once using XArray and dropping strict CF convention compliance.

# 1. Extract relevant data using pygrib

In [None]:
import dask
import dask.distributed
import dask_jobqueue
import datetime
import multiprocessing
import netCDF4
import os
import pathlib
import pygrib
import xarray as xr

In [None]:
DATA_DIR = pathlib.Path(os.getenv('DATA_DIR'))
GDPS_DIR = DATA_DIR / '2021-02-08-one-month-sample/'
NETCDF_DIR = DATA_DIR / 'hdd_scratch/smc01/2021-02-09-one-month-single-file'

TO_EXTRACT = [
    'prmsl',
    'hpbl',
    'prate',
    '2t',
    '2d',
    '2r',
    '10si',
    '10wdir',
]

In [None]:
cluster = dask_jobqueue.SLURMCluster(
    env_extra=[
        'source ~/.bash_profile','conda activate smc01'],
    name='smc01-dask',
)

In [None]:
cluster.scale(jobs=10)

In [None]:
client = dask.distributed.Client(cluster)

In [None]:
client

In [None]:
gdps_path = pathlib.Path(GDPS_DIR)
output_path = pathlib.Path(NETCDF_DIR)

In [None]:
output_path.mkdir(exist_ok=True)

In [None]:
grib_files = sorted(list(gdps_path.glob('*.grib2')))

In [None]:
grib_files[0:10]

### Figure out how many dates and steps there are

In [None]:
pass_strings = [g.stem[22:32] for g in grib_files]

In [None]:
unique_passes = sorted(list(set(pass_strings)))

In [None]:
unique_passes

In [None]:
n_pass = len(unique_passes)

In [None]:
step_strings = [g.stem[-4:] for g in grib_files]

In [None]:
step_strings[0:10]

In [None]:
unique_steps = sorted(list(set(step_strings)))

In [None]:
n_step = len(unique_steps)

In [None]:
n_step

In [None]:
grib_files[0].stem

In [None]:
target_file_name = grib_files[0].stem + '_filtered.nc'
target_file_path = output_path / target_file_name

In [None]:
target_file_path

In [None]:
lambdas = [
    lambda x: x.shortName == 'st' and x.typeOfLevel == 'surface',
]

In [None]:
def compound_lambda(message):
    for l in lambdas:
        if l(message):
            return True
        
    return False

In [None]:
def pass_to_datetime(pass_string):
    print(pass_string)
    return datetime.datetime(
        int(pass_string[:4]),
        int(pass_string[4:6]),
        int(pass_string[6:8]),
        int(pass_string[8:10])
    )

In [None]:
def prepare_netcdf_file(output_file, date, step):
    time_units = 'seconds since 1970-01-01 00:00:00.0'

    output_file.createDimension('latitude', size=751)
    output_file.createDimension('longitude', size=1500)
    output_file.createDimension('time', size=1)
    output_file.createDimension('step', size=1)

    step_var = output_file.createVariable('step', 'i4', dimensions=('step'))
    step_var[0] = step

    time_var = output_file.createVariable('time', 'f8', dimensions=('time'))
    float_date = netCDF4.date2num(date, time_units)
    time_var[0] = float_date
    time_var.units = time_units

    variables = {}
    for short_name in TO_EXTRACT:
        variables[short_name] = output_file.createVariable(
            short_name, 'f4', dimensions=('time', 'step', 'latitude', 'longitude'),
            zlib=True) 
        
    return variables

In [None]:
def add_latlon_to_file(output_file, lat, lon):
    lat_var = output_file.createVariable('latitude', 'f4', dimensions=('latitude'))
    lat_var[:] = lat

    lon_var = output_file.createVariable('longitude', 'f4', dimensions=('longitude'))
    lon_var[:] = lon

In [None]:
def handle_one_file(input_dir, pass_string, step_string, output_dir):
    print(pass_string)
    target_file_path = output_dir / (f'gdps_{pass_string}_{step_string}.nc')
    
    if target_file_path.is_file():
        print(f'Skipping {pass_string} because output file already exists')
    
    output_file = netCDF4.Dataset(str(target_file_path), 'w')
    date = pass_to_datetime(pass_string)
    
    step = int(step_string[1:])
    
    variables = prepare_netcdf_file(output_file, date, step)

    print(pass_string, step_string)
    grib_file_name = f'CMC_glb_latlon.24x.24_{pass_string}_{step_string}.grib2'
    grib_file_path = input_dir / grib_file_name

    grib_file = pygrib.open(str(grib_file_path))

    lat, lon = grib_file[1].latlons()
    lat = lat[:,0]
    lon = lon[0]

    add_latlon_to_file(output_file, lat, lon)

    for message in grib_file:
        if message.shortName in TO_EXTRACT:
            var = variables[message.shortName]
            var[0,0,:] = message.values

    output_file.close()

### Manually

In [None]:
for p in unique_passes[:1]:
    for s in unique_steps[:1]:
        handle_one_file(gdps_path, p, s, output_path)

### With multiprocessing

In [None]:
with multiprocessing.Pool(processes=8) as pool:
    pool.starmap(handle_one_pass, [(gdps_path, p, output_path) for p in unique_passes])

### With dask

In [None]:
handle_one_file_delayed = dask.delayed(handle_one_file)

In [None]:
delayeds = [handle_one_file_delayed(gdps_path, p, s, output_path) for p in unique_passes for s in unique_steps]

In [None]:
delayeds[0:10]

In [None]:
len(delayeds)

In [None]:
dask.compute(*delayeds)

In [None]:
target_file_name = grib_files[0].stem + '_filtered.nc'
target_file_path = output_path / target_file_name

grib_file.seek(0)
for grib_file_path in grib_files:
    grib_file = pygrib.open(str(grib_file_path))
    for message in grib_file:
        if message.shortName in to_extract or compound_lambda(message):
            print(message)
            
    grib_file.close()

In [None]:
lat, lon = grib_file[1].latlons()
lat = lat[:,0]
lon = lon[0]

In [None]:
target_file_path.unlink()

In [None]:
root = netCDF4.Dataset(str(target_file_path), 'w')

In [None]:
lon[0]

In [None]:
root

In [None]:
root.createDimension('latitude', size=lat.shape[0])

In [None]:
root.createDimension('longitude', size=lon.shape[0])

In [None]:
root.createDimension('time', size=0)

In [None]:
root.createDimension('step', size=81)

In [None]:
lat_var = root.createVariable('latitude', 'f4', dimensions=('latitude'))

In [None]:
lat_var[:] = lat

In [None]:
lon_var = root.createVariable('longitude', 'f4', dimensions=('longitude'))

In [None]:
lon_var[:] = lon

In [None]:
grib_file.seek(0)
for message in grib_file:
    if message.shortName in to_extract or compound_lambda(message):
        var = root.createVariable(message.shortName, 'f4', dimensions=('latitude', 'longitude'))
        var[:] = message.values

In [None]:
root

In [None]:
root.close()

## Open using Xarray

In [None]:
def nest_filenames(files):
    passes = {}
    for f in files:
        pass_name = f.stem[5:15]
        
        pass_list = passes.get(pass_name, [])
        pass_list.append(f)
        passes[pass_name] = pass_list
        
    sorted_passes = sorted(passes.keys())
        
    return [passes[k] for k in sorted_passes]

In [None]:
nc_filenames = sorted(list(output_path.glob('*.nc')))

In [None]:
nc_filenames[1:10]

In [None]:
nested_nc = nest_filenames(nc_filenames)

In [None]:
gdps = xr.open_mfdataset(
    nested_nc, concat_dim=['time', 'step'], 
    combine='nested', parallel=True, compat='no_conflicts')

In [None]:
gdps

In [None]:
gdps.sel(time=datetime.datetime(2020,7,22,12), step=15)['2t'].plot()

In [None]:
gdps.mean(dim='time').sel(step=0)['2t'].plot()