Official implementation of the paper Deep TPC: Temporal-Prior Conditioning for Time Series Forecasting.
- Python 3.8+
- PyTorch 2.0+
- CUDA (for GPU training)
- Install Pytorch and necessary dependencies.
pip install -r requirements.txt
-
Put the datasets [Google Drive] [Tsinghua Cloud] under the folder
./dataset/. -
Download the large language models from Hugging Face.GPT2
If you download and put the
gpt2directory successfully, the directory structure is as follows:- data_provider
- dataset
- gpt2
- config.json
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
- ...
- ...
- run.py
-
Using the position embedding from textual timestamps. Note that we have provided the embeddings of the given datasets in the download links, which are generated by gpt2, suffixed by
{dataset_name}.pt. If you want to generate the embeddings from your customized datasets, please refer to the following codes:
# preprocess timestamps to generate text embedding
python ./preprocess.py --gpu 0 --dataset ETTh1
- Train and evaluate the model. We provide all the above tasks under the folder
./scripts/.
Training:
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_672_96 \
--model GPT2WithTPC \
--data ETTh1 \
--seq_len 672 --label_len 576 --token_len 96 \
--test_seq_len 672 --test_label_len 576 --test_pred_len 96 \
--batch_size 256 --learning_rate 0.0005 \
--train_epochs 10 --use_amp --cosine --tmax 10 \
--mix_embeds --drop_last \
--tpc_layers 0 2 4 6 8 10 \
--num_fusion_tokens 20 \
--llm_ckp_dir gpt2Testing:
python -u run.py \
--task_name long_term_forecast \
--is_training 0 \
--model GPT2WithTPC \
--data ETTh1 \
--llm_ckp_dir gpt2 \
--test_dir <path_to_checkpoint_folder>--llm_ckp_dir: Path to GPT-2 checkpoint (e.g.gpt2for auto-download, or local path)--tpc_layers: Indices of GPT-2 layers replaced with TPC blocks (e.g.0 2 4 6 8 10)--num_fusion_tokens: Number of learnable fusion tokens--mix_embeds: Fuse time-series and mark (timestamp) embeddings
See LICENSE file.