## Imports

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from ipywidgets import IntSlider, interact
from transformers import set_seed

In [2]:
from tsfm_public import TimeSeriesPreprocessor
from tsfm_public.models.tspulse import TSPulseForReconstruction
from tsfm_public.toolkit.time_series_imputation_pipeline import TimeSeriesImputationPipeline

## Preparing the Dataset

In [3]:
# Set seed for reproducibility
SEED = 42
set_seed(SEED)
CONTEXT_LENGTH = 512
PREDICTION_LENGTH = 0

TARGET_DATASET = "etth1"
dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"

In [4]:
timestamp_column = "date"
id_columns = []  # mention the ids that uniquely identify a time-series.
target_columns = [
    "HUFL",
    "HULL",
    "MUFL",
    "MULL",
    "LUFL",
    "LULL",
    "OT",
]  # mention the target column names in the dataset that should be imputed by the model

data = pd.read_csv(
    dataset_path,
    parse_dates=[timestamp_column],
)

fully_observed_ground_truth = data.copy()

print(data.shape)
print(data.head())

column_specifiers = {
    "timestamp_column": timestamp_column,
    "id_columns": id_columns,
    "target_columns": target_columns,
    "control_columns": [],
}

(17420, 8)
                 date   HUFL   HULL   MUFL   MULL   LUFL   LULL         OT
0 2016-07-01 00:00:00  5.827  2.009  1.599  0.462  4.203  1.340  30.531000
1 2016-07-01 01:00:00  5.693  2.076  1.492  0.426  4.142  1.371  27.787001
2 2016-07-01 02:00:00  5.157  1.741  1.279  0.355  3.777  1.218  27.787001
3 2016-07-01 03:00:00  5.090  1.942  1.279  0.391  3.807  1.279  25.044001
4 2016-07-01 04:00:00  5.358  1.942  1.492  0.462  3.868  1.279  21.948000


### Introducing NaN values at the end to mimic forecasting horizon

In [5]:
def mask_last_context(df: pd.DataFrame, forecast_length: int) -> pd.DataFrame:
    """
    Replaces the last `context_length` values in each column of the DataFrame with NaN.
    
    Args:
        df (pd.DataFrame): Input DataFrame.
        forecast_length (int): Number of values at the end of each column to mask as NaN.
    
    Returns:
        pd.DataFrame: Modified DataFrame with NaNs at the end of each column.
    """
    df_copy = df.copy()
    if forecast_length > 0:
        df_copy.iloc[-forecast_length:, :] = np.nan
    return df_copy

In [6]:
FORECAST_LEN = 96

In [7]:
data = data.iloc[:512, :]
data

Unnamed: 0,date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT
0,2016-07-01 00:00:00,5.827,2.009,1.599,0.462,4.203,1.340,30.531000
1,2016-07-01 01:00:00,5.693,2.076,1.492,0.426,4.142,1.371,27.787001
2,2016-07-01 02:00:00,5.157,1.741,1.279,0.355,3.777,1.218,27.787001
3,2016-07-01 03:00:00,5.090,1.942,1.279,0.391,3.807,1.279,25.044001
4,2016-07-01 04:00:00,5.358,1.942,1.492,0.462,3.868,1.279,21.948000
...,...,...,...,...,...,...,...,...
507,2016-07-22 03:00:00,15.271,5.291,9.772,2.452,4.813,1.401,35.666000
508,2016-07-22 04:00:00,12.525,3.684,7.782,1.990,4.630,1.371,35.525002
509,2016-07-22 05:00:00,13.329,4.421,8.315,2.025,4.873,1.401,36.862000
510,2016-07-22 06:00:00,12.860,4.488,8.102,2.097,4.904,1.492,35.033001


In [8]:
data = mask_last_context(data, forecast_length=FORECAST_LEN)
data

Unnamed: 0,date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT
0,2016-07-01 00:00:00,5.827,2.009,1.599,0.462,4.203,1.340,30.531000
1,2016-07-01 01:00:00,5.693,2.076,1.492,0.426,4.142,1.371,27.787001
2,2016-07-01 02:00:00,5.157,1.741,1.279,0.355,3.777,1.218,27.787001
3,2016-07-01 03:00:00,5.090,1.942,1.279,0.391,3.807,1.279,25.044001
4,2016-07-01 04:00:00,5.358,1.942,1.492,0.462,3.868,1.279,21.948000
...,...,...,...,...,...,...,...,...
507,NaT,,,,,,,
508,NaT,,,,,,,
509,NaT,,,,,,,
510,NaT,,,,,,,


### Creating a tsp (To preprocess the input data and perform scaling (if needed))

In [9]:
tsp = TimeSeriesPreprocessor(
    **column_specifiers,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    scaling=True,
    encode_categorical=False,
    scaler_type="standard",
)

## Getting the pre-trained TSPulse Model

