In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
root_path = '/content/drive/MyDrive/M5-FITS/processed-nonparam/'
fits_dir = '/content/drive/MyDrive/FITS'

In [3]:
import numpy as np
import pandas as pd
import os, sys, gc, time, warnings, pickle, psutil, random

os.chdir(fits_dir)
from models.FITS import Model
# os.chdir(fits_dir + '/data_provider')
# from data_provider.data_factory import data_provider

In [6]:
import os
import numpy as np
import pandas as pd
import os
import torch
import random
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler

warnings.filterwarnings('ignore')

class Config(object):
  def __init__(self, args):
    #basic config
    self.is_training = args.get('is_training', 1)
    self.model_id = args.get('model_id', 'test')
    self.model = args.get('model', 'Autoformer')
    #dataloader
    self.data = args.get('data', 'test')
    self.root_path = args.get('root_path', '/content/drive/MyDrive/M5-FITS/processed-nonparam')
    self.data_path = args.get('data_path', 'm5.csv')
    self.features = args.get('features', 'M')
    self.target = args.get('target', 'sales')
    self.freq = args.get('freq', 'd')
    self.checkpoints = args.get('checkpoints', '/content/drive/MyDrive/M5-FITS/checkpoints')
    #forecasting
    self.seq_len = args.get('seq_len', 56)
    self.pred_len = args.get('pred_len', 28)
    self.label_len = args.get('label_len', 28)
    self.individual = args.get('individual', False)
    #optimization
    self.num_workers = args.get('num_workers', 10)
    self.itr = args.get('itr', 2)
    self.train_epochs = args.get('train_epochs', 100)
    self.batch_size = args.get('batch_size', 32)
    self.patience = args.get('patience', 3)
    self.learning_rate = args.get('learning_rate', 0.0001)
    self.des = args.get('des', 'test')
    self.loss = args.get('loss', 'mse')
    self.lradj = args.get('lradj', 'type3')
    self.use_amp = args.get('use_amp', False)
    #GPU
    self.use_gpu = args.get('use_gpu', True)
    self.gpu = args.get('gpu', 0)
    self.use_multi_gpu = args.get('use_multi_gpu', False)
    self.devices = args.get('devices', '0,1,2,3')
    self.test_flop = args.get('test_flop', False)
    #Augmentation
    self.aug_method = args.get('aug_method', 'NA')
    self.aug_rate = args.get('aug_rate', 0.5)
    self.in_batch_augmentation = args.get('in_batch_augmentation', False)
    self.in_dataset_augmentation = args.get('in_dataset_augmentation', False)
    self.data_size = args.get('data_size', 1)
    self.aug_data_size = args.get('aug_data_size', 1)
    self.seed = args.get('seed', 2021)
    #continue learning
    self.testset_div = args.get('testset_div', 2)
    self.test_time_train = args.get('test_time_train', False)
    #Formers
    self.embed = args.get('embed', 'timeF')
    self.enc_in = args.get('enc_in', 7)
    self.dec_in = args.get('dec_in', 7)
    self.c_out = args.get('c_out', 7)
    self.d_model = args.get('d_model', 512)
    self.n_heads = args.get('n_heads', 8)
    self.e_layers = args.get('e_layers', 2)
    self.d_layers = args.get('d_layers', 1)
    self.d_ff = args.get('d_ff', 2048)
    self.moving_avg = args.get('moving_avg', 25)
    self.factor = args.get('factor', 1)
    self.distil = args.get('distil', True)
    self.dropout = args.get('dropout', 0.1)
    self.activation = args.get('activation', 'relu')
    self.output_attention = args.get('output_attention', False)
    self.do_predict = args.get('do_predict', False)

    #Flinear
    self.train_mode = args.get('train_mode', 0)
    self.cut_freq = args.get('cut_freq', 0)
    self.base_T = args.get('base_T', 24)
    self.H_order = args.get('H_order', 2)

    self.use_gpu = True if torch.cuda.is_available() and self.use_gpu else False
    cfreq = args.get('cut_freq', 0)
    if cfreq == 0:
      self.cut_freq = int(self.seq_len // self.base_T + 1) * self.H_order + 10

    fix_seed = self.seed
    random.seed(fix_seed)
    torch.manual_seed(fix_seed)
    np.random.seed(fix_seed)

In [7]:
def file_name(store_id, dept_id):
  return 'data_train_' + store_id + '_' + dept_id + '.csv'

In [8]:
args = {
    'root_path': '/content/drive/MyDrive/M5-FITS/processed/v1',
    'data': 'custom',
    # 'data_path': 'df_train_TX_1_HOUSEHOLD_1.csv',
    'features': 'S',
    'model': 'FITS'
}
config = Config(args)

In [9]:
store_id = 'CA_1'
dept_id = 'HOBBIES_1'

In [10]:
config.data_path = file_name(store_id, dept_id)

In [11]:
shuffle_flag = True
drop_last = True
batch_size = config.batch_size
freq = config.freq
timeenc = 0 if config.embed != 'timeF' else 1

In [12]:
from torch.utils.data import DataLoader
from utils.timefeatures import time_features

TRAIN_START = 0
TRAIN_END = 1941
TEST_END = 1969
TRAIN_LEN = TRAIN_END - TRAIN_START
HORIZON = 28

class Dataset_Custom(Dataset):
    def __init__(self, config, root_path, flag='train', size=None,
                 features='S', data_path='ETTh1.csv',
                 target='OT', scale=True, timeenc=0, freq='h'):
        self.args = config
        # info
        if size == None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len = size[0]
            self.label_len = size[1]
            self.pred_len = size[2]
        # init
        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        self.features = features
        self.target = target
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq

        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()
        self.collect_all_data()
        if self.args.in_dataset_augmentation and self.set_type==0:
            self.data_augmentation()

    def __read_data__(self):
        self.scaler = StandardScaler()
        #TODO: read bottom-level data. Reproduce similar functionality as in DLinear
        df_raw = pd.read_csv(os.path.join(self.root_path,
                                          self.data_path))
        '''
        df_raw.columns: ['date', ...(other features), target feature]
        '''
        cols = list(df_raw.columns)
        cols.remove(self.target)
        cols.remove('date')
        df_raw = df_raw[['date'] + cols + [self.target]]
        # print(cols)
        # num_train = int(len(df_raw) * 0.7)
        # num_test = int(len(df_raw) * 0.2)
        num_test = HORIZON  # Fixed to the last 28 days
        num_train = int((TRAIN_END - TRAIN_START) * 0.8)
        num_vali = TRAIN_END - TRAIN_START - num_train - num_test
        border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
        border2s = [num_train, num_train + num_vali, len(df_raw)]

        if self.args.test_time_train:
            num_train = int(len(df_raw) * 0.9)
            border1s = [0, num_train - self.seq_len, len(df_raw)]
            border2s = [num_train, len(df_raw), len(df_raw)]

        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]

        if self.features == 'M' or self.features == 'MS':
            cols_data = df_raw.columns[1:]
            df_data = df_raw[cols_data]
        elif self.features == 'S':
            df_data = df_raw[[self.target]]

        if self.scale:
            train_data = df_data[border1s[0]:border2s[0]]
            self.scaler.fit(train_data.values)
            # print(self.scaler.mean_)
            # exit()
            data = self.scaler.transform(df_data.values)
        else:
            data = df_data.values

        df_stamp = df_raw[['date']][border1:border2]
        df_stamp['date'] = pd.to_datetime(df_stamp.date)
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['date'], 1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        self.data_x = data[border1:border2]
        self.data_y = data[border1:border2]
        self.data_stamp = data_stamp
        print(border1, border2)

    def regenerate_augmentation_data(self):
        self.collect_all_data()
        self.data_augmentation()

    def reload_data(self, x_data, y_data, x_time, y_time):
        self.x_data = x_data
        self.y_data = y_data
        self.x_time = x_time
        self.y_time = y_time

    def collect_all_data(self):
        self.x_data = []
        self.y_data = []
        self.x_time = []
        self.y_time = []
        data_len = len(self.data_x) - self.seq_len - self.pred_len + 1
        mask_data_len = int((1-self.args.data_size) * data_len) if self.args.data_size < 1 else 0
        for i in range(len(self.data_x) - self.seq_len - self.pred_len + 1):
            if (self.set_type == 0 and i >= mask_data_len) or self.set_type != 0:
                s_begin = i
                s_end = s_begin + self.seq_len
                r_begin = s_end - self.label_len
                r_end = r_begin + self.label_len + self.pred_len
                self.x_data.append(self.data_x[s_begin:s_end])
                self.y_data.append(self.data_y[r_begin:r_end])
                self.x_time.append(self.data_stamp[s_begin:s_end])
                self.y_time.append(self.data_stamp[r_begin:r_end])

    def data_augmentation(self):
        origin_len = len(self.x_data)
        if not self.args.closer_data_aug_more:
            aug_size = [self.args.aug_data_size for i in range(origin_len)]
        else:
            aug_size = [int(self.args.aug_data_size * i/origin_len) + 1 for i in range(origin_len)]

        for i in range(origin_len):
            for _ in range(aug_size[i]):
                aug = augmentation('dataset')
                if self.args.aug_method == 'f_mask':
                    x,y = aug.freq_dropout(self.x_data[i],self.y_data[i],dropout_rate=self.args.aug_rate)
                elif self.args.aug_method == 'f_mix':
                    rand = float(np.random.random(1))
                    i2 = int(rand*len(self.x_data))
                    x,y = aug.freq_mix(self.x_data[i],self.y_data[i],self.x_data[i2],self.y_data[i2],dropout_rate=self.args.aug_rate)
                else:
                    raise ValueError
                self.x_data.append(x)
                self.y_data.append(y)
                self.x_time.append(self.x_time[i])
                self.y_time.append(self.y_time[i])

    def __getitem__(self, index):
        seq_x = self.x_data[index]
        seq_y = self.y_data[index]
        return seq_x, seq_y, self.x_time[index], self.y_time[index]

    def __len__(self):
        return len(self.x_data)

    def inverse_transform(self, data):
        return self.scaler.inverse_transform(data)

