In [None]:
%load_ext autoreload
%autoreload 2

# Interpolate from netCDF

The purpose of this notebook is to perform the same interpolation that is found in 017-interpolate-at-stations.ipynb.
This time, however, we start from netcdf files generated with 019-prepare-dataset-pygrib.ipynb instead of raw grib files.
The hypothesis is that XArray will be much happier working from netCDF that grib.

The plan is to
1. Load the netCDF datacube.
2. Load station coordinates.
3. Interpolate at stations.
4. Generate a dataset.

In [None]:
import dask
import dask.dataframe as dd
import dask.distributed
import dask_jobqueue
import datetime
import netCDF4
import numpy as np
import os
import pandas as pd
import pathlib
import pygrib
import pymongo
import seaborn as sns

import xarray as xr

In [None]:
cluster = dask_jobqueue.SLURMCluster(
    env_extra=[
        'source ~/.bash_profile','conda activate smc01'],
    name='smc01-dask',
)

In [None]:
cluster.scale(jobs=8)

In [None]:
client = dask.distributed.Client(cluster)

In [None]:
client

# 1. Load the netCDF datacube

In [None]:
DATA_DIR = pathlib.Path(os.getenv('DATA_DIR'))
GDPS_DIR = DATA_DIR / '2021-02-10-one-month-more-vars/'

In [None]:
gdps_path = pathlib.Path(GDPS_DIR)

In [None]:
gdps_files = sorted([f for f in gdps_path.glob("*.nc")])
gdps_files[0:10]

In [None]:
passes = {}
for f in gdps_files:
    pass_string = f.stem[5:15]
    
    pass_files = passes.get(pass_string, [])
    pass_files.append(f)
    
    passes[pass_string] = pass_files

In [None]:
def nest_filenames(files):
    passes = {}
    for f in files:
        pass_name = f.stem[5:15]
        
        pass_list = passes.get(pass_name, [])
        pass_list.append(f)
        passes[pass_name] = pass_list
        
    sorted_passes = sorted(passes.keys())
        
    return [passes[k] for k in sorted_passes]

In [None]:
nested_gdps = nest_filenames(gdps_files)

In [None]:
def drop_vars(dataset):
    to_drop = ['r_850', 'r_500']
    
    for var in to_drop:
        if var in dataset:
            dataset = dataset.drop(var)

    return dataset

In [None]:
gdps = xr.open_mfdataset(
    nested_gdps, concat_dim=['time', 'step'], 
    combine='nested', parallel=True, compat='no_conflicts',
    preprocess=drop_vars)

In [None]:
gdps

