# Introduction
This notebook shows how to use TimesFM with finetuning. 

In order to perform finetuning, you need to create the Pytorch Dataset in a proper format. The example of the Dataset is provided below.
The finetuning code can be found in timesfm.finetuning_torch.py. This notebook just imports the methods from finetuning

### Dataset Creation

In [None]:
from os import path
from typing import Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
import yfinance as yf
from finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner
from huggingface_hub import snapshot_download
from torch.utils.data import Dataset

from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams
from timesfm.pytorch_patched_decoder import PatchedTimeSeriesDecoder
import os


class TimeSeriesDataset(Dataset):
  """Dataset for time series data compatible with TimesFM."""

  def __init__(self,
               series: np.ndarray,
               context_length: int,
               horizon_length: int,
               freq_type: int = 0):
    """
        Initialize dataset.

        Args:
            series: Time series data
            context_length: Number of past timesteps to use as input
            horizon_length: Number of future timesteps to predict
            freq_type: Frequency type (0, 1, or 2)
        """
    if freq_type not in [0, 1, 2]:
      raise ValueError("freq_type must be 0, 1, or 2")

    self.series = series
    self.context_length = context_length
    self.horizon_length = horizon_length
    self.freq_type = freq_type
    self._prepare_samples()

  def _prepare_samples(self) -> None:
    """Prepare sliding window samples from the time series."""
    self.samples = []
    total_length = self.context_length + self.horizon_length

    for start_idx in range(0, len(self.series) - total_length + 1):
      end_idx = start_idx + self.context_length
      x_context = self.series[start_idx:end_idx]
      x_future = self.series[end_idx:end_idx + self.horizon_length]
      self.samples.append((x_context, x_future))

  def __len__(self) -> int:
    return len(self.samples)

  def __getitem__(
      self, index: int
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    x_context, x_future = self.samples[index]

    x_context = torch.tensor(x_context, dtype=torch.float32)
    x_future = torch.tensor(x_future, dtype=torch.float32)

    input_padding = torch.zeros_like(x_context)
    freq = torch.tensor([self.freq_type], dtype=torch.long)

    return x_context, input_padding, freq, x_future

def prepare_datasets(series: np.ndarray,
                     context_length: int,
                     horizon_length: int,
                     freq_type: int = 0,
                     train_split: float = 0.9) -> Tuple[Dataset, Dataset]:
  """
    Prepare training and validation datasets from time series data.

    Args:
        series: Input time series data
        context_length: Number of past timesteps to use
        horizon_length: Number of future timesteps to predict
        freq_type: Frequency type (0, 1, or 2)
        train_split: Fraction of data to use for training

    Returns:
        Tuple of (train_dataset, val_dataset)
    """
  train_size = int(len(series) * train_split)
  train_data = series[:train_size]
  val_data = series[train_size:]

  # Create datasets with specified frequency type
  train_dataset = TimeSeriesDataset(train_data,
                                    context_length=context_length,
                                    horizon_length=horizon_length,
                                    freq_type=freq_type)

  val_dataset = TimeSeriesDataset(val_data,
                                  context_length=context_length,
                                  horizon_length=horizon_length,
                                  freq_type=freq_type)

  return train_dataset, val_dataset


### Model Creation

In [None]:
def get_model(load_weights: bool = False):
  device = "cuda" if torch.cuda.is_available() else "cpu"
  repo_id = "google/timesfm-2.0-500m-pytorch"
  hparams = TimesFmHparams(
      backend=device,
      per_core_batch_size=32,
      horizon_len=32,
      num_layers=50,
      use_positional_embedding=False,
      context_len=
      128+96,  # Context length can be anything up to 2048 in multiples of 32
  )
  tfm = TimesFm(hparams=hparams,
                checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id))

  model = PatchedTimeSeriesDecoder(tfm._model_config)
  if load_weights:
    checkpoint_path = path.join(snapshot_download(repo_id), "torch_model.ckpt")
    loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)
    model.load_state_dict(loaded_checkpoint)
  return model, hparams, tfm._model_config


In [None]:
def plot_predictions(
    model: TimesFm,
    val_dataset: Dataset,
    save_path: Optional[str] = "predictions.png",
) -> None:
  """
    Plot model predictions against ground truth for a batch of validation data.

    Args:
      model: Trained TimesFM model
      val_dataset: Validation dataset
      save_path: Path to save the plot
    """
  import matplotlib.pyplot as plt

  model.eval()

  x_context, x_padding, freq, x_future = val_dataset[0]
  x_context = x_context.unsqueeze(0)  # Add batch dimension
  x_padding = x_padding.unsqueeze(0)
  freq = freq.unsqueeze(0)
  x_future = x_future.unsqueeze(0)

  device = next(model.parameters()).device
  x_context = x_context.to(device)
  x_padding = x_padding.to(device)
  freq = freq.to(device)
  x_future = x_future.to(device)

  with torch.no_grad():
    predictions = model(x_context, x_padding.float(), freq)
    predictions_mean = predictions[..., 0]  # [B, N, horizon_len]
    last_patch_pred = predictions_mean[:, -1, :]  # [B, horizon_len]

  context_vals = x_context[0].cpu().numpy()
  future_vals = x_future[0].cpu().numpy()
  pred_vals = last_patch_pred[0].cpu().numpy()

  context_len = len(context_vals)
  horizon_len = len(future_vals)

  plt.figure(figsize=(12, 6))

  plt.plot(range(context_len),
           context_vals,
           label="Historical Data",
           color="blue",
           linewidth=2)

  plt.plot(
      range(context_len, context_len + horizon_len),
      future_vals,
      label="Ground Truth",
      color="green",
      linestyle="--",
      linewidth=2,
  )

  plt.plot(range(context_len, context_len + horizon_len),
           pred_vals,
           label="Prediction",
           color="red",
           linewidth=2)

  plt.xlabel("Time Step")
  plt.ylabel("Value")
  plt.title("TimesFM Predictions vs Ground Truth")
  plt.legend()
  plt.grid(True)

  if save_path:
    plt.savefig(save_path)
    print(f"Plot saved to {save_path}")

  plt.close()