def data_provider(args, flag):
  Data = Dataset_Custom
  timeenc = 0 if args.embed != 'timeF' else 1

  if flag == 'test':
      shuffle_flag = False
      drop_last = False # True
      batch_size = args.batch_size
      freq = args.freq
  else:
      shuffle_flag = True
      drop_last = True
      batch_size = args.batch_size
      freq = args.freq

  data_set = Data(
      config=args,
      root_path=args.root_path,
      data_path=args.data_path,
      flag=flag,
      size=[args.seq_len, args.label_len, args.pred_len],
      features=args.features,
      target=args.target,
      timeenc=timeenc,
      freq=freq
  )

  data_loader = DataLoader(
      data_set,
      batch_size=batch_size,
      shuffle=shuffle_flag,
      num_workers=args.num_workers,
      drop_last=drop_last)

  return data_set, data_loader


data_set = Dataset_Custom(
    config=config,
    root_path=config.root_path,
    data_path=config.data_path,
    flag='train',
    size=[config.seq_len, config.label_len, config.pred_len],
    features=config.features,
    target=config.target,
    timeenc=timeenc,
    freq=config.freq,
    scale=True
)

print('train', len(data_set))
data_loader = DataLoader(
    data_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    drop_last=drop_last)

0 1552
train 1469


