In [28]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from time import time
from types import SimpleNamespace
from statsmodels.tsa.stattools import adfuller, kpss

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset

from utils.metrics import metric
from data_provider.data_factory import data_provider
from utils.tools import EarlyStopping, adjust_learning_rate, visual
from data_provider.data_loader import Dataset_Custom, Dataset_ETT_hour, Dataset_ETT_minute
from models import DLinear, iTransformer, Autoformer, FEDformer, Linear, LSTM
from models.trend import exp_1, exp_2, exp_3, exp_4
from normalizers import SAN, DDN, TP, RevIN, NoNorm
from layers.Autoformer_EncDec import series_decomp

configs = SimpleNamespace(
    task_name = "long_term_forecast",
    model_name = "exp_4", 
    
    # Task
    seq_len=336,
    label_len=0,
    pred_len=336,
    enc_in=7,
    dec_in=7,
    c_out=7,
    features='M',
    freq='h',
    
    # Data
    root_path='datasets/ETT-small',
    data='ETTh1',
    data_path='ETTh1.csv',
    target='OT',
    batch_size=32,
    
    # Basic Model Settings
    checkpoints='./trend_exp/checkpoints/',
    dropout=0.1,
    embed='timeF',
    output_attention=False,
    activation='gelu',
    moving_avg=25,
    num_kernels=6,
    individual=False,
    learning_rate=0.001,
    n_heads=8,
    patience=3,
    train_epochs=10,
    use_amp=False,
    lradj='type1',
    # d_model=512,
    # d_ff=2048,
    
    # iTransformer
    # d_model=128,
    # d_ff=128,
    # factor=1,
    # e_layers=2,
    # class_strategy='projection',
    
    # Autoformer
    # factor=3,
    # e_layers=2,
    # d_layers=1,
    
    # FEDformer
    # factor=3,
    # e_layers=2,
    # d_layers=1,
    
    # Normalizer
    use_norm='none',
    station_type='adaptive',
    affine=True,
    period_len=24,
    station_lr=0.0005,
    pre_epoch=5,
    twice_epoch=0,
    j=1,
    learnable=False,
    wavelet='coif3',
    dr=0.05,
    kernel_len=7,
    hkernel_len=5,
    pd_ff=128,
    pd_model=128,
    pe_layers=0,
    kernel_size=25,
    s_norm=True,
    t_norm=True,
)
model_dict={
    'DLinear': DLinear,
    'iTransformer': iTransformer,
    'Autoformer': Autoformer,
    'FEDformer': FEDformer,
    'Linear': Linear,
    'LSTM': LSTM,
    'exp_1': exp_1,
    'exp_2': exp_2,
    'exp_3': exp_3,
    'exp_4': exp_4,
}
linear_models = ['DLinear', 'Linear', 'LSTM', 'exp_1', 'exp_2', 'exp_3', 'exp_4']
norm_dict = {
    'none': NoNorm,
    'revin': RevIN,
    'san': SAN,
    'ddn': DDN,
    'tp': TP,
}
decomp = series_decomp(25).cuda()

# SAN
def san_loss(y, statistics_pred):
    bs, len, dim = y.shape
    y = y.reshape(bs, -1, configs.period_len, dim)
    mean = torch.mean(y, dim=2)
    std = torch.std(y, dim=2)
    station_ture = torch.cat([mean, std], dim=-1)
    loss = criterion(statistics_pred, station_ture)
    return loss

# DDN
def ddn_loss(y, statistics_pred):
    _, (mean, std) = norm.norm(y.transpose(-1, -2), False)
    station_ture = torch.cat([mean, std], dim=1).transpose(-1, -2)
    loss = criterion(statistics_pred, station_ture)
    return loss

# TREAD
def tread_loss(y, statistics_pred):
    trend_pred = statistics_pred[-1]
    _, trend_true = decomp(y)
    loss = criterion(trend_pred, trend_true)
    return loss

