# ABC paper: Neural network baseline (Fan et al., 2021)

Learned correction for SubX ensemble forecasts

In [None]:
import os, sys
from subseasonal_toolkit.utils.notebook_util import isnotebook
if isnotebook():
    # Autoreload packages that are modified
    %load_ext autoreload
    %autoreload 2
    #%cd "~/forecast_rodeo_ii"
    #%pwd
else:
    from argparse import ArgumentParser
import pandas as pd
import numpy as np
from scipy.spatial.distance import cdist, euclidean
from datetime import datetime, timedelta
from filelock import FileLock
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from IPython.display import Markdown, display
from subseasonal_data.utils import get_measurement_variable
from subseasonal_toolkit.utils.general_util import printf, tic, toc
from subseasonal_toolkit.utils.experiments_util import (get_first_year, get_start_delta,
                                                        get_forecast_delta, pandas2hdf)
from subseasonal_toolkit.utils.models_util import (get_submodel_name, start_logger, log_params, get_forecast_filename,
                                                   save_forecasts, get_selected_submodel_name)
from subseasonal_toolkit.utils.eval_util import get_target_dates, mean_rmse_to_score, save_metric
from sklearn.linear_model import *

from subseasonal_data import data_loaders

# Make NumPy printouts easier to read.
np.set_printoptions(precision=3, suppress=True)
print(tf.__version__)

In [None]:
#
# Specify model parameters
#
model_name = "nn-a"
if not isnotebook():
    # If notebook run as a script, parse command-line arguments
    parser = ArgumentParser()
    parser.add_argument("pos_vars",nargs="*")  # gt_id and horizon                                                                                  
    parser.add_argument('--target_dates', '-t', default="std_test")
    # Fit intercept parameter if and only if this flag is specified
    parser.add_argument('--num_epochs', '-e', default=10, 
                        help="number of training epochs")
    args, opt = parser.parse_known_args()
    
    # Assign variables                                                                                                                                     
    gt_id = args.pos_vars[0] # "contest_precip" or "contest_tmp2m"                                                                            
    horizon = args.pos_vars[1] # "12w", "34w", or "56w"                                                                                        
    target_dates = args.target_dates
    num_epochs = int(args.num_epochs)
else:
    # Otherwise, specify arguments interactively 
    gt_id = "us_precip_1.5x1.5"
    horizon = "12w"
    target_dates = "20200101"#"std_paper_forecast"
    num_epochs = 10
    
#
# Process model parameters
#

# Get list of target date objects
target_date_objs = pd.Series(get_target_dates(date_str=target_dates, horizon=horizon))

# Identify measurement variable name
measurement_variable = get_measurement_variable(gt_id) # 'tmp2m' or 'precip'

# Column names for gt_col, clim_col and anom_col 
gt_col = measurement_variable
clim_col = measurement_variable+"_clim"
anom_col = get_measurement_variable(gt_id)+"_anom" # 'tmp2m_anom' or 'precip_anom'

# For a given target date, the last observable training date is target date - gt_delta
# as gt_delta is the gap between the start of the target date and the start of the
# last ground truth period that's fully observable at the time of forecast issuance
gt_delta = timedelta(days=get_start_delta(horizon, gt_id))

base_shift = get_forecast_delta(horizon) 

# Record model and submodel names
submodel_name = get_submodel_name(model=model_name, horizon=horizon, num_epochs=num_epochs)

if not isnotebook():
    # Save output to log file
    logger = start_logger(model=model_name,submodel=submodel_name,gt_id=gt_id,
                          horizon=horizon,target_dates=target_dates)
    # Store parameter values in log                                                                                                                        
    params_names = ['gt_id', 'horizon', 'target_dates', 'num_epochs'
                   ]
    params_values = [eval(param) for param in params_names]
    log_params(params_names, params_values)
    
