# Quick Start: Running Super-Linear on gift-eval benchmark

This notebook shows how to run Super-Linear on the gift-eval benchmark.

Make sure you download the gift-eval benchmark and set the `GIFT-EVAL` environment variable correctly before running this notebook.

We will use the `Dataset` class to load the data and run the model. If you have not already please check out the [dataset.ipynb](./dataset.ipynb) notebook to learn more about the `Dataset` class. We are going to just run the model on two datasets for brevity. But feel free to run on any dataset by changing the `short_datasets` and `med_long_datasets` variables below.

### Installation

In [None]:
# %pip install transformers==4.40.1 # Use this version and Python 3.10 for stable compatibility

## Setting up the data and metrics

In [None]:
import json
import torch
from dotenv import load_dotenv

# Load environment variables
load_dotenv()


device  =  "cuda" if torch.cuda.is_available() else "cpu"

short_datasets = "m4_yearly m4_quarterly m4_monthly m4_weekly m4_daily m4_hourly electricity/15T electricity/H electricity/D electricity/W solar/10T solar/H solar/D solar/W hospital covid_deaths us_births/D us_births/M us_births/W saugeenday/D saugeenday/M saugeenday/W temperature_rain_with_missing kdd_cup_2018_with_missing/H kdd_cup_2018_with_missing/D car_parts_with_missing restaurant hierarchical_sales/D hierarchical_sales/W LOOP_SEATTLE/5T LOOP_SEATTLE/H LOOP_SEATTLE/D SZ_TAXI/15T SZ_TAXI/H M_DENSE/H M_DENSE/D ett1/15T ett1/H ett1/D ett1/W ett2/15T ett2/H ett2/D ett2/W jena_weather/10T jena_weather/H jena_weather/D bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
#short_datasets = "electricity/W restaurant"

med_long_datasets = "electricity/15T electricity/H solar/10T solar/H kdd_cup_2018_with_missing/H LOOP_SEATTLE/5T LOOP_SEATTLE/H SZ_TAXI/15T M_DENSE/H ett1/15T ett1/H ett2/15T ett2/H jena_weather/10T jena_weather/H bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
#med_long_datasets = "bitbrains_fast_storage/5T"

# Get union of short and med_long datasets
all_datasets = sorted(list(set(short_datasets.split() + med_long_datasets.split())))

dataset_properties_map = json.load(open("dataset_properties.json"))

In [None]:
from gluonts.ev.metrics import (
    MAE,
    MAPE,
    MASE,
    MSE,
    MSIS,
    ND,
    NRMSE,
    RMSE,
    SMAPE,
    MeanWeightedSumQuantileLoss,
)

# Instantiate the metrics
metrics = [
    MSE(forecast_type="mean"),
    MSE(forecast_type=0.5),
    MAE(),
    MASE(),
    MAPE(),
    SMAPE(),
    MSIS(),
    RMSE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(
        quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    ),
]

## Super-Linear

For foundation models, we need to implement a wrapper containing the model and use the wrapper to generate predicitons.

This is just meant to be a simple wrapper to get you started, feel free to use your own custom implementation to wrap any model.

In [None]:

from typing import List
import numpy as np
from tqdm.auto import tqdm
from gluonts.itertools import batcher
from gluonts.model import Forecast
from gluonts.model.forecast import QuantileForecast
from transformers import AutoModelForCausalLM
from gluonts.model.forecast import QuantileForecast, SampleForecast