station_loss_dict = {
    'none': None,
    'revin': None,
    'san': san_loss,
    'ddn': ddn_loss,
    'tp': tread_loss,
}
station_loss = station_loss_dict[configs.use_norm]

# [pre train, pre epoch, joint train, join epoch]
station_setting_dict = {
    'none': [0, 0, 0, 0],
    'revin': [0, 0, 0, 0],
    'san': [1, configs.pre_epoch, 0, 0],
    # 'san': [0, 0, 1, configs.twice_epoch],
    'ddn': [1, configs.pre_epoch, 1, configs.twice_epoch],
    'tp': [1, configs.pre_epoch, 1, configs.twice_epoch],
}
station_setting = station_setting_dict[configs.use_norm]

criterion = nn.MSELoss()
def criterion_2(true, pred, seasonal_pred, trend_pred):
    seasonal_true, trend_true = decomp(true)
    # loss1 = criterion(pred, true)
    loss2 = criterion(seasonal_pred, seasonal_true.cuda())
    loss3 = criterion(trend_pred, trend_true.cuda())
    return loss2 + loss3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model_dict[configs.model_name].Model(configs).float().cuda()
norm = norm_dict[configs.use_norm].Model(configs).float().cuda()
model_optim = optim.Adam(model.parameters(), lr=configs.learning_rate)
norm_optim = optim.Adam(norm.parameters(), lr=configs.station_lr)

