# VAR Baseline for Cambridge UK Weather Forecasting

Gradient boosting models for time series analysis of Cambridge UK temperature measurements taken at the [University computer lab weather station](https://www.cl.cam.ac.uk/research/dtg/weather/).

This notebook is being developed on [Google Colab](https://colab.research.google.com), using [LightGBM](https://lightgbm.readthedocs.io/) and the [Darts](https://unit8co.github.io/darts/) time series package.  Initially I was most interested in short term temperature forecasts (less than 2 hours) but now mostly produce results up to 24 hours in the future for comparison with earlier [baselines](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/cammet_baselines_2021.ipynb).

See my previous notebooks, web apps etc:
 * [Cambridge UK temperature forecast python notebooks](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks)
 * [Cambridge UK temperature forecast R models](https://github.com/makeyourownmaker/CambridgeTemperatureModel)
 * [Bayesian optimisation of prophet temperature model](https://github.com/makeyourownmaker/BayesianProphet)
 * [Cambridge University Computer Laboratory weather station R shiny web app](https://github.com/makeyourownmaker/ComLabWeatherShiny)

The linked notebooks, web apps etc contain further details including:
 * data description
 * data cleaning and preparation
 * data exploration

In particular, see the notebooks:
 * [cammet_baselines_2021](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/cammet_baselines_2021.ipynb) including persistent, simple exponential smoothing, Holt Winter's exponential smoothing and vector autoregression
 * [keras_mlp_fcn_resnet_time_series](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/keras_mlp_fcn_resnet_time_series.ipynb), which uses a streamlined version of data preparation from [Tensorflow time series forecasting tutorial](https://www.tensorflow.org/tutorials/structured_data/time_series)
 * [lstm_time_series](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/lstm_time_series.ipynb) with stacked LSTMs, bidirectional LSTMs and ConvLSTM1D networks
 * [cnn_time_series](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/cnn_time_series.ipynb) with Conv1D, multi-head Conv1D, Conv2D and Inception-style models
 * [encoder_decoder](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/encoder_decoder.ipynb) which includes autoencoder with attention, encoder decoder with teacher forcing, transformer with teacher forcing and padding, encoder only with MultiHeadAttention
 * [feature_engineering](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/feature_engineering.ipynb) solar-based and meteorology-based calculated features, rolling stats, tsfeatures, catch22, bivariate features and more
 * [tsfresh_feature_engineering](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/tsfresh_feature_engineering.ipynb) automated feature engineering and selection for time series analysis of Cambridge UK weather measurements

Most of the above repositories, notebooks, web apps etc were built on both less data and less thoroughly cleaned data.

---

## Table of Contents


**TODO** Add internal links before "final" commits

Some sections may get added/deleted during development.

Don't want any broken links, so finish later.


Gradient Boosted Trees Introduction

Code Setup
 * darts Installation
 * Library Imports
 * Environment Variables
 * Custom Functions

Data Setup
 * Load pre-calculated features
 * See [feature_engineering.ipynb](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/feature_engineering.ipynb)

LightGBM Models
 * Variable Selection
 * Lag Selection
 * Hyperparameter Tuning

Comparison with Baselines

Conclusion
 * What Worked
 * What Failed
 * Rejected Ideas
 * Future Work

Metadata

---

## VAR Baseline

...

### Load Libraries

Load most of the required packages.

In [None]:
import re
import sys
import math
import timeit
import datetime
import itertools
import subprocess
import pkg_resources

import numpy as np
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
from tqdm import tqdm
from itertools import product
import statsmodels.api as sm
from statsmodels.tsa.stattools import acf, pacf, lagmat, coint
from statsmodels.tsa.stattools import adfuller, kpss, grangercausalitytests
from statsmodels.nonparametric.smoothers_lowess import lowess


# Reduces variance in results but won't eliminate it :-(
%env PYTHONHASHSEED=0
import random

# set seed to make all processes deterministic
seed = 0
random.seed(seed)
np.random.seed(seed)


%matplotlib inline

# Prevent lightgbm 'Converting column-vector to 1d array' warning
import warnings
warnings.filterwarnings(
    action   = 'ignore',
    category = UserWarning,
    module   = r'.*lightgbm'
)




### Custom Functions

Next, define some utility functions:
 * `rmse_`
 * `mse_`
 * `mae_`
 * `summarise_backtest`
 * `print_rmse_mae`
 * `drop_cols_correlated_with_feat_cols`
 * `drop_problem_cols`
 * `summarise_historic_comparison`
 * `plot_lagged_feat_imp_subplot`
 * `get_pastcov_features`
 * `get_pastcov_lags`
 * `plot_lagged_feature_importances`
 * `plot_feature_importances`
 * `get_feature_importances`
 * `expand_grid`
 * `keep_key`
 * `get_historic_comparison`
 * `_plot_xy_for_label`
 * `plot_multistep_obs_vs_preds`
 * `plot_multistep_obs_vs_mean_preds_by_step`
 * `plot_multistep_obs_preds_dists`
 * `plot_multistep_residuals`
 * `plot_multistep_residuals_dist`
 * `plot_multistep_residuals_vs_predicted`
 * `se_`
 * `metric_ci_vals`
 * `plot_horizon_metrics`
 * `plot_horizon_metrics_boxplots`
 * `plot_multistep_diagnostics`
 * `_filter_out_missing`
 * `plot_multistep_forecast_examples`
 * `get_rmse_mae_from_backtest`
 * `plot_catboost_learning_curve`
 * `plot_lgb_learning_curve`
 * `drop_correlated_cols`
 * `get_feature_selection_scores`
 * `plot_observation_examples`
 * `sanity_check_df_rows_cols_labels`
 * `sanity_check_before_after_dfs`
 * `sanity_check_train_valid_test`
 * `print_train_valid_test_shapes`
 * `plot_feature_history`
 * `plot_feature_history_separately`
 * `check_high_low_thresholds`
 * `get_features_filename`
 * `merge_data_and_aggs`
 * `get_rolling_features`
 * `finalise_rolling_features`
 * `print_null_columns`
 * `print_na_locations`
 * `get_features`
 * `get_darts_series`
 * `plot_short_term_acf`
 * `plot_long_term_acf`


In [None]:
def _check_obs_preds_lens_eq(obs, preds):
    obs_preds_lens_eq = 1

    if len(obs) != len(preds):
        print("obs:  ", len(obs))
        print("preds:", len(preds))
        obs_preds_lens_eq = 0

    return obs_preds_lens_eq


def rmse_(obs, preds):
    if _check_obs_preds_lens_eq(obs, preds) == 0:
        stop()
    else:
        return np.sqrt(np.mean((obs - preds) ** 2))


def mse_(obs, preds):
    if _check_obs_preds_lens_eq(obs, preds) == 0:
        stop()
    else:
        return np.mean((obs - preds) ** 2)


def mae_(obs, preds):
    "mean absolute error - equivalent to the keras loss function"
    if _check_obs_preds_lens_eq(obs, preds) == 0:
        stop()
    else:
        return np.mean(np.abs(obs - preds))      # keras loss
        # return np.median(np.abs(obs - preds))  # earlier baselines


# TODO Remove me?
def summarise_backtest(backtest, df, horizon = HORIZON, digits = 6, y_col = Y_COL):

    if len(backtest[0]) == 1:
        print("\n# Backtest RMSE:", round(rmse_(val_ser[-len(backtest):].values(), backtest.values()), digits))
        print("# Backtest MAE: ",   round( mae_(val_ser[-len(backtest):].values(), backtest.values()), digits))

        print("\nbacktest[", y_col, "]:\n", sep='')
        backtest_stats = stats.describe(backtest[y_col].values())
        print("count\t", backtest_stats[0])
        print("mean\t",  round(backtest_stats[2][0], digits))
        print("std\t",   round(np.sqrt(backtest_stats[3][0]), digits))
        print("min\t",   round(np.min(backtest_stats[1]), digits))
        print("25%\t",   round(np.percentile(backtest[y_col].values(), 25), digits))
        print("50%\t",   round(np.median(backtest[y_col].values()), digits))
        print("75%\t",   round(np.percentile(backtest[y_col].values(), 75), digits))
        print("max\t",   round(np.max(backtest_stats[1]), digits))
    elif len(backtest[0]) == horizon:
        preds_df = pd.concat([backtest[i].pd_dataframe() for i in range(len(backtest))], axis=0)
        trues_df = df.loc[preds_df.index, [y_col]]
        hist_comp = pd.concat([trues_df, preds_df[y_col]], axis = 1)
        hist_comp.columns = [y_col, 'pred']
        list_int = [i for i in range(1, horizon + 1)]
        reps = len(hist_comp) // len(list_int)
        hist_comp['step'] = np.tile(list_int, reps)

        print("\nBacktest RMSE all:", round(rmse_(hist_comp[y_col], hist_comp['pred']), digits))
        print("Backtest MAE all: ",    round(mae_(hist_comp[y_col], hist_comp['pred']), digits))

        print("\n# Backtest RMSE 48th:", round(rmse_(hist_comp.loc[hist_comp['step'] == horizon, y_col], \
                                                     hist_comp.loc[hist_comp['step'] == horizon, 'pred']), digits))
        print("# Backtest MAE 48th: ",    round(mae_(hist_comp.loc[hist_comp['step'] == horizon, y_col], \
                                                     hist_comp.loc[hist_comp['step'] == horizon, 'pred']), digits))

        lasttest_stats = stats.describe(hist_comp['pred'])
        print("\nbacktest[", y_col, "]:\n", sep='')
        print("count\t", len(hist_comp['pred']))
        print("mean\t",  round(lasttest_stats[2], digits))
        print("std\t",   round(np.sqrt(lasttest_stats[3]), digits))
        print("min\t",   round(np.min(lasttest_stats[1]), digits))
        print("25%\t",   round(np.percentile(hist_comp['pred'], 25), digits))
        print("50%\t",   round(np.median(hist_comp['pred']), digits))
        print("75%\t",   round(np.percentile(hist_comp['pred'], 75), digits))
        print("max\t",   round(np.max(lasttest_stats[1]), digits))


def print_rmse_mae(obs, preds, postfix_str, prefix_str = '', digits = 6):
    print(prefix_str, "Backtest RMSE ", postfix_str, ": ",
          round(rmse_(obs, preds), digits),
          sep='')
    print(prefix_str, "Backtest MAE ",  postfix_str, ":  ",
          round( mae_(obs, preds), digits),
          sep='')
    print()


def drop_cols_correlated_with_feat_cols(df, feats_df, threshold=0.95):

  for feat_col in feats_df.columns:
    corrs = df.corrwith(feats_df[feat_col])
    drop_cols = corrs[(corrs > threshold) & (corrs != 1.0)]

    for i in range(len(drop_cols)):
      drop_col = drop_cols.index[i]
      if drop_col in df.columns:  # and drop_col not in feats_df.columns:
        del df[drop_col]

  return df


def drop_problem_cols(df, lag, drop_cor=True,
                      var_cutoff=0.05, cor_cutoff=0.95, na_cutoff=0.05,
                      verbose = False):

  if verbose:
    print('drop_problem_cols - start:', df.shape)


  # drop all NA columns
  df = df.dropna(axis = 1, how = 'all')

  if verbose:
    print('drop_problem_cols - after dropna:', df.shape)


  # drop single value columns
  df = df.loc[:, (df != df.iloc[lag]).any()]

  if verbose:
    print('drop_problem_cols - after drop single value cols:', df.shape)


  # drop low variance columns
  if 'ds' in df.columns:
    df = df.drop(['ds'], axis=1)
    df = df.loc[:, df.std() > var_cutoff]
    df['ds'] = df.index
  else:
    df = df.loc[:, df.std() > var_cutoff]

  if verbose:
    print('drop_problem_cols - after drop low var cols:', df.shape)


  # drop highly correlated columns
  if drop_cor:
    df = drop_correlated_cols(df, cor_cutoff)

  if verbose:
    print('drop_problem_cols - after drop correlated cols:', df.shape)


  # drop cols with high % of NA values
  pc_thresh = int(na_cutoff * df.shape[0])
  #print('five_pc_thresh:', five_pc_thresh)
  print('columns with null values:')
  display(df.isnull().sum())
  df = df.loc[:, df.isnull().sum() < pc_thresh]

  if verbose:
    print('drop_problem_cols - after drop high % of NAs:', df.shape)

  return df


def summarise_historic_comparison(hc, df, horizon = HORIZON,
                                  digits = 6,
                                  y_col = Y_COL,
                                  df_name = 'valid_df'):

    print('\n')
    print_rmse_mae(hc[y_col], hc['pred'], 'all')

    obs   = hc.loc[hc['step'] == horizon, y_col]
    preds = hc.loc[hc['step'] == horizon, 'pred']
    if horizon == 1:
      post_str = '1st'
    elif horizon == 2:
      post_str = '2nd'
    elif horizon == 3:
      post_str = '3rd'
    else:
      post_str = str(horizon) + 'th'
    print_rmse_mae(obs, preds, post_str, '# ')

    obs   = hc.loc[hc['missing'] == 0.0, y_col]
    preds = hc.loc[hc['missing'] == 0.0, 'pred']
    print_rmse_mae(obs, preds, 'miss==0')

    obs   = hc.loc[hc['missing'] == 1.0, y_col]
    preds = hc.loc[hc['missing'] == 1.0, 'pred']
    print_rmse_mae(obs, preds, 'miss==1')

    if y_col == 'y_des':
      # preds = hc['pred'] - hc['y_seasonal']
      preds = hc['pred'] - hc['y_yearly'] - hc['y_daily'] - hc['y_trend']
    elif y_col == 'y_des_fft':
      preds = hc['pred'] - hc['y_fft']
    elif y_col == 'y':
      preds = hc['pred']
    elif y_col == 'y_res':
      preds = hc['pred'] - hc['y_yearly'] - hc['y_daily']

    preds.dropna(inplace=True)
    lasttest_stats = stats.describe(preds)
    print("\nbacktest['", y_col, "']:", sep='')
    print("count\t", len(preds))
    print("mean\t",  round(lasttest_stats[2], digits))
    print("std\t",   round(np.sqrt(lasttest_stats[3]), digits))
    print("min\t",   round(np.min(lasttest_stats[1]), digits))
    print("25%\t",   round(np.percentile(preds, 25), digits))
    print("50%\t",   round(np.median(preds), digits))
    print("75%\t",   round(np.percentile(preds, 75), digits))
    print("max\t",   round(np.max(lasttest_stats[1]), digits))

    print("\n", df_name, "['", y_col, "']:\n", df[y_col].describe(), '\n', sep='')


def plot_lagged_feat_imp_subplot(fi_df, subset):
    bar_height = 0.25
    title = 'Feature importance'

    fi_max = fi_df['importance'].max()
    if fi_max > 100:
      xl_max = int(np.ceil(fi_max / 100.0)) * 100
    elif fi_max > 10:
      xl_max = int(np.ceil(fi_max / 10.0)) * 10
    else:
      xl_max = fi_max

    if subset is not None:
      fi_df = fi_df.loc[fi_df['feature'].str.contains(subset, regex=True), :]
      title += ' - ' + subset + ' features'
      plt.figure(figsize=(7, 3))
    else:
      title += ' - all features'
      plt.figure(figsize=(20, 10))

    plt.xlim(0, xl_max)
    plt.barh(width  = fi_df['importance'],
             y      = fi_df['feature'],
             height = bar_height)


    def plot_highlighted_lagged_feat_imp_subset(data, subset_str, hl_col, bar_height = 0.25):
        data_subset = data.loc[data['feature'].str.contains(subset_str, regex=True), :]

        plt.barh(width  = data_subset['importance'],
                 y      = data_subset['feature'],
                 height = bar_height,
                 color  = hl_col)


    plot_highlighted_lagged_feat_imp_subset(fi_df, 'shadow', 'red')

    yregex = '^' + Y_COL
    ytarg_str = yregex + '_target_'
    plot_highlighted_lagged_feat_imp_subset(fi_df, ytarg_str, 'green')

    ypcov_str = yregex + '_pastcov_'
    plot_highlighted_lagged_feat_imp_subset(fi_df, ypcov_str, 'blue')

    plt.title(title)
    plt.show()


def get_pastcov_features(fi_df):
    pcov_feats_long = fi_df.loc[fi_df['feature'].str.contains('_pastcov_'), 'feature'].to_list()

    r = re.compile('_pastcov_.*$')

    pcov_feats_dups = [r.sub('', pcov_feat_long) for pcov_feat_long in pcov_feats_long]
    pcov_feats = list(set(pcov_feats_dups))

    if Y_COL in pcov_feats:
      pcov_feats.remove(Y_COL)

    return pcov_feats


def get_pastcov_lags(fi_df):
    pcov_feats_long = fi_df.loc[fi_df['feature'].str.contains('_pastcov_'), 'feature'].to_list()

    r = re.compile('^.*_pastcov_')

    pcov_lags_dups = [r.sub('', pcov_feat_long) for pcov_feat_long in pcov_feats_long]
    pcov_lags = list(set(pcov_lags_dups))

    return pcov_lags


# TODO: Also, consider combining plot_feature_importances and
#       plot_lagged_feature_importances into a single function
def plot_lagged_feature_importances(model):
    '''Plot feature importance for models with multiple lags

    Should be easier to compare importance across features, lags, targets
    and past covariates

    No support for future covariates

    Use plot_feature_importances function for lags = 1, past_cov_lags = 1 models

    '''

    imp_thresh = 0
    if isinstance(model, CatBoostModel):
      imp_thresh = 0  # some catboost importance values below 1
    elif isinstance(model, LightGBMModel):
      imp_thresh = 1

    imp_df = pd.DataFrame({'feature':    model.lagged_feature_names,
                           'importance': model.model.feature_importances_})
    imp_df = imp_df.sort_values('importance')

    plot_lagged_feat_imp_subplot(imp_df, None)
    plot_lagged_feat_imp_subplot(imp_df, '^' + Y_COL)

    pcov_feats = get_pastcov_features(imp_df)
    for pcov_feat in pcov_feats:
        plot_lagged_feat_imp_subplot(imp_df, '^' + pcov_feat)

    pcov_lags = get_pastcov_lags(imp_df)
    for pcov_lag in pcov_lags:
        plot_lagged_feat_imp_subplot(imp_df, '_pastcov_' + pcov_lag + '$')


# TODO: Consider combining plot_feature_importances and
#       plot_lagged_feature_importances into a single function
def plot_feature_importances(model, \
                             y_col        = Y_COL, \
                             include_cols = None,  \
                             exclude_cols = None):
    '''Plot feature importances from lightGBM models

    WARNING: Only works with lags = 1 and lags_past_cov = 1
             Use plot_lagged_feature_importances for models with additional lags
    '''

    imp_thresh = 0
    if isinstance(model, CatBoostModel):
      imp_thresh = 0  # some catboost importance values below 1
    elif isinstance(model, LightGBMModel):
      imp_thresh = 1

    if include_cols is not None:
        col_names = include_cols
    else:
        col_names = model.lagged_feature_names

    cols_df = pd.DataFrame({'feature': col_names,
                            'importance': model.model.feature_importances_})
    cols_df = cols_df.sort_values('importance')
    cols_df = cols_df[cols_df.importance >= imp_thresh]

    if exclude_cols is not None:
      cols_df = cols_df[~cols_df['feature'].isin(exclude_cols)]

    # print("cols_df:", cols_df, sep='\n')

    plt.figure(figsize=(20, 10))
    plt.barh(width  = cols_df['importance'],
             y      = cols_df['feature'],
             height = 0.25);
    plt.barh(width  = cols_df.loc[cols_df['feature'].str.contains('shadow'),
                                  'importance'],
             y      = cols_df.loc[cols_df['feature'].str.contains('shadow'),
                                  'feature'],
             height = 0.25,
             color  = 'red')
    plt.title('Feature importance\nimportance threshold = ' + str(imp_thresh))
    plt.show()


def get_feature_importances(model,
                            y_col = Y_COL,
                            imp_thresh   = None,
                            include_cols = None,
                            exclude_cols = None,
                            verbose      = True):

    # imp_thresh = 0
    if imp_thresh is None and isinstance(model, CatBoostModel):
      imp_thresh = 0  # some catboost importance values below 1
    elif imp_thresh is None and isinstance(model, LightGBMModel):
      imp_thresh = 1

    if include_cols is not None:
        # col_names = pd.Series(include_cols)
        col_names = include_cols
    else:
        col_names = model.lagged_feature_names

    cols_df = pd.DataFrame({'feature': col_names,
                            'importance': model.model.feature_importances_})
    cols_df = cols_df.sort_values('importance')
    cols_df = cols_df[cols_df.importance >= imp_thresh]

    if exclude_cols is not None:
      cols_df = cols_df[~cols_df['feature'].isin(exclude_cols)]

    # print("cols_df:", cols_df, sep='\n')

    shadow_str = '_shadow_'
    if shadow_str in ''.join(cols_df['feature'].values):
        shad_thresh = cols_df.loc[cols_df['feature'].str.contains(shadow_str),
                                  'importance'].tail(1).values[0]
    else:
        shad_thresh = 0.0
    # print('shad_thresh:', shad_thresh)

    cols_df = cols_df[cols_df['importance'] > shad_thresh]

    if verbose:
      print(cols_df.to_string(index=False), sep='\n')

    inc_cols = []
    for feature in cols_df['feature'].values:
        inc_cols.append(re.sub('_(pastcov|target|futcov)_lag.*', '', feature))

    # remove duplicates from a list, while preserving order
    seen = set()
    inc_cols = [col for col in inc_cols if col not in seen and not seen.add(col)]

    if verbose:
      print('\ninc_cols:', inc_cols)

    return inc_cols


def expand_grid(dictionary):
   return pd.DataFrame([row for row in product(*dictionary.values())],
                       columns = dictionary.keys())


def keep_key(d, k):
  """ models = keep_key(models, 'datasets') """
  return {k: d[k]}


def get_historic_comparison(backtest, df, y_col = Y_COL, horizon = HORIZON):
    if horizon > 1:
      assert len(backtest[0]) > 1

    if y_col == 'y_des':
      # cols = ['y_des', 'y_seasonal']
      cols = ['y_des', 'y_yearly', 'y_daily', 'y_trend']
    elif y_col == 'y_des_fft':
      cols = ['y_des_fft', 'y_fft']
    elif y_col == 'y_res':
      cols = ['y_res', 'y_yearly', 'y_daily']
    elif y_col == 'y':
      cols = ['y']

    # cols.extend(['missing', 'mi_filled', 'isd_outlier', 'hist_average'])
    cols.extend(['missing', 'isd_outlier'])

    preds_df = pd.concat([backtest[i].pd_dataframe() for i in range(len(backtest))], axis=0)
    trues_df = df.loc[preds_df.index, cols]

    hist_comp = pd.concat([trues_df, preds_df[y_col]], axis = 1)
    cols.append('pred')
    hist_comp.columns = cols

    # re-seasonalise
    if y_col == 'y_des':
      hist_comp['y_des'] += hist_comp['y_yearly'] + hist_comp['y_daily'] + hist_comp['y_trend']
      hist_comp['pred']  += hist_comp['y_yearly'] + hist_comp['y_daily'] + hist_comp['y_trend']
    elif y_col == 'y_des_fft':
      hist_comp['y_des_fft'] += hist_comp['y_fft']
      hist_comp['pred']      += hist_comp['y_fft']
    elif y_col == 'y_res':
      hist_comp['y_res'] += hist_comp['y_yearly'] + hist_comp['y_daily']
      hist_comp['pred']  += hist_comp['y_yearly'] + hist_comp['y_daily']


    hist_comp['res']    = hist_comp[y_col] - hist_comp['pred']
    hist_comp['res^2']  = hist_comp['res'] * hist_comp['res']
    hist_comp['res_sign']  = np.sign(hist_comp['res'])
    hist_comp['missing']   = hist_comp['missing']#.astype(int)
    # hist_comp['mi_filled'] = hist_comp['mi_filled']#.astype(int)
    # hist_comp['hist_average'] = hist_comp['hist_average']#.astype(int)

    list_int = [i for i in range(1, horizon+1)]
    reps = len(hist_comp) // len(list_int)
    hist_comp['step'] = np.tile(list_int, reps)
    hist_comp['id']   = np.repeat([i for i in range(reps)], horizon)
    hist_comp['date'] = hist_comp.index.values

    return hist_comp


def plot_one_step_abs_err_boxplot(one_step, title):
  one_step['abs_err'] = np.abs(one_step['res'])
  one_step[['abs_err']].boxplot(meanline  = False,
                                showmeans = True,
                                showcaps  = True,
                                showbox   = True,
                                # showfliers = False,
                                )
  plt.title(title + '\nboxplot with mean and median')
  plt.suptitle('')
  plt.ylabel('absolute error')
  plt.show()


def plot_one_step_residuals_dist(one_step, title):
  plt.figure(figsize = (12, 16))
  plt.subplot(5, 1, 5)
  pd.Series(one_step['res']).plot(kind = 'density', label='residuals')
  plt.xlim(-10, 10)
  plt.title(title)
  plt.show()


def plot_one_step_residuals(one_step, title):
  x_miss = one_step.loc[one_step['missing'] == 1.0, 'obs'].index
  y_miss = one_step.loc[one_step['missing'] == 1.0, 'res']

  plt.figure(figsize = (12, 16))
  plt.subplot(5, 1, 4)
  plt.scatter(x = one_step.index, y = one_step['res'])
  plt.scatter(x_miss, y_miss, color='red', label='missing')
  plt.axhline(y = 0, color = 'grey')
  plt.xlabel('Index position')
  plt.ylabel('Residuals')
  plt.legend(loc='lower right')
  plt.title(title)
  plt.show()


def plot_one_step_obs_preds_dists(one_step, title):
  obs   = one_step['obs']
  preds = one_step['preds']
  r2score = r2_score(obs, preds)

  plt.figure(figsize = (12, 16))
  plt.subplot(5, 1, 3)
  pd.Series(obs).plot(kind = 'density', label='observations')
  pd.Series(preds).plot(kind = 'density', label='predictions')
  plt.xlim(-10, 40)
  plt.title(title)
  plt.legend()
  plt.annotate("$R^2$ = {:.3f}".format(r2score), (-7.5, 0.055))
  # plt.tight_layout()
  plt.show()


def plot_one_step_obs_vs_preds(one_step, title):

  obs   = one_step['obs']
  preds = one_step['preds']
  x_miss = one_step.loc[one_step['missing'] == 1.0, 'obs']
  y_miss = one_step.loc[one_step['missing'] == 1.0, 'preds']

  r2score = r2_score(obs, preds)

  plt.figure(figsize = (12, 16))
  plt.subplot(5, 1, 1)
  plt.scatter(x = obs, y = preds)
  plt.scatter(x_miss, y_miss, color='red', label='missing')
  plt.axline((0, 0), slope=1.0, color="grey")
  plt.xlabel('Observations')
  plt.ylabel('Predictions')
  plt.legend(loc='lower right')
  plt.annotate("$R^2$ = {:.3f}".format(r2score), (-9, 31))
  plt.title(title)
  plt.xlim((-10, 35))
  plt.ylim((-10, 35))
  plt.show()


def plot_one_step_diagnostics(model, data, val_series, val_pastcov_series, title, val_fut_cov=None):
  plot_feature_importances(model)

  # re-seasonalise observations
  if Y_COL == 'y_des':
    obs = data['y_des'] + data['y_yearly'] + data['y_daily'] + data['y_trend']
  elif Y_COL == 'y_des_fft':
    obs = data['y_des_fft'] + data['y_fft']
  else:
    obs = data[Y_COL]

  # print('data:', data.shape)
  # display(data[['y_des_fft', 'y_fft']])
  # print('obs:', obs.shape)
  # display(obs)

  if val_fut_cov is None:
    res = model.residuals(series = val_series,
                          past_covariates = val_pastcov_series,
                          retrain = False).pd_series()
  else:
    res = model.residuals(series = val_series,
                          past_covariates = val_pastcov_series,
                          future_covariates = val_fut_cov,
                          retrain = False).pd_series()

  preds = obs + res
  preds = preds.dropna()
  obs  = obs[preds.index]
  res  = res[preds.index]
  miss = data.loc[preds.index, 'missing']

  print_rmse_mae(obs, preds, '1st', '# ')

  one_step = pd.concat([obs, preds, res, miss], axis=1)
  one_step.columns = ['obs', 'preds', 'res', 'missing']

  title = 'step = 1 ' + title
  plot_one_step_obs_vs_preds(one_step, title)
  # plot_obs_vs_mean_preds_by_step(hist, title)
  plot_one_step_obs_preds_dists(one_step, title)
  plot_one_step_residuals(one_step, title + ' residuals')
  plot_one_step_residuals_dist(one_step, title + ' residuals density')
  plot_one_step_residuals_acf(one_step, title + ' residuals acf')
  plot_one_step_residuals_qq(one_step, title + ' residuals qq-plot')
  plot_one_step_abs_err_boxplot(one_step, title)


def plot_one_step_residuals_qq(one_step, title_):
  fig, axs = plt.subplots(figsize=(6, 6))
  sm.qqplot(one_step['res'], line='q', ax=axs)
  axs.set_title(title_)
  plt.show()


def plot_one_step_residuals_acf(one_step, title_, max_lags = 300):
  plt.figure(figsize = (6, 6))

  acf = pd.DataFrame()
  acf_feat = 'res'

  acf[acf_feat] = [one_step[acf_feat].autocorr(l) for l in range(1, max_lags)]
  plt.plot(acf[acf_feat], label='residual')

  plt.axhline(0, linestyle='--', c='black')
  plt.ylabel('autocorrelation')
  plt.xlabel('time lags')
  plt.title(title_)
  plt.show()


def _plot_xy_for_label(data, label, x_feat, y_feat, color):
    x = data.loc[data[label] == 1.0, x_feat]
    y = data.loc[data[label] == 1.0, y_feat]

    if len(x) > 0:
        plt.scatter(x = x, y = y, color=color, alpha=0.5, label=label)


def plot_multistep_obs_vs_preds(hist, title, y_col=Y_COL):
    plt.figure(figsize = (12, 16))
    plt.subplot(5, 1, 1)
    plt.scatter(x = hist[y_col], y = hist['pred'])
    _plot_xy_for_label(hist, 'missing',      y_col, 'pred', 'red')
    # _plot_xy_for_label(hist, 'hist_average', y_col, 'pred', 'yellow')
    # _plot_xy_for_label(hist, 'mi_filled',    y_col, 'pred', 'purple')
    plt.axline((0, 0), slope=1.0, color="grey")
    plt.xlabel('Observations')
    plt.ylabel('Predictions')
    plt.legend(loc='lower right')
    obs   = hist.loc[hist[[y_col, 'pred']].notnull().all(1), y_col]
    preds = hist.loc[hist[[y_col, 'pred']].notnull().all(1), 'pred']
    r2score = r2_score(obs, preds)
    plt.annotate("$R^2$ = {:.3f}".format(r2score), (-9, 31))
    plt.title(title)
    plt.xlim((-10, 35))
    plt.ylim((-10, 35))
    plt.show()


def plot_multistep_obs_vs_mean_preds_by_step(hist, title, y_col = Y_COL,
                                             step_ = HORIZON, ci = False):
    '''For specific step, plot mean prediction for each observation

    A 95 % confidence interval is plotted, but can be disabled
    '''

    mean_preds = hist.loc[hist['step'] == step_, [y_col, 'pred']].groupby(y_col).mean('pred')
    obs   = mean_preds.index.values
    preds = mean_preds['pred'].values

    plt.figure(figsize = (12, 16))
    ax = plt.subplot(5, 1, 2)
    plt.plot(obs, preds)

    if ci is True:
      ci = 1.96 * np.std(preds) / np.sqrt(len(obs))
      # print(ci)
      ax.fill_between(obs, (preds - ci), (preds + ci), color='b', alpha=.1)

    plt.axline((0, 0), slope=1.0, color="grey")
    r2score = r2_score(obs, preds)
    plt.annotate("$R^2$ = {:.3f} - step = {}".format(r2score, step_), (-9, 31))
    plt.title(title + ' step = ' + str(step_))
    plt.xlabel('Temperature')
    plt.ylabel('Mean prediction')
    plt.xlim((-10, 35))
    plt.ylim((-10, 35))
    plt.show()


def plot_multistep_obs_preds_dists(hist, title, y_col=Y_COL):
    obs   = hist.loc[hist[[y_col, 'pred']].notnull().all(1), y_col]
    preds = hist.loc[hist[[y_col, 'pred']].notnull().all(1), 'pred']
    r2score = r2_score(obs, preds)
    plt.figure(figsize = (12, 16))
    plt.subplot(5, 1, 3)
    pd.Series(obs).plot(kind = 'density', label='observations')
    pd.Series(preds).plot(kind = 'density', label='predictions')
    plt.xlim(-10, 40)
    plt.title(title)
    plt.legend()
    plt.annotate("$R^2$ = {:.3f}".format(r2score), (-7.5, 0.055))
    #plt.tight_layout()
    plt.show()


def plot_multistep_residuals(hist, title):
    plt.figure(figsize = (12, 16))
    plt.subplot(5, 1, 4)
    plt.scatter(x = range(len(hist)), y = hist['res'])
    hist['id.2'] = range(len(hist))
    _plot_xy_for_label(hist, 'missing',      'id.2', 'res', 'red')
    # _plot_xy_for_label(hist, 'hist_average', 'id.2', 'res', 'yellow')
    # _plot_xy_for_label(hist, 'mi_filled',    'id.2', 'res', 'purple')
    plt.axhline(y = 0, color = 'grey')
    plt.xlabel('Index position')
    plt.ylabel('Residuals')
    plt.legend(loc='lower right')
    plt.title(title)
    plt.show()


def plot_multistep_residuals_dist(hist, title):
    plt.figure(figsize = (12, 16))
    plt.subplot(5, 1, 5)
    pd.Series(hist['res']).plot(kind = 'density', label='residuals')
    plt.xlim(-10, 10)
    plt.title(title)
    plt.show()


# Unused?
# TODO Diagonal structure of these plots might need further consideration
#      Add lowess fit to check for problems
def plot_multistep_residuals_vs_predicted(hist, title):
    plt.subplot(5, 1, 5)
    plt.scatter(x = hist['pred'], y = hist['res'])
    _plot_xy_for_label(hist, 'missing',      'pred', 'res', 'red')
    # _plot_xy_for_label(hist, 'hist_average', 'pred', 'res', 'yellow')
    # _plot_xy_for_label(hist, 'mi_filled',    'pred', 'res', 'purple')
    plt.axhline(y = 0, color = 'grey')

    n = 24  # slow to run all points :-(
            # 12 takes approx 2 mins to run
            #  8 takes approx 4 mins to run
    xy = hist.iloc[::n, :]
    # x = hist.iloc[::n, :]
    y_l = lowess(xy['res'], xy['pred'])
    plt.plot(y_l[:, 0], y_l[:, 1], 'green', label='lowess fit')

    plt.xlabel('Predictions')
    plt.ylabel('Residuals')
    plt.legend(loc='upper right')
    plt.title(title);


def se_(obs, preds, metric):
    '''Standard error of sum of squared residuals or sum of absolute residuals'''

    if _check_obs_preds_lens_eq(obs, preds) == 0:
        stop()

    if metric == 'rmse':
        se = np.sqrt(np.sum((obs - preds) ** 2) / len(obs))
    elif metric == 'mae':
        se = np.sqrt(np.sum(np.abs(obs - preds)) / len(obs))
    else:
        print('Unrecognised metric:', metric)
        print("metric should be 'rmse' or 'mae'")
        stop()

    return se


def metric_ci_vals(test_val, se, z_val = 1.95996):
    cil = z_val * se
    # print('cil:', cil)
    metric_cil = test_val - cil
    metric_ciu = test_val + cil

    return metric_cil, metric_ciu


# TODO: Remove unused confidence intervals
# NOTE: VAR baseline metrics cvar_rmse and cvar_mae hardcoded to 48 steps
def plot_horizon_metrics(hist, title, y_col=Y_COL, horizon = HORIZON, ci=False):
    steps = [i for i in range(1, horizon+1)]

    # calculate metrics
    z_val_95 = 1.95996
    z_val_50 = 0.674
    rmse_h,   mae_h    = np.zeros(horizon), np.zeros(horizon)
    res_se_h, abs_se_h = np.zeros(horizon), np.zeros(horizon)
    rmse_ciu, rmse_cil = np.zeros(horizon), np.zeros(horizon)
    mae_ciu,  mae_cil  = np.zeros(horizon), np.zeros(horizon)

    for i in range(1, horizon+1):
      obs   = hist.loc[hist['step'] == i, y_col]
      preds = hist.loc[hist['step'] == i, 'pred']
      rmse_h[i-1] = rmse_(obs, preds)
      mae_h[i-1]  =  mae_(obs, preds)
      res_se_h[i-1] = se_(obs, preds, 'rmse')
      abs_se_h[i-1] = se_(obs, preds, 'mae')
      # mae_h[i]  = np.median(np.abs(obs - preds))  # for comparison with baselines
      rmse_cil[i-1], rmse_ciu[i-1] = metric_ci_vals(rmse_h[i-1], res_se_h[i-1], z_val_50)
      mae_cil[i-1],  mae_ciu[i-1]  = metric_ci_vals(mae_h[i-1],  abs_se_h[i-1], z_val_50)

    # print('rmse_h:', rmse_h)
    # print('mae_h:',  mae_h)

    # plot metrics for horizons
    fig, axs = plt.subplots(1, 2, figsize = (14, 7))
    fig.suptitle(title + ' forecast horizon errors')
    axs = axs.ravel()


    mean_val_lab = title + ' mean value'
    axs[0].plot(steps, rmse_h, color='green', label=title)

    if ci is True:
      axs[0].fill_between(steps, rmse_cil, rmse_ciu, color='green', alpha=0.25)

    # i - initial, u - updated, c - corrected
    #ivar_rmse = np.array([0.39, 0.52, 0.64, 0.75, 0.86, 0.96, 1.06, 1.15, 1.23,
    #                     1.31, 1.38, 1.45, 1.51, 1.57, 1.63, 1.68, 1.73, 1.77,
    #                     1.81, 1.85, 1.89, 1.92, 1.96, 1.99, 2.02, 2.05, 2.08,
    #                     2.1 , 2.13, 2.15, 2.18, 2.2 , 2.22, 2.24, 2.26, 2.28,
    #                     2.3 , 2.31, 2.33, 2.35, 2.36, 2.38, 2.39, 2.4 , 2.42,
    #                     2.43, 2.44, 2.45])
    # NOTE: uvar_rmse tested on test_df
    #uvar_rmse = np.array([0.36, 0.49, 0.6, 0.7, 0.8, 0.89, 0.98, 1.06, 1.14,
    #                      1.21, 1.28, 1.35, 1.41, 1.47, 1.52, 1.57, 1.62, 1.66,
    #                      1.7, 1.74, 1.78, 1.81, 1.84, 1.87, 1.9, 1.93, 1.96,
    #                      1.99, 2.01, 2.03, 2.06, 2.08, 2.1, 2.12, 2.14, 2.16,
    #                      2.18,  2.19, 2.21, 2.23, 2.24, 2.26, 2.27, 2.29, 2.3,
    #                      2.31, 2.33, 2.34])
    cvar_rmse = np.array([0.49318888, 0.70222546, 0.88570688, 1.05495349,
    1.21081157, 1.34945832, 1.46844034, 1.57779714, 1.67754323, 1.7665827,
    1.84567039, 1.91561743, 1.97899766, 2.03616174, 2.08661944, 2.13396441,
    2.17809725, 2.21946156, 2.25780078, 2.29370568, 2.3272055,  2.35760153,
    2.38520845, 2.41076185, 2.43404716, 2.45466806, 2.47361784, 2.49117761,
    2.50625606, 2.52023589, 2.53319205, 2.54566125, 2.55764924, 2.56870554,
    2.57976955, 2.59102429, 2.6018822, 2.61242356, 2.62280045, 2.63353767,
    2.64410312, 2.65458709, 2.66532837, 2.67609086, 2.68675178, 2.69745108,
    2.71002892, 2.72445726])
    #axs[0].plot(steps, ivar_rmse, color='black', label='Initial VAR')
    axs[0].plot(steps, cvar_rmse, color='blue', label='Updated VAR')
    axs[0].hlines(np.mean(rmse_h), xmin=1, xmax=horizon,
                  color='green', linestyles='dotted', label=mean_val_lab)
    axs[0].hlines(np.mean(cvar_rmse), xmin=1, xmax=horizon,
                  color='blue', linestyles='dotted', label='Updated VAR mean value')
    axs[0].set_xlabel("horizon - half hour steps")
    axs[0].set_ylabel("rmse")


    axs[1].plot(steps, mae_h, color='green', label=title)

    if ci is True:
      axs[1].fill_between(steps, mae_cil, mae_ciu, color='green', alpha=0.25)

    # NOTE: ivar_mae tested on test_df
    #ivar_mae = np.array([0.39, 0.49, 0.57, 0.66, 0.74, 0.83, 0.91, 0.98, 1.05,
    #                    1.12, 1.18, 1.24, 1.29, 1.34, 1.39, 1.43, 1.47, 1.5 ,
    #                    1.53, 1.56, 1.59, 1.62, 1.64, 1.66, 1.68, 1.7 , 1.72,
    #                    1.73, 1.75, 1.76, 1.77, 1.78, 1.8 , 1.81, 1.82, 1.83,
    #                    1.83, 1.84, 1.85, 1.85, 1.86, 1.86, 1.87, 1.87, 1.88,
    #                    1.88, 1.89, 1.89])
    #uvar_mae = np.array([0.36, 0.45, 0.53, 0.61, 0.69, 0.76, 0.83, 0.9, 0.97,
    #                     1.03, 1.09, 1.14, 1.19, 1.24, 1.28, 1.32, 1.36, 1.4,
    #                     1.43, 1.46, 1.49, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62,
    #                     1.63, 1.65, 1.66, 1.68, 1.69, 1.7, 1.71, 1.72, 1.73,
    #                     1.74, 1.74, 1.75, 1.75, 1.76, 1.76, 1.77, 1.77, 1.78,
    #                     1.78, 1.78, 1.78])
    cvar_mae = np.array([0.34694645, 0.50765333, 0.65132003, 0.78584432,
    0.9077075,  1.01705088, 1.11113622, 1.19759807, 1.27696634, 1.34941444,
    1.4134705,  1.47180058, 1.52304802, 1.56961154, 1.60903759, 1.64763418,
    1.68391297, 1.71690735, 1.74787094, 1.77721642, 1.80442554, 1.82951782,
    1.85358226, 1.87488643, 1.89346337, 1.91069565, 1.92613218, 1.94071845,
    1.95245349, 1.96323923, 1.9736734,  1.98370815, 1.99367508, 2.00204077,
    2.00992601, 2.01796976, 2.02747736, 2.03477489, 2.04173317, 2.04985428,
    2.05843847, 2.06731348, 2.07606609, 2.08533656, 2.09560914, 2.10668272,
    2.1183637,  2.13164371])
    #axs[1].plot(steps, ivar_mae, color='black', label='Initial VAR')
    axs[1].plot(steps, cvar_mae, color='blue', label='Updated VAR')
    axs[1].hlines(np.mean(mae_h), xmin=1, xmax=horizon,
                  color='green', linestyles='dotted', label=mean_val_lab)
    axs[1].hlines(np.mean(cvar_mae), xmin=1, xmax=horizon,
                  color='blue', linestyles='dotted', label='Updated VAR mean value')
    axs[1].set_xlabel("horizon - half hour steps")
    axs[1].set_ylabel("mae")

    plt.legend(bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0)
    plt.show()


def plot_horizon_metrics_boxplots(hist, title):

  hist['abs_err'] = np.abs(hist['res'])
  hist[['abs_err', 'step']].boxplot(by='step',
                                    meanline=False,
                                    showmeans=True,
                                    showcaps=True,
                                    showbox=True,
                                    showfliers=False,
                                    )
  plt.title(title + '\nboxplots with mean and median')
  plt.suptitle('')
  plt.xlabel("horizon - half hour steps")
  plt.ylabel("absolute error")
  x_step = 10.0
  x_max  = np.ceil(np.max(hist.step) / x_step) * int(x_step)
  plt.xticks(np.arange(0, x_max, int(x_step)))
  plt.show()


def plot_multistep_diagnostics(hist, title, y_col=Y_COL):
  title = 'Multi-step ' + title
  plot_multistep_obs_vs_preds(hist, title, y_col)
  plot_multistep_obs_vs_mean_preds_by_step(hist, title, y_col)
  plot_multistep_obs_preds_dists(hist, title, y_col)
  plot_multistep_residuals(hist, title + ' residuals')
  plot_multistep_residuals_dist(hist, title + ' residuals density')
  plot_horizon_metrics(hist, title, y_col)
  plot_horizon_metrics_boxplots(hist, title)
  # plot_multistep_forecast_examples(hist, title + ' forecast examples')


# TODO Refactor this
#      miss, preds, obs, res, err, dates etc "family" of variables
#      is a warning sign
#      try-catch around lagged_miss is clear indication of upsteam issues
#      Consider using a better data structure
#      See also: plot_forecast_examples immediately below
def _filter_out_missing(pos_neg_rmse_all, miss, lags, subplots):
    '''Check if obs (lags and horizon) missing == 1.0
    and
    Avoid contiguous indices'''

    # print("pos_neg_rmse_all:", pos_neg_rmse_all)

    pos_neg_rmse = pd.Series(subplots)
    subplot_count = j = 0

    while subplot_count < subplots:
      restart = False
      idx = pos_neg_rmse_all.index[j]
      # print(j, idx, pos_neg_rmse_all.loc[pos_neg_rmse_all.index[j]])

      # Avoid indices in the first few observations
      # Would be incomplete
      if idx < lags:
        # print('idx < lags:', idx)
        j += 1
        continue

      # Avoid contiguous indices - don't want 877, 878, 879
      if subplot_count > 0:
        for i in range(subplot_count):
          if abs(idx - pos_neg_rmse[i]) < lags:
            # print('contiguous indices - idx, pos_neg_rmse[i]:', idx, pos_neg_rmse[i])
            restart = True
            break

      if restart is False:
        try:
            lagged_miss = (miss.loc[idx - lags, :] == 1.0).any()
        except KeyError:
            lagged_miss = True

        horizon_miss = (miss.loc[idx, :] == 1.0).any()
        missing = lagged_miss or horizon_miss
        # print("\nlagged_miss:", lagged_miss)
        # print("horizon_miss:",  horizon_miss)
        # print("missing:", missing)

        if missing is False:
          pos_neg_rmse[subplot_count] = idx
          subplot_count += 1
        #else:
        #  print('missing')

      j += 1

    return pos_neg_rmse


# TODO Refactor this
#      miss, preds, obs, res, err, dates etc "family" of variables
#      is a warning sign
#      Consider using a better data structure
#      See also: _filter_out_missing immediately above
def plot_multistep_forecast_examples(hist, title, subplots = 3, horizon = HORIZON, lags = 48):
    """Plot example forecasts with observations and lagged temperatures.
       Ensure examples are non-contiguous.

       First row shows near zero rmse forecasts.
       Second row shows most positive rmse forecasts.
       Third row shows most negative rmse forecasts.

       missing == 0 - ie no imputation for missing data
    """

    assert subplots in [3, 4, 5]

    # hist = hist.dropna()

    col = 'step'
    id_col = 'id'
    miss  = hist.pivot_table(index=id_col, columns=col, values='missing')
    preds = hist.pivot_table(index=id_col, columns=col, values='pred')
    obs   = hist.pivot_table(index=id_col, columns=col, values='y_des')
    res   = hist.pivot_table(index=id_col, columns=col, values='res')
    err   = hist.pivot_table(index=id_col, columns=col, values='res^2')
    dates = hist.pivot_table(index=id_col, columns=col, values='date')

    miss.dropna(inplace=True)
    # print("miss:", miss.shape)
    preds.dropna(inplace=True)
    # print("preds:", preds.shape)
    # obs.dropna(inplace=True)
    # print("obs:", obs.shape)
    res.dropna(inplace=True)
    # print("res:", res.shape)
    err.dropna(inplace=True)
    # print("err:", err.shape)
    dates.dropna(inplace=True)
    # print("dates:", dates.shape)
    dates = dates.iloc[err.index, :]
    # print("dates indexed:", dates.shape)

    # res_sign = np.sign(-res.mean(axis = 1))
    # err_row_means = err.mean(axis = 1)
    # rmse_rows = res_sign * np.sqrt(err_row_means)
    err_row_means = np.sum(err, axis = 1) / horizon
    res_sum = np.sum(res, axis = 1)
    # print("res_sum:",  len(res_sum))
    # print(res_sum[0:5])
    res_sign  = np.sign(np.sum(res, axis = 1))
    rmse_rows = res_sign * np.sqrt(err_row_means)
    # print("rmse_rows:", len(rmse_rows))
    # print("res_sign:",  len(res_sign))
    # print(res_sign[0:5])

    # choose forecasts - check for missing == 0
    # neg_rmse_all = np.argsort(rmse_rows)
    ##pos_rmse_all = np.flip(np.argsort(rmse_rows))
    # pos_rmse_all = np.argsort(-rmse_rows)
    neg_rmse_all = rmse_rows.sort_values()
    # print(rmse_rows.loc[neg_rmse_all.index])
    pos_rmse_all = neg_rmse_all[::-1]
    # print(rmse_rows.loc[pos_rmse_all.index])
    nz_rmse_all  = rmse_rows.abs().sort_values()
    # print(rmse_rows.loc[nz_rmse_all.index])
    # nz_rmse_all  = np.argsort(np.abs(rmse_rows))  # nz near zero
    # print("\nneg_rmse_all:", len(neg_rmse_all))
    # print(rmse_rows[neg_rmse_all[0:5]])
    # print("pos_rmse_all:", len(pos_rmse_all))
    # print(rmse_rows[pos_rmse_all[0:5]])
    # print("nz_rmse_all: ", len(nz_rmse_all))
    # print(rmse_rows[nz_rmse_all[0:5]])

    nz_rmse  = _filter_out_missing(nz_rmse_all,  miss, lags, subplots)
    pos_rmse = _filter_out_missing(pos_rmse_all, miss, lags, subplots)
    neg_rmse = _filter_out_missing(neg_rmse_all, miss, lags, subplots)

    plot_idx = np.concatenate((nz_rmse, pos_rmse, neg_rmse))
    # print("\nplot_idx:", len(plot_idx))
    # print("\nplot_idx:", plot_idx)

    # plot forecasts
    fig, axs = plt.subplots(3, subplots, sharey = True, figsize = (15, 10))
    fig.tight_layout()
    fig.subplots_adjust(hspace = 0.3, top = 0.87)
    axs = axs.ravel()

    myFmt = mdates.DateFormatter('%H:%M')

    for i in range(3 * subplots):
      # print("plot_idx[i] - lags:",plot_idx[i],  plot_idx[i] - lags)
      axs[i].plot(dates.iloc[plot_idx[i] - lags, :],
                  obs.loc[plot_idx[i] - lags, :],
                  'blue',
                  label='lagged observations')

      axs[i].plot(dates.iloc[plot_idx[i], :],
                  obs.loc[plot_idx[i], :],
                  'green',
                  label='observations')

      axs[i].plot(dates.iloc[plot_idx[i], :],
                  preds.loc[plot_idx[i], :],
                  'orange',
                  label='forecast')

      axs[i].xaxis.set_major_formatter(myFmt)
      obs_dates = dates.iloc[plot_idx[i] - lags, :]
      sub_title = "{0}, {1:d}, {2:.3f}".format(obs_dates.iloc[0],
                                               plot_idx[i],
                                               rmse_rows.loc[plot_idx[i]])
      axs[i].title.set_text(sub_title)

    fig.suptitle(title + "\ninit date, period idx, signed rmse")
    fig.text(0.5, 0.04, 'hour', ha='center')
    fig.text(0.04, 0.5, 'Temperature - $^\circ$C', va='center', rotation='vertical')
    plt.legend(bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0)
    plt.show();


# WARN This function probably has too many arguments - consider refactoring
def get_rmse_mae_from_backtest(model, param_df, i, series, past_cov, data, prefix, horizon=HORIZON, digits=6):
    backtest = model.historical_forecasts(series = series,
                                          past_covariates = past_cov,
                                          start   = 0.01,
                                          retrain = False,
                                          verbose = True,
                                          forecast_horizon = horizon,
                                          last_points_only = False)
    hc = get_historic_comparison(backtest, data)
    obs   = hc.loc[hc['step'] == horizon, Y_COL]
    preds = hc.loc[hc['step'] == horizon, 'pred']
    param_df.at[i, prefix + '_rmse'] = round(rmse_(obs, preds), digits)
    param_df.at[i, prefix + '_mae']  = round(mae_(obs,  preds), digits)

    return param_df


def plot_lgb_learning_curve(models, title = None, metric = 'l2', margin = None):
    '''Plot learning curve for lightgbm models using the lightgbm plot_metric function

    evals_results_ for validation data missing in action
    So, build 2 models - first with training data for validation
                       - second with validation data for validation
                       - pass both models in as a list
                       - order of models is important

    Training and validation curves are plotted when model.fit is called with
    both training and validation data:
    model.fit(series,
              past_covariates = past_cov,
              val_series = val_ser,
              val_past_covariates = val_past_cov)

    Primarily tested with catboost
    '''

    assert len(models) == 2

    final_rmse = []

    for model in models:
      assert hasattr(model, 'model')
      assert hasattr(model.model, 'evals_result_')

      final_rmse.append(model.model.evals_result_['valid_0'][metric][-1])

    if margin is None:
      lgb.plot_metric(models[0].model.evals_result_)
    else:
      assert margin > 0.0
      y_lim_min = min(final_rmse) - margin
      y_lim_max = max(final_rmse) + margin

      if y_lim_min < 0.0:
        y_lim_min = 0.0

      y_lim = (y_lim_min, y_lim_max)

      lgb.plot_metric(models[0].model.evals_result_, ylim = y_lim)

    plt.plot(models[1].model.evals_result_['valid_0']['l2'])
    plt.gca().get_lines()[0].set_color('blue')

    labels_ = ['train']
    if len(plt.gca().get_lines()) == 1 and plt.gca().get_label() == 'valid_0':
      labels_ = ['valid']

    if len(plt.gca().get_lines()) > 1:
      plt.gca().get_lines()[1].set_color('orange')
      labels_.append('valid')

    if title is not None:
      plt.title(title)

    plt.legend(labels = labels_)
    plt.show()


def plot_learning_curves(models, title, margin=0.05):
  print("\n")

  if type(models) is list and len(models) == 2:
    plot_lgb_learning_curve(models, title)
    plot_lgb_learning_curve(models, title, margin = margin)
  else:
    print('Unsupported number of models: ', len(models))
    print('models should have length 1 or 2!')

  print("\n")

  return None


def get_main_plot_title(pre_str, lag_params, mod_params):
  lag_params_str = ', '.join([f"{' '.join(map(str, v))}" for v in lag_params.items()])
  mod_params_str = ', '.join([f"{' '.join(map(str, v))}" for v in mod_params.items()])
  plot_title = pre_str + lag_params_str + '\n' + mod_params_str

  return plot_title


def build_two_lgbm_models(mod_params, data_params, train, valid):
  '''lgbm only for now ...'''

  series, past_cov, fut_cov = get_darts_series(train.loc['2016-01-12':,], data_params)
  val_ser, val_past_cov, val_fut_cov = get_darts_series(valid, data_params)

  # add_encoders3 = {'cyclic': {'future': ['minute', 'hour', 'dayofyear']}}
  #                           #'past':   ['minute', 'hour', 'dayofyear']}}
  # model_tr1 = LightGBMModel(**mod_params, add_encoders=add_encoders3)
  # model1    = LightGBMModel(**mod_params, add_encoders=add_encoders3)
  model_tr1 = LightGBMModel(**mod_params)
  model1    = LightGBMModel(**mod_params)

  if data_params['fut_cov_cols'] is not None:
    model_tr1.fit(series,
                  past_covariates = past_cov,
                  future_covariates = fut_cov,
                  val_series = series,
                  val_past_covariates = past_cov,
                  val_future_covariates = fut_cov,
                  callbacks = [lgb.log_evaluation(0)]
                 )
  else:
    model_tr1.fit(series,
                  past_covariates = past_cov,
                  val_series = series,
                  val_past_covariates = past_cov,
                  callbacks = [lgb.log_evaluation(0)]
                 )

  if data_params['fut_cov_cols'] is not None:
    model1.fit(series,
               past_covariates = past_cov,
               future_covariates = fut_cov,
               val_series = val_ser,
               val_past_covariates = val_past_cov,
               val_future_covariates = val_fut_cov,
               callbacks = [lgb.log_evaluation(0)]
              )
  else:
    model1.fit(series,
               past_covariates = past_cov,
               val_series = val_ser,
               val_past_covariates = val_past_cov,
               callbacks = [lgb.log_evaluation(0)]
              )

  return model_tr1, model1


def drop_correlated_cols(dataset, threshold=0.95):
  '''Adapted from https://stackoverflow.com/a/44674459/100129'''

  col_corr = set()  # Set of all the names of deleted columns
  corr_matrix = dataset.corr(numeric_only=True).abs()

  for i in range(len(corr_matrix.columns)):
    for j in range(i):
      if (corr_matrix.iloc[i, j] >= threshold) and (corr_matrix.columns[j] not in col_corr):
        colname = corr_matrix.columns[i]
        col_corr.add(colname)
        if colname in dataset.columns:
          del dataset[colname]

  return dataset


def plot_observation_examples(df, cols, num_plots = 9):
    """Plot 9 sets of observations in 3 * 3 matrix"""

    num_plots_sqrt = int(np.sqrt(num_plots))
    assert num_plots_sqrt ** 2 == num_plots

    days = df.ds.dt.date.sample(n = num_plots).sort_values()
    p_data = [df[df.ds.dt.date.eq(days[i])] for i in range(num_plots)]

    fig, axs = plt.subplots(num_plots_sqrt, num_plots_sqrt, figsize = (15, 10))
    axs = axs.ravel()  # apl for the win :-)

    for i in range(num_plots):
        for col in cols:
            axs[i].plot(p_data[i]['ds'], p_data[i][col])
            axs[i].xaxis.set_tick_params(rotation = 20, labelsize = 10)

    fig.suptitle('Observation examples')
    fig.legend(cols, loc = 'lower center',  ncol = len(cols))

    return None


# TODO Change to operate on single dataframe - More useful function :-)
#      Change as far as possible - merge(), common_cols etc
#      Then write a wrapper to operate on before and after dataframes
#      combine results and calculate differences
def sanity_check_df_rows_cols_labels(before, after,
                                     row_var_cutoff=0.005, col_var_cutoff=0.05,
                                     col_corr_cutoff=0.,
                                     fast=True, verbose=False):
  '''Sanity check dataframes before and after modifications

  WARN: default row_var_cutoff, col_var_cutoff, col_corr_cutoff are fairly arbitrary
        there is some redundancy between these tests

  '''

  print_v = print if verbose else lambda *a, **k: None

  df = pd.DataFrame(columns = ['before', 'after', 'diff'])
  df_labels = []

  label = 'rows'
  # start_time = timeit.default_timer()
  i = 0
  df.loc[len(df), df.columns] = before.shape[i], after.shape[i], 0
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'cols'
  # start_time = timeit.default_timer()
  i = 1
  df.loc[len(df), df.columns] = before.shape[i], after.shape[i], 0
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'missing_rows'
  # start_time = timeit.default_timer()
  i = 0
  before_after = pd.merge(before, after, left_index=True, right_index=True, how='outer', indicator=True)
  missing_rows = before_after.loc[before_after['_merge'] == 'left_only', :]
  df.loc[len(df), df.columns] = 0, missing_rows.shape[i], 0
  if missing_rows.shape[i] > 0:
    print_v('\n', label, ':')
    print_v(missing_rows)
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'missing_cols'
  # start_time = timeit.default_timer()
  i = 1
  common_cols = before.columns.intersection(after.columns)
  missing_cols = before.shape[i] - len(common_cols)
  df.loc[len(df), df.columns] = 0, missing_cols, 0
  if missing_cols > 0:
    print_v('\n', label, ':')
    print_v(set(before.columns) - set(common_cols))
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'total_nas'
  # start_time = timeit.default_timer()
  df.loc[len(df), df.columns] = before.isna().sum().sum(), \
                                after.isna().sum().sum(), 0
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'rows_with_nas'
  # start_time = timeit.default_timer()
  before_rows_nas = before.isnull().any(axis=1).sum()
  after_rows_nas  = after.isnull().any(axis=1).sum()
  df.loc[len(df), df.columns] = before_rows_nas, after_rows_nas, 0
  if before_rows_nas != after_rows_nas:
    print_v('\n', label, ':')
    print_v(before[before.isnull().any(axis=1)])
    print_v(after[after.isnull().any(axis=1)])
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'cols_with_nas'
  # start_time = timeit.default_timer()
  before_cols_nas = before.isnull().any().sum()
  after_cols_nas  = after.isnull().any().sum()
  df.loc[len(df), df.columns] = before_cols_nas, after_cols_nas, 0
  if before_cols_nas != after_cols_nas:
    print_v('\n', label, ':')
    print_v(before.isnull().any().index.values)
    print_v(after.isnull().any().index.values)
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'single_value_rows'
  if not fast:
    # start_time = timeit.default_timer()
    before_single_value_rows = np.sum(before.nunique(axis=1) <= 1)
    after_single_value_rows  = np.sum(after.nunique(axis=1) <= 1)
    df.loc[len(df), df.columns] = before_single_value_rows, \
                                  after_single_value_rows, 0
    if before_single_value_rows != after_single_value_rows:
      print_v('\n', label, ':')
      print_v(before[before.nunique(axis=1) <= 1])
      print_v(after[after.nunique(axis=1) <= 1])
    df_labels.append(label)
    # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'single_value_cols'
  # start_time = timeit.default_timer()
  before_single_value_cols = np.sum(before.nunique() <= 1)
  after_single_value_cols  = np.sum(after.nunique() <= 1)
  df.loc[len(df), df.columns] = before_single_value_cols, \
                                after_single_value_cols, 0
  if before_single_value_cols != after_single_value_cols:
    print_v('\n', label, ':')
    print_v(before.columns[before.nunique() <= 1].values)
    print_v(after.columns[after.nunique() <= 1].values)
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  # warnings.resetwarnings()

  with warnings.catch_warnings():
    warnings.simplefilter('ignore')

    label = 'low_var_rows'
    # start_time = timeit.default_timer()
    before_low_var_rows = (before.select_dtypes(include=[np.number]).std(axis=1) <= row_var_cutoff).sum()
    after_low_var_rows  = (after.select_dtypes(include=[np.number]).std(axis=1) <= row_var_cutoff).sum()
    df.loc[len(df), df.columns] = before_low_var_rows, after_low_var_rows, 0
    if before_low_var_rows != after_low_var_rows:
      print_v('\n', label, ':')
      print_v(before.select_dtypes(include=[np.number]).std(axis=1) <= row_var_cutoff)
      print_v(after.select_dtypes(include=[np.number]).std(axis=1)  <= row_var_cutoff)
    df_labels.append(label)
    # print('\t', label, round(timeit.default_timer() - start_time, 2))

    label = 'low_var_cols'
    # start_time = timeit.default_timer()
    before_low_var_cols = (before.select_dtypes(include=[np.number]).std() <= col_var_cutoff).sum()
    after_low_var_cols  = (after.select_dtypes(include=[np.number]).std() <= col_var_cutoff).sum()
    df.loc[len(df), df.columns] = before_low_var_cols, after_low_var_cols, 0
    if before_low_var_cols != after_low_var_cols:
      print_v('\n', label, ':')
      s = before.select_dtypes(include=[np.number]).std() <= col_var_cutoff
      t = after.select_dtypes(include=[np.number]).std()  <= col_var_cutoff
      print_v(s[s].index.values)
      print_v(t[t].index.values)
    df_labels.append(label)
    # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'duplicate_rows'
  # start_time = timeit.default_timer()
  before_dup_rows = before.shape[0] - before.drop_duplicates().shape[0]
  after_dup_rows  = after.shape[0]  - after.drop_duplicates().shape[0]
  df.loc[len(df), df.columns] = before_dup_rows, after_dup_rows, 0
  if before_dup_rows != after_dup_rows:
    print_v('\n', label, ':')
    print_v(before[before.duplicated(keep=False)])
    print_v(after[after.duplicated(keep=False)])
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'highly_correlated_cols'
  # .copy() so we don't modify the original dataframe
  if not fast:
    # start_time = timeit.default_timer()
    before_high_corr_cols = before.shape[1] - drop_correlated_cols(before.copy(), col_corr_cutoff).shape[1]
    after_high_corr_cols  = after.shape[1]  - drop_correlated_cols(after.copy(), col_corr_cutoff).shape[1]
    df.loc[len(df), df.columns] = before_high_corr_cols, after_high_corr_cols, 0
    if before_high_corr_cols != after_high_corr_cols:
      print_v('\n', label, ':')
      print_v(set(before.columns) - set(drop_correlated_cols(before.copy(), col_corr_cutoff).columns))
      print_v(set(after.columns) - set(drop_correlated_cols(after.copy(), col_corr_cutoff).columns))
    df_labels.append(label)
    # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'duplicate_index_labels'
  # start_time = timeit.default_timer()
  before_idx_labels = before.index.duplicated().sum()
  after_idx_labels  = after.index.duplicated().sum()
  df.loc[len(df), df.columns] = before_idx_labels, after_idx_labels, 0
  if before_idx_labels != after_idx_labels:
    print_v('\n', label, ':')
    print_v(before.index.duplicated())
    print_v(after.index.duplicated())
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  label = 'duplicate_col_labels'
  # start_time = timeit.default_timer()
  before_dup_col_labels = before.columns.duplicated().sum()
  after_dup_col_labels  = after.columns.duplicated().sum()
  df.loc[len(df), df.columns] = before_dup_col_labels, after_dup_col_labels, 0
  if before_dup_col_labels != after_dup_col_labels:
    print_v('\n', label, ':')
    print_v(before.columns.duplicated())
    print_v(after.columns.duplicated())
  df_labels.append(label)
  # print('\t', label, round(timeit.default_timer() - start_time, 2))

  # TODO Find renamed columns from before in after?


  df['diff'] = df['after'] - df['before']
  df.index = df_labels

  return df


def sanity_check_before_after_dfs(before_, after_, ds_name, fast=True, verbose=False):

  print('\n', ds_name, sep='')

  # Reasons I HATE pandas number Inf a neverending series:
  # PerformanceWarning: DataFrame is highly fragmented.  This is usually the
  # result of calling `frame.insert` many times, which has poor performance.
  # Consider joining all columns at once using pd.concat(axis=1) instead. To
  # get a de-fragmented frame, use `newframe = frame.copy()`
  before = before_.copy()
  after  = after_.copy()

  # start_time = timeit.default_timer()
  sanity_df = sanity_check_df_rows_cols_labels(before, after, fast=fast, verbose=verbose)
  # print('\t sanity_check_df_rows_cols_labels', round(timeit.default_timer() - start_time, 2))


  # start_time = timeit.default_timer()
  print('before.index.equals(after.index):', before.index.equals(after.index))

  # check index freq is set and are equal
  print('before.index.freq == after.index.freq:', before.index.freq == after.index.freq)
  if verbose:
    print('before.index.freq:', before.index.freq)
    print('after.index.freq:',  after.index.freq)


  # check if common column dtypes have changed
  common_cols = before.columns.intersection(after.columns)
  print('before[common_cols].dtypes == after[common_cols].dtypes:',
        (before[common_cols].dtypes == after[common_cols].dtypes).all())
  if verbose:
    print('before[common_cols].dtypes:', before[common_cols].dtypes)
    print('after[common_cols].dtypes:',  after[common_cols].dtypes)

  # check if describe() summaries are equal
  print('before[common_cols].describe() == after[common_cols].describe():',
        (before[common_cols].describe() == after[common_cols].describe()).all().all())
  if verbose:
    print(before[common_cols].describe() == after[common_cols].describe())

  # check after subsetted by before equals before
  print('\nbefore[common_cols].equals(after[common_cols]):',
  before[common_cols].dropna().drop_duplicates().equals(after[common_cols].dropna().drop_duplicates())
  )
  if verbose:
    print('before.isin(after):',
    before[common_cols].dropna().drop_duplicates().isin(after[common_cols].dropna().drop_duplicates()).all().all()
    )
    print(before.dropna().drop_duplicates().isin(after.dropna().drop_duplicates()).all())
    print(before.dropna().drop_duplicates().isin(after.dropna().drop_duplicates()))


  # Reasons I HATE pandas number Inf a neverending series:
  # PerformanceWarning: DataFrame is highly fragmented.  This is usually the
  # result of calling `frame.insert` many times, which has poor performance.
  # Consider joining all columns at once using pd.concat(axis=1) instead. To
  # get a de-fragmented frame, use `newframe = frame.copy()`
  # calculate duplicate row counts then find mean duplicate count
  # for each column and finally find mean of means aka redundancy
  # warnings.resetwarnings()
  with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    before_red = before.dropna().groupby(before.select_dtypes(include=np.number).columns.tolist(), as_index=False).size().mean().mean()
    after_red  = after.dropna().groupby(after.select_dtypes(include=np.number).columns.tolist(), as_index=False).size().mean().mean()
    print('redundancy before > after:', before_red > after_red)
    print('mean before feature redundancy:', round(before_red, 3))
    print('mean after feature redundancy: ', round(after_red,  3))

  # Check all data is numeric, finite (but allow NAs) and reasonably shaped
  # If any problems then this will error out
  # Only checking 'after' dataframe
  # https://scikit-learn.org/stable/modules/generated/sklearn.utils.check_X_y.html
  if Y_COL in after.columns:
    _, _ = check_X_y(after.drop(columns=[Y_COL, 'ds']),
                     after[Y_COL],
                     y_numeric = True,
                     force_all_finite = 'allow-nan')

  print()
  # print('\t end sanity_check_before_after_dfs', round(timeit.default_timer() - start_time, 2))

  display(sanity_df)

  return sanity_df


def compare_train_valid_test_sanity_dfs(train_sanity, valid_sanity, test_sanity, ex_labels=None):
  '''...'''

  if ex_labels is None:
    ex_labels = ['rows']

  train_sanity = train_sanity.loc[~train_sanity.index.isin(ex_labels)]
  valid_sanity = valid_sanity.loc[~valid_sanity.index.isin(ex_labels)]
  test_sanity  =  test_sanity.loc[~test_sanity.index.isin(ex_labels)]

  if not train_sanity.equals(valid_sanity):
    print('WARN: train_sanity != valid_sanity')
    display(pd.concat([train_sanity, valid_sanity]).drop_duplicates(keep=False))

  if not train_sanity.equals(test_sanity):
    print('WARN: train_sanity != test_sanity')
    display(pd.concat([train_sanity, test_sanity]).drop_duplicates(keep=False))

  if not test_sanity.equals(valid_sanity):
    print('WARN: test_sanity != valid_sanity')
    display(pd.concat([test_sanity, valid_sanity]).drop_duplicates(keep=False))

  return None


# TODO Remove some of the code duplication
def sanity_check_train_valid_test(train_df, valid_df, test_df,
                                  over_cols = ['y_des', 'dew.point_des', 'humidity', 'pressure'],
                                  dp = 2):

  # Check number of columns is equal
  if (train_df.shape[1] != valid_df.shape[1]) or \
     (train_df.shape[1] != test_df.shape[1])  or \
     (valid_df.shape[1] != test_df.shape[1]):
    print('ERROR: Inconsistent number of columns!')
    print('train_df.shape[1]:', train_df.shape[1])
    print('valid_df.shape[1]:', valid_df.shape[1])
    print('test_df.shape[1]:',  test_df.shape[1])


  # Check column names are equal
  if not (train_df.columns == valid_df.columns).all():
    print('ERROR: Inconsistent train_df, valid_df column names!')
    print('train_df.columns:', train_df.columns)
    print('valid_df.columns:', valid_df.columns)

  if not (train_df.columns == test_df.columns).all():
    print('ERROR: Inconsistent train_df, test_df column names!')
    print('train_df.columns:', train_df.columns)
    print('test_df.columns:',  test_df.columns)

  if not (valid_df.columns == test_df.columns).all():
    print('ERROR: Inconsistent valid_df, test_df column names!')
    print('valid_df.columns:', valid_df.columns)
    print('test_df.columns:',  test_df.columns)


  # Check column dtypes are equal
  if not (train_df.dtypes == valid_df.dtypes).all():
    print('ERROR: Inconsistent train_df, valid_df dtypes!')
    print('train_df.dtypes:', train_df.dtypes)
    print('valid_df.dtypes:', valid_df.dtypes)

  if not (train_df.dtypes == test_df.dtypes).all():
    print('ERROR: Inconsistent train_df, test_df dtypes!')
    print('train_df.dtypes:', train_df.dtypes)
    print('test_df.dtypes:',  test_df.dtypes)

  if not (valid_df.dtypes == test_df.dtypes).all():
    print('ERROR: Inconsistent valid_df, test_df dtypes!')
    print('valid_df.dtypes:', valid_df.dtypes)
    print('test_df.dtypes:',  test_df.dtypes)


  # Check index freqs are equal
  if train_df.index.freq != valid_df.index.freq:
    print('ERROR: Inconsistent train_df, valid_df index frequencies!')
    print('train_df.index.freq:', train_df.index.freq)
    print('valid_df.index.freq:', valid_df.index.freq)

  if train_df.index.freq != test_df.index.freq:
    print('ERROR: Inconsistent train_df, test_df index frequencies!')
    print('train_df.index.freq:', train_df.index.freq)
    print('test_df.index.freq:',   test_df.index.freq)

  if valid_df.index.freq != test_df.index.freq:
    print('ERROR: Inconsistent valid_df, test_df index frequencies!')
    print('valid_df.index.freq:', valid_df.index.freq)
    print('test_df.index.freq:',   test_df.index.freq)


  # Verify dataframes are different!
  if train_df.equals(valid_df):
    print('ERROR: train_df == valid_df!')

  if train_df.equals(test_df):
    print('ERROR: train_df == test_df!')

  if valid_df.equals(test_df):
    print('ERROR: valid_df == test_df!')


  # Check no overlap between train_df.index and valid_df.index
  # train_df.index strictly before valid_df.index and test_df.index
  if max(train_df.index) >= min(valid_df.index):
    print('ERROR: Overlap between train_df, valid_df indices!')
    print('max(train_df.index):', max(train_df.index))
    print('min(valid_df.index):', max(valid_df.index))

  # Check no overlap between train_df.index and test_df.index
  # train_df.index strictly before valid_df.index and test_df.index
  if max(train_df.index) >= min(test_df.index):
    print('ERROR: Overlap between train_df, test_df indices!')
    print('max(train_df.index):', max(train_df.index))
    print('min(test_df.index):',  max(test_df.index))


  # Check no overlap between valid_df.index and test_df.index
  # valid_df.index can be before or after test_df.index
  if (max(valid_df.index) >= min(test_df.index)) and \
     (max(valid_df.index) <= max(test_df.index)):
    print('ERROR: Overlap between valid_df, test_df indices!')
    print('valid_df.index:', max(valid_df.index), '-', max(valid_df.index))
    print('test_df.index:',  max(test_df.index),  '-', max(test_df.index))

  if (min(valid_df.index) >= min(test_df.index)) and \
     (min(valid_df.index) <= max(test_df.index)):
    print('ERROR: Overlap between valid_df, test_df indices!')
    print('valid_df.index:', max(valid_df.index), '-', max(valid_df.index))
    print('test_df.index:',  max(test_df.index),  '-', max(test_df.index))


  # TODO: Consider enforcing a gap of 1 day to 1 week between
  #       train_df.index and {valid_df,test_df}.index to avoid data leakage?


  # Check train_df has more observations than valid_df and test_df
  if valid_df.shape[0] > train_df.shape[0]:
    print('ERROR: valid_df more observations than train_df!')
    print('train_df observations:', train_df.shape[0])
    print('valid_df observations:', valid_df.shape[0])

  if test_df.shape[0] > train_df.shape[0]:
    print('ERROR: test_df more observations than train_df!')
    print('train_df observations:', train_df.shape[0])
    print('test_df observations:',  test_df.shape[0])


  # Check valid_df and test_df have equal number of observations
  # valid_df and test_df may be different sizes but
  # large size difference may indicate an issue
  # TODO: Use calendar.isleap() to check if leap year
  if valid_df.shape[0] != test_df.shape[0]:
    print('WARN: Inconsistent number of valid_df, test_df rows.  Leap year?')


  # Check valid_df and test_df are each 1 year long
  YEAR_OBS_MIN = 48 * 365
  YEAR_OBS_MAX = 48 * 366
  if (valid_df.shape[0] < YEAR_OBS_MIN) or \
     (valid_df.shape[0] > YEAR_OBS_MAX):
    print('ERROR: valid_df should be 1 year long [',
          YEAR_OBS_MIN, ',', YEAR_OBS_MAX, ']!')
    print('valid_df observations:', valid_df.shape[0])

  if (test_df.shape[0] < YEAR_OBS_MIN) or \
     (test_df.shape[0] > YEAR_OBS_MAX):
    print('ERROR: test_df should be 1 year long [',
          YEAR_OBS_MIN, ',', YEAR_OBS_MAX, ']!')
    print('test_df observations:', test_df.shape[0])

  # Check approx number of overlapping rows between train_df and valid_df
  dups_pc_lim = 15.0
  n_dups, dups_pc = get_approx_overlap(train_df, valid_df, over_cols, decs=dp)
  if dups_pc > dups_pc_lim:
    print('WARN: high overlap between train_df and valid_df rows!')
    print(f"Number of shared rows: {n_dups}")
    print(f'Approximate overlap: {dups_pc} %\n')
    # print(f'Decimal places: {dp}')
    # print('Overlap features:', over_cols)

  # Check approx number of overlapping rows between train_df and test_df
  n_dups, dups_pc = get_approx_overlap(train_df, test_df, over_cols, decs=dp)
  if dups_pc > dups_pc_lim:
    print('WARN: high overlap between train_df and test_df rows!')
    print(f"Number of shared rows: {n_dups}")
    print(f'Approximate overlap: {dups_pc} %\n')
    # print(f'Decimal places: {dp}')
    # print('Overlap features:', over_cols)

  return None


def print_train_valid_test_shapes(df, train_df, valid_df, test_df):
  print("df shape: ",            df.shape)
  print("train shape:   ", train_df.shape)
  print("valid shape:   ", valid_df.shape)
  print("test shape:    ",  test_df.shape)

  return None


def plot_feature_history_single_df(data, var, missing=False):
    plt.figure(figsize = (12, 6))
    plt.scatter(data.index, data[var],
                label='train', color='black', s=3)
    if missing:
      label = 'missing'
      x_lab = data.loc[data[label] == 1.0, 'ds']
      y_lab = data.loc[data[label] == 1.0, var]
      plt.scatter(x_lab, y_lab, color='red', label=label, s=3)

    plt.title(var)
    plt.show()


def plot_feature_history(train, valid, test, var, missing=False):
    label = 'missing'

    plt.figure(figsize = (12, 6))
    plt.scatter(train.index, train[var],
                label='train', color='black', s=3)
    if missing:
      x_lab = train.loc[train[label] == 1.0, 'ds']
      y_lab = train.loc[train[label] == 1.0, var]
      plt.scatter(x_lab, y_lab, color='red', label=label, s=3)

    plt.scatter(valid.index, valid[var],
                label='valid', color='blue', s=3)
    if missing:
      x_lab = valid.loc[valid[label] == 1.0, 'ds']
      y_lab = valid.loc[valid[label] == 1.0, var]
      plt.scatter(x_lab, y_lab, color='red', label=label, s=3)

    plt.scatter(test.index,  test[var],
                label='test', color='purple', s=3)
    if missing:
      x_lab = test.loc[test[label] == 1.0, 'ds']
      y_lab = test.loc[test[label] == 1.0, var]
      plt.scatter(x_lab, y_lab, color='red', label=label, s=3)

    plt.title(var)
    #ax = plt.gca()
    #leg = ax.get_legend()
    #leg.legendHandles[0].set_color('black')
    #leg.legendHandles[1].set_color('red')
    #leg.legendHandles[2].set_color('blue')
    #leg.legendHandles[3].set_color('red')
    #leg.legendHandles[4].set_color('purple')
    #leg.legendHandles[5].set_color('red')
    #hl_dict = {handle.get_label(): handle for handle in leg.legendHandles}
    #hl_dict['train'].set_color('black')
    #hl_dict['valid'].set_color('blue')
    #hl_dict['test'].set_color('purple')
    #hl_dict[label].set_color('red')
    #plt.legend(['train', 'valid', 'test', label])
    plt.show()


def plot_feature_history_separately(train, valid, test, var):
    fig, axs = plt.subplots(1, 3, figsize = (14, 7))

    axs[0].plot(train.index, train[var])
    axs[0].set_title('train')

    axs[1].plot(valid.index, valid[var])
    axs[1].set_title('valid')
    axs[1].set_xticks(axs[1].get_xticks(), axs[1].get_xticklabels(), rotation=45, ha='right')

    axs[2].plot(test.index,  test[var])
    axs[2].set_title('test')
    axs[2].set_xticks(axs[2].get_xticks(), axs[2].get_xticklabels(), rotation=45, ha='right')

    fig.suptitle(var)
    plt.show()


def check_high_low_thresholds(df, ds=None):
  '''Check main features from dataframe are within reasonable thresholds'''

  all_ok = True
  feats = ['y', 'dew.point', 'humidity', 'pressure',
           'wind.speed.mean', 'wind.speed.max']
  highs = [ 45,  25, 100, 1060, 35, 70]
  lows  = [-20, -20,   5,  950,  0,  0]

  thresh = pd.DataFrame({'feat': feats,
                         'high': highs,
                         'low':  lows,})
  thresh.index = feats

  for feat in feats:
    feat_high = thresh.loc[feat, 'high']
    feat_low  = thresh.loc[feat, 'low']

    if not df[feat].between(feat_low, feat_high).all():
      all_ok = False
      print('%15s [%3d, %3d] - % 7.3f, % 7.3f' %
            (feat, feat_low, feat_high,
            round(min(df[feat]), 3), round(max(df[feat]), 3)))

  # check if dew.point ever greater than temperature
  if df.loc[df['dew.point'] > df['y'], ['y', 'dew.point']].shape[0] != 0:
    all_ok = False
    print('dew.point > y:')
    display(df.loc[df['dew.point'] > df['y'], ['y', 'dew.point']])

  if all_ok is False:
    print(' ... from', ds)

  return None


def get_features_filename(feat_name, data_name, date_str, file_ext='.csv.xz'):
    return feat_name + data_name + date_str + file_ext


def merge_data_and_aggs(data, aggs):
  data = pd.concat((data, aggs), axis=1)
  # data = data.join(aggs)

  # data.set_index('ds', drop = False, inplace = True)
  data['ds'] = data.index
  data = data[~data.index.duplicated(keep = 'first')]
  data = data.asfreq(freq = '30min')

  # Reasons I HATE pandas number Inf a neverending series:
  # PerformanceWarning: DataFrame is highly fragmented.  This is usually the
  # result of calling `frame.insert` many times, which has poor performance.
  # Consider joining all columns at once using pd.concat(axis=1) instead. To
  # get a de-fragmented frame, use `newframe = frame.copy()`
  # data_ = data.copy()

  return data


def print_null_columns(df, df_name):
  print('\n', df_name, 'null columns:')
  display(df[df.columns[df.isnull().any()]].isnull().sum())


def print_na_locations(df):
  '''Print index row and column labels for NA in dataframe'''

  for index, row in df[df.isna().any(axis=1)].items():
    for col_name, row_item in row.items():
      if pd.isnull(df.loc[index, col_name]):
        print(index, col_name)

  return None


def get_darts_series(data, data_params):
  series   = TimeSeries.from_dataframe(data, value_cols=data_params['y_col'])
  past_cov = TimeSeries.from_dataframe(data, value_cols=data_params['past_cov_cols'])

  if data_params['fut_cov_cols'] is not None:
    fut_cov = TimeSeries.from_dataframe(data, value_cols=data_params['fut_cov_cols'])
  else:
    fut_cov = None

  return series, past_cov, fut_cov


def plot_short_term_acf(data, acf_feats, acf_cols, title_,
                        mean_feat = False, max_lags = 300):
  plt.figure(figsize = (12, 6))

  acf = pd.DataFrame()

  for acf_feat, acf_col in zip(acf_feats, acf_cols):
    acf[acf_feat] = [data[acf_feat].autocorr(l) for l in range(1, max_lags)]
    plt.plot(acf[acf_feat], label=acf_feat, c=acf_col)

  if mean_feat:
    acf['mean_acf'] = acf.mean(axis=1)
    plt.plot(acf['mean_acf'], label='mean_acf', c='black')

  plt.axhline(0, linestyle='--', c='black')
  plt.axhline(0.875, linestyle=':', c='lightgrey')
  plt.ylabel('autocorrelation')
  plt.xlabel('time lags')
  plt.title(title_)
  plt.legend()
  plt.show()


def plot_long_term_acf(data, var, num_years=3):
  # WARN: Slow function :-(
  #       Results are more useful when displaying more years of data
  pd.plotting.autocorrelation_plot(data[var].head(17532 * num_years))
  plt.title(var)
  plt.show()


def print_df_summary(df):
  print("Shape:")
  display(df.shape)

  total_nas = df.isna().sum().sum()
  rows_nas  = df.isnull().any(axis=1).sum()
  cols_nas  = df.isnull().any().sum()
  print('\nTotal NAs:', total_nas)
  print('Rows with NAs:', rows_nas)
  print('Cols with NAs:', cols_nas)

  print("\nInfo:")
  display(df.info())

  print("\nSummary stats:")
  display(df.describe())

  print("\nRaw data:")
  display(df)
  print("\n")


def get_approx_overlap(X1, X2, over_cols, decs=2, verbose=False):
  '''Calculate approximate overlap between 2 dataframes of different sizes.

  If exact values are used then overlap is probably too low,
  so use np.round() to reduce precision.
  Use MinMaxScaler so single decimals parameter is applicable to all columns.
  Assumes X1 is train and X2 is valid/test.
  Duplicates dropped from X1 & X2 before calculating overlap.
  Percent overlap can be greater than 100 if decs is too low.

  Based on https://stackoverflow.com/a/71002234/100129
  '''

  assert X1.shape[0] >= X2.shape[0]

  X1 = X1[over_cols].drop_duplicates()
  X2 = X2[over_cols].drop_duplicates()

  Xcomb = pd.concat((X1, X2), axis=0, ignore_index=True)

  # scale
  scaler = MinMaxScaler()
  Xscl = scaler.fit_transform(Xcomb)

  # round
  # df_scl = pd.DataFrame(np.round(Xcomb, decimals=decs), columns=over_cols)
  df_scl = pd.DataFrame(np.round(Xscl, decimals=decs), columns=over_cols)

  # count overlaps
  n_uniq = df_scl.drop_duplicates().shape[0]
  n_dup = X1.shape[0] + X2.shape[0] - n_uniq
  dup_pc = round(n_dup * 100 / X2.shape[0], 2)

  if verbose:
    print(f"Number of shared rows: {n_dup}")
    print(f'Approximate overlap: {dup_pc} %\n')

  if dup_pc > 100.0:
    print('Approx. overlap over 100 %!')
    print('Increase decs argument')
    print(f"decs = {decs}\n")

  return n_dup, dup_pc


def add_transform_column(mod_vi_df, us_feats):
  mod_vi_df['transform'] = 'None'

  ne_mask = mod_vi_df['feature_transform'] != mod_vi_df['feature']
  ne_fs = mod_vi_df.loc[ne_mask, 'feature'].to_list()
  ne_fts = mod_vi_df.loc[ne_mask, 'feature_transform'].to_list()
  # display(mod_vi_df.loc[ne_mask, ['feature', 'feature_transform']])
  # print('ne_fs:', ne_fs)
  # print('ne_fts:', ne_fts)

  ne_ts = [ne_ft.replace(ne_fs[i] + '_', '') for i, ne_ft in enumerate(ne_fts)]

  mod_vi_df.loc[ne_mask, 'transform'] = ne_ts

  for us_feat in us_feats:
    if us_feat.endswith('_des'):
      # 'y_des', 'dew.point_des'
      mod_vi_df.loc[mod_vi_df['feature_transform'] == us_feat, 'transform'] = 'des'
    else:
      # 't_pot', 'vp_def', 'air_density', 'ground_hf', 'za_rad', ...
      mod_vi_df.loc[mod_vi_df['feature'] == us_feat, 'transform'] = 'None'

  wind_mask = mod_vi_df['feature_transform_lag'].str.contains('_window_')
  mod_vi_df.loc[wind_mask, 'transform'] = mod_vi_df.loc[wind_mask, 'feature_transform_lag'].str.extract(r'_window_\d+_(.*)_[target|pastcov|futcov]', expand=False)

  return mod_vi_df


def add_feature_column(mod_vi_df, us_feats):
  mod_vi_df['feature'] = ''

  # extract features which DO NOT contain underscores
  # eg. irradiance, pressure ...
  useq0_mask = mod_vi_df['feature_transform'].str.count('_') == 0
  mod_vi_df.loc[useq0_mask, 'feature'] = mod_vi_df.loc[useq0_mask, 'feature_transform']
  # print('useq0:')
  # display(mod_vi_df.loc[useq0_mask, 'feature'])


  # extract features which DO contain underscores
  # cannot distinguish between y_des and pressure_grad
  useq1_mask = mod_vi_df['feature_transform'].str.count('_') == 1
  # pandas, for clowns by clowns - expand=False your arse (at least an hour wasted)
  mod_vi_df.loc[useq1_mask, 'feature'] = mod_vi_df.loc[useq1_mask, 'feature_transform'].str.extract(r'^(.*)_', expand=False)
  # print('useq1:')
  # display(mod_vi_df.loc[useq1_mask, 'feature'])

  # cannot distinguish between y_des_hist_mode and pressure_intervals_intervals_mean
  usgt1_mask = mod_vi_df['feature_transform'].str.count('_') > 1
  # pandas, for clowns by clowns - expand=False your arse (at least an hour wasted)
  mod_vi_df.loc[usgt1_mask, 'feature'] = mod_vi_df.loc[usgt1_mask, 'feature_transform'].str.extract(r'^([a-zA-Z0-9\.]+_[a-zA-Z0-9\.]+)_', expand=False)
  # print('usgt1:')
  # display(mod_vi_df.loc[usgt1_mask, 'feature'])

  # try and fix cannot distinguish problems mentioned above
  for us_feat in us_feats:
    mod_vi_df.loc[mod_vi_df['feature_transform'] == us_feat, 'feature'] = us_feat

  wind_mask = mod_vi_df['feature_transform_lag'].str.contains('_window_')
  mod_vi_df.loc[wind_mask, 'feature'] = mod_vi_df.loc[wind_mask, 'feature_transform_lag'].str.extract(r'^(.*?)_window_\d+', expand=False)

  return mod_vi_df


def add_shadow_imp_column(mod_vi_df):
  shad_imp = mod_vi_df.loc[mod_vi_df['feature'] == 'y_des_shadow', ['feature', 'model', 'lag', 'imp']]
  shad_imp = shad_imp.rename(columns={'imp': 'shadow_imp'})

  mod_vi_si = pd.merge(mod_vi_df, shad_imp, how='outer', left_on=['model', 'lag'], right_on=['model', 'lag'])
  mod_vi_si.drop('feature_y', axis=1, inplace=True)
  mod_vi_si.rename(columns={'feature_x': 'feature'}, inplace=True)
  mod_vi_si['shadow_imp'] = mod_vi_si['shadow_imp'].fillna(0)
  mod_vi_si['shadow_imp'] = mod_vi_si['shadow_imp'].astype('int')

  return mod_vi_si


def add_feature_window_transform_column(mod_vi_df):
  mod_vi_df['feature_window_transform'] = mod_vi_df['feature_transform_lag']
  with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    mod_vi_df['feature_window_transform'] = mod_vi_df['feature_window_transform'].str.replace('_lag.*', '', regex=True)
    mod_vi_df['feature_window_transform'] = mod_vi_df['feature_window_transform'].str.replace('_(target|pastcov|futcov)', '', regex=True)
    mod_vi_df['feature_window_transform'] = mod_vi_df['feature_window_transform'].str.replace('_shift_\d+', '', regex=True)

  return mod_vi_df


def add_feature_transform_column(mod_vi_df):
  mod_vi_df['feature_transform'] = mod_vi_df['feature_window_transform']
  with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    mod_vi_df['feature_transform'] = mod_vi_df['feature_transform'].str.replace('_window_\d+', '', regex=True)

  # print('feature_transform:')
  # display(mod_vi_df.loc[mod_vi_df['feature_transform'].isna(), 'feature_transform'])

  return mod_vi_df


def get_multi_model_feat_imps(mmodel, horizon=HORIZON, verbose=False):
  mod_vi_list = []

  for i in range(horizon):
    mod_vi = pd.DataFrame({
        'model': i,
        'feature_transform_lag': mmodel.lagged_feature_names,
        'imp': mmodel.get_multioutput_estimator(i, 0).feature_importances_})
    mod_vi_list.append(mod_vi)

  mod_vi_df = pd.concat(mod_vi_list)
  mod_vi_df.reset_index(drop=True, inplace=True)

  # should have been fixed upstream - my bad
  with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    mod_vi_df['feature_transform_lag'] = mod_vi_df['feature_transform_lag'].str.replace(r'^_(.*?)_\1', r'_\1', regex=True)
    # mod_vi_df['feature_transform_lag'] = mod_vi_df['feature_transform_lag'].str.replace('intervals_intervals', 'intervals')

  mod_vi_df = add_feature_window_transform_column(mod_vi_df)
  mod_vi_df = add_feature_transform_column(mod_vi_df)

  # us_feats - feature names which contain underscores
  us_feats = ['y_des', 'dew.point_des', 't_pot', 'vp_def', 'air_density',
              'ground_hf', 'za_rad', 'azimuth_cos', 'azimuth_sin',
              'rain_prev_6_hours', 'rain_prev_12_hours', 'rain_prev_24_hours',
              'rain_prev_24_hours_binary', 'rain_prev_48_hours',
              'rain_prev_48_hours_binary', 'mixing_ratio', 'specific_humidity',
              'vapour_pressure',]
  mod_vi_df = add_feature_column(mod_vi_df, us_feats)
  mod_vi_df = add_transform_column(mod_vi_df, us_feats)

  mod_vi_df['lag']  = mod_vi_df['feature_transform_lag'].str.extract('_lag(-?\d+)$').fillna(0).astype(int)
  mod_vi_df['type'] = mod_vi_df['feature_transform_lag'].str.extract('_(target|pastcov|futcov)_').fillna('')
  mod_vi_df['shift']  = mod_vi_df['feature_transform_lag'].str.extract('_shift_(\d+)_').fillna(0).astype(int)
  mod_vi_df['window'] = mod_vi_df['feature_transform_lag'].str.extract('_window_(\d+)_').fillna(0).astype(int)

  # very slow (2 or 3 mins) :-(
  mod_vi_df = get_above_between_below_feats_with_model_lag(mod_vi_df, ['dew.point_des', 'pressure', 'humidity'], 'cf')
  # mod_vi_df = get_above_between_below_feats_with_model_lag(mod_vi_df, ['irradiance', 'za_rad'], 'sf')
  # mod_vi_df = get_above_between_below_feats_with_model_lag(mod_vi_df, ['y_des_shadow'], 'shadow')

  mod_vi_si = add_shadow_imp_column(mod_vi_df)

  mod_vi_si['shadow_geq'] = 0
  mod_vi_si.loc[mod_vi_si['shadow_imp'] >= mod_vi_si['imp'], 'shadow_geq'] = 1

  if verbose:
    display(mod_vi_si)

  return mod_vi_si


# WARN Very slow - 2 or 3 mins :-(
# TODO Speed it up!
def get_above_between_below_feats_with_model_lag(fs, feats, feat_str, imp='imp'):
  # cf - core features
  # sf - solar features

  fs['above_' + feat_str] = 0
  fs['below_' + feat_str] = 0
  if len(feats) > 1:
      fs['between_' + feat_str] = 0

  # groupby model lag feature
  fs_gb = fs[['model', 'lag', 'feature', imp]].groupby(['model', 'lag', 'feature'])

  for name, groups in fs_gb:
    # print(f'{name = }')
    model, lag, feature = name
    # print(model, lag, feature)
    min_score = np.min(fs.loc[(fs['model'] == model) & (fs['lag'] == lag) & (fs['feature'].isin(feats)), imp])
    max_score = np.max(fs.loc[(fs['model'] == model) & (fs['lag'] == lag) & (fs['feature'].isin(feats)), imp])

    fs.loc[(fs['model'] == model) & (fs['lag'] == lag) & (fs[imp] > max_score), 'above_' + feat_str] = 1
    fs.loc[(fs['model'] == model) & (fs['lag'] == lag) & (fs[imp] < min_score), 'below_' + feat_str] = 1

    if len(feats) > 1:
      fs.loc[(fs['model'] == model) & (fs['lag'] == lag) & (fs[imp] >= min_score) & (fs[imp] <= max_score), 'between_' + feat_str] = 1
      # fs.loc[(fs['model'] == model) & (fs['lag'] == lag) & (fs['feature'].isin(feats)), 'between_' + feat_str] = 1

  return fs


# TODO Fix code duplication with plot_multi_model_single_feature_imp
def plot_multi_model_feat_imps(model_vi, title, imp_col='imp', ft_mean_lim=10):
  if title is None:
    title = ''

  n = 10
  display(model_vi.sort_values(imp_col).tail(n))
  display(model_vi.sort_values(imp_col).head(n))

  model_vi.boxplot(imp_col)
  plt.title(title + ' - Variable importance')
  plt.show()

  model_vi.boxplot(imp_col, by='type')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Feature types')
  plt.suptitle('')
  plt.show()

  bxpm = model_vi.boxplot(imp_col, by='model')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Models')
  plt.suptitle('')
  n = 2
  plt.setp(bxpm.get_xticklabels()[::n], visible=False)
  plt.show()

  if np.min(model_vi['lag']) < 0:
    bxpl = model_vi.loc[model_vi['lag'] < 0,].boxplot(imp_col, by='lag')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Past Lags')
    plt.suptitle('')
    plt.show()

  if np.max(model_vi['lag']) > 0:
    bxpl = model_vi.loc[model_vi['lag'] >= 0,].boxplot(imp_col, by='lag')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Future Lags')
    plt.suptitle('')
    n = 2
    plt.setp(bxpl.get_xticklabels()[::n], visible=False)
    plt.show()

  if model_vi['window'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='window')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Windows')
    plt.suptitle('')
    plt.show()

  if model_vi['shift'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='shift')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Shifts')
    plt.suptitle('')
    plt.show()

  # fi_ft = groupby_multi_model_feat_imps(model_vi, 'feature_transform')
  # print('fi_ft:')
  # display(fi_ft)
  # fts = fi_ft.loc[fi_ft['mean'] >= ft_mean_lim,].index.values.tolist()
  # print(f'{fts = }')
  #
  # model_vi.loc[model_vi['feature_transform'].isin(fts),].boxplot(imp_col, by='feature_transform')
  # plt.xticks(rotation=45, ha='right')
  # plt.ylabel(str(imp_col))
  # plt.title(title + ' - Feature transform mean >= ' + str(ft_mean_lim))
  # plt.suptitle('')
  # plt.show()

  model_vi.boxplot(imp_col, by='feature_transform')
  plt.ylabel(str(imp_col))
  plt.xticks(rotation=45, ha='right')
  plt.title(title + ' - Feature Transform')
  plt.suptitle('')
  plt.show()

  model_vi.loc[model_vi['above_cf'] == 1,].boxplot(imp_col, by='feature')
  plt.xticks(rotation=45, ha='right')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Features (above core features)')
  plt.suptitle('')
  plt.show()

  model_vi.loc[model_vi['between_cf'] == 1,].boxplot(imp_col, by='feature')
  plt.xticks(rotation=45, ha='right')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Features (between core features)')
  plt.suptitle('')
  plt.show()

  model_vi.loc[model_vi['above_cf'] == 1,].boxplot(imp_col, by='transform')
  plt.xticks(rotation=45, ha='right')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Transforms (above core features)')
  plt.suptitle('')
  plt.show()

  model_vi.loc[model_vi['between_cf'] == 1,].boxplot(imp_col, by='transform')
  plt.xticks(rotation=45, ha='right')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Transforms (between core features)')
  plt.suptitle('')
  plt.show()

  model_vi.boxplot(imp_col, by='feature')
  plt.ylabel(str(imp_col))
  plt.xticks(rotation=45, ha='right')
  plt.title(title + ' - Base Features')
  plt.suptitle('')
  plt.show()

  model_vi.boxplot(imp_col, by='transform')
  plt.ylabel(str(imp_col))
  plt.xticks(rotation=45, ha='right')
  plt.title(title + ' - Transform')
  plt.suptitle('')
  plt.show()


def plot_x_y_importance_interaction(model_vi, xvar, yvar, imp_var, title):
  x = sorted(model_vi[xvar].unique())
  y = sorted(model_vi[yvar].unique())
  X, Y = np.meshgrid(x, y)

  model_vi_grp_min  = model_vi[[yvar, xvar, imp_var]].groupby([yvar, xvar]).min(imp_var)
  model_vi_grp_max  = model_vi[[yvar, xvar, imp_var]].groupby([yvar, xvar]).max(imp_var)
  model_vi_grp_mean = model_vi[[yvar, xvar, imp_var]].groupby([yvar, xvar]).mean(imp_var)

  zs_min  = np.array(model_vi_grp_min[imp_var])
  zs_max  = np.array(model_vi_grp_max[imp_var])
  zs_mean = np.array(model_vi_grp_mean[imp_var])
  Zmin  = zs_min.reshape(X.shape)
  Zmax  = zs_max.reshape(X.shape)
  Zmean = zs_mean.reshape(X.shape)

  fig = plt.figure(figsize=(14, 6))
  ax1 = fig.add_subplot(131, projection='3d')
  ax1.plot_surface(X, Y, Zmin, cmap='viridis')
  ax1.set_xlabel(xvar.title())
  ax1.set_ylabel(yvar.title())
  ax1.set_title(title + ' - min')
  # ax1.set_zlabel(str(imp_var).title())

  ax2 = fig.add_subplot(132, projection='3d')
  ax2.plot_surface(X, Y, Zmean, cmap='viridis')
  ax2.set_xlabel(xvar.title())
  ax2.set_ylabel(yvar.title())
  ax2.set_title(title + ' - mean')
  # ax2.set_zlabel(str(imp_var).title())

  ax3 = fig.add_subplot(133, projection='3d')
  ax3.plot_surface(X, Y, Zmax, cmap='viridis')
  ax3.set_xlabel(xvar.title())
  ax3.set_ylabel(yvar.title())
  ax3.set_title(title + ' - max')
  ax3.set_zlabel(str(imp_var).title())
  plt.show()


def plot_multi_model_importance_interactions(model_vi, title=None):

  plot_x_y_importance_interaction(model_vi, 'model', 'lag', 'imp', title)

  if model_vi['window'].nunique() >= 2:
    plot_x_y_importance_interaction(model_vi, 'model', 'window', 'imp', title)
    plot_x_y_importance_interaction(model_vi.loc[model_vi['lag'] < 0,],
                                    'lag',
                                    'window',
                                    'imp',
                                    title)

  if model_vi['shift'].nunique() >= 2:
    plot_x_y_importance_interaction(model_vi, 'model', 'shift', 'imp', title)
    plot_x_y_importance_interaction(model_vi.loc[model_vi['lag'] < 0,],
                                    'lag',
                                    'shift',
                                    'imp',
                                    title)


def groupby_multi_model_feat_imps(mod_imps, group, verbose=False):
  if verbose:
    print('y_des_shadow:')
    display(mod_imps.loc[mod_imps['feature'] == 'y_des_shadow', 'imp'].describe())

  # mi_gb - model importance group by
  mi_gb_desc = mod_imps[[group, 'imp']].groupby(group).describe()
  mi_gb_desc = mi_gb_desc.droplevel(axis=1, level=0)
  mi_gb_desc = mi_gb_desc.astype({'count': 'int', 'min': 'int', '25%': 'int',
                                  '50%': 'int', '75%': 'int', 'max': 'int'})
  if verbose:
    display(mi_gb_desc)


  mi_gb_sum = mod_imps[[group, 'imp']].groupby(group).sum()
  mi_gb_sum = mi_gb_sum.rename(columns={'imp': 'sum'})
  if verbose:
    print('\nsum:')
    display(mi_gb_sum)

  mi_gb_desc_sum = pd.merge(mi_gb_desc, mi_gb_sum, left_index=True, right_index=True)


  mi_gb_0imp = mod_imps.loc[mod_imps['imp'] == 0].groupby(group).size().to_frame()
  mi_gb_0imp = mi_gb_0imp.rename(columns={0: 'num_0'})
  if verbose:
    print('\nnum 0 imp:')
    display(mi_gb_0imp.sort_values('num_0'))


  mi_gb_desc_sum_0imp = pd.merge(mi_gb_desc_sum, mi_gb_0imp, left_index=True, right_index=True)
  mi_gb_desc_sum_0imp['pc_0'] = mi_gb_desc_sum_0imp['num_0'] * 100 / mi_gb_desc_sum_0imp['count']
  # display(mi_gb_desc_sum_0imp)


  mi_gb_shad_geq = mod_imps[[group, 'shadow_geq']].groupby(group).sum()
  mi_gb_shad_geq = mi_gb_shad_geq.rename(columns={'shadow_geq': 'num_shad_geq'})
  if verbose:
    print('\nnum_shad_geq:')
    display(mi_gb_shad_geq)

  mi_gb_all = pd.merge(mi_gb_desc_sum_0imp, mi_gb_shad_geq, left_index=True, right_index=True)
  mi_gb_all['pc_shad_geq'] = mi_gb_all['num_shad_geq'] * 100 / mi_gb_all['count']
  # display(mi_gb_all)

  return mi_gb_all


def plot_feat_imp_cumsum(mod_vi, title):
  mod_impcs = mod_vi[['model', 'imp']].groupby('model').apply(lambda grp: grp.imp.sort_values(ascending=False).cumsum().reset_index(drop=True))
  mod_impcs = mod_impcs.melt(ignore_index=False).reset_index().rename(columns={'imp': 'order', 'value': 'imp_cumsum'})

  xvar = 'order'
  yvar = 'model'
  fig = plt.figure(figsize=(14, 6))
  ax1 = fig.add_subplot(121, projection='3d')
  ax1.plot_trisurf(mod_impcs[xvar], mod_impcs[yvar], mod_impcs['imp_cumsum'],
                   cmap='viridis', linewidth=0)

  ax1.set_xlabel(xvar.title())
  ax1.set_ylabel(yvar.title())
  ax1.set_title(title)

  ax2 = fig.add_subplot(122, projection='3d')
  ax2.plot_trisurf(mod_impcs[yvar], mod_impcs[xvar], mod_impcs['imp_cumsum'],
                   cmap='viridis', linewidth=0)

  ax2.set_xlabel(yvar.title())
  ax2.set_ylabel(xvar.title())
  ax2.set_zlabel('imp_cumsum')
  ax2.set_title(title)
  plt.show()


# TODO Fix code duplication with plot_multi_model_feat_imps
def plot_multi_model_single_feature_imp(model_vi, col_name, col_value, title, imp_col='imp'):
  if title is None:
    title = col_value
  else:
    title = title + ' ' + col_value

  model_vi = model_vi.loc[model_vi[col_name] == col_value,]

  model_vi.boxplot(imp_col)
  plt.title(title + ' - Variable importance')
  plt.show()

  if model_vi['type'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='type')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Feature types')
    plt.suptitle('')
    plt.show()

  bxpm = model_vi.boxplot(imp_col, by='model')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Models')
  plt.suptitle('')
  n = 2
  plt.setp(bxpm.get_xticklabels()[::n], visible=False)
  plt.show()

  # palette = ['red', 'green']
  # box_cols = ['y_des_window_24_hist_mode', 'y_des_shadow']
  # model3_vi_bp = model3_vi.loc[model3_vi['feature'].isin(box_cols)]
  # sns.boxplot(data=model3_vi_bp, x='lag', y='imp', hue='feature',
  #             gap=.2, palette=palette, fill=False, linewidth=.75)
  # plt.show()

  if np.min(model_vi['lag']) < 0:
    bxpl = model_vi.loc[model_vi['lag'] < 0,].boxplot(imp_col, by='lag')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Past Lags')
    plt.suptitle('')
    plt.show()

  if np.max(model_vi['lag']) > 0:
    bxpl = model_vi.loc[model_vi['lag'] >= 0,].boxplot(imp_col, by='lag')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Future Lags')
    plt.suptitle('')
    n = 2
    plt.setp(bxpl.get_xticklabels()[::n], visible=False)
    plt.show()

  if model_vi['window'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='window')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Windows')
    plt.suptitle('')
    plt.show()

  if model_vi['shift'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='shift')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Shifts')
    plt.suptitle('')
    plt.show()

  if model_vi['transform'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='transform')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Transforms')
    plt.suptitle('')
    plt.show()


  xvar = 'model'
  yvar = 'lag'

  fig = plt.figure(figsize=(14, 7))
  ax1 = fig.add_subplot(121, projection='3d')
  ax1.plot_trisurf(model_vi[xvar], model_vi[yvar], model_vi[imp_col],
                   cmap='viridis', linewidth=0)
  # ax1.plot_trisurf(model_vi[xvar], model_vi[yvar], model_vi['shadow_imp'],
  #                  color='pink', linewidth=0)
  ax1.set_xlabel(xvar.title())
  ax1.set_ylabel(yvar.title())
  ax1.set_title(title)

  ax2 = fig.add_subplot(122, projection='3d')
  ax2.plot_trisurf(model_vi[yvar], model_vi[xvar], model_vi[imp_col],
                   cmap='viridis', linewidth=0)
  # ax2.plot_trisurf(model_vi[yvar], model_vi[xvar], model_vi['shadow_imp'],
  #                  color='pink', linewidth=0)
  ax2.set_xlabel(yvar.title())
  ax2.set_ylabel(xvar.title())
  # ax2.zaxis.set_rotate_label(False)  # disable automatic rotation
  ax2.set_zlabel(str(imp_col))  #, rotation=-90)
  ax2.set_title(title)
  plt.show()
  print('\n\n')


def load_features_file(feature_set,
                       data_set,
                       location = 'gdrive',
                       date_str = '.2022.09.20',
                       filex    = '.parquet'):

  if location == 'github':
    base_url = 'https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/data/features/'
    filex += '?raw=true'
  elif location == 'gdrive':
    base_url = '/content/drive/MyDrive/data/CambridgeTemperatureNotebooks/features/'
  else:
    print("Unsupported 'location' in load_features_file function:")
    print('  location =', location)

  file_str = feature_set + '_' + data_set + date_str + filex
  data_url = base_url + file_str

  df = pd.read_parquet(data_url)

  df.set_index('ds', drop=False, inplace=True)
  df = df[~df.index.duplicated(keep='first')]
  df = df.asfreq(freq='30min')

  return df


def load_train_valid_test_features(feature_set, location='gdrive'):
  train = load_features_file(feature_set, 'train', location)
  valid = load_features_file(feature_set, 'valid', location)
  test  = load_features_file(feature_set, 'test',  location)

  sanity_check_train_valid_test(train, valid, test)

  check_high_low_thresholds(train, 'train '+feature_set)
  check_high_low_thresholds(valid, 'valid '+feature_set)
  check_high_low_thresholds(test,   'test '+feature_set)

  return train, valid, test


def get_feature_selection_data(df, sel_cols, y_col, fs_lags, pred_step):
  feat_cols = df.columns.to_list()

  excludes = ['_ucm_', '_yhat', '_diff_', '_yearly', '_daily', '_trend',
              'humidity_des', 'pressure_des', 'ds', 'spike', 'tsclean',
              'day.', 'year', 'rain', 'wind.', '.log', 'dy_dh', 'dy_dp',
              'missing', 'dT_dH', 'dT_dP', 'dT_dTdp', 'known_inaccuracy',
              'isd_', 'long_run', 'cooksd_out', 'tau'
              ]
  feat_cols = [feat_col for feat_col in feat_cols if all([x not in feat_col for x in excludes])]

  feat_cols.extend(sel_cols)
  feat_cols = list(set(feat_cols))

  feat_cols.remove(y_col)
  feat_cols.remove('y')
  feat_cols.remove('dew.point')

  all_cols = [*feat_cols, y_col]
  df_nona = df[all_cols].dropna()

  # Add lagged features
  if fs_lags is not None:
    # PerformanceWarning: DataFrame is highly fragmented ...
    # df_nona_lagged = df_nona.assign(**{
    #   f'{col}_lag_{lag}': df_nona[col].shift(lag)
    #   for lag in fs_lags
    #   for col in feat_cols})
    df_nona_lagged_only = pd.DataFrame({
       f'{col}_lag_{lag}': df_nona[col].shift(lag)
       for lag in fs_lags
       for col in feat_cols})
    df_nona_lagged_plus = pd.concat([df_nona, df_nona_lagged_only], axis=1)
    df_nona_lagged_plus = df_nona_lagged_plus.dropna()
    X_df = df_nona_lagged_plus.drop(y_col, axis=1)
    y_df = df_nona_lagged_plus[[y_col]]
  else:
    X_df = df_nona[feat_cols]
    y_df = df_nona[[y_col]]

  y_df = y_df[y_col].shift(-pred_step).dropna()
  X_df = X_df.head(y_df.shape[0])

  return X_df, y_df


def get_feature_selection_scores(df, sel_cols, lags=None, pred_step=0,
                                 y_col=Y_COL, sort_col='f_test', mi=False):
  '''WARNING: These tests assume a linear model.  This may not be optimal.

     Don't draw any hasty conclusions from these scores.
  '''

  X_df, y_df = get_feature_selection_data(df, sel_cols, y_col, lags, pred_step)
  feat_cols = X_df.columns

  f_tests, _ = f_regression(X_df, y_df)
  f_tests /= np.sum(f_tests)

  r_tests = r_regression(X_df, y_df)
  r_tests /= np.sum(r_tests)

  # Correlations with Y_COL
  # corrs = []
  # for feat in feat_cols:
  #   corrs.append(X_df[feat].corr(y_df))

  fs_df = pd.DataFrame({#'correlation': corrs,
                        'r_test': r_tests.round(6),
                        'f_test': f_tests.round(6),
                       })
  fs_df.index = feat_cols

  # Slow :-(
  if mi:
    mi_feats = mutual_info_regression(X_df, y_df)
    mi_feats /= np.sum(mi_feats)
    fs_df['mi'] = mi_feats

  if sort_col == 'f_test':
    fs_df = fs_df.sort_values('f_test', ascending=False)
  elif sort_col == 'r_test':
    fs_df = fs_df.sort_values('r_test', ascending=False)
  elif mi is True and sort_col == 'mi':
    fs_df = fs_df.sort_values('mi', ascending=False)

  return fs_df


def get_above_between_below_features(fs, feats, feat_str, imp='f_test'):
  # cf - core features
  # sf - solar features

  fs['above_' + feat_str] = 0
  fs['below_' + feat_str] = 0

  min_score = np.min(fs.loc[feats, imp])
  max_score = np.max(fs.loc[feats, imp])

  fs.loc[fs[imp] > max_score, 'above_' + feat_str] = 1
  fs.loc[fs[imp] < min_score, 'below_' + feat_str] = 1

  if len(feats) > 1:
    fs['between_' + feat_str] = 0
    fs.loc[(fs[imp] >= min_score) & (fs[imp] <= max_score), 'between_' + feat_str] = 1

  return fs


def get_multi_step_feat_sel_scores(train, sel_cols, lags=None, horizon=48):
  fs_steps_list = []

  for i in range(horizon):
    fs_df = get_feature_selection_scores(train, sel_cols, lags, pred_step=i)
    fs_df['rank'] = [i for i in range(fs_df.shape[0])]
    fs_df['step'] = i

    if lags is None:
      fs_df['feature_transform'] = fs_df.index
    else:
      fs_df['feature_transform_lag'] = fs_df.index

    fs_df.drop('r_test', axis=1, inplace=True)
    fs_df = get_above_between_below_features(fs_df, ['dew.point_des', 'pressure', 'humidity'], 'cf')
    fs_df = get_above_between_below_features(fs_df, ['irradiance', 'za_rad'], 'sf')
    fs_df = get_above_between_below_features(fs_df, ['y_des_shadow'], 'shadow')
    fs_steps_list.append(fs_df)

  fs_steps = pd.concat(fs_steps_list)

  # add window, shift, transform, feature ...

  if lags is not None:
    fs_steps['feature_transform'] = fs_steps['feature_transform_lag']

    with warnings.catch_warnings():
      warnings.simplefilter(action='ignore', category=FutureWarning)
      fs_steps['feature_transform'] = fs_steps['feature_transform'].str.replace('_lag.*', '', regex=True)

    fs_steps['lag'] = fs_steps['feature_transform_lag'].str.extract('_lag(-?\d+)$').fillna(0).astype(int)
    fs_feats = ['feature_transform_lag', 'feature_transform', 'feature',
                'transform', 'window', 'shift', 'lag', 'step', 'f_test',
                'rank', 'above_cf', 'between_cf', 'below_cf', 'above_sf',
                'between_sf', 'below_sf', 'above_shadow', 'below_shadow']
  else:
    fs_steps['lag'] = 0
    fs_feats = ['feature_transform', 'feature', 'transform', 'window', 'shift',
                'lag', 'step', 'f_test', 'rank', 'above_cf', 'between_cf',
                'below_cf', 'above_sf', 'between_sf', 'below_sf',
                'above_shadow', 'below_shadow']

  fs_steps['shift']  = fs_steps['feature_transform'].str.extract('_shift_(\d+)_').fillna(0).astype(int)
  fs_steps['window'] = fs_steps['feature_transform'].str.extract('_window_(\d+)_').fillna(0).astype(int)

  fs_steps['transform'] = fs_steps['feature_transform'].str.extract('_window_\d+_([a-zA-Z0-9\._]+)$').fillna('None')
  # fs_steps['transform'] = fs_steps['transform'].str.replace('_lag_[0-9]+', '')

  fs_steps.loc[fs_steps['transform'].str.match('intervals_intervals_mean'), 'transform'] = 'intervals_mean'

  fs_steps.loc[fs_steps['feature_transform'].str.endswith('_shadow'), 'transform'] = 'shadow'
  fs_steps.loc[fs_steps['feature_transform'].str.endswith('_grad'), 'transform'] = 'grad'
  fs_steps.loc[fs_steps['feature_transform'].str.endswith('_des'), 'transform'] = 'des'

  fs_steps['feature'] = fs_steps['feature_transform'].str.extract('([a-zA-Z0-9\.]+)_').fillna('None')
  fs_steps.loc[fs_steps['transform'] == 'None', 'feature'] = fs_steps.loc[fs_steps['transform'] == 'None', 'feature_transform']

  fs_steps = fs_steps[fs_feats]

  return fs_steps


def groupby_multi_step_feat_sel_scores(mod_imps, group, verbose=False):
  if verbose:
    print('y_des_shadow:')
    display(mod_imps.loc[mod_imps['feature_transform'] == 'y_des_shadow', 'f_test'].describe())

  # mi_gb - model importance group by
  mi_gb_desc = mod_imps[[group, 'f_test']].groupby(group).describe()
  mi_gb_desc = mi_gb_desc.droplevel(axis=1, level=0)
  mi_gb_desc = mi_gb_desc.astype({'count': 'int'})#, 'min': 'int', '25%': 'int',
                                  #'50%': 'int', '75%': 'int', 'max': 'int'})
  if verbose:
    display(mi_gb_desc)


  mi_gb_sum = mod_imps[[group, 'f_test']].groupby(group).sum()
  mi_gb_sum = mi_gb_sum.rename(columns={'f_test': 'sum'})
  if verbose:
    print('\nsum:')
    display(mi_gb_sum)

  mi_gb_desc_sum = pd.merge(mi_gb_desc, mi_gb_sum, left_index=True, right_index=True)


  mi_gb_0imp = mod_imps.loc[mod_imps['f_test'] == 0.0].groupby(group).size().to_frame()
  mi_gb_0imp = mi_gb_0imp.rename(columns={0: 'num_0'})
  if verbose:
    print('\nnum 0 f_test:')
    display(mi_gb_0imp.sort_values('num_0'))


  mi_gb_desc_sum_0imp = pd.merge(mi_gb_desc_sum, mi_gb_0imp, how='outer', left_index=True, right_index=True)
  mi_gb_desc_sum_0imp['num_0'] = mi_gb_desc_sum_0imp['num_0'].fillna(0)
  mi_gb_desc_sum_0imp['pc_0'] = mi_gb_desc_sum_0imp['num_0'] * 100 / mi_gb_desc_sum_0imp['count']
  # display(mi_gb_desc_sum_0imp)


  #mi_gb_shad_geq = mod_imps[[group, 'shadow_geq']].groupby(group).sum()
  #mi_gb_shad_geq = mi_gb_shad_geq.rename(columns={'shadow_geq': 'num_shad_geq'})
  #if verbose:
  #  print('\nnum_shad_geq:')
  #  display(mi_gb_shad_geq)

  #mi_gb_all = pd.merge(mi_gb_desc_sum_0imp, mi_gb_shad_geq, left_index=True, right_index=True)
  #mi_gb_all['pc_shad_geq'] = mi_gb_all['num_shad_geq'] * 100 / mi_gb_all['count']
  # display(mi_gb_all)

  #return mi_gb_all
  return mi_gb_desc_sum_0imp


def plot_multi_step_feat_sel_scores(model_vi, title, imp_col='f_test'):
  if title is None:
    title = ''

  model_vi.boxplot(imp_col)
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Variable importance')
  plt.suptitle('')
  plt.show()

  bxpm = model_vi.boxplot(imp_col, by='step')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - forecast step')
  plt.suptitle('')
  n = 2
  plt.setp(bxpm.get_xticklabels()[::n], visible=False)
  plt.show()

  if model_vi['window'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='window')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Windows')
    plt.suptitle('')
    plt.show()

  if model_vi['shift'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='shift')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Shifts')
    plt.suptitle('')
    plt.show()

  model_vi.loc[model_vi['above_cf'] == 1,].boxplot(imp_col, by='feature_transform')
  plt.xticks(rotation=45, ha='right')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Features (above core features)')
  plt.suptitle('')
  plt.show()

  # model_vi.boxplot(imp_col, by='feature_transform')
  model_vi.loc[model_vi['between_cf'] == 1,].boxplot(imp_col, by='feature_transform')
  plt.xticks(rotation=45, ha='right')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Features (between core features)')
  plt.suptitle('')
  plt.show()

  model_vi.boxplot(imp_col, by='feature')
  plt.xticks(rotation=45, ha='right')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Base Feature')
  plt.suptitle('')
  plt.show()

  model_vi.boxplot(imp_col, by='transform')
  plt.xticks(rotation=45, ha='right')
  plt.ylabel(str(imp_col))
  plt.title(title + ' - Transform')
  plt.suptitle('')
  plt.show()

  if model_vi['lag'].nunique() >= 2:
    model_vi.boxplot(imp_col, by='lag')
    plt.ylabel(str(imp_col))
    plt.title(title + ' - Lag')
    plt.suptitle('')
    plt.show()

  if model_vi['step'].nunique() >= 2 and model_vi['lag'].nunique() >= 2:
    plot_x_y_importance_interaction(model_vi, 'step', 'lag', 'f_test', title)

  if model_vi['window'].nunique() >= 2 and model_vi['lag'].nunique() >= 2:
    plot_x_y_importance_interaction(model_vi, 'window', 'lag', 'f_test', title)

  if model_vi['step'].nunique() >= 2 and model_vi['window'].nunique() >= 2:
    plot_x_y_importance_interaction(model_vi, 'step', 'window', 'f_test', title)


def summarise_multi_step_feat_sel_scores(train, title, sel_cols, fs_lags=None, ft_mean_lim=0.01, verbose=False):
  fs = get_multi_step_feat_sel_scores(train, sel_cols, lags=fs_lags)

  if verbose:
    display(fs)

  if fs_lags is not None:
    print('feature_transform_lag - ft_mean_lim = ' + str(ft_mean_lim) + ':')
    fs_ftl = groupby_multi_step_feat_sel_scores(fs, 'feature_transform_lag')
    display(fs_ftl.loc[fs_ftl['mean'] >= ft_mean_lim,].sort_values('mean', ascending=False))
    print('lag:')
    fs_lag = groupby_multi_step_feat_sel_scores(fs, 'lag')
    display(fs_lag.sort_values('mean', ascending=False))
  else:
    print('feature_transform - ft_mean_lim = ' + str(ft_mean_lim) + ':')
    fs_ft = groupby_multi_step_feat_sel_scores(fs, 'feature_transform')
    display(fs_ft.loc[fs_ft['mean'] >= ft_mean_lim,].sort_values('mean', ascending=False))

  if verbose:
    fs_feature = groupby_multi_step_feat_sel_scores(fs, 'feature')
    print('feature:')
    display(fs_feature)

    fs_transform = groupby_multi_step_feat_sel_scores(fs, 'transform')
    print('transform:')
    display(fs_transform)

    fs_window = groupby_multi_step_feat_sel_scores(fs, 'window')
    print('window:')
    display(fs_window)

  plot_multi_step_feat_sel_scores(fs, title)


def summarise_multi_model_feat_imps(model_vi, title, verbose=False):
  if verbose:
    display(model_vi)

    for gb_col in ['model', 'lag', 'type', 'shift', 'window', 'feature', 'transform']:
      if model_vi[gb_col].nunique() >= 2:
        vi_by_gb = groupby_multi_model_feat_imps(model_vi, gb_col)
        display(vi_by_gb); print()

  plot_multi_model_feat_imps(model_vi, title)
  plot_multi_model_importance_interactions(model_vi, title)

  if verbose:
    plot_feat_imp_cumsum(model_vi, title)

## Data Setup

<a name='import'></a>

### Import Pre-calculated Features


See [feature_engineering.ipynb](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/feature_engineering.ipynb) notebook for further details.

Load default features:

In [None]:
train_df, valid_df, test_df = load_train_valid_test_features('default', location='github')

print('train_df:')
print_df_summary(train_df)
plot_cols = ['y', 'humidity', 'dew.point', 'wind.speed.mean.x',
             'wind.speed.mean.y']  # 'pressure',
plot_observation_examples(train_df, plot_cols)


...

TODO Summarise results


---


## Comparison with Baselines

Finally, we can compare the best performing gradient boosted etc models against the best baseline method.  The VAR (Vector Auto-Regression) model from the [baselines notebook](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/cammet_baselines_2021.ipynb) was the best performing baseline.

The best encoder decoder model, after 5 training epochs, was conv2dk2d_28l_48s_16bs_448fm_64f_1ksf_7kst.  Here I train the same model for 20 epochs.

Some points to note regarding the `plot_forecasts` diagnostic plot:
 * on validation data not test data
 * `plot_forecasts`
   * plot example forecasts with observations and lagged temperatures
      * first row shows examples of best near zero rmse forecasts
      * second row shows examples of worst positive rmse forecasts
      * third row shows examples of worst negative rmse forecasts
      * lagged observations are negative
      * the day of the year the forecast begins in and the rmse value is displayed above each sub-plot

### Updated VAR model

...

In [None]:
from statsmodels.tsa.api import VAR
from statsmodels.tools.eval_measures import rmse, medianabs


def plot_baseline_metrics(metrics, main_title):
  fig, axs = plt.subplots(1, 2, figsize = (14, 7))
  fig.suptitle(main_title)
  axs = axs.ravel()  # APL ftw!

  methods = metrics.method.unique()

  for method in methods:
    met_df = metrics.query('metric == "rmse" & method == "%s"' % method)
    axs[0].plot(met_df.horizon, met_df.value, color='blue', label='Updated VAR')

  ivar_rmse = np.array([0.39, 0.52, 0.64, 0.75, 0.86, 0.96, 1.06, 1.15, 1.23,
                        1.31, 1.38, 1.45, 1.51, 1.57, 1.63, 1.68, 1.73, 1.77,
                        1.81, 1.85, 1.89, 1.92, 1.96, 1.99, 2.02, 2.05, 2.08,
                        2.1 , 2.13, 2.15, 2.18, 2.2 , 2.22, 2.24, 2.26, 2.28,
                        2.3 , 2.31, 2.33, 2.35, 2.36, 2.38, 2.39, 2.4 , 2.42,
                        2.43, 2.44, 2.45])
  steps = [i for i in range(1, len(ivar_rmse)+1)]
  axs[0].plot(steps, ivar_rmse, color='black', label='Initial VAR')

  axs[0].set_xlabel("horizon - half hour steps")
  axs[0].set_ylabel("rmse")
  # axs[0].legend(methods)


  for method in methods:
    met_df = metrics.query('metric == "mae" & method == "%s"' % method)
    axs[1].plot(met_df.horizon, met_df.value, color='blue', label='Updated VAR')

  ivar_mae = np.array([0.39, 0.49, 0.57, 0.66, 0.74, 0.83, 0.91, 0.98, 1.05,
                       1.12, 1.18, 1.24, 1.29, 1.34, 1.39, 1.43, 1.47, 1.5 ,
                       1.53, 1.56, 1.59, 1.62, 1.64, 1.66, 1.68, 1.7 , 1.72,
                       1.73, 1.75, 1.76, 1.77, 1.78, 1.8 , 1.81, 1.82, 1.83,
                       1.83, 1.84, 1.85, 1.85, 1.86, 1.86, 1.87, 1.87, 1.88,
                       1.88, 1.89, 1.89])
  axs[1].plot(steps, ivar_mae, color='black', label='Initial VAR')

  axs[1].set_xlabel("horizon - half hour steps")
  axs[1].set_ylabel("mae")
  # axs[1].legend(methods)

  plt.legend(bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0)
  plt.show()


def update_metrics(metrics, test_data, method, get_metrics,
                   model = None,
                   met_cols = ['type', 'method', 'metric', 'horizon', 'value']):
  metrics_h = []

  if method in ['SES', 'HWES']:
    horizons = [i for i in range(4, 49, 4)]
    horizons.insert(0, 1)
  else:
    # horizons = [1, 48]
    horizons = range(1, 49)

  if method in ['VAR']:
    variates = 'multivariate'
  else:
    variates = 'univariate'

  print("h\trmse\tmae")
  for h in horizons:
    if method in ['VAR']:
      rmse_h, mae_h = get_metrics(test_data, h, method, model)
    else:
      rmse_h, mae_h = get_metrics(test_data, h, method)

    metrics_h.append(dict(zip(met_cols, [variates, method, 'rmse', h, rmse_h])))
    metrics_h.append(dict(zip(met_cols, [variates, method,  'mae', h,  mae_h])))

  print("\n")

  metrics_method = pd.DataFrame(metrics_h, columns = met_cols)
  metrics = metrics.append(metrics_method)

  return metrics


# rolling_cv with pre-trained model
def var_rolling_cv(data, horizon, method, model):
    lags = model.k_ar  # lag order
    i = lags
    h = horizon
    rmse_roll, mae_roll = [], []
    endo_vars = ['y', 'dew.point', 'humidity', 'pressure']
    exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',
                 'irradiance', 'azimuth_cos', 'za_rad'
                ]

    while (i + h) < len(data):
        obs_df  = data[endo_vars].iloc[i:(i + h)]
        endo_df = data[endo_vars].iloc[(i - lags):i].values
        exog_df = data[exog_vars].iloc[i:(i + h)]

        # y_hat = model.forecast(endo_df, steps = h)
        y_hat = model.forecast(endo_df, exog_future = exog_df, steps = h)
        preds = pd.DataFrame(y_hat, columns = endo_vars)

        rmse_i = rmse(obs_df.y,      preds.y)
        mae_i  = medianabs(obs_df.y, preds.y)
        rmse_roll.append(rmse_i)
        mae_roll.append(mae_i)

        i = i + 1

    print(h, '\t', np.nanmean(rmse_roll).round(3), '\t', np.nanmean(mae_roll).round(3))

    return [np.nanmean(rmse_roll).round(2), np.nanmean(mae_roll).round(2)]

...

In [None]:
# approx. 5 mins

train_df = train_df.asfreq(freq='30min')
valid_df = valid_df.asfreq(freq='30min')
test_df  = test_df.asfreq(freq='30min')

train_df.dropna(inplace=True)

endo_vars = ['y', 'dew.point', 'humidity', 'pressure']
exog_vars = [
            'day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',
            'irradiance', 'azimuth_cos', 'za_rad'
            ]
endo_df = train_df[endo_vars]
exog_df = train_df[exog_vars]

var_model = VAR(endog = endo_df, exog = exog_df)
# var_model = VAR(endog = endo_df)
MAX_LAGS = 96
lag_order_res = var_model.select_order(MAX_LAGS)
display(lag_order_res.summary())
display(lag_order_res.selected_orders)
print(lag_order_res.selected_orders['bic'])

lag_order_table = lag_order_res.summary().data
headers = lag_order_table.pop(0)
lag_order_df = pd.DataFrame(lag_order_table, columns=headers)
lag_order_df.drop('', axis=1, inplace=True)

with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    lag_order_df = pd.concat([lag_order_df[col].str.replace('*', '').astype(float)
                             for col in lag_order_df], axis=1)

lag_order_df.loc[1:, ['AIC','BIC','HQIC']].plot()
plt.xlabel('lag')
plt.ylabel('IC')
plt.show()

The lowest BIC value occurs at 51 lags.  I'm going to use `maxlags = 51` because that is where decreasing returns sets in.

In [None]:


def get_var_backtest(model, data, endo_vars, exog_vars, y_col=Y_COL, horizon=HORIZON):
  lags = model.k_ar  # lag order
  i = lags
  h = horizon
  preds = []

  while (i + h) < len(data):
    if i % 1000 == 0:
      print(i)

    obs_df  = data[endo_vars].iloc[i:(i + h)]
    endo_vals = data[endo_vars].iloc[(i - lags):i].values

    if exog_vars is not None:
      exog_df = data[exog_vars].iloc[i:(i + h)]
      y_hat_lol = model.forecast(endo_vals, exog_future = exog_df, steps = h)
    else:
      y_hat_lol = model.forecast(endo_vals, steps = h)

    y_col_pos = 0  # hardcoding is bad mkay - make function param?
    y_hat_series = pd.Series(data  = [y_hat_l[y_col_pos] for y_hat_l in y_hat_lol],
                             index = obs_df.index,
                             name  = y_col)
    y_hat_ts = TimeSeries.from_series(y_hat_series)
    # y_hat_ts = TimeSeries.from_values(np.array([y_hat_l[y_col_pos] for y_hat_l in y_hat_lol]))
    # y_hat = [y_hat_l[y_col_pos] for y_hat_l in y_hat_lol]

    preds.append(y_hat_ts)
    i = i + 1

  return preds


var_fit = var_model.fit(maxlags = 51, ic = 'bic')
print(var_fit.summary())

main_var_col = 'y'
backtest_var = get_var_backtest(var_fit, valid_df, endo_vars, exog_vars, y_col = main_var_col)
# display(len(backtest_var))
# display(backtest_var[0])
hist_comp_var = get_historic_comparison(backtest_var, valid_df, y_col = main_var_col)
# display(hist_comp_var)
summarise_historic_comparison(hist_comp_var, valid_df, y_col = main_var_col)

title_var = 'VAR ' + main_var_col + '...'
plot_multistep_diagnostics(hist_comp_var, title_var, y_col = main_var_col)


# metric_cols = ['type', 'method', 'metric', 'horizon', 'value']
# metrics = pd.DataFrame([], columns = metric_cols)
# metrics = update_metrics(metrics, valid_df, 'VAR', var_rolling_cv, var_fit)
## metrics = update_metrics(metrics, test_df, 'VAR', var_rolling_cv, var_fit)
# plot_baseline_metrics(metrics, 'Multivariate Baseline Comparison - 2021 valid data')


# 2019 data
# maxlags = 5
# ...
# h	   rmse	   mae
# 1 	 0.39 	 0.39
# 48 	 2.45 	 1.89

# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# maxlags = 52
# h	   rmse	   mae
# 1 	 0.37 	 0.37
# 48 	 2.253 	 1.784
# maxlags = 52 substantially better than maxlags = 9

# endo_vars = ['y', 'dew.point', 'humidity',]
# maxlags = 52
# h	   rmse	   mae
# 1 	 0.37 	 0.37
# 48 	 2.293 	 1.814
# including pressure is beneficial

# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['za_rad', 'irradiance', 'azimuth_cos',]
# maxlags = 51
# h	   rmse	   mae
# 1    0.369 	 0.369
# 48   2.163 	 1.729
# exog_vars is beneficial
# 1 hr 28 mins :-(

# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',]
# maxlags = 52
# h	   rmse	   mae
# 1    0.37 	 0.37
# 48   2.133 	 1.68
# Sinusoidal terms better than irradiance etc!

# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1', 'irradiance']
# maxlags = 51
# h	   rmse	   mae
# 1    0.369 	 0.369
# 48   2.105 	 1.667
# irradiance worth adding to sinusoidal terms

# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1', 'za_rad']
# maxlags = 51
# h	   rmse	   mae
# 1    0.37 	 0.37
# 48   2.134 	 1.679
# za_rad not as beneficial as irradiance

# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1', 'azimuth_cos']
# maxlags = 51
# h	   rmse	   mae
# 1    0.37 	 0.37
# 48   2.131 	 1.675
# azimuth_cos more beneficial than za_rad

# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',
#              'irradiance', 'azimuth_cos']
# maxlags = 51
# h	   rmse	   mae
# 1 	 0.368 	 0.368
# 48 	 2.098 	 1.658
# Best model so far

# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',
#              'irradiance', 'azimuth_cos', 'za_rad']
# maxlags = 51
# h	   rmse	   mae
# 1 	 0.368 	 0.368
# 48 	 2.098 	 1.657
# Marginally better with za_rad

# valid_df
# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',
#              'irradiance', 'azimuth_cos', 'za_rad']
# maxlags = 51
# h	   rmse	   mae
# 1 	 0.347 	 0.347
# 48 	 2.012 	 1.581
#

# valid_df
# endo_vars = ['y_des', 'dew.point_des', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',
#              'irradiance', 'azimuth_cos', 'za_rad']
# maxlags = 53
# h	   rmse	   mae
# 1 	 0.347 	 0.347
# 48 	 2.724   2.132

# valid_df
# endo_vars = ['y_des', 'dew.point_des', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',
#              'irradiance', 'azimuth_cos', 'za_rad']
# maxlags = 53
# h	   rmse	   mae
# 1 	 0.357 	 0.357
# 48 	 2.712   2.121

# valid_df
# train_df.loc['2016-01-01':,]
# endo_vars = ['y', 'dew.point', 'pressure', 'humidity',]
# exog_vars = ['day.cos.1', 'day.sin.1', 'year.cos.1', 'year.sin.1',
#              'irradiance', 'azimuth_cos', 'za_rad']
# maxlags = 22
# h	   rmse	   mae
# 1 	 0.352 	 0.352
# 48 	 2.926   2.305
# Backtest RMSE 48th: 2.92592
# Backtest MAE 48th:  2.304481
# Radical decrease in maxlags!
# Not a great model


rmse
```
[0.39, 0.52, 0.64, 0.75, 0.86, 0.96, 1.06, 1.15, 1.23,
 1.31, 1.38, 1.45, 1.51, 1.57, 1.63, 1.68, 1.73, 1.77,
 1.81, 1.85, 1.89, 1.92, 1.96, 1.99, 2.02, 2.05, 2.08,
 2.1 , 2.13, 2.15, 2.18, 2.2 , 2.22, 2.24, 2.26, 2.28,
 2.3 , 2.31, 2.33, 2.35, 2.36, 2.38, 2.39, 2.4 , 2.42,
 2.43, 2.44, 2.45]
```

mae
```
[0.39, 0.49, 0.57, 0.66, 0.74, 0.83, 0.91, 0.98, 1.05,
 1.12, 1.18, 1.24, 1.29, 1.34, 1.39, 1.43, 1.47, 1.5 ,
 1.53, 1.56, 1.59, 1.62, 1.64, 1.66, 1.68, 1.7 , 1.72,
 1.73, 1.75, 1.76, 1.77, 1.78, 1.8 , 1.81, 1.82, 1.83,
 1.83, 1.84, 1.85, 1.85, 1.86, 1.86, 1.87, 1.87, 1.88,
 1.88, 1.89, 1.89]
```

In [None]:
var_fit.plot()
plt.show()

# var_fit.plot_acorr()
# plt.show()

var_fit.fevd(48).plot()
plt.show()

var_fit.mse(48)

The updated VAR model shows substantial improvement.  It would benefit from further validation, including residual plots, QQ plots, autocorrelation of residual plots, residual boxplots across the forecast horizon steps etc

NOTE: Updated VAR validated on 2021 data; initial VAR validated on 2019 data.

TODO Move VAR baseline to separate notebook

---

**TODO** Plot model diagnostics.


Next, I plot the best model and VAR model rmse and mae values for forecast horizons up to 48 (24 hours, each horizon step is equivalent to 30 minutes).  This plot plus the two others are for forecasts on the previously unused 2019 "test" data.  This is different from the 2018 "validation" data used elsewhere in this notebook.

Some points to note regarding diagnostic plots:
 * once again, on test data not validation data
 * `plot_horizon_metrics`
   * plot rmse and mae values for each individual step-ahead
 * `check_residuals`
   * observations against predictions
   * residuals over time
   * residual distribution
 * `plot_forecasts`
   * see sub-section immediately above for notable points

Broadly speaking, these results are very similar to the results from the VAR model.


Diagnostic plots summary:
 * once again, these plots use test data not validation data
 * `plot_horizon_metrics`
   * initially, these results look quite contradictory
   * the rmse plot indicates better forecasts for the VAR method (in orange)
   * the mae plot indicates better forecasts for the Conv2D kernel 2D method (in blue, mis-labelled as LSTM)
 * `check_residuals`
   * the observations against predictions plot indicates
     * predictions are too high at cold temperatures (below 0 C)
     * predictions are too low at hot temperatures (above 25 C)
   * residuals over time
     * no obvious heteroscadicity
     * no obvious periodicity
       * surprising given observations against predictions plot
   * residual distribution appears to be approximately normal (slightly right-skewed)
     * no obvious sign of fat tails
 * `plot_forecasts`
   * notable lack of noisy observations for the large positive and negative rmse examples

The median absolute error (mae) is less sensitive to outliers compared to the root mean squared error (rmse) metric.

Therefore, the rmse and mae plot difference may be due to the presence of outliers. I have maintained from the start that this data set is quite noisy, and attempts to correct these problems may have unintensionally introduced new issues.

Transformed mean values across the 48 step horizon:
 * rmse of 2.05796
 * mae of 1.17986

---


## Conclusion

The best results from the gradient boosted trees are similar/different to  results from the [best LSTM model](https://github.com/makeyourownmaker/CambridgeTemperatureNotebooks/blob/main/notebooks/lstm_time_series.ipynb).

How and why are they similar/different?

...

The conclusion is separated into the following sections:
 1. What worked
 2. What didn't work
 3. Rejected ideas
 4. Future work