In [None]:
deltas = gdps.step.data.astype('timedelta64[h]')
times = [datetime.datetime.utcfromtimestamp(x) for x in gdps.time.data.astype(datetime.datetime) // 1e9]

In [None]:
times = gdps.time.data.reshape(-1, 1)

In [None]:
deltas = deltas.reshape(1, -1)

In [None]:
valid_times = times + deltas

In [None]:
valid_times.shape

In [None]:
gdps = gdps.assign_coords(valid_time=xr.DataArray(valid_times, dims=('time', 'step')))

In [None]:
gdps

# 2. Load station coordinates

In [None]:
MONGO_URL = 'localhost'
MONGO_PORT = 27017
USERNAME = None
PASSWORD = None
ADMIN_DB = 'admin'
DB = 'smc01_raw_obs_test'
COLLECTION = 'iem'

In [None]:
begin_date = gdps.valid_time.min().data.item()
begin_date = datetime.datetime.utcfromtimestamp(begin_date // 1e9)

end_date = gdps.valid_time.max().data.item()
end_date = datetime.datetime.utcfromtimestamp(end_date // 1e9)

In [None]:
begin_date

In [None]:
mongo_client = pymongo.MongoClient(host=MONGO_URL, port=MONGO_PORT, username=USERNAME, password=PASSWORD, authSource=ADMIN_DB)

In [None]:
db = mongo_client.smc01_raw_obs_test

In [None]:
collection = db.iem

In [None]:
query = {
    'valid': {
        '$gte': begin_date,
        '$lt': end_date
}}

In [None]:
stations = collection.distinct('station')

In [None]:
station_infos = []

for station in stations:
    one_obs = collection.find_one({'station': station})
    station_infos.append({
        'station': station,
        'lat': one_obs['lat'],
        'lon': one_obs['lon'],
        'elevation': one_obs['elevation']
    })

In [None]:
station_df = pd.DataFrame(station_infos)

In [None]:
station_df

# 3. Interpolate at stations

In [None]:
at_stations = gdps.interp({
    'latitude': xr.DataArray(station_df['lat'], dims='station'),
    'longitude': xr.DataArray(station_df['lon'], dims='station'),
})

In [None]:
at_stations

In [None]:
at_stations = at_stations.assign_coords(station=xr.DataArray(station_df['station'], dims='station'))

In [None]:
at_stations_compute = at_stations.compute()

In [None]:
at_stations_compute

In [None]:
at_stations_compute.to_netcdf(DATA_DIR / '2021-02-10-march-interpolated-at-stations.nc')

In [None]:
at_stations_compute.nbytes / 1024 / 1024

In [None]:
at_stations_compute.valid_time

In [None]:
groups = list(at_stations_compute.groupby('valid_time'))

In [None]:
groups[20]

# 4. Compute dataset

In [None]:
def pipeline_of_station(station_name, begin_date, end_date):
    return [
        {
            '$addFields': {
                'minute': {
                    '$minute': '$valid'
                },
                'hour': {
                    '$hour': '$valid'
                }
            }
        },
        {
            '$match': {
                'minute': 0,
                'station': station_name,
                'valid': {
                    '$gte': begin_date,
                    '$lt': end_date
                },
                'hour': {
                    '$in': [0, 3, 6, 9, 12, 15, 18, 21]
                },
                'tmpf': {
                    '$exists': True,
                },
            }
        }
    ]

In [None]:
mongo_obs_of_station = list(collection.aggregate(pipeline_of_station('CYUL', begin_date, end_date)))

In [None]:
mongo_obs_of_station[0]

In [None]:
def compute_reports_of_station(station_name, begin_date, end_date, model_at_station):

    by_valid = {valid_time: group for valid_time, group in model_at_station.groupby('valid_time')}

    reports = []
    
    station_obs = list(
        collection.aggregate(
            pipeline_of_station(station_name, begin_date, end_date)))

    for obs in station_obs:
        obs_time = np.datetime64(obs['valid'], 'ns')
        obs_temp = (obs['tmpf'] - 32) * (5/9) 

        if obs_time in by_valid:
            group_of_time = by_valid[obs_time]

            for i in range(len(group_of_time.stacked_time_step)):
                date = datetime.datetime.utcfromtimestamp(group_of_time.time[i].item() / 1e9)
                
                step = datetime.timedelta(hours=group_of_time.step[i].item())
                temp = group_of_time['2t'][i].item() - 273.15
                dewpoint = group_of_time['2d'][i].item() - 273.15

                report = {
                    'station': obs['station'],
                    'valid': date + step,
                    'lat': obs['lat'],
                    'lon': obs['lon'],
                    'elevation': obs['elevation'],
                    'obs_2t': obs_temp,
                    'date': date,
                    'step': step,
                    'gdps_2d': dewpoint,
                    'gdps_2t': temp,
                }
                
                if 'dwpt' in obs:
                    report['obs_2d'] = (obs['dwpt'] - 32) * (5/9)
                else:
                    report['obs_2d'] = np.nan
                
                if 'sknt' in obs:
                    report['obs_10si'] = obs['sknt'] / 1.94384
                else:
                    report['obs_10si'] = np.nan
                    
                if 'mslp' in obs:
                    report['obs_prmsl'] = obs['mslp']
                else:
                    report['obs_prmsl'] = np.nan

                
                obs_target_pairs = [
                    ('drct', 'obs_10wdir'),
                    ('relh', 'obs_2r'),
                ]
                
                for obs_key, target in obs_target_pairs:
                    if obs_key in obs:
                        report[target] = obs[obs_key]
                    else:
                        report[target] = np.nan
                        
                report['gdps_prmsl'] = group_of_time['prmsl'][i].item() / 100.

                for key in ['10si', '10wdir', '2r', 'hpbl', 'prate']:
                    report['gdps_' + key] = group_of_time[key][i].item()

                reports.append(report)

    return pd.DataFrame(reports)

In [None]:
station_name = 'CYUL'
model_at_station = at_stations_compute.sel(station=station_name)

reports = compute_reports_of_station(station_name, begin_date, end_date, model_at_station)

In [None]:
reports['step'].value_counts()

In [None]:
reports

In [None]:
reports['2t_square_error'] = (reports['gdps_2t'] - reports['obs_2t'])**2

In [None]:
by_step = reports.groupby('step').mean()

In [None]:
sns.lineplot(x='step', y='2t_square_error', data=reports)

In [None]:
observations_by_countion['CYUL'][0]

In [None]:
reports

In [None]:
by_valid[np.datetime64('2020-03-01T00:00:00.000000000')]

## 4.2 Compute reports in parallel

In [None]:
station_obs = observations_by_station['CYUL']
model_at_station = at_stations_compute.sel(station='CYUL')

#reports = compute_reports_of_station(station_obs, model_at_station)

In [None]:
compute_reports_delayed = dask.delayed(compute_reports_of_station)

In [None]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

In [None]:
days = [begin_date + datetime.timedelta(days=i) for i in range((end_date - begin_date).days)]

In [None]:
(end_date - begin_date).days

In [None]:
n_days = (end_date - begin_date).days

delayeds = []
for s in station_df['station']:
    
    for i in range(n_days // 10):
        begin_batch =  begin_date + datetime.timedelta(days=i * 10)
        end_batch = begin_date + datetime.timedelta(days=(i + 1) * 10)
        
        
        
        model_at_station = at_stations_compute.sel(station=s)

        delayed = compute_reports_delayed(s, begin_batch, end_batch, model_at_station)
        delayeds.append(delayed)

In [None]:
model_at_station

In [None]:
begin_batch

In [None]:
model_at_station.where(time >= begin_batch)

In [None]:
len(delayeds)

In [None]:
delayeds[0:10]

In [None]:
station_obs = observations_by_station['CYUL'][0:10]
model_at_station = at_stations_compute.sel(station='CYRL')
sample = compute_reports_of_station('CYUL', begin_date, end_date, model_at_station)

In [None]:
sample

In [None]:
big_df = dd.from_delayed(delayeds, meta=sample, verify_meta=True)

In [None]:
big_df

In [None]:
big_df = big_df.persist()

In [None]:
big_df['step'] = big_df['step'].dt.total_seconds() / 3600

In [None]:
big_df.to_parquet(DATA_DIR / 'hdd_scratch/smc01/march.parquet')

In [None]:
big_df.npartitions

In [None]:
big_df = big_df.repartition(200)

In [None]:
big_df_compute = big_df.compute()

In [None]:
big_df_compute

In [None]:
big_df['2t_squared_error'] = (big_df['gdps_2t'] - big_df['obs_2t'])**2
big_df['2r_squared_error'] = (big_df['gdps_2r'] - big_df['obs_2r'])**2

In [None]:
by_step = big_df.groupby('step')

In [None]:
by_step_compute = by_step.mean()

In [None]:
by_step_compute['2t_rmse'] = np.sqrt(by_step_compute['2t_squared_error'])
by_step_compute['2r_rmse'] = np.sqrt(by_step_compute['2r_squared_error'])

In [None]:
by_step_compute = by_step_compute.compute()

In [None]:
sns.lineplot(x='step', y='2r_rmse', data=by_step_compute.iloc[0:20])