# Validating Cluster Model Performance

#### Install and import packages


In [1]:
#Imports

# multiprocessing
from joblib import Parallel, delayed

# data manipulation
import numpy as np
import pandas as pd
import heapq

# data visualization
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from tqdm import tqdm

#hmm needed? 
import torch
import torch.nn as nn
#from torch import nn
import torch.optim as optim
import shutil
import sklearn
from sklearn.preprocessing import MinMaxScaler
#from tqdm import tqdm_notebook as tqdm

#darts
# transformers and preprocessing
import darts
from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.utils.utils import SeasonalityMode, TrendMode, ModelMode
from darts.utils.statistics import check_seasonality, plot_acf
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.models import * #everything

#loss metrics
from darts.metrics import mape
from darts.metrics import smape
from darts.utils.losses import SmapeLoss

# likelihood
from darts.utils.likelihood_models import GaussianLikelihood

# settings
import warnings
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)

#hyperparameter tuning
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna.visualization import (
    plot_optimization_history,
    plot_contour,
    plot_param_importances,
)
from pytorch_lightning.callbacks import Callback, EarlyStopping
from sklearn.preprocessing import MaxAbsScaler

#extras
import os
import time
import random
from datetime import datetime
from itertools import product
from typing import List, Tuple, Dict
from sklearn.preprocessing import MaxAbsScaler
from sklearn.linear_model import Ridge
import glob

#### Load data
Consider making this a bit more smooth (glob.glob to access all files at the same time)

In [None]:
cl1 = pd.read_csv('/work/Data-Science-Liv/cluster_forecasts/forecasts_cluster0.csv', index_col=0)
cl2 = pd.read_csv('/work/Data-Science-Liv/cluster_forecasts/forecasts_cluster1.csv', index_col=0)
cl3 = pd.read_csv('/work/Data-Science-Liv/cluster_forecasts/forecasts_cluster2.csv', index_col=0)
cl4 = pd.read_csv('/work/Data-Science-Liv/cluster_forecasts/forecasts_cluster3.csv', index_col=0)


#### Data inspection
- Removing NA rows from years included in the training set
- Making sure the year span is correct
- Checking the unique countries in each cluster

In [None]:
#Remove rows with NAs - these rows represent the first years in the data that were used for training. Here we only have the test forecasts
cl1 = cl1.dropna(subset=['final_prediction'])
cl2 = cl2.dropna(subset=['final_prediction'])
cl3 = cl3.dropna(subset=['final_prediction'])
cl4 = cl4.dropna(subset=['final_prediction'])

In [None]:
#Are there any NAs in the dfs now?
data_frames = [cl1, cl2, cl3, cl4]

# Check if any of the data frames contain NAs
for df in data_frames:
    if df.isna().any().any():
        print(f"At least one NA value found in DataFrame {df}.")
        break
else:
    print("No NA values found in any of the data frames.")

# There are no NAs in the dfs now


In [None]:
# Are there overlapping countries in the data frames?

data_frames = [cl1, cl2, cl3, cl4]
common_countries = set()

# Compare unique countries pairwise between data frames
for i in range(len(data_frames)):
    for j in range(i+1, len(data_frames)):
        countries_i = set(data_frames[i]['Country'].unique())
        countries_j = set(data_frames[j]['Country'].unique())
        
        if countries_i & countries_j:  # Check if there are common countries
            common_countries.update(countries_i & countries_j)

# Check if there are any common countries among the data frames
if common_countries:
    print("At least two data frames have common countries.")
else:
    print("No common countries found among the data frames.")


#The clusters do not contain the same countries

In [85]:
# Checking the years over which the data span
year_unique1 = set(cl1['ds'].unique())
year_unique2 = set(cl2['ds'].unique())
year_unique3 = set(cl3['ds'].unique())
year_unique4 = set(cl4['ds'].unique())

#All dfs has the right testing range (span from 2004-2001)

#Inspecting the entire data frame
# with pd.option_context('display.max_rows', None):
#     print(cl4)



### This is where the validation goes down
- sMAPE

Symmetric mean absolute percentage error (SMAPE or sMAPE) is an accuracy measure based on percentage (or relative) errors.
It is a metric used to evaluate the accuracy of a forecasting model or compare the performance of different forecasting models.

In [None]:
# Eva's code
# Code for calculating sMAPE
# countries er en liste af lande som jeg trækker ud fra pandas dataframe inden dataet bliver lavet om til timeseries, countries = list(data_countries["Country"].unique())
def eval_local_model(
    ID_list: List[str],
    train_series: List[TimeSeries], 
    test_series: List[TimeSeries], 
    model_cls, 
    **kwargs
) -> Tuple[List[float], float, List[Tuple[str, float]]]:

    ''' Fitting the model of choice to training data. Retrieving predictions for a given forecasting horizon.
    Extracting computation time. Extracting sMAPEs per time-series and the time-series with highest sMAPEs. '''

    #define empty lists for predictions and processing time
    preds = []
    start_time = time.time()

    #Fit model and predict for each series individually
    print("fitting models...")
    for series in tqdm(train_series):
        model = model_cls(**kwargs)
        model.fit(series)
        pred = model.predict(n=HORIZON)
        preds.append(pred)
    elapsed_time = time.time() - start_time

    name = model.__class__.__name__

    #Apply eval_forecasts function to extract sMAPEs and plot
    print("extracting sMAPEs...")
    smapes = eval_forecasts(name, preds, test_series)

    #Extract sMAPES per time-series and time-series with highest sMAPEs
    smapes_pr_country = list(zip(countries, smapes))
    highest_smapes = heapq.nlargest(10, smapes_pr_country, key=lambda x: x[1])
    
    return smapes, elapsed_time, smapes_pr_country, highest_smapes