In [None]:
import sys
sys.path.append("../worldcup")
from hybrid import Hybrid
import pandas as pd

In [None]:
freq = 60
df = pd.read_csv("../data/clean_minutes.csv", index_col=0, parse_dates=True)
downsampled_df = df.resample(str(freq) + "T").mean()
data = downsampled_df["view"].values

In [None]:
len(data)

In [None]:
hybrid = Hybrid(sarima_order=(4,1,2), sarima_seasonal_order=(2,1,1,24))

In [None]:
from sklearn.metrics import mean_squared_error

window_size = 300
refit = False

actuals = []
hybrid_predictions = []
sarima_predictions = []
naive_predictions = []

lstm_residual_predictions = []
residual_actuals = []



for t in range(window_size, len(data)):
  print(f"{t+1} / {len(data)}")
  
  actual = data[t]
  
  train_data = data[t - window_size : t]
  actual = data[t]
  
  naive_predictions.append(data[t-1])
  
  # Fit hybrid
  hybrid.fit(train_data)
  
  forecast = hybrid.forecast(horizon=1)
  
  hybrid_predictions.append(forecast["hybrid_forecast"])
  sarima_predictions.append(forecast["sarima_forecast"])
  lstm_residual_predictions.append(forecast["lstm_residual_forecast"])
  residual_actuals.append(actual - forecast["sarima"])
  actuals.append(actual)

In [None]:

from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error


hybrid_mse = mean_squared_error(hybrid_predictions, actuals)
naive_mse = mean_squared_error(naive_predictions, actuals)
sarima_mse = mean_squared_error(sarima_predictions, actuals)

hybrid_mae = mean_absolute_error(hybrid_predictions, actuals)
naive_mae = mean_absolute_error(naive_predictions, actuals)
sarima_mae = mean_absolute_error(sarima_predictions, actuals)

print("hybrid_mae", hybrid_mae)
print("naive_mae", naive_mae)
print("sarima_mae", sarima_mae)


plt.plot(actuals, label="Actual")
plt.plot(hybrid_predictions, label="Hybrid")
plt.plot(naive_predictions, label="Naive")
# plt.plot(sarima_predictions, label="SARIMA")
plt.xlabel("Time  (Hour)")
plt.ylabel("Requests")
plt.legend()

mse_text = (
    f"Hybrid MSE: {hybrid_mse:.2f}\n"
    f"Naive MSE: {naive_mse:.2f}\n"
    f"SARIMA MSE: {sarima_mse:.2f}"
)
plt.text(0.01, 0.95, mse_text, transform=plt.gca().transAxes,
         fontsize=10, verticalalignment='top', bbox=dict(facecolor='white', alpha=0.7))

plt.tight_layout()
# filename= f"figures/ARIMA({p},{d},{q})({P},{D},{Q})-norefit-onestep-scaled.png"
# plt.savefig(filename, dpi=300)
# print(f"Saved figure to {filename}")
plt.show()