# GRIB to XArray

The purpose of this notebook is to read a bunch of GRIB files with PyGrib, and directly build an XArray dataset.
The motivation is to avoid having an intermediary file format like NetCDF, and consequently save a lot on IO.

It would be nice to do it using the cfgrib engine for XArray, but cfgrib makes it impractical to open all the fields we want, because it doesn't allow us to open multiple fields with different vertical levels.

In [None]:
%load_ext autoreload
%autoreload 2

The overall strategy is to 

1. Read and filter an GRIB file
2. Put the fields we want in an XArray Dataset
3. Read many grib files in parallel.
4. Merge the XArray datasets from different files.

## 1 Read and filter a grib file

In [None]:
import datetime
import multiprocessing
import numpy as np
import os
import pathlib
import pandas as pd
import pygrib
from tqdm.notebook import tqdm
import xarray as xr

In [None]:
DATA_DIR = pathlib.Path(os.getenv('DATA_DIR'))
GRIB_INPUT_DIR = DATA_DIR / 'data/gdps/2020020112'

In [None]:
def from_short_name(message):
    to_extract = [
        'al',
        'hpbl',
        'prate',
        'prmsl',
        'thick',
        '10si',
        '10wdir',
        '10u',
        '10v',
        '2d',
        '2r',
        '2t',
    ]
    if message.shortName in to_extract:
        return message.shortName
    else:
        return False

In [None]:
class ShortNameLevelExtractor:
    def __init__(self, name, levels):
        self.name = name
        self.levels = levels
        
    def __call__(self, message):
        if message.shortName == self.name:
            for level in self.levels:
                if message.level == level:
                    return f'{self.name}_{level}'
        return False

In [None]:
class CompositeExtractor:
    def __init__(self, extractors):
        self.extractors = extractors
        
    def __call__(self, message):
        for e in self.extractors:
            is_included = e(message)
            if is_included:
                return is_included
        return False

In [None]:
extractor = CompositeExtractor([
    from_short_name,
    ShortNameLevelExtractor('t', [850, 500]),
    ShortNameLevelExtractor('gh', [1000, 850, 500]),
    ShortNameLevelExtractor('q', [850, 500]),
    ShortNameLevelExtractor('u', [500]),
    ShortNameLevelExtractor('v', [500]),
])

In [None]:
def extract_fields(grib_iter, extractor):
    fields = {}
    for message in grib_iter:
        label = extractor(message)
        if label:
            fields[label] = message
            
    return fields

In [None]:
def do_one_file(grib_file_path, extractor):
    print(grib_file_path)
    gribfile = pygrib.open(str(grib_file_path))
    dataset = file_to_xarray(gribfile, extractor)
    gribfile.close()
    
    return dataset

In [None]:
def pass_to_xarray(pass_dir, extractor):
    input_path = pathlib.Path(GRIB_INPUT_DIR)
    input_files = sorted(list(input_path.glob('*.grib2')))
    
    input_files = input_files[:8]
    
    with multiprocessing.Pool() as pool:
        datasets = pool.starmap(do_one_file, [(f, extractor) for f in input_files[:10]])
    
    return xr.concat(datasets, dim='step')

In [None]:
def file_to_xarray(grib_file, extractor):
    sample = next(iter(grib_file))
    
    lats, lons = sample.latlons()
    lats, lons = lats[:,0], lons[0]
    
    datetime = grib_date_to_pandas(sample.dataDate, sample.dataTime)
    datetime = [datetime]
    
    if sample.stepUnits == 1:
        step = [pd.Timedelta(sample.step, 'h')]
    else:
        raise ValueError('Unhandled step units')
        
    fields = {}
    for msg in grib_file:
        label = extractor(msg)
        if label:
            data_array = message_to_xarray(msg, lats, lons, step, datetime)
            fields[label] = data_array
            
    return xr.Dataset(fields)

In [None]:
def message_to_xarray(grib_message, lats, lons, step, datetime):
    values = np.array(grib_message.values, dtype=np.float32)
    values = np.expand_dims(values, axis=0)
    values = np.expand_dims(values, axis=0)

    da = xr.DataArray(
        values, dims=['datetime', 'step', 'lat', 'lon'], 
        coords={'lat': lats, 'lon': lons, 'datetime': datetime, 'step': step}
    )
    
    da.attrs['units'] = grib_message.units
    
    return da

