In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from src.utils import set_seed
from src.model import GPT2TS
from src.trainer import Trainer
from src.dataset import Dataset_ETT_minute, Dataset_ETT_hour

In [None]:
set_seed(3407)

In [None]:
trainset = Dataset_ETT_hour(root_path='./data', flag='train',)
valset = Dataset_ETT_hour(root_path='./data', flag='val')
testset = Dataset_ETT_hour(root_path='./data', flag='test')

In [None]:
model = GPT2TS.from_pretrained(
    config={
        'input_len':trainset.seq_len,
        'pred_len':trainset.pred_len,
        'n_layer':6,
        'model_type':'gpt2',
        'num_series': 1,
        'patch_size': 16,
        'patch_stride':8
    }
)

In [None]:
print(f"num of total parameters: {model.num_params['total']/1e6: .2f}M")
print(f"num of trainable parameters: {model.num_params['grad']/1e6: .2f}M")

In [None]:
tra = Trainer(model, use_amp=True, features="S", num_workers=6)

In [None]:
tra.train(trainset, valset, batch_size=200, max_epochs=200, lr=0.001)