In [1]:
# -- Cell 1 -----------------------------------------------------------
import os, sys, pathlib

# ✏️ adjust these two paths if your layout differs
REPO_ROOT = pathlib.Path("~/toto").expanduser()          # outer repo
INNER_DIR = REPO_ROOT / "toto"                           # inner package

# make 'model.', 'data.' etc. importable
inner_str = str(INNER_DIR)
if inner_str not in sys.path:
    sys.path.insert(0, inner_str)

# stay completely offline
os.environ["HF_HUB_OFFLINE"]      = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"

print("✓ sys.path patched, offline mode activated")


✓ sys.path patched, offline mode activated


In [2]:
# -- Cell 2 -----------------------------------------------------------
from toto.model.toto import Toto

CKPT_DIR = pathlib.Path("~/Toto-Open-Base-1.0").expanduser()  # ↙ your folder
toto = Toto.from_pretrained(str(CKPT_DIR)).to("cuda")
toto.compile()          # optional speed-up on PyTorch ≥2.0

print("✓ Toto loaded — #params:", sum(p.numel() for p in toto.parameters())/1e6, "M")




Loading weights from local directory
✓ Toto loaded — #params: 151.30608 M


In [3]:
# -- Cell 3 -----------------------------------------------------------
import torch
from data.util.dataset import MaskedTimeseries        # path works after Cell 1
from toto.inference.forecaster import TotoForecaster

DEVICE = "cuda"
forecaster = TotoForecaster(toto.model)

# dummy 7-var × 4096-step input ─ replace with your real series
input_series = torch.randn(7, 4096, device=DEVICE)
timestamp_seconds    = torch.zeros_like(input_series)
time_interval_seconds = torch.full((7,), 60*15, device=DEVICE)   # 15-min grid

inputs = MaskedTimeseries(
    series              = input_series,
    padding_mask        = torch.ones_like(input_series, dtype=torch.bool),
    id_mask             = torch.zeros_like(input_series),
    timestamp_seconds   = timestamp_seconds,
    time_interval_seconds = time_interval_seconds,
)

forecast = forecaster.forecast(
    inputs,
    prediction_length = 336,
    num_samples       = 128,
    samples_per_batch = 128,   # <= num_samples is simplest
)

print("Forecast median shape:", forecast.median.shape)   # (7, 336)


Forecast median shape: torch.Size([1, 7, 336])


In [1]:
import pandas as pd

data = pd.read_pickle('01_Seismic_Wave_Data_Prediction/03_Results/seislm_toto_20250610_191031/evaluation/SeisLMForMagnitudePrediction_sample_embeddings.pkl')
data.head()

Unnamed: 0,station,period_start,period_end,label,embedding
0,BRE,2023-09-23T00:00:00.019539Z,2023-10-02T23:59:59.994539Z,2,"[-0.15884927, -0.88898444, -3.1738896, 1.50697..."
1,BRE,2024-12-11T00:00:00.019538Z,2024-12-20T23:59:59.994538Z,1,"[-0.17559603, -0.8910637, -3.1793027, 1.484997..."
2,MAN,2024-12-06T00:00:00.004200Z,2024-12-15T23:59:59.979200Z,2,"[-0.0833036, -0.7336906, -2.6597593, 1.167067,..."
3,PASC,2022-09-13T00:00:00.019538Z,2022-09-22T23:59:59.994538Z,2,"[0.29382408, -1.1622548, -3.103171, 1.6138947,..."
4,MAN,2024-05-15T00:00:00.000000Z,2024-05-24T23:59:59.975000Z,0,"[0.017347943, -0.45573452, -2.0281405, 0.61933..."


In [4]:
data['embedding'][0].shape

(768,)