class SuperLinearPredictor:

    def __init__(
        self,
        model: AutoModelForCausalLM,
        prediction_length: int,
        device: str = "cuda",
    ):
        self.model = model
        self.model.eval()
        self.device = device
        self.model.to(self.device)

        self.prediction_length = prediction_length
        self.train_seq_len = self.model.backbone.train_seq_len
        self.lookback_resampling = self.model.backbone.lookback_resampling
        self.scale_list = np.array(self.model.backbone.scale_list)

        if self.lookback_resampling:
            self.max_lookback = np.max(np.append(self.scale_list*self.train_seq_len,self.train_seq_len))
        else:
            self.max_lookback = self.train_seq_len
        self.max_lookback = int(self.max_lookback)

    def predict(self, test_data_input, batch_size: int = 1024) -> List[Forecast]:
        # Group time series by length to process similar-length series together
        length_groups = {}
        for i, entry in enumerate(test_data_input):
            arr = np.array(entry["target"])
            arr_len = min(len(arr), self.max_lookback)
            arr_len = len(arr)
            if arr_len not in length_groups:
                length_groups[arr_len] = []
            length_groups[arr_len].append((i, entry))
        
        # Process each length group in batches
        all_forecasts = [None] * len(test_data_input)
        
        # Iterate over each group of sequences with the same length
        for length, group in length_groups.items():
            for mini_batch in batcher(group, batch_size=batch_size):
                indices = [item[0] for item in mini_batch]
                entries = [item[1] for item in mini_batch]
                
                # Prepare context
                context = []
                for entry in entries:
                    arr = torch.tensor(entry["target"])
                    arr = arr[-self.max_lookback:]
                    if torch.isnan(arr).any():
                        # Handle NaN values by interpolation
                        arr = interpolate_missing_values(arr.unsqueeze(0).unsqueeze(-1)).squeeze(0).squeeze(-1)
                    context.append(arr)
                
                # Create tensor - no padding needed since all sequences in this group have same length
                input_x = torch.stack(context, dim=0).unsqueeze(1).to(self.device)
            
                # Forward pass through the model
                all_output = self.model(input_x, pred_len=self.prediction_length, get_prob=False)
                output = all_output.logits # Predicted values
                batch_forecasts = output.detach().cpu().numpy()
                #batch_forecasts = batch_forecasts[:,:, :self.prediction_length ]
                

                # Store forecasts in the correct order
                for i, idx in enumerate(indices):
                    forecast_start_date = entries[i]["start"] + len(entries[i]["target"])
                    all_forecasts[idx] = SampleForecast(
                        samples=batch_forecasts[i], 
                        start_date=forecast_start_date
                    )
        
        return all_forecasts



def interpolate_missing_values(data_batch):
    """
    Interpolates missing (NaN) values in a batch of time series data.
    
    Args:
        data_batch: Tensor of shape (batch, sequence, channel)
    
    Returns:
        Tensor of same shape with NaN values filled by linear interpolation
    """
    batch_size, seq_len, channels = data_batch.shape
    result = torch.zeros_like(data_batch)
    
    # Process each batch and channel independently
    for b in range(batch_size):
        for c in range(channels):
            # Get the current time series
            ts = data_batch[b, :, c]
            
            # Create mask for non-NaN values
            mask = ~torch.isnan(ts)
            
            # If all values are NaN, fill with zeros
            if not torch.any(mask):
                result[b, :, c] = 0.0
                continue
                
            # If all values are valid, just copy them
            if torch.all(mask):
                result[b, :, c] = ts
                continue
                
            # Get valid indices and values
            indices = torch.arange(seq_len, device=ts.device)
            valid_indices = indices[mask]
            valid_values = ts[mask]
            
            # Copy valid values to result
            result[b, mask, c] = valid_values
            
            # Interpolate NaN values
            for i in indices[~mask]:
                # Find nearest valid indices before and after current position
                before = valid_indices[valid_indices < i]
                after = valid_indices[valid_indices > i]
                
                if len(before) == 0:  # No valid points before, use the first valid point after
                    result[b, i, c] = valid_values[torch.argmin(torch.abs(valid_indices - i))]
                elif len(after) == 0:  # No valid points after, use the last valid point before
                    result[b, i, c] = valid_values[torch.argmax(torch.abs(valid_indices - i))]
                else:  # Interpolate between closest points before and after
                    i_before = torch.max(before)
                    i_after = torch.min(after)
                    
                    # Calculate interpolation weights
                    w_after = (i - i_before).float() / (i_after - i_before).float()
                    w_before = 1 - w_after
                    
                    # Linear interpolation
                    result[b, i, c] = w_before * ts[i_before] + w_after * ts[i_after]
    
    return result

In [None]:
import logging


class WarningFilter(logging.Filter):
    def __init__(self, text_to_filter):
        super().__init__()
        self.text_to_filter = text_to_filter

    def filter(self, record):
        return self.text_to_filter not in record.getMessage()


gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(
    WarningFilter("The mean prediction is not stored in the forecast data")
)

In [None]:
model_name = "super_linear"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "path_to_model" # todo
model = AutoModelForCausalLM.from_pretrained(model_path,trust_remote_code=True, force_download=True) 