In [13]:
!pip3 install thop

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238


In [14]:
from models.FITS import Model
from utils.tools import EarlyStopping, adjust_learning_rate, test_params_flop, visual
from utils.metrics import metric
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from utils.augmentations import augmentation
import os
import time

import warnings
import matplotlib.pyplot as plt
import numpy as np

from thop import profile

class FITS(object):
    def __init__(self, args):
        self.args = args
        self.device = self._acquire_device()
        self.model = self._build_model().to(self.device)

    def _build_model(self):
        raise NotImplementedError
        return None

    def _acquire_device(self):
        if self.args.use_gpu:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(
                self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
            device = torch.device('cuda:{}'.format(self.args.gpu))
            print('Use GPU: cuda:{}'.format(self.args.gpu))
        else:
            device = torch.device('cpu')
            print('Use CPU')
        return device

    def _get_data(self):
        pass

    def vali(self):
        pass

    def train(self):
        pass

    def test(self):
        pass

class M5FITS(FITS):
    def __init__(self, args):
        super(M5FITS, self).__init__(args)

    def _build_model(self):
        model = Model(self.args).float()

        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)

        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        print('!!!!!!!!!!!!!!learning rate!!!!!!!!!!!!!!!')
        print(self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def _get_profile(self, model):
        _input=torch.randn(self.args.batch_size, self.args.seq_len, self.args.enc_in).to(self.device)
        macs, params = profile(model, inputs=(_input,))
        print('FLOPs: ', macs)
        print('params: ', params)
        return macs, params

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)[:,-self.args.pred_len:,:]
                batch_xy = torch.cat([batch_x, batch_y], dim=1)

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if 'FITS' in self.args.model:
                    outputs, low = self.model(batch_x)
                elif 'SCINet' in self.args.model:
                    outputs = self.model(batch_x)
                else:
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                    else:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                f_dim = -1 if self.args.features == 'MS' else 0
                outputs = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:]

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion(pred, true)

                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def train(self, setting, ft=False):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')
        print(self.model)
        self._get_profile(self.model)
        print('Trainable parameters: ', sum(p.numel() for p in self.model.parameters() if p.requires_grad))

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            if self.args.in_dataset_augmentation:
                train_loader.dataset.regenerate_augmentation_data()

            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()

                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)[:,-self.args.pred_len:,:]
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                # print(batch_x.shape, batch_y.shape)
                batch_xy = torch.cat([batch_x, batch_y], dim=1)

                # if self.args.in_batch_augmentation:
                #     aug = augmentation('batch')
                #     methods = {'f_mask':aug.freq_mask, 'f_mix': aug.freq_mix, 'noise':aug.noise,'noise_input':aug.noise_input}
                #     for step in range(self.args.aug_data_size):
                #         xy = methods[self.args.aug_method](batch_x, batch_y[:, -self.args.pred_len:, :], rate=self.args.aug_rate, dim=1)
                #         batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
                #         if 'noise' not in self.args.aug_method:
                #             batch_x = torch.cat([batch_x,batch_x2],dim=0)
                #             batch_y = torch.cat([batch_y,batch_y2],dim=0)
                #             batch_x_mark = torch.cat([batch_x_mark,batch_x_mark],dim=0)
                #             batch_y_mark = torch.cat([batch_y_mark,batch_y_mark],dim=0)
                #         else:
                #             print('noise')
                #             batch_x = batch_x2
                #             batch_y = batch_y2

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)

                # encoder - decoder
                if 'FITS' in self.args.model:
                        outputs, low = self.model(batch_x)
                elif 'SCINet' in self.args.model:
                        outputs = self.model(batch_x)
                else:
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                    else:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, batch_y)

                # print(outputs.shape,batch_y.shape)
                f_dim = -1 if self.args.features == 'MS' else 0
                if ft:
                    outputs = outputs[:, -self.args.pred_len:, f_dim:]
                    batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                    # print(outputs.shape,batch_xy.shape)
                    #loss = criterion(outputs, batch_xy)
                    loss = criterion(outputs, batch_y)
                else:
                    outputs = outputs[:, :, f_dim:]
                    # batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) #???
                    loss = criterion(outputs, batch_xy)
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                loss.backward()
                model_optim.step()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args)

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model

    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')

        if test:
            print('loading model')
            self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))

        preds = []
        trues = []
        inputx = []
        reconx = []
        inputxy = []
        reconxy = []
        lows = []
        folder_path = './test_results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)[:,-self.args.pred_len:,:]
                batch_xy = torch.cat([batch_x, batch_y], dim=1).float().to(self.device)

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder

                if 'FITS' in self.args.model:
                        outputs, low = self.model(batch_x)
                elif 'SCINet' in self.args.model:
                        outputs = self.model(batch_x)
                else:
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]

                    else:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

                f_dim = -1 if self.args.features == 'MS' else 0
                # print(outputs.shape,batch_y.shape)
                outputs_ = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                outputs_ = outputs_.detach().cpu().numpy()
                batch_y = batch_y.detach().cpu().numpy()


                pred = outputs_  # outputs.detach().cpu().numpy()  # .squeeze()
                true = batch_y  # batch_y.detach().cpu().numpy()  # .squeeze()

                preds.append(pred)
                trues.append(true)
                inputx.append(batch_x.detach().cpu().numpy())
                inputxy.append(batch_xy.detach().cpu().numpy())
                reconx.append(outputs[:, :-self.args.pred_len, f_dim:].detach().cpu().numpy())
                reconxy.append(outputs.detach().cpu().numpy())
                lows.append(low.detach().cpu().numpy())
                if i % 20 == 0:
                    input = batch_x.detach().cpu().numpy()
                    gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
                    pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
                    visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))

        if self.args.test_flop:
            test_params_flop((batch_x.shape[1],batch_x.shape[2]))
            exit()
        preds = np.concatenate(preds, axis=0)
        trues = np.concatenate(trues, axis=0)
        # inputx = np.array(inputx)
        # reconx = np.array(reconx)
        # reconxy = np.array(reconxy)
        # inputxy = np.array(inputxy)
        # lows = np.array(lows)


        # preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        # trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
        # inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])
        # reconx = reconx.reshape(-1, reconx.shape[-2], reconx.shape[-1])
        # reconxy = reconxy.reshape(-1, reconxy.shape[-2], reconxy.shape[-1])
        # inputxy = inputxy.reshape(-1, inputxy.shape[-2], inputxy.shape[-1])
        # lows = lows.reshape(-1, lows.shape[-2], lows.shape[-1])

        # try:
        #     for i in range(0,2800,300):

        #         # create a figure with 3 subplots
        #         fig, axs = plt.subplots(3, 1, figsize=(10, 10))
        #         # plot pred and true in the first subplot
        #         axs[0].plot(trues[i, :, -1], label='true')
        #         axs[0].plot(preds[i, :, -1], label='pred')
        #         axs[0].set_title('pred and true')
        #         # plot inputx and reconx in the second subplot
        #         axs[1].plot(inputx[i, :, -1], label='inputx')
        #         axs[1].plot(reconx[i, :, -1], label='reconx')
        #         axs[1].set_title('inputx and reconx')
        #         # plot inputxy and reconxy in the third subplot
        #         axs[2].plot(inputxy[i, :, -1], label='inputxy')
        #         axs[2].plot(reconxy[i, :, -1], label='reconxy')
        #         axs[2].plot(lows[i, :, -1])
        #         axs[2].set_title('inputxy and reconxy')
        #         # show the legend
        #         plt.legend()
        #         # save the figure to file
        #         fig.savefig(os.path.join(folder_path, str(i) + '_F.png'))
        #         # print('plottting')
        # except:
        #     pass

        # result save
        # folder_path = './results/' + setting + '/'
        # if not os.path.exists(folder_path):
        #     os.makedirs(folder_path)

        # mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
        # print('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
        # f = open("result.txt", 'a')
        # f.write(setting + "  \n")
        # f.write('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
        # f.write('\n')
        # f.write('\n')
        # f.close()

        # np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe,rse, corr]))
        # np.save(folder_path + 'pred.npy', preds)
        # np.save(folder_path + 'true.npy', trues)
        # np.save(folder_path + 'x.npy', inputx)
        return

    def predict(self, setting, load=False):
        pred_data, pred_loader = self._get_data(flag='pred')

        if load:
            path = os.path.join(self.args.checkpoints, setting)
            best_model_path = path + '/' + 'checkpoint.pth'
            self.model.load_state_dict(torch.load(best_model_path))

        preds = []

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float().to(batch_y.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if 'Linear' in self.args.model:
                    outputs = self.model(batch_x)
                else:
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                    else:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                pred = outputs.detach().cpu().numpy()  # .squeeze()
                preds.append(pred)

        preds = np.array(preds)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        np.save(folder_path + 'real_prediction.npy', preds)

        return

In [19]:
setting = '{}_{}_{}_ft{}_sl{}_ll{}_pl{}_H{}_{}'.format(
            config.model_id,
            config.model,
            config.data,
            config.features,
            config.seq_len,
            config.label_len,
            config.pred_len,
            config.H_order, 1)

config.features = 'MS'
main = M5FITS(config)

## S
main.train(setting, ft=True)

Use CPU
0 1552
1496 1913
1885 1969
Model(
  (freq_upsampler): Linear(in_features=16, out_features=24, bias=True)
)
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs:  86016.0
params:  408.0
Trainable parameters:  408
!!!!!!!!!!!!!!learning rate!!!!!!!!!!!!!!!
0.0001
Epoch: 1 cost time: 0.9846093654632568
Epoch: 1, Steps: 45 | Train Loss: 1.1337448 Vali Loss: 0.9233958 Test Loss: 20.3183861
Validation loss decreased (inf --> 0.923396).  Saving model ...
Updating learning rate to 0.0001
Epoch: 2 cost time: 1.1247379779815674
Epoch: 2, Steps: 45 | Train Loss: 1.1074700 Vali Loss: 0.9074556 Test Loss: 20.2344322
Validation loss decreased (0.923396 --> 0.907456).  Saving model ...
Updating learning rate to 9.5e-05
Epoch: 3 cost time: 1.435006856918335
Epoch: 3, Steps: 45 | Train Loss: 1.0765099 Vali Loss: 0.8863791 Test Loss: 20.1549664
Validation loss decreased (0.907456 --> 0.886379).  Saving model ...
Updating learning rate to 9.025e-05
Epoch: 4 cost time

Model(
  (freq_upsampler): Linear(in_features=16, out_features=24, bias=True)
)

56

1378


4