In [None]:
%load_ext autoreload
%autoreload 2

# GDPS Error by first interpolating the model at station coordinates

The purpose of this file is to try out a new strategy for error computation of a model output.
The idea is as follows:

1. Load the model array
2. Fetch all station coordinates
3. Interpolate the model at station coordinates
4. Flatten the interpolated values to a Pandas Dataframe (or a Dask Dataframe if necessary).

The resulting dataframe should have columns such as
- model
- step
- time
- t2m_gdps
- station
- lat
- lon
- t2m_obs
Then it should be easy to compute the error and make conclusions.

In [None]:
import dask
import dask.array as da
import dask.bag as db
import dask_jobqueue
import dask.distributed
import datetime
import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
import pandas as pd
import pymongo
import seaborn as sns
import time
import xarray as xr

In [None]:
DATA_DIR = pathlib.Path(os.getenv('DATA_DIR'))
GDPS_DIR = DATA_DIR / 'data/2021-02-02-one-week-sample/'

MONGO_URL = 'localhost'
MONGO_PORT = 27017
USERNAME = None
PASSWORD = None
ADMIN_DB = 'admin'
DB = 'smc01_raw_obs_test'
COLLECTION = 'iem'

In [None]:
cluster = dask_jobqueue.SLURMCluster(
    env_extra=[
        'source ~/.bash_profile','conda activate smc01'],
    name='smc01-dask',
)

In [None]:
cluster.scale(jobs=2)

In [None]:
client = dask.distributed.Client(cluster)

In [None]:
client

# 1. Load model array

In [None]:
gdps_files = sorted(list(pathlib.Path(GDPS_DIR).glob('CMC_glb_latlon.24x.24_*.grib2')))

In [None]:
gdps_files[0:10]

In [None]:
def nest_filenames(files):
    passes = {}
    for f in files:
        pass_name = f.stem[22:32]
        
        pass_list = passes.get(pass_name, [])
        pass_list.append(f)
        passes[pass_name] = pass_list
        
    sorted_passes = sorted(passes.keys())
        
    return [passes[k] for k in sorted_passes]

In [None]:
nested_files = nest_filenames(gdps_files)

In [None]:
gdps = xr.open_mfdataset(
    nested_files, engine='cfgrib', concat_dim=['time', 'step'], 
    combine='nested', parallel=True, compat='no_conflicts',
    backend_kwargs={'filter_by_keys': {
        'typeOfLevel': 'heightAboveGround',
        'stepType': 'instant',
}})

In [None]:
#gdps = gdps.persist()

In [None]:
gdps

# 2. Fetch station coordinates