In [None]:
from gluonts.model import evaluate_model
from gluonts.time_feature import get_seasonality
import csv
import os
import sys
from gift_eval.data import Dataset


all_ds_tuples = []

pretty_names = {
    "saugeenday": "saugeen",
    "temperature_rain_with_missing": "temperature_rain",
    "kdd_cup_2018_with_missing": "kdd_cup_2018",
    "car_parts_with_missing": "car_parts",
}

for ds_num, ds_name in enumerate(all_datasets):
    ds_key = ds_name.split("/")[0]
    print(f"Processing dataset: {ds_name} ({ds_num + 1} of {len(all_datasets)})")

    terms = ["short", "medium", "long"]
    for term in terms:
        if (term == "medium" or
            term == "long") and ds_name not in med_long_datasets.split():
            print(f"Skipping {ds_name} for term {term} as it is not in the medium/long datasets list.")
            continue

        if "/" in ds_name:
            ds_key = ds_name.split("/")[0]
            ds_freq = ds_name.split("/")[1]
            ds_key = ds_key.lower()
            ds_key = pretty_names.get(ds_key, ds_key)
        else:
            ds_key = ds_name.lower()
            ds_key = pretty_names.get(ds_key, ds_key)
            ds_freq = dataset_properties_map[ds_key]["frequency"]
        ds_config = f"{ds_key}/{ds_freq}/{term}"
        # Initialize the dataset
        to_univariate = (False if Dataset(
            name=ds_name, term=term, to_univariate=False).target_dim == 1 else True)
        dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
        all_ds_tuples.append(
            (dataset.prediction_length, ds_config, ds_name, to_univariate))

In [None]:
import csv
import os
import sys

from gluonts.model import evaluate_model
from gluonts.time_feature import get_seasonality
from gift_eval.data import Dataset

# Ensure output directory exists
output_dir = f"../results/{model_name}"
os.makedirs(output_dir, exist_ok=True)

# Define the path for the CSV file
csv_file_path = os.path.join(output_dir, "all_results.csv")

# Write the CSV header
with open(csv_file_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow([
        "dataset",
        "model",
        "eval_metrics/MSE[mean]",
        "eval_metrics/MSE[0.5]",
        "eval_metrics/MAE[0.5]",
        "eval_metrics/MASE[0.5]",
        "eval_metrics/MAPE[0.5]",
        "eval_metrics/sMAPE[0.5]",
        "eval_metrics/MSIS",
        "eval_metrics/RMSE[mean]",
        "eval_metrics/NRMSE[mean]",
        "eval_metrics/ND[0.5]",
        "eval_metrics/mean_weighted_sum_quantile_loss",
        "domain",
        "num_variates",
    ])

# Iterate over datasets
for entry in all_ds_tuples:
    prediction_length = entry[0]
    ds_name = entry[2]
    to_univariate = entry[3]
    ds_config = entry[1]
    ds_key, ds_freq, term = ds_config.split("/")

    dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
    season_length = get_seasonality(dataset.freq)

    print(f"Processing entry: {entry}")
    print(f"Dataset size: {len(dataset.test_data)}")

    predictor = SuperLinearPredictor(
        model=model,
        prediction_length=dataset.prediction_length,
        device=device,
    )

    # Evaluate with timing
    with torch.no_grad():
        res = evaluate_model(
            predictor,
            test_data=dataset.test_data,
            metrics=metrics,
            batch_size=1024,
            axis=None,
            mask_invalid_label=True,
            allow_nan_forecast=False,
            seasonality=season_length,
        )

    # Append results to CSV
    with open(csv_file_path, "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([
            ds_config,
            model_name,
            res["MSE[mean]"][0],
            res["MSE[0.5]"][0],
            res["MAE[0.5]"][0],
            res["MASE[0.5]"][0],
            res["MAPE[0.5]"][0],
            res["sMAPE[0.5]"][0],
            res["MSIS"][0],
            res["RMSE[mean]"][0],
            res["NRMSE[mean]"][0],
            res["ND[0.5]"][0],
            res["mean_weighted_sum_quantile_loss"][0],
            dataset_properties_map[ds_key]["domain"],
            dataset_properties_map[ds_key]["num_variates"],
        ])

    print(f"Results for {ds_name} have been written to {csv_file_path}")


In [None]:
import pandas as pd

df = pd.read_csv(f"../results/{model_name}/all_results.csv")
df