In [None]:
# Load solana dataset
def load_solana():
  # Load the Solana dataset
  df = pd.read_csv("../datasets/solana_data.csv")
  df["Date"] = pd.to_datetime(df["Date"])
  df.set_index("Date", inplace=True)
  df.sort_index(inplace=True)

  # Convert to numpy array
  time_series = df["Price"].values

  return time_series

def load_APPL() -> np.ndarray:
    df = yf.download("AAPL", start="2010-01-01", end="2019-01-01")
    time_series = df["Close"].values
    return time_series

In [None]:
def save_video(wandb_project: str, exp_num: int) -> None:
      # Collect all prediction images saved during finetuning
    import glob
    import cv2
    
    image_files = sorted(
        glob.glob(f"predictions_plts/predictions_epoch_*.png"),
        key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split("_")[-1])
    )

    if image_files:
        # Read the first image to get dimensions
        frame = cv2.imread(image_files[0])
        height, width, layers = frame.shape

        video_path = f"timesfm_{wandb_project}_exp{exp_num}_predictions.mp4"
    
        out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 2, (width, height))

        for img_file in image_files:
            img = cv2.imread(img_file)
            out.write(img)

        out.release()
        print(f"Video saved to {video_path}")
    else:
        print("No prediction images found to create video.")


In [None]:
import cv2
import glob
import os

def get_data(context_len: int,
             horizon_len: int,
             loading_func: callable = None,
             freq_type: int = 0,
             ) -> Tuple[Dataset, Dataset]:
  
  time_series = loading_func() if loading_func else ValueError(
      "No loading function provided. Please provide a function to load data."
  )

  train_dataset, val_dataset = prepare_datasets(
      series=time_series,
      context_length=context_len,
      horizon_length=horizon_len,
      freq_type=freq_type,
      train_split=0.8,
  )

  print(f"Created datasets:")
  print(f"- Training samples: {len(train_dataset)}")
  print(f"- Validation samples: {len(val_dataset)}")
  print(f"- Using frequency type: {freq_type}")
  return train_dataset, val_dataset


def single_gpu_example(experiment_config, data_func: Optional[callable] = None, wandb_project: str = None, exp_num: int = 0):
  """Basic example of finetuning TimesFM on stock data."""
  model, hparams, tfm_config = get_model(load_weights=True)
  config = experiment_config

  train_dataset, val_dataset = get_data(128+96,
                                        tfm_config.horizon_len,
                                        loading_func=data_func,
                                        freq_type=config.freq_type)
  finetuner = TimesFMFinetuner(model, config)

  print("\nStarting finetuning...")
  results = finetuner.finetune(train_dataset=train_dataset,
                               val_dataset=val_dataset, vizualize=True)

  print("\nFinetuning completed!")
  print(f"Training history: {len(results['history']['train_loss'])} epochs")

  plot_predictions(
      model=model,
      val_dataset=val_dataset,
      save_path=f"timesfm_{wandb_project}_predictions_{exp_num}.png",
  )

  save_video(wandb_project, exp_num)
  print("Video of predictions saved.")



In [None]:
dataset = 'solana'

In [None]:
if dataset == 'APPL': 
    wandb_project = "timesfm-finetuning"
    data_func = load_APPL
elif dataset == 'solana':
    wandb_project = "timesfm-finetuning-solana"
    data_func = load_solana

In [None]:
experiment_configs = [
    FinetuningConfig(
        batch_size=64,
        num_epochs=20,
        learning_rate=1e-5,
        use_wandb=True,
        freq_type=0,
        log_every_n_steps=10,
        val_check_interval=0.75,
        wandb_project = wandb_project,
        use_quantile_loss=True),
    FinetuningConfig(
        batch_size=64,
        num_epochs=20,
        learning_rate=1e-6,
        use_wandb=True,
        freq_type=0,
        log_every_n_steps=10,
        val_check_interval=0.75,
        wandb_project = wandb_project,
        use_quantile_loss=True),
    FinetuningConfig(
        batch_size=64,
        num_epochs=20,
        learning_rate=5e-7,
        use_wandb=True,
        freq_type=0,
        log_every_n_steps=10,
        val_check_interval=0.75,
        wandb_project = wandb_project,
        use_quantile_loss=True),
    
    FinetuningConfig(
        batch_size=32,
        num_epochs=20,
        learning_rate=1e-5,
        use_wandb=True,
        freq_type=0,
        log_every_n_steps=10,
        val_check_interval=0.75,
        wandb_project = wandb_project,
        use_quantile_loss=True),
    FinetuningConfig(
        batch_size=32,
        num_epochs=20,
        learning_rate=1e-6,
        use_wandb=True,
        freq_type=0,
        log_every_n_steps=10,
        val_check_interval=0.75,
        wandb_project = wandb_project,
        use_quantile_loss=True),
    FinetuningConfig(
        batch_size=64,
        num_epochs=20,
        learning_rate=5e-7,
        use_wandb=True,
        freq_type=0,
        log_every_n_steps=10,
        val_check_interval=0.75,
        wandb_project = wandb_project,
        use_quantile_loss=True),
]

In [None]:
for exp_num, config in enumerate(experiment_configs):
    single_gpu_example(data_func=data_func, 
                   wandb_project=wandb_project, experiment_config=config, exp_num=exp_num)