In [None]:
! pip install "granite-tsfm[notebooks] @ git+https://github.com/ibm-granite/granite-tsfm.git@v0.2.22"

In [None]:
import math
import os
import tempfile

import pandas as pd
import numpy as np
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments,set_seed
from transformers.integrations import INTEGRATION_TO_CALLBACK

from tsfm_public import TimeSeriesPreprocessor, TrackingCallback, count_parameters, get_datasets
from tsfm_public.toolkit.get_model import get_model
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
from tsfm_public.toolkit.visualization import plot_predictions

In [None]:
import warnings


# Suppress all warnings
warnings.filterwarnings("ignore")

In [None]:
# Set seed for reproducibility
#SEED = 45
#set_seed(SEED)
import time

# TTM Model path. The default model path is Granite-R2. Below, you can choose other TTM releases.
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
# TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r1"
# TTM_MODEL_PATH = "ibm-research/ttm-research-r2"

# Context length, Or Length of the history.
# Currently supported values are: 512/1024/1536 for Granite-TTM-R2 and Research-Use-TTM-R2, and 512/1024 for Granite-TTM-R1
CONTEXT_LENGTH = 90

# Granite-TTM-R2 supports forecast length upto 720 and Granite-TTM-R1 supports forecast length upto 96
PREDICTION_LENGTH = 24

# Results dir
OUT_DIR = "../results/"

In [None]:
dataset_path= "../L1MAG.csv.bin"
dataset= np.fromfile(dataset_path)
dataset_path2= "../L1MAG_part2_summary_statistics.csv"

if data_path2.endswith(".csv"):
  hint_data= np.loadtxt(data_path2)
elif data_path2.endswith(".bin"):
  hint_data= np.fromfile(data_path2)

new_data1= np.concatenate((dataset,hint_data))

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
new_data= new_data1.reshape(-1,1)
fit_data= scaler.fit(new_data)
new_data= scaler.transform(new_data)

#print(new_data)
timestamp_column = "date"

date_new= pd.date_range(start=pd.to_datetime("2018-07-01"), periods=len(new_data), freq="H")
new_data= new_data.flatten()
df_raw= pd.DataFrame({"date":date_new,"Data":new_data})
print(df_raw)

id_columns = []  # mention the ids that uniquely identify a time-series.

target_columns = ["Data"]
split_config = {
    "train": [0, 3668],
    "valid": [3668,3767],
    "test": [
        3767,
        7535,
    ],
}

column_specifiers = {
    "timestamp_column": timestamp_column,
    "id_columns": id_columns,
    "target_columns": target_columns,
    "control_columns": [],
}

In [None]:
dataset_name= "Hinted_dataset"

