In [None]:
import torch

import matplotlib.pyplot as plt
import torchmetrics as tm

from config import *
from stages import *
from train import *
from data.ssa import SSA

from data.util import crop_q_between, split_weekdays_and_weekends

%reload_ext autoreload
%autoreload 2

In [None]:
mse = tm.MeanSquaredError().to(CONFIG.device)
mape = tm.MeanAbsolutePercentageError().to(CONFIG.device)
CONFIG.device

##### Compare MSE and MAPE losses for different parameter configurations

Pre-train DBN and Train DBN attaching KELM on each epoch

In [None]:
results_wd = {}
results_we = {}

ssa = SSA(5, 2, CONFIG.device)
mat_q = CONFIG.load('mat_q.pt')
CONFIG.alpha = 0.03

for P in range(3125, 28125, 3125):
    CONFIG.spectral_threshold = P
    dbn = DBN(CONFIG.time_window_length, CONFIG.dbn_hidden_layer_sizes,
              CONFIG.gibbs_sampling_steps).to(CONFIG.device)
    dbn_pretrained_state = dbn.state_dict()
    for N in range(4, 360, 40):
        CONFIG.dbn_hidden_layer_sizes = [N, N, N]

        mat_q_trend, mat_q_resid = preprocess_data(P, mat_q)
        mat_c, mat_x, representatives = compress_data(
            mat_q_resid.abs(), CONFIG.read_period, CONFIG.train_period, CONFIG.alpha)
        mat_q_trend = mat_q_trend[:, representatives].real
        mat_q_resid = mat_q_resid[:, representatives].real
        
        (_, _, test_trend_wd_dataloader), (_, _, test_trend_we_dataloader) = crop_and_split_mat(mat_q_trend, CONFIG)
        (_, _, test_resid_wd_dataloader), (_, _, test_resid_we_dataloader) = crop_and_split_mat(mat_q_resid, CONFIG)
        mat_c_wd_datasets, mat_c_we_datasets = split_mat(mat_c, CONFIG)

        del mat_c

        dbn.load_state_dict(dbn_pretrained_state)

        dbn, kelm = train_with_config(CONFIG, dbn, mat_c_we_datasets, dbn_training_epochs=100, stride=1)
        # kelm = None

        mse_loss = torch.tensor([0.,]).to(CONFIG.device)
        iter_resid = iter(test_resid_we_dataloader)
        for X_trend, y_trend in test_trend_we_dataloader:
            X_resid, y_resid = next(iter_resid)
            pred_resid = dbn(X_resid).squeeze()
            pred_resid = kelm(pred_resid).T

            pred_trend = ssa.forecast(X_trend.squeeze().T, 1).sum(0)[-1][None]
            pred = pred_trend + pred_resid

            y = y_trend + y_resid

            mse_loss += mse(pred, y)

        print(f'WE P={P}, N={N}, Loss={mse_loss.item()}')
        results_we.setdefault(P, {})[N] = mse_loss


    # N = results_wd[P]
    # data = torch.tensor(tuple(results_wd[P].items()))
    # plt.suptitle(f'WD P={P}')
    # plt.xlabel('N')
    # plt.ylabel('Loss')
    # plt.plot(data[:, 0], data[:, 1])
    # plt.show()

    N = results_we[P]
    data = torch.tensor(tuple(results_we[P].items()))
    plt.suptitle(f'WE P={P}')
    plt.xlabel('N')
    plt.ylabel('Loss')
    plt.plot(data[:, 0], data[:, 1])
    plt.show()