In [None]:
begin_date = gdps.valid_time.min().data.item()
begin_date = datetime.datetime.utcfromtimestamp(begin_date // 1e9)

In [None]:
end_date = gdps.valid_time.max().data.item()
end_date = datetime.datetime.utcfromtimestamp(end_date // 1e9)

In [None]:
mongo_client = pymongo.MongoClient(host=MONGO_URL, port=MONGO_PORT, username=USERNAME, password=PASSWORD, authSource=ADMIN_DB)

In [None]:
db = mongo_client.smc01_raw_obs_test

In [None]:
collection = db.iem

In [None]:
query = {
    'valid': {
        '$gte': begin_date + datetime.timedelta(days=1),
        '$lt': end_date
}}

In [None]:
stations = collection.distinct('station')

In [None]:
station_infos = []

for station in stations:
    one_obs = collection.find_one({'station': station})
    station_infos.append({
        'station': station,
        'lat': one_obs['lat'],
        'lon': one_obs['lon']
    })

In [None]:
station_df = pd.DataFrame(station_infos)

In [None]:
station_df

# 3. Interpolate model at stations

In [None]:
gdps.t2m.data

In [None]:
at_stations = gdps.interp({
    'latitude': xr.DataArray(station_df['lat'], dims='station'),
    'longitude': xr.DataArray(station_df['lon'], dims='station'),
})

In [None]:
at_stations.t2m.data

In [None]:
at_stations = at_stations.compute()

# 4. Fetch observations and find appropriate model output

In [None]:
pipeline = [
    {
        '$addFields': {
            'minute': {
                '$minute': '$valid'
            },
            'hour': {
                '$hour': '$valid'
            }
        }
    },
    {
        '$match': {
            'minute': 0,
            'valid': {
                '$gte': begin_date,
                '$lt': end_date
            },
            'hour': {
                '$in': [0, 3, 6, 9, 12, 15, 18, 21]
            },
            'tmpf': {
                '$exists': True,
            }
        }
    },
    {
        '$group': {
            '_id': "$station",
            'obs': {
                '$push': '$$ROOT'
            }
        }
    }
]

In [None]:
mongo_obs_by_station = list(collection.aggregate(pipeline))

In [None]:
collection.find_one({'valid': begin_date})

In [None]:
observations_by_station = {d['_id']: d['obs'] for d in mongo_obs_by_station}

In [None]:
observations_by_station['CYUL']

In [None]:
for station in observations_by_station:
    for obs in observations_by_station[station]:
        if 'tmpf' not in obs:
            print(obs)

In [None]:
at_stations.t2m

In [None]:
def compute_station_reports(observations, at_station):
    reports = []
    
    groups = {
        valid_time: group 
        for valid_time, group in list(at_station.groupby('valid_time'))
    }

    for obs in observations:
        obs_time = np.datetime64(obs['valid'], 'ns')
        obs_temp = (obs['tmpf'] - 32) * (5/9) 

        if obs_time in groups:
            group_of_time = groups[obs_time]

            for i in range(len(group_of_time.stacked_step_time)):
                date = datetime.datetime.utcfromtimestamp(group_of_time.time[i].item() / 1e9)
                step = datetime.timedelta(seconds=group_of_time.step[i].item() / 1e9)
                temp = group_of_time.t2m[i].item() - 273.15

                reports.append({
                    'station': obs['station'],
                    'date': date,
                    'step': step,
                    'gdps_temp': temp,
                    'obs_temp': obs_temp,
                    'lat': obs['lat'],
                    'lon': obs['lon'],
                    'elevation': obs['elevation']
                })
            
    return reports

In [None]:
delayed_reports = []
compute_delayed = dask.delayed(compute_station_reports)
#compute_delayed = compute_station_reports

for i, station in enumerate(stations[0:30]):
   
    if station in observations_by_station:
        begin_sel = time.time()
        at_station = at_stations.sel({
            'station': i
        })
        
        groups = {
            valid_time: group 
            for valid_time, group in list(at_station.groupby('valid_time'))
        }
        
        station_observations = observations_by_station[station]
        station_reports = compute_delayed(station_observations, at_station)
        delayed_reports.append(station_reports)

In [None]:
delayed_reports[0:10]

In [None]:
begin_compute = time.time()
reports = dask.compute(*delayed_reports)
print('Reports took {} to compute'.format(time.time() - begin_compute))

In [None]:
reports = [report for station_reports in reports for report in station_reports]

In [None]:
reports[-1]

# 5. Analyze

In [None]:
reports[0:10]

In [None]:
df = pd.DataFrame(reports)
df['hours'] = df['step'].dt.total_seconds() / 3600
df['pass'] = df['date'].dt.hour

In [None]:
df['error'] = df['gdps_temp'] - df['obs_temp']
df['squared_error'] = df['error']**2
df['abs_error'] = df['error'].abs()

In [None]:
df[(df['date'] == datetime.datetime(2020, 7, 20, 0)) & (df['hours'] < 48)].groupby('hours').mean()['abs_error'].plot()

In [None]:
sns.lineplot(x='hours', y='abs_error', data=df)

In [None]:
sns.boxplot(x='pass', y='abs_error', data=df, showfliers=False)

In [None]:
df.groupby('station').mean()

In [None]:
sns.boxplot(x='lat', y='abs_error', data=df)