In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as mlt
import seaborn as sp
from torch.autograd import Variable
from torch import autograd
from datetime import datetime
import matplotlib.pyplot as plt
import argparse
from datetime import timedelta
import torch.autograd.functional as F

In [19]:
class ModelWrapper():
    def __init__(self, model_parameter, col):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.column = col
        self.model_parameter = model_parameter

    def train_model(self, model, optimizer_model, loss_function, real_train, real_test, step_per_epoch):
    
        real_dataset = torch.empty((0, self.model_parameter.future_step, self.model_parameter.input_size), dtype=torch.float32).to(self.device)
        gen_dataset = torch.empty((0, self.model_parameter.future_step, self.model_parameter.input_size), dtype=torch.float32).to(self.device)
        epoch = 300
        loss_evol = []
        
        for i in range(epoch):
            count = 0
            epoch_loss = 0.0

            for step in range(step_per_epoch):
                # Data for training the discriminator
                real_data = self.gen_batch_forecasting(self.model_parameter.batch_size, count, real_train)
                real_label = self.gen_batch_forecasting(self.model_parameter.batch_size, count, real_test)

                model.zero_grad()
                generated_samples = model(real_data)

                L2 = torch.sum((real_label[:, :, self.column] - generated_samples[:, :, self.column])**2)
                g_loss = loss_function(generated_samples, real_label) + L2
                g_loss.backward()
                optimizer_model.step()
                epoch_loss += g_loss.item()

                count += 1
                if i == epoch - 1:
                    real_dataset = torch.cat([real_dataset, real_label], dim=0)
                    gen_dataset = torch.cat([gen_dataset, generated_samples], dim=0)

            epoch_loss /= step_per_epoch # average loss per batch
            loss_evol.append(epoch_loss)
            
            print(f"epoch: {i}, Train loss: {epoch_loss:.7f}")

        return real_dataset, gen_dataset, loss_evol
    
    def test_model(self, model, real_data_test, data_label_test, loss_function, step_per_epoch):
        
        losses_smape = []
        losses_mae = []
        losses_mse = []
        losses_rmse = []
        losses_r2 = []

        model.eval()
        test_loss = []
        real_dataset = torch.empty((0, self.model_parameter.future_step, self.model_parameter.input_size), dtype=torch.float32).to(self.device)
        gen_dataset = torch.empty((0, self.model_parameter.future_step, self.model_parameter.input_size), dtype=torch.float32).to(self.device)

        with torch.no_grad():
            count = 0
            for step in range(step_per_epoch):
                real_data = self.gen_batch_forecasting(self.model_parameter.batch_size, count, real_data_test)
                real_label = self.gen_batch_forecasting(self.model_parameter.batch_size, count, data_label_test)
                
                generated_samples = model(real_data)

                # sMAPE
                absolute_percentage_errors = 2 * torch.abs(generated_samples[:, :, self.column] - real_label[:, :, self.column]) / (torch.abs(generated_samples[:, :, self.column]) + torch.abs(real_label[:, :, self.column]))
                loss_smape = torch.mean(absolute_percentage_errors) * 100
                # MAE
                loss_mae = torch.mean(torch.abs(generated_samples[:, :, self.column] - real_label[:, :, self.column]))
                # MSE
                loss_mse = torch.mean((generated_samples[:, :, self.column] - real_label[:, :, self.column])**2)
                # RMSE
                loss_rmse = torch.sqrt(loss_mse)
                # R squared
                loss_r2 = 1 - torch.sum((real_label[:, :, self.column] - generated_samples[:, :, self.column])**2) / torch.sum((real_label[:, :, self.column] - torch.mean(real_label[:, :, self.column]))**2)

                losses_smape.append(loss_smape.item())
                losses_mae.append(loss_mae.item())
                losses_mse.append(loss_mse.item())
                losses_rmse.append(loss_rmse.item())
                losses_r2.append(loss_r2.item())
                L2 = torch.sum((real_label[:, :, self.column] - generated_samples[:, :, self.column])**2)
                loss = loss_function(generated_samples, real_label) + L2
                test_loss.append(loss.item())

                real_dataset = torch.cat([real_dataset, real_label], dim=0)
                gen_dataset = torch.cat([gen_dataset, generated_samples], dim=0)

                count += 1

        smape_loss = np.array(losses_smape).mean()
        mae_loss = np.array(losses_mae).mean()
        mse_loss = np.array(losses_mse).mean()
        rmse_loss = np.array(losses_rmse).mean()
        r2_loss = np.array(losses_r2).mean()

        print("RMSE: ", rmse_loss)
        print("MAE: ", mae_loss)

        return real_dataset, gen_dataset, test_loss, smape_loss, mae_loss, mse_loss, rmse_loss, r2_loss
    
    def gen_batch_forecasting(self, batch_size, step, dset):
        real_dset = dset[step * batch_size: (step + 1) * batch_size]
        return real_dset