def get_input_precip(horizon = "34w"):
    m = "deb_cfsv2"
    # the daily bias corrected Week 3-4 lead time forecast for total precipitation 
    # CFSv2 bias-corrected ensemble means for total precip ==> use the data in models/deb_cfsv2/submodel_forecasts/
    gt_id, horizon = "us_precip_1.5x1.5", horizon
    base_shift = get_forecast_delta(horizon) 
    measurement_variable = f"precip_shift{base_shift}"
    task = f"{gt_id}_{horizon}"
    sn = get_selected_submodel_name(model=m, gt_id=gt_id, horizon=horizon)
    # Input data:
    data_file = os.path.join("data", "dataframes", f"nn-a_{task}.h5")
    if os.path.isfile(data_file):
        printf(f"Loading {data_file}..."); tic()
        data_precip = pd.read_hdf(data_file)
        toc()
    else:
        printf(f"Creating {data_file}..."); tic()
        forecasts_dir = os.path.join("models",m,"submodel_forecasts",sn,task)
        filenames = sorted([f for f in os.listdir(forecasts_dir) if f.endswith(".h5")])
        for i, f in enumerate(filenames):
            printf(f)
            data_f = pd.read_hdf(os.path.join(forecasts_dir, f))
            data_precip = data_f if i==0 else data_precip.append(data_f)
        data_precip = data_precip.rename(columns={"pred": measurement_variable})
        toc()
        pandas2hdf(data_precip, data_file)
    return data_precip

def get_inp_tmp2m_anom(horizon="34w"):
    m = "deb_cfsv2"
    # CFSv2 bias-corrected ensemble means for anomaly T2m ==> subtract ground truth tmp2m climatology from 
    # the data in models/deb_cfsv2/submodel_forecasts/deb_cfsv2-years1999-2010_leads15-15/
    gt_id, horizon = "us_tmp2m_1.5x1.5", horizon
    base_shift = get_forecast_delta(horizon) 
    measurement_variable = f"tmp2m_shift{base_shift}"
    task = f"{gt_id}_{horizon}"
    sn = get_selected_submodel_name(model=m, gt_id=gt_id, horizon=horizon)# Input data:
    forecasts_dir = os.path.join("models",m,"submodel_forecasts",sn,task)
    data_file = os.path.join("data", "dataframes", f"nn-a_{task}.h5")
    if os.path.isfile(data_file):
        printf(f"\nLoading {data_file}"); tic()
        printf(f"Based on forecasts in\n{forecasts_dir}")
        data = pd.read_hdf(data_file)
        toc()
    else:
        printf(f"\nCreating {data_file}"); tic()
        printf(f"Using forecasts in\n{forecasts_dir}")
        filenames = sorted([f for f in os.listdir(forecasts_dir) if f.endswith(".h5")])
        # data_template = data_loaders.get_forecast(forecast_id = "iri_cfsv2-precip-us1_5")
        # filenames
        for i, f in enumerate(filenames):
            printf(f)
            data_f = pd.read_hdf(os.path.join(forecasts_dir, f))
            # Transform to wide format
        #     data_f_wide = data_f.set_index(['lat','lon','start_date']).unstack(['lat','lon'])
            data = data_f if i==0 else data.append(data_f)
        data = data.rename(columns={"pred": f"tmp2m_shift{base_shift}"})
        pandas2hdf(data, data_file)
        toc()
    # printf(data)

    # Calculate anomalies
    data_file = f"baseline_nn_input_tmp2m_anom_{horizon}.h5"
    data_file = os.path.join("data", "dataframes", f"nn-a_{task}_anom.h5")
    if os.path.isfile(data_file):
        printf(f"Loading {data_file}"); tic()
        printf(f"Based on forecasts in\n{forecasts_dir}")
        data_tmp2m_anom = pd.read_hdf(data_file)
        toc()
    else:
        base_shift = get_forecast_delta(horizon) 
        gt = data_loaders.get_ground_truth(gt_id = "us_tmp2m_1.5x1.5", shift=base_shift)
        measurement_variable = f"tmp2m_shift{base_shift}"
        # Compute climatology based on post 2011 data
        print("\nLoading anomalies"); tic()
        sub_gt = gt[gt.start_date < "2011-01-01"]
        clim = sub_gt.groupby(['lat','lon', sub_gt.start_date.dt.month, 
                               sub_gt.start_date.dt.day]).mean()
        toc() 
        clim
        print("\nComputing anomalies"); tic()
        data_anom = pd.merge(data, clim, left_on = ['lat','lon',
                                           data.start_date.dt.month,
                                           data.start_date.dt.day], 
                      right_index = True, how='left', suffixes=('','_clim'))
        data_anom[measurement_variable] -= data_anom[measurement_variable+'_clim']
        data_anom.drop(columns=measurement_variable+'_clim',inplace=True)
        data_anom.rename(columns={measurement_variable: measurement_variable+'_anom'},
                  inplace=True)
        data_tmp2m_anom = data_anom.drop(columns=[f"tmp2m_sqd_shift{base_shift}", f"tmp2m_std_shift{base_shift}"])
        pandas2hdf(data_tmp2m_anom, data_file)
        toc()
    return data_tmp2m_anom

