In [None]:
#%%
## 1. LOAD PACKAGES
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
from sklearn.metrics import r2_score, mean_squared_error

from np_functions import load_data, format_input, setup_model



In [None]:
#%%
## 2. SET UP TRAINING AND TEST DATASETS
LOCATION = 'Adelphi-Weir' # can be changed out for another flow dataset

df_full = load_data(LOCATION)

# training data
df_train = df_full.loc['2001':'2014'].copy()
df_train = format_input(df_train)
# validation data
df_val = df_full.loc['2015': '2019'].copy()
df_val = format_input(df_val)
# test data
df_test = df_full.loc['2020': '2024'].copy()
df_test = format_input(df_test)


In [None]:
#%%
## 3. a) LOAD PARAMETERS FROM OPTIMISATION

MAX_HORIZON = 6 # set to 1, 3 or 6 

with open(f'np_hyperopt_result_{MAX_HORIZON}h.json', 'r') as file:
    data = json.load(file)

# Retrieve the objects
best = data['best']
trials = data['trials']

params = best

# set integer parameters to 'int'
for key, value in params.items():
    if key not in ['ar_reg', 'reg_reg']:
        params[key] = int(value)



In [None]:
#%%

## 3. b) CUSTOM PARAMETERS (overrides loaded parameters)
params = {
    'yearly_seasonality': False,
    'ar_n_lags': 6,
    'ar_reg': 0.883,
    'ar_layer_size': 94,
    'ar_layer_num': 4,
    'reg_n_lags': 377,
    'reg_reg': 1.3975,
    'reg_layer_size': 53,
    'reg_layer_num': 3,
}


In [None]:
#%%
## 4. SET UP AND TRAIN MODEL

model = setup_model(n_forecasts = MAX_HORIZON, params = params)

metrics_train = model.fit(df_train)
## DO NOT USE model.test() - broken for n_forecasts > 1

In [None]:
#%%
## 5. GENERATE FORECAST

forecast = model.predict(df_val)
forecast_test = model.predict(df_test)



In [None]:
#%%
## 6. CALCULATE METRICS

metrics_horizon = MAX_HORIZON
forecast_metrics = forecast
# drop null values
forecast_metrics = forecast_metrics.dropna(subset=[f'yhat{metrics_horizon}'])
# retrieve true and predicted values for flow
y = forecast_metrics['y']
y_pred = forecast_metrics[f'yhat{metrics_horizon}']
# mean squared error
mse = mean_squared_error(y_pred, y)
# Nash-Sutcliffe efficiency (= r^2)
nse = r2_score(y_pred, y)

print('Mean squared error: ', mse, '\nNash-Sutcliffe Efficiency: ', nse)

# calculate percentage of cases where observed values are within confidence interval of prediction

condition = (forecast_metrics['y'] >= forecast_metrics[f'yhat{metrics_horizon} 5.0%']) & (forecast_metrics['y'] <= forecast_metrics[f'yhat{metrics_horizon} 95.0%'])

percentage = (condition.sum() / len(forecast_metrics)) * 100
print(f'Observed data within 90% confidence interval in {round(percentage, 1)}% of cases')


In [None]:
#%%
## 7. PREDICTION PLOT


forecast_plot = pd.concat([forecast, forecast_test])
forecast_plot.set_index('ds', inplace=True)

# choose forecast window
#window = ('2015-01-01', '2024-05-14')
# event 1: 2015-12-26
window = ('2015-12-25','2015-12-27')
# event 2: 2020-02-09
#window = ('2020-02-08-18','2020-02-10')
# event 3: 2021-01-20
#window = ('2021-01-18','2021-01-22')

forecast_plot = forecast_plot.loc[window[0]:window[1]]

# choose plot horizon to plot
plot_horizon = 6
cols_to_plot = ['y', f'yhat{plot_horizon}']
# cols_to_plot = ['y', 'yhat1', 'yhat3', 'yhat6']

# create plot
fig, ax = plt.subplots(figsize=(10, 5))
forecast_plot.plot(y=cols_to_plot, ax=ax)

# line labels
ax.lines[0].set_label('Observed')
ax.lines[1].set_label(f'Predicted ({plot_horizon}h)')

# plot confidence interval
ax.fill_between(forecast_plot.index, 
                 forecast_plot[f'yhat{plot_horizon} 5.0%'],
                 forecast_plot[f'yhat{plot_horizon} 95.0%'],
                 color='orange', alpha=0.3,
                 label='90% conf. interval')

# plot legend            
ax.legend(framealpha=1)

# add vertical lines at horizon interval
xmin, xmax = plt.gca().get_xlim()
ymin, ymax = plt.gca().get_ylim()
x_lines = np.arange(xmin, xmax, plot_horizon)
plt.vlines(x_lines, ymin=ymin, ymax=ymax, color='black', linestyle='dotted', linewidth=0.8)
# reset plot limits
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)

# format spines and set axis labels
plt.gca().spines[:].set_visible(False)
plt.xlabel('Time')
plt.ylabel('Flow (m$^3$/s)')



In [None]:
#%%
## 8. PLOT MODEL PARAMETERS
# plots trend, seasonality, autoregression, and lagged regressor contributions to predictions
forecast_in_focus = plot_horizon
model.plot_parameters(plotting_backend='matplotlib', forecast_in_focus=forecast_in_focus)