In [10]:
model = TSPulseForReconstruction.from_pretrained(
    "ibm-granite/granite-timeseries-tspulse-r1",
    revision="tspulse-hybrid-dualhead-512-p8-r1",
    num_input_channels=tsp.num_input_channels,
    mask_type="user",
)

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

## Using Imputation Pipeline for Zero-Shot Imputation

In [11]:
tsp.train(data)  # train the tsp
pipe = TimeSeriesImputationPipeline(model, feature_extractor=tsp, batch_size=1000, device=device)

Device set to use cuda


In [12]:
out = pipe(data)

In [13]:
out.shape

(512, 15)

In [14]:
out.head()

Unnamed: 0,date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT,HUFL_imputed,HULL_imputed,MUFL_imputed,MULL_imputed,LUFL_imputed,LULL_imputed,OT_imputed
0,2016-07-01 00:00:00,5.827,2.009,1.599,0.462,4.203,1.34,30.531,5.827,2.009,1.599,0.462,4.203,1.34,30.531
1,2016-07-01 01:00:00,5.693,2.076,1.492,0.426,4.142,1.371,27.787001,5.693,2.076,1.492,0.426,4.142,1.371,27.787001
2,2016-07-01 02:00:00,5.157,1.741,1.279,0.355,3.777,1.218,27.787001,5.157,1.741,1.279,0.355,3.777,1.218,27.787001
3,2016-07-01 03:00:00,5.09,1.942,1.279,0.391,3.807,1.279,25.044001,5.09,1.942,1.279,0.391,3.807,1.279,25.044001
4,2016-07-01 04:00:00,5.358,1.942,1.492,0.462,3.868,1.279,21.948,5.358,1.942,1.492,0.462,3.868,1.279,21.948


## Plotting the Observed and Imputed values

In [15]:
def plot_interactive_imputation(df, fully_observed_ground_truth=None, window_size=512):
    df = df.drop("date", axis=1)
    observed_cols = [col.removesuffix("_imputed") for col in df.columns if col.endswith("_imputed")]

    num_points = len(df)

    def plot_window(start_idx):
        end_idx = min(start_idx + window_size, num_points)
        x_range = np.arange(start_idx, end_idx, dtype=float)

        plt.figure(figsize=(15, 3 * len(observed_cols)))

        for i, base_col in enumerate(observed_cols):
            imputed_col = f"{base_col}_imputed"
            observed_vals = df[base_col].iloc[start_idx:end_idx]
            imputed_vals = df[imputed_col].iloc[start_idx:end_idx]

            pos_observed = ~observed_vals.isna()
            pos_imputed = observed_vals.isna()
            plt.subplot(len(observed_cols), 1, i + 1)

            plt.plot(x_range[pos_imputed], imputed_vals[pos_imputed], color="red",linewidth=2, label="Imputed")

            if fully_observed_ground_truth is not None:
                full_vals = fully_observed_ground_truth[base_col].iloc[start_idx:end_idx]
                plt.plot(x_range, full_vals, color="blue", linewidth=2, label="Ground_Truth")  # actual ground truth
                y_min = np.min(full_vals)
                y_mask = np.full(np.sum(pos_imputed), y_min)
                plt.plot(
                    x_range[pos_imputed], y_mask, color="green", linewidth=2, label="mask"
                )  # positions where model has imputed the missing values.
            else:
                plt.plot(
                    x_range[pos_observed], observed_vals[pos_observed], color="blue", linewidth=2, label="Observed"
                )  # Plot the data with missing values with "linear interpolation" in the plot at missing positions.
                y_min = np.min(observed_vals)
                y_mask = np.full(np.sum(pos_imputed), y_min)
                plt.plot(
                    x_range[pos_imputed], y_mask, color="green", linewidth=2, label="mask"
                )  # positions where model has imputed the missing values.

            plt.title(f"Channel: {base_col}")
            plt.legend()
            plt.grid(True)
            

        plt.tight_layout()
        plt.show()

    print("Plotting static plot for first window")
    plot_window(0)  # plotting the first window as static plot

    # interact(
    #     plot_window,
    #     start_idx=IntSlider(
    #         value=0,
    #         min=0,
    #         max=max(0, num_points - window_size),
    #         step=1,
    #         description="Start Index",
    #         continuous_update=False,
    #     ),
    # )

In [16]:
# plot_interactive_imputation(out, fully_observed_ground_truth)

In [17]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

