## Model Training

- [TODO] Transformer Timeseries Regression

In [3]:
# Import necessities
import os
import datetime
import pickle as pkl
import pandas as pd
import numpy as np

import seaborn as sns
from matplotlib import pyplot as plt

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

# Set plotting style
sns.set_style('whitegrid')
sns.set_palette('Set2')

In [2]:
# Readning in the dataset
dataset_csv = pd.read_csv('./data/dataset.csv', index_col=0)
dataset_rolling_csv = pd.read_csv('./data/dataset_rolling.csv', index_col=0)

In [3]:
class Dataset():
    def __init__(self, dataset, feature_cols, label_cols, timestamp=14, 
                        batch_size=1, test_size=0.03):
        ''' Initialize the dataset
        :param dataset: Dataframe of the dataset
        :param feature_cols: list of feature names,
            e.g. ['dates', 'vaccinations'] means using dates and vaccinations as features
        :param label_cols: list of prediction targets,
            e.g. ['confirmed'] means using confirmed data as labels
        :param timestamp: timestamp in LSTM model
        :param test_size: the ratio of test data in the dataset
        '''
        self.dataset = dataset
        self.feature_cols = feature_cols
        self.label_cols = label_cols
        self.timestamp = timestamp
        self.batch_size = batch_size
        self.test_size = test_size
        # Split features and labels
        self.features = self.dataset[feature_cols]
        self.labels = self.dataset[label_cols]
        # Normalize the dataset using MinMaxScaler before training
        self.scaler = MinMaxScaler(feature_range=(0, 1))
        self.dataset.loc[:, dataset.columns != 'dates'] = self.scaler.fit_transform(dataset.loc[:, dataset.columns != 'dates'])
        # Split the dataset
        self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
                self.features, self.labels, test_size=self.test_size, shuffle=False)
    
    def get_training_set(self):
        ''' Return the time-series generator used for training
        '''
        return TimeseriesGenerator(self.x_train.to_numpy(), self.y_train.to_numpy(),
                length=self.timestamp, batch_size=self.batch_size)
    
    def get_test_set(self):
        ''' Return the time-series generator used for testing
        '''
        return TimeseriesGenerator(self.x_test.to_numpy(), self.y_test.to_numpy(),
                length=self.timestamp, batch_size=self.batch_size)

In [4]:
class BaseModel():
    ''' Other models are derived by inheritance from BaseModel.
    '''
    def __init__(self, dataset, model_path, output_dim, epochs,
                        verbose, loss, optimizer, dropout, **args):
        ''' Initialize the model parameters
        param dataset: dataset Dataframe
        param output_dim: output dimension
        param epochs: training epochs
        param verbose: verbose level of logging
        param loss: loss function for training
        param optimizer: optimizer for training
        '''
        self.dataset = dataset
        self.model_path = model_path
        self.output_dim = output_dim
        self.epochs = epochs
        self.verbose = verbose
        self.loss = loss
        self.optimizer = optimizer
        self.dropout = dropout
    
    def read_model(self):
        ''' Load model from the given path
        '''
        self.model = load_model(self.model_path)

    def train(self):
        ''' Training
        '''
        training_set = self.dataset.get_training_set()
        self.model.fit(training_set, epochs=self.epochs, verbose=self.verbose)
    
    def save_model(self):
        ''' Save model to path
        '''
        self.model.save(self.model_path)
    
    def predict(self, input):
        return self.model.predict(input)
    
    def plot(self):
        plt.figure(figsize=(10, 8))
        plt.title('Labels and Predictions on ' + ', '.join(self.dataset.label_cols))
        plt.plot(self.dataset.y_train, label='Labels')
        plt.plot(self.model.predict(self.dataset.get_training_set()), label='Predictions')
        plt.legend()
        