In [None]:
import torch
from data.util.dataset import MaskedTimeseries
from inference.forecaster import TotoForecaster
from model.toto import Toto

# DEVICE = 'cuda'
DEVICE = 'cpu'

# Load pre-trained Toto model
toto = Toto.from_pretrained('Datadog/Toto-Open-Base-1.0').to(DEVICE)

# Optional: compile model for enhanced speed
toto.compile()

forecaster = TotoForecaster(toto.model)

# Example input series (7 variables, 4096 timesteps)
input_series = torch.randn(7, 4096).to(DEVICE)
timestamp_seconds = torch.zeros(7, 4096).to(DEVICE)
time_interval_seconds = torch.full((7,), 60*15).to(DEVICE)

inputs = MaskedTimeseries(
    series=input_series,
    padding_mask=torch.full_like(input_series, True, dtype=torch.bool),
    id_mask=torch.zeros_like(input_series),
    timestamp_seconds=timestamp_seconds,
    time_interval_seconds=time_interval_seconds,
)

# Generate forecasts for next 336 timesteps
forecast = forecaster.forecast(
    inputs,
    prediction_length=336,
    num_samples=256,
    samples_per_batch=256,
)

# Access results
median_prediction = forecast.median
prediction_samples = forecast.samples
lower_quantile = forecast.quantile(0.1)
upper_quantile = forecast.quantile(0.9)


In [8]:
median_prediction.shape

torch.Size([1, 7, 336])

In [5]:
toto

Toto(
  (model): TotoBackbone(
    (patch_embed): PatchEmbedding(
      (projection): Linear(in_features=64, out_features=768, bias=True)
    )
    (transformer): Transformer(
      (rotary_emb): TimeAwareRotaryEmbedding()
      (layers): ModuleList(
        (0-10): 11 x TransformerLayer(
          (norm1): RMSNorm()
          (norm2): RMSNorm()
          (attention): TimeWiseMultiheadAttention(
            (rotary_emb): TimeAwareRotaryEmbedding()
            (wQKV): Linear(in_features=768, out_features=2304, bias=True)
            (wO): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp): Sequential(
            (0): Linear(in_features=768, out_features=6144, bias=True)
            (1): SwiGLU()
            (2): Linear(in_features=3072, out_features=768, bias=True)
            (3): Dropout(p=0.1, inplace=False)
          )
        )
        (11): TransformerLayer(
          (norm1): RMSNorm()
          (norm2): RMSNorm()
          (attention): SpaceWiseMul

Now on 