# Getting started with TinyTimeMixer (TTM)

This notebooke demonstrates the usage of a pre-trained `TinyTimeMixer` model for several multivariate time series forecasting tasks. For details related to model architecture, refer to the [TTM paper](https://arxiv.org/pdf/2401.03955.pdf).

In this example, we will use a pre-trained TTM-512-96 model. That means the TTM model can take an input of 512 time points (`context_length`), and can forecast upto 96 time points (`forecast_length`) in the future. We will use the pre-trained TTM in two settings:
1. **Zero-shot**: The pre-trained TTM will be directly used to evaluate on the `test` split of the target data. Note that the TTM was NOT pre-trained on the target data.
2. **Few-shot**: The pre-trained TTM will be quickly fine-tuned on only 5% of the `train` split of the target data, and subsequently, evaluated on the `test` part of the target data.

Note: Alternatively, this notebook can be modified to try the TTM-1024-96 or TTM-1536-96 model.

Pre-trained TTM models will be fetched from the [Hugging Face TTM Model Repository](ibm-granite/granite-timeseries-ttm-r2).

1. TTM-R1 pre-trained models can be found here: [TTM-R1 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r1)
    1. For 512-96 model set `TTM_MODEL_REVISION="main"`
    2. For 1024-96 model set `TTM_MODEL_REVISION="1024_96_v1"`
2. TTM-R2 pre-trained models can be found here: [TTM-R2 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)
    1. For 512-96 model set `TTM_MODEL_REVISION="main"`
    2. For 1024-96 model set `TTM_MODEL_REVISION="1024-96-r2"`
    3. For 1536-96 model set `TTM_MODEL_REVISION="1536-96-r2"`

Details about the revisions (R1 and R2) can be found [here](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2).

## Install `tsfm` 
**[Optional for Local Run / Mandatory for Google Colab]**  
Run the below cell to install `tsfm`. Skip if already installed.

In [None]:
# Install the tsfm library
! pip install "tsfm_public[notebooks] @ git+https://github.com/ibm-granite/granite-tsfm.git@v0.2.12"

## Imports

In [None]:
import os
import torch
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 指定使用第0个GPU
print(torch.cuda.device_count())  # 输出可用的 GPU 数量

In [2]:

import math
import os
import tempfile

from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn as nn
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from transformers.integrations import INTEGRATION_TO_CALLBACK

from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
from tsfm_public.toolkit.visualization import plot_predictions
import numpy as np

In [3]:
import warnings
# Suppress all warnings
warnings.filterwarnings("ignore")

### Important arguments

In [4]:
# Set seed for reproducibility
SEED = 42
set_seed(SEED)

# TTM Revision (1 or 2)
TTM_REVISION = 2

# Context length, Or Length of the history.
# Currently supported values are: 512/1024/1536 for TTM-R-2, and 512/1024 for TTM-R1
CONTEXT_LENGTH = 512

FORECAST_LENGTH = 96 

# Dataset
# The dataloaders will utilize the easy-to-use YAML configurations defined below.
# Dataset configuration YAMLS: https://github.com/ibm-granite/granite-tsfm/tree/main/tsfm_public/resources/data_config
# Note that `dataset_root_path` can also be provided instead of `dataset_path` to the `load_dataset()` function to
# run this notebook on already downloaded dataset.
# Check the `load_dataset()` function to see more functionalities.
# TARGET_DATASET = "etth1"
TARGET_DATASET = "electricity"
# DATASET_PATH = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"
DATASET_PATH = "/home/zhupengtian/zhangqingliang/granite-tsfm/datasets/electricity/electricity.csv"

# Results dir
# OUT_DIR = "ttm_finetuned_models/"
OUT_DIR = "/home/zhupengtian/zhangqingliang/granite-tsfm/ttm_finetuned_models/"

#### Automatically set TTM_MODEL_PATH and TTM_MODEL_REVISION

In [5]:
# ----- TTM model path -----
if TTM_REVISION == 1:
    TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r1"
    # ----- TTM model branch -----
    # For R1 models
    if CONTEXT_LENGTH == 512:
        TTM_MODEL_REVISION = "main"
    elif CONTEXT_LENGTH == 1024:
        TTM_MODEL_REVISION = "1024_96_v1"
    else:
        raise ValueError(f"Unsupported CONTEXT_LENGTH for TTM_MODEL_PATH={TTM_MODEL_PATH}")
elif TTM_REVISION == 2:
    TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
    # ----- TTM model branch -----
    # For R2 models
    if CONTEXT_LENGTH == 512:
        TTM_MODEL_REVISION = "main"
    elif CONTEXT_LENGTH == 1024:
        TTM_MODEL_REVISION = "1024-96-r2"
    elif CONTEXT_LENGTH == 1536:
        TTM_MODEL_REVISION = "1536-96-r2"
    else:
        raise ValueError(f"Unsupported CONTEXT_LENGTH for TTM_MODEL_PATH={TTM_MODEL_PATH}")
else:
    raise ValueError("Wrong TTM_REVISION. Stay tuned for future models.")

In [None]:
print("Chosen TTM model:")
print(f"{TTM_MODEL_PATH}, revision = {TTM_MODEL_REVISION}")

In [None]:
# 打印数据集信息
dataset = load_dataset('electricity', context_length=512, forecast_length=96, dataset_path=DATASET_PATH)
# 直接打印数据集的类型和内容
print(type(dataset))
print(dataset)  # 打印 dataset 的内容
# 假设你的数据集是一个元组，包含训练集、验证集和测试集
train_dataset, val_dataset, test_dataset = dataset  # dataset 是加载的数据集

# 获取测试集的一个示例
test_index = 0  # 你可以修改为需要的索引
test_sample = test_dataset[test_index]

# 打印测试集示例的一些信息
print("测试集数据:")
print("过去值:", test_sample['past_values'])
print("未来值:", test_sample['future_values'])
print("过去观测掩码:", test_sample['past_observed_mask'])
print("时间戳:", test_sample['timestamp'])
print("ID:", test_sample['id'])

## Zero-shot evaluation method

In [10]:
def zeroshot_eval(dataset_name, batch_size, context_length=512, forecast_length=96, prediction_filter_length=None):
    if prediction_filter_length is not None:
        if prediction_filter_length >= forecast_length:
            raise ValueError(
                "`prediction_filter_length` should be less than the original `forecast_length` of the pre-trained TTM model."
            )
        forecast_length = forecast_length - prediction_filter_length

    # Get data
    _, _, dset_test = load_dataset(
        dataset_name=dataset_name,
        context_length=context_length,
        forecast_length=forecast_length,
        fewshot_fraction=1.0,
        dataset_path=DATASET_PATH,
    )

    # Load model
    if prediction_filter_length is None:
        zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(TTM_MODEL_PATH, revision=TTM_MODEL_REVISION)
    else:
        if prediction_filter_length <= forecast_length:
            zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(
                TTM_MODEL_PATH,
                revision=TTM_MODEL_REVISION,
                prediction_filter_length=prediction_filter_length,
            )
        else:
            raise ValueError("`prediction_filter_length` should be <= `forecast_length")
    temp_dir = tempfile.mkdtemp()
    # zeroshot_trainer
    zeroshot_trainer = Trainer(
        model=zeroshot_model,
        args=TrainingArguments(
            output_dir=temp_dir,
            per_device_eval_batch_size=batch_size,
            seed=SEED,
            report_to="none",
        ),
    )
    # evaluate = zero-shot performance
    print("+" * 20, "Test MSE zero-shot", "+" * 20)
    zeroshot_output = zeroshot_trainer.evaluate(dset_test)
    print(zeroshot_output)
    # plot
    plot_predictions(
        model=zeroshot_trainer.model,
        dset=dset_test,
        plot_dir=os.path.join(OUT_DIR, dataset_name),
        plot_prefix="test_zeroshot",
        indices=[685, 118, 902, 1984, 894, 967, 304, 57, 265, 1015],
        channel=0,
    )

## Example: downstream target dataset - etth1

### Zero-shot

In [None]:
zeroshot_eval(dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64)

 ## Few-shot finetune and evaluation method

In [None]:
def fewshot_finetune_eval(
    dataset_name,
    batch_size,
    learning_rate=None,
    context_length=512,
    forecast_length=96,
    fewshot_percent=5,
    freeze_backbone=True,
    num_epochs=50,
    save_dir=OUT_DIR,
    prediction_filter_length=None,
):
    out_dir = os.path.join(save_dir, dataset_name)

    print("-" * 20, f"Running few-shot {fewshot_percent}%", "-" * 20)

    if prediction_filter_length is not None:
        if prediction_filter_length >= forecast_length:
            raise ValueError(
                "`prediction_filter_length` should be less than the original `forecast_length` of the pre-trained TTM model."
            )
        forecast_length = forecast_length - prediction_filter_length

    # Data prep: Get dataset
    dset_train, dset_val, dset_test = load_dataset(
        dataset_name,
        context_length,
        forecast_length,
        fewshot_fraction=fewshot_percent / 100,
        dataset_path=DATASET_PATH,
    )

    # change head dropout to 0.7 for ett datasets
    if "ett" in dataset_name:
        if prediction_filter_length is None:
            finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(
                TTM_MODEL_PATH, revision=TTM_MODEL_REVISION, head_dropout=0.7
            )
        elif prediction_filter_length <= forecast_length:
            finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(
                TTM_MODEL_PATH,
                revision=TTM_MODEL_REVISION,
                head_dropout=0.7,
                prediction_filter_length=prediction_filter_length,
            )
        else:
            raise ValueError("`prediction_filter_length` should be <= `forecast_length")
    else:
        if prediction_filter_length is None:
            finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(
                TTM_MODEL_PATH,
                revision=TTM_MODEL_REVISION,
            )
        elif prediction_filter_length <= forecast_length:
            finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(
                TTM_MODEL_PATH,
                revision=TTM_MODEL_REVISION,
                prediction_filter_length=prediction_filter_length,
            )
        else:
            raise ValueError("`prediction_filter_length` should be <= `forecast_length")
    if freeze_backbone:
        print(
            "Number of params before freezing backbone",
            count_parameters(finetune_forecast_model),
        )

        # Freeze the backbone of the model
        for param in finetune_forecast_model.backbone.parameters():
            param.requires_grad = False

        # Count params
        print(
            "Number of params after freezing the backbone",
            count_parameters(finetune_forecast_model),
        )

    # Find optimal learning rate
    # Use with caution: Set it manually if the suggested learning rate is not suitable
    if learning_rate is None:
        learning_rate, finetune_forecast_model = optimal_lr_finder(
            finetune_forecast_model,
            dset_train,
            batch_size=batch_size,
        )
        print("OPTIMAL SUGGESTED LEARNING RATE =", learning_rate)

    print(f"Using learning rate = {learning_rate}")
    finetune_forecast_args = TrainingArguments(
        output_dir=os.path.join(out_dir, "output"),
        overwrite_output_dir=True,
        learning_rate=learning_rate,
        num_train_epochs=num_epochs,
        do_eval=True,
        evaluation_strategy="epoch",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        dataloader_num_workers=8,
        report_to="none",
        save_strategy="epoch",
        logging_strategy="epoch",
        save_total_limit=1,
        logging_dir=os.path.join(out_dir, "logs"),  # Make sure to specify a logging directory
        load_best_model_at_end=True,  # Load the best model when training ends
        metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
        greater_is_better=False,  # For loss
        seed=SEED,
    )

    # Create the early stopping callback
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=10,  # Number of epochs with no improvement after which to stop
        early_stopping_threshold=1e-5,  # Minimum improvement required to consider as improvement
    )
    tracking_callback = TrackingCallback()

    # Optimizer and scheduler
    optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)
    scheduler = OneCycleLR(
        optimizer,
        learning_rate,
        epochs=num_epochs,
        steps_per_epoch=math.ceil(len(dset_train) / (batch_size)),
    )

    finetune_forecast_trainer = Trainer(
        model=finetune_forecast_model,
        args=finetune_forecast_args,
        train_dataset=dset_train,
        eval_dataset=dset_val,
        callbacks=[early_stopping_callback, tracking_callback],
        optimizers=(optimizer, scheduler),
    )
    finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"])

    # Fine tune
    finetune_forecast_trainer.train()

    # Evaluation
    print("+" * 20, f"Test MSE after few-shot {fewshot_percent}% fine-tuning", "+" * 20)
    fewshot_output = finetune_forecast_trainer.evaluate(dset_test)
    print(fewshot_output)
    print("+" * 60)

    # plot
    plot_predictions(
        model=finetune_forecast_trainer.model,
        dset=dset_test,
        plot_dir=os.path.join(OUT_DIR, dataset_name),
        plot_prefix="test_fewshot",
        indices=[685, 118, 902, 1984, 894, 967, 304, 57, 265, 1015],
        channel=0,
    )

### Few-shot 5%

In [None]:
fewshot_finetune_eval(
    dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64, fewshot_percent=5, learning_rate=0.001
)

## Example: Automatically truncating the forecast horizon

Here, we demonstrate that a pre-trained 512-96 TTM model (i.e., context length = 512, forecast horizon = 96) 
can be used for a task having forecast horizon less than 96 time points.
We need to specify the argument `prediction_filter_length` while loading the model. That's it!

Note that the model performance might be sacrificed by some margin while truncating the model forecast. It is recommended to try 
this feature in your validation data for your experiment, to verify if the model performance is in the acceptable threshold. 
Otherwise, a new TTM model can be pre-trained with the required forecast horizon.

In this example, we will use a 512-96 TTM and use it on etth1 data for forecasting 48 points in both zero-shot and 5% few-shot settings.

### Zero-shot

In [None]:
zeroshot_eval(dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64, prediction_filter_length=48)

### Few-shot 5%

In [None]:
fewshot_finetune_eval(
    dataset_name=TARGET_DATASET,
    context_length=CONTEXT_LENGTH,
    batch_size=64,
    prediction_filter_length=48,
    fewshot_percent=5,
    learning_rate=None,
)