In [29]:
def vali(vali_data, vali_loader, criterion, epoch):
    total_loss = []
    model.eval()
    norm.eval()
    with torch.no_grad():
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)
            
            # Seasonal or Trend
            seasonal_x, trend_x = decomp(batch_x)
            seasonal_y, trend_y = decomp(batch_y)
            # batch_x = trend_x
            # batch_y = trend_y
            # batch_x = seasonal_x
            # batch_y = seasonal_y

            batch_x_mark = batch_x_mark.float().to(device)
            batch_y_mark = batch_y_mark.float().to(device)

            # normalize
            if configs.use_norm == 'ddn':
                if epoch + 1 <= station_setting[1]:
                    batch_x, statistics_pred, statistics_seq = norm.normalize(batch_x, p_value=False)
                else:
                    batch_x, statistics_pred, statistics_seq = norm.normalize(batch_x)
            else:
                batch_x, statistics_pred = norm.normalize(batch_x)

            # station pretrain
            if epoch + 1 <= station_setting[1]:
                f_dim = -1 if configs.features == 'MS' else 0
                batch_y = batch_y[:, -configs.pred_len:, f_dim:].to(device)
                if configs.features == 'MS':
                    statistics_pred = statistics_pred[:, :, [configs.enc_in - 1, -1]]
                loss = station_loss(batch_y, statistics_pred)
            
            # model train
            else:
                # decoder x
                dec_inp = torch.zeros_like(batch_y[:, -configs.pred_len:, :]).float()
                dec_label = batch_x[:, -configs.label_len:, :]
                dec_inp = torch.cat([dec_label, dec_inp], dim=1).float()
                # encoder - decoder
                if configs.use_amp:
                    with torch.cuda.amp.autocast():
                        if configs.model_name in linear_models:
                            outputs, seasonal_pred, trend_pred = model(batch_x)
                        else:
                            if configs.output_attention:
                                outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if configs.model_name in linear_models:
                        outputs, seasonal_pred, trend_pred = model(batch_x)
                    else:
                        if configs.output_attention:
                            outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                f_dim = -1 if configs.features == 'MS' else 0
                if configs.features == 'MS':
                    statistics_pred = statistics_pred[:, :, [configs.enc_in - 1, -1]]
                outputs = outputs[:, -configs.pred_len:, f_dim:]
                
                # de-normalize
                outputs = norm.de_normalize(outputs, statistics_pred)
                
                batch_y = batch_y[:, -configs.pred_len:, f_dim:].to(device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion_2(pred, true, seasonal_pred, trend_pred)

            total_loss.append(loss.cpu().item())
    total_loss = np.average(total_loss)
    model.train()
    norm.train()
    return total_loss

def _get_data(flag):
    data_set, data_loader = data_provider(configs, flag)
    return data_set, data_loader

def train(setting):
    train_data, train_loader = _get_data(flag='train')
    vali_data, vali_loader = _get_data(flag='val')
    test_data, test_loader = _get_data(flag='test')

    path = os.path.join(configs.checkpoints, setting)
    if not os.path.exists(path):
        os.makedirs(path)

    path_station = './station/' + '{}_s{}_p{}'.format(configs.use_norm, configs.data,
                                                        configs.seq_len, configs.pred_len)
    if not os.path.exists(path_station):
        os.makedirs(path_station)

    time_now = time()

    train_steps = len(train_loader)
    early_stopping = EarlyStopping(patience=configs.patience, verbose=True)
    early_stopping_station_model = EarlyStopping(patience=configs.patience, verbose=True)

    if configs.use_amp:
        scaler = torch.cuda.amp.GradScaler()

    for epoch in range(configs.train_epochs + station_setting[1]):
        iter_count = 0
        train_loss = []
        
        # Load best station model after pretraining
        if station_setting[0] > 0 and epoch == station_setting[1]:
            best_model_path = path_station + '/' + 'checkpoint.pth'
            norm.load_state_dict(torch.load(best_model_path))
            print('loading pretrained adaptive station model')
        
        # Add station parameters to model optim after pretraining and delay epochs for joint training
        if station_setting[2] > 0 and station_setting[3] == epoch - station_setting[1]:
            lr = model_optim.param_groups[0]['lr']
            model_optim.add_param_group({'params': norm.parameters(), 'lr': lr})
        
        model.train()
        norm.train()
        epoch_time = time()
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
            iter_count += 1
            model_optim.zero_grad()
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)
            
            # Seasonal or Trend
            seasonal_x, trend_x = decomp(batch_x)
            seasonal_y, trend_y = decomp(batch_y)
            # batch_x = trend_x
            # batch_y = trend_y
            # batch_x = seasonal_x
            # batch_y = seasonal_y
            
            # normalize
            if configs.use_norm == 'ddn':
                if epoch + 1 <= station_setting[1]:
                    batch_x, statistics_pred, statistics_seq = norm.normalize(batch_x, p_value=False)
                else:
                    batch_x, statistics_pred, statistics_seq = norm.normalize(batch_x)
            else:
                batch_x, statistics_pred = norm.normalize(batch_x)
            
            # station pretrain
            if epoch + 1 <= station_setting[1]:
                f_dim = -1 if configs.features == 'MS' else 0
                batch_y = batch_y[:, -configs.pred_len:, f_dim:].to(device)
                if configs.features == 'MS':
                    statistics_pred = statistics_pred[:, :, [configs.enc_in - 1, -1]]
                loss = station_loss(batch_y, statistics_pred)
                train_loss.append(loss.item())
            
            # model train
            else:
                batch_x_mark = batch_x_mark.float().to(device)
                batch_y_mark = batch_y_mark.float().to(device)

                # decoder x
                dec_inp = torch.zeros_like(batch_y[:, -configs.pred_len:, :]).float()
                dec_label = batch_x[:, -configs.label_len:, :]
                dec_inp = torch.cat([dec_label, dec_inp], dim=1).float().to(device)

                # encoder - decoder
                if configs.use_amp:
                    with torch.cuda.amp.autocast():
                        if configs.model_name in linear_models:
                            outputs, seasonal_pred, trend_pred = model(batch_x)
                        else:
                            if configs.output_attention:
                                outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                        f_dim = -1 if configs.features == 'MS' else 0
                        outputs = outputs[:, -configs.pred_len:, f_dim:]
                        batch_y = batch_y[:, -configs.pred_len:, f_dim:].to(device)
                        loss = criterion_2(outputs, batch_y, seasonal_pred, trend_pred)
                        train_loss.append(loss.item())
                else:
                    if configs.model_name in linear_models:
                        outputs, seasonal_pred, trend_pred = model(batch_x)
                    else:
                        if configs.output_attention:
                            outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                    f_dim = -1 if configs.features == 'MS' else 0
                    outputs = outputs[:, -configs.pred_len:, f_dim:]
                    if configs.features == 'MS':
                        statistics_pred = statistics_pred[:, :, [configs.enc_in - 1, -1]]
                        
                # de-normalize
                outputs = norm.de_normalize(outputs, statistics_pred)
                
                batch_y = batch_y[:, -configs.pred_len:, f_dim:].to(device)
                loss = criterion_2(outputs, batch_y, seasonal_pred, trend_pred)
                train_loss.append(loss.item())

            if (i + 1) % 100 == 0:
                print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                speed = (time() - time_now) / iter_count
                left_time = speed * (
                        (configs.train_epochs + station_setting[1] - epoch) * train_steps - i)
                print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                iter_count = 0
                time_now = time()
            if configs.use_amp:
                scaler.scale(loss).backward()
                scaler.step(model_optim)
                scaler.update()
            else:
                loss.backward()
                # two-stage training schema
                if epoch + 1 <= station_setting[1]:
                    norm_optim.step()
                else:
                    model_optim.step()
                model_optim.zero_grad()
                norm_optim.zero_grad()

        print("Epoch: {} cost time: {}".format(epoch + 1, time() - epoch_time))
        train_loss = np.average(train_loss)
        vali_loss = vali(vali_data, vali_loader, criterion_2, epoch)
        test_loss = vali(test_data, test_loader, criterion_2, epoch)

        if epoch + 1 <= station_setting[1]:
            print(
                "Station Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                    epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping_station_model(vali_loss, norm, path_station)
            adjust_learning_rate(norm_optim, epoch + 1, configs, configs.station_lr)
        else:
            print(
                "Backbone Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                    epoch + 1 - station_setting[1], train_steps, train_loss, vali_loss, test_loss))
            # if: joint training, else: only model training
            if station_setting[2] > 0 and station_setting[3] <= epoch - station_setting[1]:
                early_stopping(vali_loss, model, path, norm, path_station)
            else:
                early_stopping(vali_loss, model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            adjust_learning_rate(model_optim, epoch + 1 - station_setting[1], configs,
                                    configs.learning_rate)
            adjust_learning_rate(norm_optim, epoch + 1 - station_setting[1], configs,
                                    configs.station_lr)

    best_model_path = path + '/' + 'checkpoint.pth'
    model.load_state_dict(torch.load(best_model_path))
    if station_setting[2] > 0:
        norm.load_state_dict(torch.load(path_station + '/' + 'checkpoint.pth'))
    return model

def check_stationarity(batch_x, test='adf'):
    ts = batch_x[0, :, -1]
    if test == 'adf':
        stat, pvalue, *_ = adfuller(ts)
        # 낮은 p-value일수록 정상성 ↑ → 점수는 1 - pvalue
        score = 1 - min(pvalue, 1.0)
    elif test == 'kpss':
        stat, pvalue, *_ = kpss(ts, nlags="auto")
        # 높은 p-value일수록 정상성 ↑
        score = min(pvalue, 1.0)
    return score

def test(setting, test=0):
    test_data, test_loader = _get_data(flag='test')

    if test:
        print('loading model')
        model.load_state_dict(torch.load(os.path.join('./trend_exp/checkpoints/' + setting, 'checkpoint.pth')))

    preds = []
    trues = []
    inputx = []
    folder_path = './trend_exp/' + setting + '/'
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    model.eval()
    norm.eval()
    with torch.no_grad():
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)
            input_x = batch_x
            
            # Seasonal or Trend
            # batch_x = trend_x
            # batch_y = trend_y
            # batch_x = seasonal_x
            # batch_y = seasonal_y

            # normalize
            input_noNorm = batch_x
            if configs.use_norm == 'ddn':
                batch_x, statistics_pred, statistics_seq = norm.normalize(batch_x)
            else:
                batch_x, statistics_pred = norm.normalize(batch_x)
            input_withNorm = batch_x

            batch_x_mark = batch_x_mark.float().to(device)
            batch_y_mark = batch_y_mark.float().to(device)

            # decoder x
            dec_inp = torch.zeros_like(batch_y[:, -configs.pred_len:, :]).float()
            dec_label = batch_x[:, -configs.label_len:, :]
            dec_inp = torch.cat([dec_label, dec_inp], dim=1).float().to(device)
            # encoder - decoder
            if configs.use_amp:
                with torch.cuda.amp.autocast():
                    if configs.model_name in linear_models:
                        outputs, seasonal_pred, trend_pred = model(batch_x)
                    else:
                        if configs.output_attention:
                            outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            else:
                if configs.model_name in linear_models:
                    outputs, seasonal_pred, trend_pred = model(batch_x)
                else:
                    if configs.output_attention:
                        outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                    else:
                        outputs, seasonal_pred, trend_pred = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

            f_dim = -1 if configs.features == 'MS' else 0
            outputs = outputs[:, -configs.pred_len:, f_dim:]
            if configs.features == 'MS':
                statistics_pred = statistics_pred[:, :, [configs.enc_in - 1, -1]]
                
            # de-normalize
            output_noDenorm = outputs
            outputs = norm.de_normalize(outputs, statistics_pred)
            output_withDenorm = outputs
            
            batch_y = batch_y[:, -configs.pred_len:, f_dim:].to(device)
            outputs = outputs.detach().cpu().numpy()
            batch_y = batch_y.detach().cpu().numpy()

            pred = outputs  # outputs.detach().cpu().numpy()  # .squeeze()
            true = batch_y  # batch_y.detach().cpu().numpy()  # .squeeze()

            preds.append(pred)
            trues.append(true)
            inputx.append(batch_x.detach().cpu().numpy())
            if i % 10 == 0:
                x = input_x.detach().cpu().numpy()
                gt = np.concatenate((x[0, :, -1], true[0, :, -1]), axis=0)
                pd = np.concatenate((x[0, :, -1], pred[0, :, -1]), axis=0)
                visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
            
            # visualize input and output with/without normalization/denormalization
            if i == 0:
                i_n = input_noNorm.detach().cpu().numpy()
                i_w = input_withNorm.detach().cpu().numpy()
                o_n = output_noDenorm.detach().cpu().numpy()
                o_w = output_withDenorm.detach().cpu().numpy()
                
                plt.close('all')
                plt.figure(figsize=(6, 4))
                # plt.subplot(3, 1, 1)
                # plt.plot(i_w[0, :, -1], label='Normalized', color='darkorange')
                # plt.plot(i_n[0, :, -1], label='Raw')
                # plt.legend()
                # plt.grid()
                
                plt.subplot(2, 1, 1)
                plt.plot(i_n[0, :, -1], label='Raw')
                plt.legend()
                plt.grid()
                
                plt.subplot(2, 1, 2)
                plt.plot(i_w[0, :, -1], label='Normalized', color='darkorange')
                plt.legend()
                plt.grid()
                plt.savefig(os.path.join(folder_path, 'normalized_input.pdf'))
                
                plt.close('all')
                plt.figure(figsize=(8, 4))
                plt.subplot(2, 1, 1)
                plt.plot(true[0, :, -1], label='Ground Truth', color='black', linewidth=2)
                plt.plot(o_w[0, :, -1], label='After Norm(Final Output)')
                plt.plot(0, label='Before Norm(Model Output)')
                # legend outside of plot
                plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3)
                
                plt.grid()
                plt.subplot(2, 1, 2)
                plt.plot(o_w[0, :, -1], label='After Norm(Final Output)')
                plt.plot(o_n[0, :, -1], label='Before Norm(Model Output)')
                # plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.2), ncol=2)
                plt.grid()
                plt.savefig(os.path.join(folder_path, 'normalized_output.pdf'))

    # if configs.test_flop:
    #     test_params_flop((batch_x.shape[1], batch_x.shape[2]))
    #     exit()
    preds = np.array(preds, dtype=object)
    trues = np.array(trues, dtype=object)
    # inputx = np.array(inputx)

    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    # inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])

    # result save
    folder_path = './trend_exp/' + setting + '/'
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
    print('mse:{}, mae:{}'.format(mse, mae))
    f = open("result.txt", 'a')
    f.write(setting + "  \n")
    f.write('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
    f.write('\n')
    f.write('\n')
    f.close()

    # np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe,rse, corr]))
    # np.save(folder_path + 'pred.npy', preds)
    # np.save(folder_path + 'true.npy', trues)
    # np.save(folder_path + 'x.npy', inputx)
    return mse, mae

In [30]:
setting = 'ETTh1_s336_p336_exp_4_none'
train(setting)
test(setting)

train 7969
val 2545
test 2545
	iters: 100, epoch: 1 | loss: 0.0006470
	speed: 0.0387s/iter; left time: 92.4931s
	iters: 200, epoch: 1 | loss: 0.0002946
	speed: 0.0364s/iter; left time: 83.3421s
Epoch: 1 cost time: 9.351967811584473
Backbone Epoch: 1, Steps: 249 | Train Loss: 0.0010211 Vali Loss: 0.0002904 Test Loss: 0.0002885
Validation loss decreased (inf --> 0.000290).  Saving model ...
Updating learning rate to 0.001
Updating learning rate to 0.0005
	iters: 100, epoch: 2 | loss: 0.0001655
	speed: 0.0801s/iter; left time: 171.6096s
	iters: 200, epoch: 2 | loss: 0.0001339
	speed: 0.0362s/iter; left time: 73.9659s
Epoch: 2 cost time: 9.102003574371338
Backbone Epoch: 2, Steps: 249 | Train Loss: 0.0001668 Vali Loss: 0.0002148 Test Loss: 0.0001962
Validation loss decreased (0.000290 --> 0.000215).  Saving model ...
Updating learning rate to 0.0005
Updating learning rate to 0.00025
	iters: 100, epoch: 3 | loss: 0.0000967
	speed: 0.0781s/iter; left time: 147.8137s
	iters: 200, epoch: 3 | l

(0.99950784, 0.68023956)

In [31]:
test(setting, test=1, stationarity_test=False)

test 2689
loading model
mse:0.4212721288204193, mae:0.4345196485519409


In [77]:
plt.figure(figsize=(15, 3))
plt.subplot(1, 5, 1)
plt.plot(raw, label='Raw', color='black')
# plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=5)
plt.legend(loc='upper right')
plt.yticks([])
plt.tight_layout()
plt.subplot(1, 5, 2)
plt.plot(0, color='black')
plt.plot(revin, label='RevIN')
plt.legend(loc='upper right')
plt.yticks([])
plt.tight_layout()
plt.subplot(1, 5, 3)
plt.plot(0, color='black')
plt.plot(0)
plt.plot(san, label='SAN')
plt.legend(loc='upper right')
plt.yticks([])
plt.tight_layout()
plt.subplot(1, 5, 4)
plt.plot(0, color='black')
plt.plot(0)
plt.plot(0)
plt.plot(ddn, label='DDN')
plt.legend(loc='upper right')
plt.yticks([])
plt.tight_layout()
plt.subplot(1, 5, 5)
plt.plot(0, color='black')
plt.plot(0)
plt.plot(0)
plt.plot(0)
plt.plot(tp, label='TREAD')
plt.yticks([])
plt.legend(loc='upper right')
plt.tight_layout()
plt.savefig("compare.png")