def plot_interactive_imputation(df, fully_observed_ground_truth=None, window_size=512, plot_index=0, save_dir="imputation_plots"):
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    df = df.drop("date", axis=1)
    observed_cols = [col.removesuffix("_imputed") for col in df.columns if col.endswith("_imputed")]

    num_points = len(df)

    def plot_window(start_idx):
        end_idx = min(start_idx + window_size, num_points)
        x_range = np.arange(start_idx, end_idx, dtype=float)

        plt.figure(figsize=(15, 3 * len(observed_cols)))

        for i, base_col in enumerate(observed_cols):
            imputed_col = f"{base_col}_imputed"
            observed_vals = df[base_col].iloc[start_idx:end_idx]
            imputed_vals = df[imputed_col].iloc[start_idx:end_idx]

            pos_observed = ~observed_vals.isna()
            pos_imputed = observed_vals.isna()

            plt.subplot(len(observed_cols), 1, i + 1)
            plt.plot(x_range[pos_imputed], imputed_vals[pos_imputed], color="red", linewidth=2, label="Imputed")

            if fully_observed_ground_truth is not None:
                full_vals = fully_observed_ground_truth[base_col].iloc[start_idx:end_idx]
                plt.plot(x_range, full_vals, color="blue", linewidth=2, label="Ground_Truth")
                y_min = np.min(full_vals)
                y_mask = np.full(np.sum(pos_imputed), y_min)
                plt.plot(x_range[pos_imputed], y_mask, color="green", linewidth=2, label="Mask")
            else:
                plt.plot(x_range[pos_observed], observed_vals[pos_observed], color="blue", linewidth=2, label="Observed")
                y_min = np.min(observed_vals)
                y_mask = np.full(np.sum(pos_imputed), y_min)
                plt.plot(x_range[pos_imputed], y_mask, color="green", linewidth=2, label="Mask")

            plt.title(f"Channel: {base_col}")
            plt.legend()
            plt.grid(True)

        plt.tight_layout()
        plot_filename = os.path.join(save_dir, f"imputation_plot_{plot_index}.pdf")
        plt.savefig(plot_filename)
        plt.close()
        print(f"Saved plot to: {plot_filename}")

    print(f"Saving static plot for window 0 as plot index {plot_index}")
    plot_window(0)


In [18]:
import random
for i in range(20):
    r = random.randint(0, 1000)
    actual_data = fully_observed_ground_truth.iloc[r:r+512, :]
    temp_data = mask_last_context(actual_data, forecast_length=FORECAST_LEN)
    tsp.train(temp_data)
    pipe = TimeSeriesImputationPipeline(model, feature_extractor=tsp, batch_size=1000, device=device)
    out = pipe(temp_data)
    plot_interactive_imputation(out, actual_data, plot_index=r)

Device set to use cuda


Saving static plot for window 0 as plot index 654


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_654.pdf
Saving static plot for window 0 as plot index 114


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_114.pdf
Saving static plot for window 0 as plot index 25


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_25.pdf
Saving static plot for window 0 as plot index 759


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_759.pdf
Saving static plot for window 0 as plot index 281


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_281.pdf
Saving static plot for window 0 as plot index 250


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_250.pdf
Saving static plot for window 0 as plot index 228


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_228.pdf
Saving static plot for window 0 as plot index 142


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_142.pdf
Saving static plot for window 0 as plot index 754


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_754.pdf
Saving static plot for window 0 as plot index 104


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_104.pdf
Saving static plot for window 0 as plot index 692


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_692.pdf
Saving static plot for window 0 as plot index 758


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_758.pdf
Saving static plot for window 0 as plot index 913


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_913.pdf
Saving static plot for window 0 as plot index 558


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_558.pdf
Saving static plot for window 0 as plot index 89


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_89.pdf
Saving static plot for window 0 as plot index 604


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_604.pdf
Saving static plot for window 0 as plot index 432


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_432.pdf
Saving static plot for window 0 as plot index 32


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_32.pdf
Saving static plot for window 0 as plot index 30


Device set to use cuda


Saved plot to: imputation_plots/imputation_plot_30.pdf
Saving static plot for window 0 as plot index 95
Saved plot to: imputation_plots/imputation_plot_95.pdf


## Evaluate the Model

Evaluate the zero-shot performance of the model on the dataset.

In [15]:
def custom_metric(actual, missing_df, prediction, column_header="results"):
    """Simple function to compute MSE"""
    a = actual.to_numpy(dtype=float)
    p = prediction.to_numpy(dtype=float)

    missing_positions = np.isnan(missing_df)

    mse = np.mean(np.square(a[missing_positions] - p[missing_positions]))
    mae = np.mean(np.abs(a[missing_positions] - p[missing_positions]))
    return pd.DataFrame(
        {
            column_header: {
                "mean_squared_error": mse,
                "root_mean_squared_error": np.sqrt(mse),
                "mean_absolute_error": mae,
            }
        }
    )

In [16]:
ground_truth = fully_observed_ground_truth[target_columns]  # original df having no missing values
ground_truth_with_missing_data = out[target_columns]  # df having missing values

imputed_columns = [f"{col}_imputed" for col in target_columns]
imputed_df = out[imputed_columns]  # df having imputed values at the missing data positions

custom_metric(ground_truth, ground_truth_with_missing_data, imputed_df, "zero-shot imputation")

Unnamed: 0,zero-shot imputation
mean_squared_error,1.670301
root_mean_squared_error,1.292401
mean_absolute_error,0.681345
