In [None]:
# make sure the PWD is set to the main `toto` directory
%cd ..

import os

import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch

from data.util.dataset import MaskedTimeseries
from inference.forecaster import TotoForecaster
from model.toto import Toto

# These lines make gpu execution in CUDA deterministic
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)

# Time series forecasting with Toto

In this notebook, you'll learn how to perform inference with Toto for multivariate time series forecasting on the classic ETT dataset. Toto is a foundation model used for *zero-shot* forecasting,
meaning no training is required. We simply provide the historical context of a time series to Toto as input, and Toto produces forecasts of the desired length.

## Prerequisites

You'll need to run this with a CUDA-capable device. In order to get the fastest inference performance, please use an Ampere or newer architecture, as these support the xFormers fused kernel implementations for SwiGLU and Memory-Efficient Attention.

Make sure you've cloned the repo and installed dependencies with `pip install -r requirements.txt`. When running this notebook, make sure the working directory is set to `<repository_root>/toto`.

This notebook also assumes that you've downloaded the `ETT-small` dataset locally. It can be obtained from the [official repo](https://github.com/zhouhaoyi/ETDataset).

       

In [None]:
import pandas as pd
import numpy as np

# ── 1. BASIC SETTINGS ────────────────────────────────────────────────────────
start_date   = "2025-01-01"       # inclusive
end_date     = "2025-01-07"       # exclusive: runs to 2025-01-06 23:59
freq         = "1min"             # aggregation window (≤10 min is valid for VPC logs)
rng          = np.random.default_rng(42)

src_public   = "198.51.100.1"     # static Internet sender
servers = [                       # our two targets inside the VPC
    {"dstaddr": "10.0.0.1", "eni": "eni-0123456789abcdef0"},
    {"dstaddr": "10.0.0.2", "eni": "eni-0123456789abcdef1"},
]

base_packets          = 120       # mean packets per minute at 100 % load
base_bytes_per_packet = 800       # mean TCP payload size in bytes
flow_type             = "14" # constant for every row

# ── 2. CLEAR SINUSOIDAL PATTERN ──────────────────────────────────────────────
def traffic_factor(ts):
    """Return a very clear sinusoidal traffic pattern with daily periodicity."""
    # Calculate hours since start as a float
    start_ts = pd.Timestamp("2025-01-01")
    hours_since_start = (ts - start_ts).total_seconds() / 3600.0
    
    # Create 24-hour sinusoidal cycle with large amplitude
    # Peak at hour 12 (noon), trough at hour 0 (midnight)
    angle = 2 * np.pi * hours_since_start / 24.0
    
    # Large amplitude sine wave: range from 0.2 to 1.8 (20% to 180% of baseline)
    # This will make the pattern very obvious
    return 1.0 + 0.8 * np.sin(angle - np.pi/2)  # -π/2 phase shift so peak is at noon

# ── 3. BUILD THE RECORDS ─────────────────────────────────────────────────────
index = pd.date_range(start=start_date, end=end_date,
                      freq=freq, inclusive="left")

rows = []
for ts in index:
    f = traffic_factor(ts)
    for srv in servers:
        # Almost no noise - just the pure sinusoidal pattern
        pkts  = max(1, int(base_packets * f * (1 + rng.normal(0, 0.005))))  # Tiny noise: 0.5%
        bytes_ = pkts * int(base_bytes_per_packet *
                            (1 + rng.normal(0, 0.002)))  # Even tinier noise: 0.2%

        rows.append({
            "timestamp"    : ts.isoformat(),
            "interface-id" : srv["eni"],
            "srcaddr"      : src_public,
            "dstaddr"      : srv["dstaddr"],
            "type"         : flow_type,
            "packets"      : pkts,
            "bytes"        : bytes_,
        })

df = pd.DataFrame(rows)

# ── 4. SAVE & QUICK PEEK ────────────────────────────────────────────────────
df.to_csv("custom_flow_logs.csv", index=False)
df.head()


## Preprocess data

In the following section, we prepare the data in the expected input format of Toto.

Toto expects inputs to be multivariate time series data in the shape

$\text{Variate} \times \text{Time Steps}$

or, with optional batch dimension:

$\text{Batch} \times \text{Variate} \times \text{Time Steps}$

For illustration, we'll try to predict the last 96 steps of the ETTm1 time series across its 7 covariates. We'll do this using the preceding 1024 steps as context. 2048 gives a good balance of speed vs. performance; you may want to experiment with different context lengths depending on your dataset. Toto was trained with a max context length of 4096, but can extrapolate to even longer contexts. 


In [None]:
context_length = 4096
prediction_length = 336

Slice the ETTm1 data accordingly

In [None]:
df1 = pd.read_csv("custom_flow_logs.csv", parse_dates=["timestamp"])


