In [1]:
import sys
import os
import pandas as pd
import torch
from torch import nn

project_root = os.path.dirname(os.getcwd())
sys.path.append(project_root)

In [2]:
from src.dataset import PretrainDataset
from src.model import PTSM
from src.trainer import Trainer

In [3]:
data = pd.read_csv("../data/ETTh1.csv")
data.head()

Unnamed: 0,date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT
0,2016-07-01 00:00:00,5.827,2.009,1.599,0.462,4.203,1.34,30.531
1,2016-07-01 01:00:00,5.693,2.076,1.492,0.426,4.142,1.371,27.787001
2,2016-07-01 02:00:00,5.157,1.741,1.279,0.355,3.777,1.218,27.787001
3,2016-07-01 03:00:00,5.09,1.942,1.279,0.391,3.807,1.279,25.044001
4,2016-07-01 04:00:00,5.358,1.942,1.492,0.462,3.868,1.279,21.948


In [4]:
data['group_id'] = 'A'
data['date'] = pd.to_datetime(data['date'])
data.head()

Unnamed: 0,date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT,group_id
0,2016-07-01 00:00:00,5.827,2.009,1.599,0.462,4.203,1.34,30.531,A
1,2016-07-01 01:00:00,5.693,2.076,1.492,0.426,4.142,1.371,27.787001,A
2,2016-07-01 02:00:00,5.157,1.741,1.279,0.355,3.777,1.218,27.787001,A
3,2016-07-01 03:00:00,5.09,1.942,1.279,0.391,3.807,1.279,25.044001,A
4,2016-07-01 04:00:00,5.358,1.942,1.492,0.462,3.868,1.279,21.948,A


In [5]:
data['time_index'] = (data['date'] - data['date'].min()).apply(lambda x: int(x.total_seconds()/60/60))

In [6]:
trainset = PretrainDataset(
    data=data,
    group_id='group_id',
    time_col='date',
    time_index='time_index',
    target='OT',
    seq_len=30 * 3,
    min_count_per_sample=20,
    stride=1,
    freq='h'
)

In [7]:
print(f"size of trainset: {len(trainset)}")

size of trainset: 17401


In [8]:
model = PTSM(
    input_len=30 * 3,
    patch_size=3,
    in_channels=1,
    embed_dim=16,
    num_heads=4,
    mlp_ratio=4,
    depth=2,
    mask_ratio=0.4,
    dropout=0.1,
)

In [9]:
print(f"model size: {model.num_parameters/1e6:.2f}M")

model size: 0.01M


In [10]:
trainer = Trainer(
    model=model,
    lr=1e-2,
    max_epochs=2,
)

In [12]:
trainer.train(
    batch_size=200,
    train_dataset=trainset,
    num_workers=4,
    save_path="./",
    save_every=1,
)

Epoch: 1, Train Loss: 0.2472, Time: 20.75s