In [None]:
def grib_date_to_pandas(date, time):
    date, time = str(date), str(time)
    date_string = f'{date[:4]}-{date[4:6]}-{date[6:8]} {time[:2]}:{time[2:4]}'
    return pd.Timestamp(date_string)

In [None]:
big_dataset = pass_to_xarray(GRIB_INPUT_DIR, extractor)

In [None]:
big_dataset

In [None]:
big_dataset = big_dataset.assign_coords(valid=lambda x: x.datetime + x.step)

In [None]:
input_path = pathlib.Path(GRIB_INPUT_DIR)
input_files = sorted(list(input_path.glob('*.grib2')))

datasets = []

for f in input_files[8:12]:
    gribfile = pygrib.open(str(f))
    dataset = file_to_xarray(gribfile, extractor)
    datasets.append(dataset)

In [None]:
d = xr.merge(datasets)

In [None]:
d

In [None]:
d.isel(step=0)['2t'].plot()

In [None]:
d.isel(step=1)['thick'].plot()

In [None]:
def sizeof_dataset(dataset):
    total = 0
    for var in dataset.variables:
        array = dataset[var]
        total += array.size * array.dtype.itemsize
        
    return total

In [None]:
sizeof_dataset(d) / 1024**2 

In [None]:
d.sizes

In [None]:
gribfile = pygrib.open(str(input_files[0]))

In [None]:
dataset = xr.Dataset(data_vars=data_arrays)

In [None]:
dataset

## 3. Read many grib files in parallel. 

In [None]:
import os
import multiprocessing

In [None]:
os.getenv('SLURM_CPUS_PER_TASK')

In [None]:
len(input_files)

In [None]:
def do_one_file(path):
    gribfile = pygrib.open(str(path))
    return file_to_xarray(gribfile, extractor)

In [None]:
in_list = input_files[:20]

with multiprocessing.Pool(int(4)) as pool:
    datasets = list(tqdm(pool.imap(do_one_file, in_list), total=len(in_list)))

## 4. Interpolate at stations

see notebooks 1901 1902


In [None]:
import pymongo

from smc01.interpolate.obs import MongoIEMDatabase

In [None]:
MONGO_URL = 'localhost'
MONGO_PORT = 27017
USERNAME = None
PASSWORD = None
ADMIN_DB = 'admin'
COLLECTION = 'iem'
DB = 'smc01_raw_obs'

In [None]:
client = pymongo.MongoClient(
    host=MONGO_URL,
    port=MONGO_PORT,
    tz_aware=True,
    authSource=ADMIN_DB,
    username=USERNAME,
    password=PASSWORD,
)

In [None]:
db = MongoIEMDatabase(client, DB, COLLECTION)

In [None]:
begin_date = big_dataset.valid.min().data.item()
begin_date = datetime.datetime.utcfromtimestamp(begin_date // 1e9)

end_date = big_dataset.valid.max().data.item()
end_date = datetime.datetime.utcfromtimestamp(end_date // 1e9)

In [None]:
station_info = db.station_info()

In [None]:
station_info

In [None]:
at_stations = big_dataset.interp(
    {
        "lat": xr.DataArray(station_info["lat"], dims="station"),
        "lon": xr.DataArray(station_info["lon"], dims="station"),
    }
)

In [None]:
big_dataset

In [None]:
at_stations

In [None]:
at_stations.station

In [None]:
at_stations = at_stations.assign_coords(
    station=xr.DataArray(station_info["station"], dims="station")
)

In [None]:
at_stations.station[0].item()

In [None]:
at_stations

In [None]:
by_valid = {
    valid_time: group for valid_time, group in at_stations.groupby("valid")
}

In [None]:
one_valid_key = next(iter(by_valid))

In [None]:
one_valid = by_valid[one_valid_key]

In [None]:
one_valid

In [None]:
one_step = one_valid.isel(stacked_datetime_step=0)

In [None]:
date, step = one_step.stacked_datetime_step.item()

In [None]:
type(date)

In [None]:
step

In [None]:
step.total_seconds() / 3600

In [None]:
date = pd.to_datetime(date, unit="s")
step = datetime.timedelta(hours=step.total_seconds() / 3600)

In [None]:
date

In [None]:
step

In [None]:
date

In [None]:
date.to_pydatetime()

In [None]:
big_dataset

In [None]:
big_dataset.data_vars.keys()

In [None]:
for v in big_dataset.data_vars.keys():
    print(v)