### 1. Create the training and testing datasets
#### Input model data

In [None]:
#Get input data
data_precip = get_input_precip(horizon = horizon)
data_tmp2m_anom = get_inp_tmp2m_anom(horizon=horizon)

# Create input dataframe
data_input = pd.merge(data_precip, data_tmp2m_anom, on=['lat', 'lon', 'start_date'], how="left")
# add ground truth climatologies for tmp2m
gt_var = 'tmp2m'
clim = data_loaders.get_climatology(gt_id = f"us_{gt_var}_1.5x1.5") 
clim.rename(columns={gt_var: f'{gt_var}_clim'}, inplace=True)
clim = clim.groupby(['lat','lon', clim.start_date.dt.month, 
                       clim.start_date.dt.day]).mean()
data_input = pd.merge(data_input, clim, left_on = ['lat','lon',
                                   data_input.start_date.dt.month,
                                   data_input.start_date.dt.day], 
              right_index = True, how='left', suffixes=('','_clim'))
# add ground truth climatologies for precip
gt_var = 'precip'
clim = data_loaders.get_climatology(gt_id = f"us_{gt_var}_1.5x1.5") 
clim.rename(columns={gt_var: f'{gt_var}_clim'}, inplace=True)
clim = clim.groupby(['lat','lon', clim.start_date.dt.month, 
                       clim.start_date.dt.day]).mean()
data_input = pd.merge(data_input, clim, left_on = ['lat','lon',
                                   data_input.start_date.dt.month,
                                   data_input.start_date.dt.day], 
              right_index = True, how='left', suffixes=('','_clim'))
# data_input

#### Output model data

In [None]:
# Output data:
# Accumulated precip over weeks 3-4
gt_var, horizon = "precip", horizon
base_shift = get_forecast_delta(horizon) 
gt_id, measurement_variable = f"us_{gt_var}_1.5x1.5", f"{gt_var}_shift{base_shift}"
task = f"{gt_id}_{horizon}"
data_out_precip = data_loaders.get_ground_truth(gt_id = gt_id, shift=base_shift).rename(columns={measurement_variable: f'{measurement_variable}_out'})

# Mean tmp2m anomalies over weeks 3-4
gt_var, horizon = "tmp2m", horizon
base_shift = get_forecast_delta(horizon) 
gt_id, measurement_variable = f"us_{gt_var}_1.5x1.5", f"{gt_var}_shift{base_shift}"
task = f"{gt_id}_{horizon}"
data_out_tmp2m = data_loaders.get_ground_truth_anomalies(gt_id = gt_id, shift=base_shift)[['lat','lon','start_date',f'{measurement_variable}_anom']].rename(columns={f'{measurement_variable}_anom': f'{measurement_variable}_anom_out'})

# Merged output data
data_output = pd.merge(data_out_precip, data_out_tmp2m, on=['lat', 'lon', 'start_date'], how="left")
# data_output

#### Merge raw dataset

In [None]:
dataset = pd.merge(data_input, data_output, on=['lat', 'lon', 'start_date'], how="left")

# Transform to wide format
raw_dataset_wide = dataset.set_index(['lat','lon','start_date']).unstack(['lat','lon'])
raw_dataset_wide

#### Clean the data

Clean if the dataset contains a few unknown values:

In [None]:
dataset.isna().sum()

#### Split the data into training and test sets

In [None]:
# train_ratio = .92
# dataset_train = raw_dataset_wide.iloc[:int(train_ratio*len(raw_dataset_wide))]
# dataset_test = raw_dataset_wide.drop(dataset_train.index)
dataset_test = raw_dataset_wide.loc[target_date_objs[0]:]
dataset_train = raw_dataset_wide.drop(dataset_test.index)

# For a given target date, the last observable training date is target date - gt_delta
# as gt_delta is the gap between the start of the target date and the start of the
# last ground truth period that's fully observable at the time of forecast issuance
gt_delta = timedelta(days=get_start_delta(horizon, gt_id))
# Find the last observable training date 
last_train_date = dataset_test.index[0] - gt_delta
# Update training dataset accordingly
dataset_train = dataset_train.loc[:last_train_date]
display(dataset_train)
display(dataset_test)

