 # Channel Independence Patch Time Series Transformer
 Fine tuning for forecasting

 Maybe add a picture of the PatchTST with forecasting head?

In [1]:
import pandas as pd

from tsfmservices.toolkit.dataset import ForecastDFDataset
from transformers import (
    PatchTSTConfig,
    PatchTSTForForecasting,
    Trainer,
    TrainingArguments,
)

 ## Load and prepare datasets

 Please adjust the following parameters to suit your application:
 - timestamp_column: column name containing timestamp information, use None if there is no such column
 - id_columns: List of column names specifying the IDs of different time series. If no ID column exists, use []
 - forecast_columns: List of columns to be modeled
 - context_length: Specifies how many historical time points are used by the model
 - prediction_length: Specifies how many timepoints should be forecasted

 Using the parameters above load the data, divide it into train and eval portions, and create torch datasets.

In [2]:
timestamp_column = "date"
id_columns = []
forecast_columns = ["OT"]

data = pd.read_csv(
    "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv",
    parse_dates=[timestamp_column],
)
print(data.head())

pretrained_model_path = "model/pretrained"
pretrained_config = PatchTSTConfig.from_pretrained(pretrained_model_path)

prediction_length = 20
context_length = 32  # use pretrained_config.context_length to match pretrained model

# to do: split data
# need utility here, group sensitive splitting should be done
train_data = data.iloc[: 12 * 30 * 24,].copy()
eval_data = data.iloc[
    12 * 30 * 24 - context_length : 12 * 30 * 24 + 4 * 30 * 24,
].copy()


train_dataset = ForecastDFDataset(
    train_data,
    timestamp_column=timestamp_column,
    id_columns=id_columns,
    input_columns=forecast_columns,
    context_length=context_length,
    prediction_length=prediction_length,
)
eval_dataset = ForecastDFDataset(
    eval_data,
    timestamp_column=timestamp_column,
    id_columns=id_columns,
    input_columns=forecast_columns,
    context_length=context_length,
    prediction_length=prediction_length,
)

                 date   HUFL   HULL   MUFL   MULL   LUFL   LULL         OT
0 2016-07-01 00:00:00  5.827  2.009  1.599  0.462  4.203  1.340  30.531000
1 2016-07-01 01:00:00  5.693  2.076  1.492  0.426  4.142  1.371  27.787001
2 2016-07-01 02:00:00  5.157  1.741  1.279  0.355  3.777  1.218  27.787001
3 2016-07-01 03:00:00  5.090  1.942  1.279  0.391  3.807  1.279  25.044001
4 2016-07-01 04:00:00  5.358  1.942  1.492  0.462  3.868  1.279  21.948000


 ## Configure the PatchTST model

 Describe only forecasting specific parameters that are configurable here.

In [3]:
pred_config = PatchTSTConfig.from_pretrained(
    pretrained_model_path,
    context_length=context_length,
    num_input_channels=len(forecast_columns),
    prediction_length=prediction_length,
)

 ## Load model and freeze base model parameters

In [4]:
forecasting_model = PatchTSTForForecasting.from_pretrained(
    "model/pretrained",
    config=pred_config,
    ignore_mismatched_sizes=True,
)
# to unfreeze the base model parameters, comment out the cell
for param in forecasting_model.base_model.parameters():
    param.requires_grad = False

Some weights of PatchTSTForForecasting were not initialized from the model checkpoint at model/pretrained and are newly initialized because the shapes did not match:
- head.linear.weight: found shape torch.Size([12, 16]) in the checkpoint and torch.Size([20, 16]) in the model instantiated
- head.linear.bias: found shape torch.Size([12]) in the checkpoint and torch.Size([20]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


 ## Train model
 Provide description of important training parameters.

In [5]:
training_args = TrainingArguments(
    output_dir="./checkpoint/forecast",
    # logging_steps = 100,
    # per_device_train_batch_size = 64, #defaults to 8
    # per_device_eval_batch_size = 64, #defaults to 8
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    # eval_steps = 100,
    save_total_limit=5,
    logging_strategy="epoch",
    load_best_model_at_end=True,
    max_steps=10,  # For a quick test
    label_names=["future_values"],
)


forecasting_trainer = Trainer(
    model=forecasting_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    # compute_metrics=compute_metrics
)

forecasting_trainer.train()

  0%|          | 0/10 [00:00<?, ?it/s]

  ids_shuffle = torch.argsort(noise, dim=-1)  # ascend: small is keep, large is remove


{'loss': 12.4689, 'learning_rate': 0.0, 'epoch': 0.01}


  0%|          | 0/358 [00:00<?, ?it/s]

{'eval_loss': 6.40577507019043, 'eval_runtime': 9.021, 'eval_samples_per_second': 317.148, 'eval_steps_per_second': 39.685, 'epoch': 0.01}
{'train_runtime': 10.0074, 'train_samples_per_second': 7.994, 'train_steps_per_second': 0.999, 'train_loss': 12.468918609619141, 'epoch': 0.01}


TrainOutput(global_step=10, training_loss=12.468918609619141, metrics={'train_runtime': 10.0074, 'train_samples_per_second': 7.994, 'train_steps_per_second': 0.999, 'train_loss': 12.468918609619141, 'epoch': 0.01})

In [6]:
# ## Inference
#
# To do: use pipeline code to produce more friendly output
import torch, copy

device = forecasting_model.device


data_sample = copy.copy(eval_dataset[0])
data_sample["past_values"] = torch.unsqueeze(data_sample["past_values"], 0)
data_sample["future_values"] = torch.unsqueeze(data_sample["future_values"], 0)
forecasting_model(
    data_sample["past_values"].to(device),
    future_values=data_sample["future_values"].to(device),
)

  nonzero_finite_vals = torch.masked_select(


PatchTSTForForecastingOutput(loss=tensor(2.2564, device='mps:0', grad_fn=<MseLossBackward0>), forecast_outputs=tensor([[[21.2610],
         [21.4202],
         [21.3384],
         [21.3653],
         [21.4324],
         [21.4399],
         [21.4298],
         [21.2684],
         [21.3639],
         [21.4728],
         [21.3428],
         [21.2446],
         [21.3248],
         [21.3429],
         [21.3274],
         [21.3319],
         [21.2655],
         [21.4246],
         [21.2654],
         [21.3365]]], device='mps:0', grad_fn=<AddBackward0>), hidden_states=[], attentions=None)