In [None]:
test= 0
pred= 0
def fewshot_finetune_eval(
    dataset_name,
    batch_size,
    learning_rate=None,
    context_length= CONTEXT_LENGTH,
    forecast_length= PREDICTION_LENGTH,
    fewshot_percent=100,
    freeze_backbone=True,
    num_epochs=50,
    save_dir=OUT_DIR,
    loss="mse",
    quantile=0.5,
):
    out_dir = os.path.join(save_dir, dataset_name)

    # Data prep: Get dataset

    tsp = TimeSeriesPreprocessor(
        **column_specifiers,
        context_length=context_length,
        prediction_length=forecast_length,
        scaling=False,
        encode_categorical=False,
        scaler_type="standard",
    )

    finetune_forecast_model = get_model(
        TTM_MODEL_PATH,
        context_length=context_length,
        prediction_length=forecast_length,
        freq_prefix_tuning=False,
        freq=None,
        force_return= "random_init_medium",
        prefer_l1_loss=False,
        prefer_longer_context=True,
        # Can also provide TTM Config args
        loss=loss,
        quantile=quantile,
    )

    dset_train, dset_val, dset_test = get_datasets(
        tsp,
        df_raw,
        split_config,
        fewshot_fraction=fewshot_percent / 100,
        fewshot_location="first",
        use_frequency_token=finetune_forecast_model.config.resolution_prefix_tuning,
    )

    if freeze_backbone:
        print(
            "Number of params before freezing backbone",
            count_parameters(finetune_forecast_model),
        )

        # Freeze the backbone of the model
        for param in finetune_forecast_model.backbone.parameters():
            param.requires_grad = False

        # Count params
        print(
            "Number of params after freezing the backbone",
            count_parameters(finetune_forecast_model),
        )

    # Find optimal learning rate
    # Use with caution: Set it manually if the suggested learning rate is not suitable
    if learning_rate is None:
        learning_rate, finetune_forecast_model = optimal_lr_finder(
            finetune_forecast_model,
            dset_train,
            batch_size=batch_size,
        )
        print("OPTIMAL SUGGESTED LEARNING RATE =", learning_rate)

    print(f"Using learning rate = {learning_rate}")
    finetune_forecast_args = TrainingArguments(
        output_dir=os.path.join(out_dir, "output"),
        overwrite_output_dir=True,
        learning_rate=learning_rate,
        num_train_epochs=num_epochs,
        do_eval=True,
        eval_strategy="epoch",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        dataloader_num_workers=8,
        report_to="none",
        save_strategy="epoch",
        logging_strategy="epoch",
        save_total_limit=1,
        logging_dir=os.path.join(out_dir, "logs"),  # Make sure to specify a logging directory
        load_best_model_at_end=True,  # Load the best model when training ends
        metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
        greater_is_better=False,  # For loss
        seed=int(time.time()),
    )

    # Create the early stopping callback
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=10,  # Number of epochs with no improvement after which to stop
        early_stopping_threshold=1e-5,  # Minimum improvement required to consider as improvement
    )
    tracking_callback = TrackingCallback()

    # Optimizer and scheduler
    optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)
    scheduler = OneCycleLR(
        optimizer,
        learning_rate,
        epochs=num_epochs,
        steps_per_epoch=math.ceil(len(dset_train) / (batch_size)),
    )

    finetune_forecast_trainer = Trainer(
        model=finetune_forecast_model,
        args=finetune_forecast_args,
        train_dataset=dset_train,
        eval_dataset=dset_val,
        callbacks=[early_stopping_callback, tracking_callback],
        optimizers=(optimizer, scheduler),
    )
    finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"])

    # Fine tune
    finetune_forecast_trainer.train()

    # Evaluation
    print("+" * 20, f"Test MSE after few-shot {fewshot_percent}% fine-tuning", "+" * 20)

    finetune_forecast_trainer.model.loss = "mse"  # fixing metric to mse for evaluation

    fewshot_output = finetune_forecast_trainer.evaluate(dset_test)
    print(fewshot_output)
    print("+" * 60)
    print(dset_test[0])

    # get predictions

    predictions_dict = finetune_forecast_trainer.predict(dset_test)
    #print("predictions_dict",predictions_dict)

    predictions_np = predictions_dict.predictions[0]
    print("predictions_np",predictions_np)
    print("testing data", dset_test[0])
    print(predictions_np.shape)
    predictions_np= np.array(predictions_np)
    pred= predictions_np.flatten()
    np.savetxt("../results/ttm_pred_data.csv",pred)

    # get backbone embeddings (if needed for further analysis)

    backbone_embedding = predictions_dict.predictions[1]

    print(backbone_embedding.shape)


    # plot
    plot_predictions(
        model=finetune_forecast_trainer.model,
        dset=dset_test,
        plot_dir=os.path.join(OUT_DIR, dataset_name),
        plot_prefix="test_fewshot",
        indices=[685, 118, 902, 1984, 894, 967, 304, 57, 265, 1015],
        channel=0,
    )

In [None]:
fewshot_finetune_eval(
  dataset_name=dataset_name,
  context_length= CONTEXT_LENGTH,
  forecast_length= PREDICTION_LENGTH,
  batch_size=64,
  fewshot_percent=100,
  earning_rate=0.001,
  )