#### Inspect the data

Review the joint distribution of a few pairs of columns from the training set.

In [None]:
# sns.pairplot(train_data_all[['precip_shift15','tmp2m_shift15_anom','tmp2m_clim','precip_clim']], diag_kind='kde')

Let's also check the overall statistics. Note how each feature covers a very different range:

In [None]:
dataset_train.describe().transpose()

#### Split features from labels

In [None]:
features = [f'precip_shift{base_shift}', f'tmp2m_shift{base_shift}_anom', 'tmp2m_clim', 'precip_clim']
labels = [f'precip_shift{base_shift}_out', f'tmp2m_shift{base_shift}_anom_out']
column_features = [c for c in dataset_train.columns if c[0] in features]
column_labels = [c for c in dataset_train.columns if c[0] in labels]

train_features = dataset_train.copy()[column_features]
train_labels = dataset_train.copy()[column_labels]

test_features = dataset_test.copy()[column_features]
test_labels = dataset_test.copy()[column_labels]

### Normalization

In the table of statistics it's easy to see how different the ranges of each feature are:

In [None]:
dataset_train.describe().transpose()[['mean', 'std']]

#### The Normalization layer

In [None]:
normalizer = tf.keras.layers.Normalization(axis=-1)
normalizer.adapt(np.array(train_features))
print(normalizer.mean.numpy())

When the layer is called, it returns the input data, with each feature independently normalized:

In [None]:
first = np.array(train_features[:1])

with np.printoptions(precision=2, suppress=True):
  print('First example:', first)
  print()
  print('Normalized:', normalizer(first).numpy())

First, create a NumPy array made of the train features. Then, instantiate the `tf.keras.layers.Normalization` and fit its state to the train data:

In [None]:
train_features_arr = np.array(train_features)
baseline_nn_normalizer = layers.Normalization(input_shape=[1504,], axis=None)
baseline_nn_normalizer.adapt(train_features_arr)

In [None]:
printf(f"train_features.shape: {train_features.shape}")
printf(f"train_labels.shape: {train_labels.shape}")

### 2. Create the baseline NN-A (Fan et al., 2021)
#### Reference: 
Fan, Y., Krasnopolsky, V., van den Dool, H., Wu, C. Y., & Gottschalck, J. (2021). Using Artificial Neural Networks to Improve CFS Week 3-4 Precipitation and 2-Meter Air Temperature Forecasts. Weather and Forecasting.

In [None]:

def build_and_compile_model(norm):
  model = keras.Sequential([
      norm,
#       layers.Dense(1504, activation='relu'),
      layers.Dense(200, input_dim=1504, activation='relu'),
      layers.Dense(752)
  ])

  model.compile(loss="mean_squared_error",
                optimizer=tf.keras.optimizers.Adam(0.001))
  return model

In [None]:
baseline_nn_model = build_and_compile_model(baseline_nn_normalizer)

In [None]:
baseline_nn_model.summary()

In [None]:
#Train the model with Keras Model.fit:
weights_dir = os.path.join ("models", model_name, "submodel_weights", submodel_name)

if os.path.isdir(weights_dir):
    printf(f"Loading {weights_dir}")
    train_history = False
    baseline_nn_model = tf.keras.models.load_model(weights_dir)
else:
    printf(f"Training and saving to {weights_dir}")
    train_history = True
    history = baseline_nn_model.fit(
        train_features,
        train_labels,
        validation_split=0.06,
        verbose=1, epochs=num_epochs)

In [None]:
def plot_loss(history):
  plt.plot(history.history['loss'], label='loss')
  plt.plot(history.history['val_loss'], label='val_loss')
#   plt.ylim([0, 10])
  plt.xlabel('Epoch')
  plt.ylabel('Error')
  plt.legend()
  plt.grid(True)

In [None]:
if train_history:
    plot_loss(history)

In [None]:
# preds = baseline_nn_model.evaluate(test_features, test_labels, verbose=0)
test_results = {}
test_results['dnn_model'] = baseline_nn_model.evaluate(test_features, test_labels, verbose=0)

## Performance

In [None]:
pd.DataFrame(test_results, index=['Mean absolute error [MPG]']).T

### Make predictions

Make predictions on the test set using Keras `Model.predict` and review the loss

In [None]:
test_predictions = baseline_nn_model.predict(test_features)#.flatten()