df = (
    pd.read_csv("custom_flow_logs.csv")
    .assign(date=lambda df: pd.to_datetime(df["timestamp"]))
    .assign(timestamp_seconds=lambda df: (df.date - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s'))

)

df.reindex(columns=["date"] + [col for col in df.columns if col != "date"]).drop(columns=["timestamp"])  # Remove timestamp column


df

In [None]:
feature_columns = ["packets", "bytes"]
n_variates = len(feature_columns)
interval = 60 * 15  # 15-min intervals
input_df = df.iloc[-(context_length+prediction_length):-prediction_length]
target_df = df.iloc[-prediction_length:]
DEVICE = "cuda"

input_series = torch.from_numpy(input_df[feature_columns].values.T).to(torch.float).to(DEVICE)
input_series.shape


Add timestamp features to the data. Note: the current version of Toto does not use these features; it handles series of different time resolutions implicitly. However, future versions may take this into account, so the API expects timestamps to be passed in.

In [None]:
timestamp_seconds = torch.from_numpy(input_df.timestamp_seconds.values.T).expand((n_variates, context_length)).to(input_series.device)
time_interval_seconds=torch.full((n_variates,), interval).to(input_series.device)
start_timestamp_seconds = timestamp_seconds[:, 0]

Toto expects its inputs in the form of a `MaskedTimeseries` dataclass.

In [None]:
inputs = MaskedTimeseries(
    series=input_series,
    # The padding mask should be the same shape as the input series.
    # It should be 0 to indicate padding and 1 to indicate valid values.
    padding_mask=torch.full_like(input_series, True, dtype=torch.bool),
    # The ID mask is used for packing unrelated time series along the Variate dimension.
    # This is used in training, and can also be useful for large-scale batch inference in order to
    # process time series of different numbers of variates using batches of a fixed shape.
    # The ID mask controls the channel-wise attention; variates with different IDs cannot attend to each other.
    # If you're not using packing, just set this to zeros.
    id_mask=torch.zeros_like(input_series),
    # As mentioned above, these timestamp features are not currently used by the model;
    # however, they are reserved for future releases.
    timestamp_seconds=timestamp_seconds,
    time_interval_seconds=time_interval_seconds,
)

Now our data is ready!

## Load Toto checkpoint

Download a Toto checkpoint from Hugging Face (TBD) to a local directory

In [None]:
toto = Toto.from_pretrained('Datadog/Toto-Open-Base-1.0')
toto.to(DEVICE)

# Optionally enable Torch's JIT compilation to speed up inference. This is mainly
# helpful if you want to perform repeated inference, as the JIT compilation can
# take time to wrm up.
toto.compile()

We generate multistep, autoregressive forecasts using the `TotoForecaster` class. 

In [None]:
from types import SimpleNamespace
import torch
import numpy as np

DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
interval = 60 * 15                       # 900 s
seq_len  = context_length                # whatever slice you chose

# ── series  (1, 6, seq_len) ────────────────────────────────────────────────
series_tensor = (
    torch.tensor(input_df[feature_columns].values.T, dtype=torch.float32)
         .unsqueeze(0)                   # add batch dim
         .to(DEVICE)
)

# ── timestamp_seconds  (1, 6, seq_len) ─────────────────────────────────────
ts_np = (input_df.index.view("int64") // 1_000_000_000)
timestamp_seconds = (
    torch.from_numpy(ts_np)
         .long()
         .unsqueeze(0).unsqueeze(0)      # batch, variate
         .repeat(1, n_variates, 1)
         .to(DEVICE)
)

# ── time_interval_seconds  (1, 6) ──────────────────────────────────────────
time_interval_seconds = torch.full(
    (1, n_variates),
    fill_value=interval,
    dtype=torch.long,
    device=DEVICE,
)

# ── Build the batch object *with padding_mask* ─────────────────────────────
inputs = SimpleNamespace(
    series=series_tensor,
    timestamp_seconds=timestamp_seconds,
    time_interval_seconds=time_interval_seconds,
    padding_mask=torch.ones_like(series_tensor, dtype=torch.bool),  # Create a mask of all True values
    id_mask=None,        # optional; keep None if you have only one “item”
)

# ── Forecast ────────────────────────────────────────────────────────────────
forecaster = TotoForecaster(toto.model)
forecast = forecaster.forecast(
    inputs,
    prediction_length=prediction_length,
    num_samples=256,
    samples_per_batch=256,
    use_kv_cache=True,
)

print("forecast mean shape:", forecast.mean.shape)  # (1, 6, prediction_length)


## Visualize the forecasts

We can plot our forecasts and confidence intervals against the ground truth:

In [None]:
DARK_GREY = "#1c2b34"
BLUE = "#3598ec"
PURPLE = "#7463e1"
LIGHT_PURPLE = "#d7c3ff"
PINK = "#ff0099"

matplotlib.rc("axes", edgecolor=DARK_GREY)
fig = plt.figure(figsize=(12, 6), layout="tight", dpi=150)
plt.suptitle("Toto Forecasts (ETTm1)")

for i, feature in enumerate(feature_columns):
    # Configure axes
    plt.subplot(n_variates, 1, i + 1)
    if i != 6:
        # only show x tick labels on the bottom subplot
        fig.gca().set_xticklabels([])
    fig.gca().tick_params(axis="x", color=DARK_GREY, labelcolor=DARK_GREY)
    fig.gca().tick_params(axis="y", color=DARK_GREY, labelcolor=DARK_GREY)
    fig.gca().yaxis.set_label_position("right")
    plt.ylabel(feature)
    plt.xlim(input_df.date.iloc[-960], target_df.date.iloc[-1])
    plt.axvline(target_df.date.iloc[0], color=PINK, linestyle=":")

    # Plot ground truth
    plt.plot(input_df["date"], input_df[feature], color=BLUE)
    plt.plot(target_df["date"], target_df[feature], color=BLUE)

    # Plot point forecasts
    plt.plot(
        target_df["date"],
        np.median(forecast.samples.squeeze()[i].cpu(), axis=-1),
        color=PURPLE,
        linestyle="--",
    )

    # Plot quantiles
    alpha = 0.05
    qs = forecast.samples.quantile(q=torch.tensor([alpha, 1 - alpha], device=forecast.samples.device), dim=-1)
    plt.fill_between(
        target_df["date"],
        qs[0].squeeze()[i].cpu(),
        qs[1].squeeze()[i].cpu(),
        color=LIGHT_PURPLE,
        alpha=0.8,
    )