df_test_predictions = pd.DataFrame(test_predictions, columns = test_labels.columns, index = test_labels.index)
df_test_predictions_precip = df_test_predictions.xs(f'precip_shift{base_shift}_out', level=0, axis=1)
df_test_predictions_tmp2m = df_test_predictions.xs(f'tmp2m_shift{base_shift}_anom_out', level=0, axis=1)


In [None]:
# Recover temperature predictions from temperature anomalies predictions
gt = data_loaders.get_ground_truth(gt_id = "us_tmp2m_1.5x1.5", shift=base_shift)
measurement_variable = f"tmp2m_shift{base_shift}"
# Compute climatology based on post 2011 data
print("\nLoading anomalies"); tic()
sub_gt = gt[gt.start_date < "2011-01-01"]
clim = sub_gt.groupby(['lat','lon', sub_gt.start_date.dt.month, 
                       sub_gt.start_date.dt.day]).mean()
toc() 

# # Store rmses
# rmses_tmp2m = pd.Series(index=target_date_objs, dtype='float64')
# rmses_precip = pd.Series(index=target_date_objs, dtype='float64')

# target_date_objs
for target_date_obj in target_date_objs:
    target_date_str = datetime.strftime(target_date_obj, '%Y%m%d')
    if target_date_obj not in df_test_predictions.index:
        printf(f"Warning: missing prediction for {target_date_obj}")
    else:
        pred_precip = pd.DataFrame(df_test_predictions_precip.loc[target_date_obj].to_frame().reset_index().values, 
                                    columns=['lat','lon','pred'])
        pred_precip.insert(loc=2, column='start_date', value=target_date_obj)
        pred_tmp2m = pd.DataFrame(df_test_predictions_tmp2m.loc[target_date_obj].to_frame().reset_index().values, 
                                    columns=['lat','lon','pred'])
        pred_tmp2m.insert(loc=2, column='start_date', value=target_date_obj)

        print("\nComputing temperatures from anomalies"); tic()
        pred_tmp2m = pd.merge(pred_tmp2m, clim, left_on = ['lat','lon',
                                           pred_tmp2m.start_date.dt.month,
                                           pred_tmp2m.start_date.dt.day], 
                      right_index = True, how='left')
        pred_tmp2m['pred'] += pred_tmp2m[measurement_variable]
        pred_tmp2m.drop(columns=clim.columns,inplace=True)
        toc()

        # Save prediction to file in standard format
        save_forecasts(
            pred_tmp2m,
            model=model_name, submodel=submodel_name, 
            gt_id="us_tmp2m_1.5x1.5", horizon=horizon, 
            target_date_str=target_date_str)
        save_forecasts(
            pred_precip,#.loc[[target_date_obj],:].unstack().rename("pred").reset_index(),
            model=model_name, submodel=submodel_name, 
            gt_id="us_precip_1.5x1.5", horizon=horizon, 
            target_date_str=target_date_str)


#         # Evaluate and store error if we have ground truth data
#         tic()
#         rmse = np.sqrt(np.square(pred - gt[gt.start_date ==target_date_obj]).mean())
#         rmses.loc[target_date_obj] = rmse
#         print("-rmse: {}, score: {}".format(rmse, mean_rmse_to_score(rmse)))
#         mean_rmse = rmses.mean()
#         print("-mean rmse: {}, running score: {}".format(mean_rmse, mean_rmse_to_score(mean_rmse)))
#         toc()

# printf("Save rmses in standard format")
# rmses = rmses.sort_index().reset_index()
# rmses.columns = ['start_date','rmse']
# save_metric(rmses, model=model_name, submodel=submodel_name, gt_id=gt_id, horizon=horizon, target_dates=target_dates, metric="rmse")
# save_metric(rmses, model=f'{forecast}pp', submodel=submodel_name, gt_id=gt_id, horizon=horizon, target_dates=target_dates, metric="rmse")



save the model for later use with `Model.save`:

In [None]:
weights_dir = os.path.join ("models", model_name, "submodel_weights", submodel_name)
baseline_nn_model.save(weights_dir)

If you reload the model, it gives identical output:

In [None]:
if False:
    reloaded = tf.keras.models.load_model(weights_dir)

    test_results['reloaded'] = reloaded.evaluate(
        test_features, test_labels, verbose=0)

    pd.DataFrame(test_results, index=['Mean squared error']).T