|              #              | TensorFlow (Keras)                                                                                                   | PyTorch Equivalent                                                                                                                                                                       |
| :-------------------------: | -------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|              1.             | **Config & seeds**<br>`config = yaml…`<br>`Config.set_sub_day(...)`<br>`tf.random.set_seed(…)`, `np.random.seed(…)`  | • Parse the same `config.yaml` (via `yaml.safe_load`)  <br>• Apply any “sub\_day” flags <br>• `torch.manual_seed(…)`, `np.random.seed(…)`                                                |
|              2.             | **Load TFRecords**<br>`load_tf_dataset(prefix+"/daily.tfrecords")` (×3)<br>`choose_from_datasets([...])`             | • Prepare `daily.pt`,`weekly.pt`,`monthly.pt`  <br>• `SeriesDataset(pt_paths)` interleaves them in `__iter__`                                                                   |
|              3.             | **Frame & task generation**<br>`.map(build_frames)` → sliding windows + single‐point Δ<br>`.map(gen_random_…)`, etc. | • Port `build_frames()` → a helper that takes a `[seq_len]` tensor and returns `(window, query_ts, label)`  <br>• Apply three generators in Python (e.g. pick one at random per example) |
|              4.             | **Filter & remove noise**<br>`.filter(filter_unusable_points)`<br>`.map(remove_noise)`                               | • In your Dataset iterator, drop invalid windows (e.g. too short)  <br>• Subtract out noise from the label before yielding                                                               |
|              5.             | **Batch, shuffle, prefetch**<br>`train_df.shuffle(...).batch(1024).prefetch(AUTO)`<br>`test_df.batch(1024)`          | • Wrap `SeriesDataset` in `DataLoader(..., batch_size=1024, shuffle=…)`  <br>• Use a custom `collate_fn` if you need to assemble `(window, query_ts)` into batched tensors               |
|              6.             | **Model instantiation & compile**<br>`model = TransformerModel(...)`<br>`model.compile(Adam, MSE, metrics…)`         | • `model = MyModel(...).to(device)`<br>• `loss_fn = nn.MSELoss()`<br>• `opt = torch.optim.Adam(model.parameters(), lr=…)`                                                                |
|              7.             | **Warm-up forward pass**<br>`_ = model(batch_X)`                                                                     | • (Optional) grab one batch from DataLoader and run `model(window, query_ts)` to ensure shapes align                                                                                |
|              8.             | **Callbacks**<br>`ModelCheckpoint`, `TensorBoard`, etc.                                                              | • After each epoch (or best-val), call `torch.save(model.state_dict(), path)`<br>• (Optional) hook up `torch.utils.tensorboard.SummaryWriter`                                            |
|              9.             | **Train loop**<br>`model.fit(..., epochs=700, steps_per_epoch=10)`                                                   | • Standard loop|
|             10.             | **Save final model**<br>`model.save(prefix+"models/"+name)`                            | • `torch.save(model.state_dict(), prefix+"/models/"+name+".pt")`                |


In [6]:
DATA_DIR = f"../synthetic-data/"

In [11]:
import sys
import os

# Go up one level to project root
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [7]:
# data_pipeline.py

import random
import torch
from torch.utils.data import IterableDataset
from prepare_dataset import (
    build_frames,
    gen_random_single_point,    gen_mean_to_random_date,    gen_std_to_random_date,
    gen_random_single_point_no_noise, gen_mean_to_random_date_no_noise, gen_std_to_random_date_no_noise,
    filter_unusable_points,
    remove_noise,
)

class RawCombinedDataset(IterableDataset):
    """Interleave daily, weekly, monthly raw series forever."""
    def __init__(self, pt_paths):
        super().__init__()
        self.blocks = [torch.load(p) for p in pt_paths]
        self.num_blocks = len(self.blocks)

    def __iter__(self):
        block_idx = 0
        while True:
            block = self.blocks[block_idx]
            for i in range(block["y"].size(0)):
                yield {
                    "ts":    block["ts"][i],    # [seq_len]
                    "y":     block["y"][i],     # [seq_len]
                    "noise": block["noise"][i], # [seq_len]
                }
            block_idx = (block_idx + 1) % self.num_blocks


class HeadRawDataset(IterableDataset):
    """Take only the first `n` raw series, then stop."""
    def __init__(self, raw_ds, n):
        super().__init__()
        self.raw_ds = raw_ds
        self.n = n

    def __iter__(self):
        it = iter(self.raw_ds)
        for _ in range(self.n):
            yield next(it)


class TailRawDataset(IterableDataset):
    """Skip the first `n` raw series, then interleave the rest forever."""
    def __init__(self, raw_ds, skip):
        super().__init__()
        self.raw_ds = raw_ds
        self.skip = skip

    def __iter__(self):
        it = iter(self.raw_ds)
        # drop first `skip` examples
        for _ in range(self.skip):
            next(it)
        # then yield the remainder forever
        for rec in it:
            yield rec


class FramedDataset(IterableDataset):
    """
    Takes a raw‐series IterableDataset, applies build_frames,
    randomly picks one of the 3 tasks (with or without noise),
    filters, then removes noise fields.
    """
    def __init__(self, raw_ds: IterableDataset, test_noise: bool):
        super().__init__()
        self.raw_ds = raw_ds
        # train always uses noise‐tasks; test_noise flag picks noise vs no‐noise
        self.train_tasks = [
            gen_random_single_point,
            gen_mean_to_random_date,
            gen_std_to_random_date,
        ]
        self.test_tasks = (
            self.train_tasks
            if test_noise
            else [
                gen_random_single_point_no_noise,
                gen_mean_to_random_date_no_noise,
                gen_std_to_random_date_no_noise,
            ]
        )

    def __iter__(self):
        # We assume that whoever instantiates this knows if it's for train or test,
        # so we'll just pick from test_tasks here.
        for rec in self.raw_ds:
            di, hist, noise, td, tv, tn = build_frames(rec)
            task_fn = random.choice(self.test_tasks)
            X, y = task_fn(di, hist, noise, td, tv, tn)
            if not filter_unusable_points(X, y):
                continue
            X_batch, y_batch = remove_noise(X, y)
            B = y_batch.shape[0]
            for j in range(B):
                Xj = {
                  "ts":        X_batch["ts"][j],       # [HISTORY_LEN,5]
                  "history":   X_batch["history"][j],  # [HISTORY_LEN,1]
                  "target_ts": X_batch["target_ts"][j],# [5]
                  "task":      X_batch["task"][j],     # scalar
                }
                yj = y_batch[j]                         # scalar
                yield Xj, yj


def get_combined_ds():
    """
    Equivalent to your TF get_combined_ds(config).
    """
    paths = [DATA_DIR + f + ".pt" for f in ("daily", "weekly", "monthly")]
    return RawCombinedDataset(paths)

def create_train_test_dataset(combined_ds, test_noise):
    test_ds  = FramedDataset(HeadRawDataset(combined_ds, 30), test_noise)
    train_ds = FramedDataset(TailRawDataset(combined_ds, 30), True)
    return train_ds, test_ds

In [8]:
from torch.utils.data import DataLoader

combined_ds = get_combined_ds()
train_ds, test_ds = create_train_test_dataset(combined_ds, False)

train_loader = DataLoader(train_ds, batch_size=1024, num_workers=0)
test_loader  = DataLoader(test_ds,  batch_size=1024, num_workers=0)

In [9]:
# sanity check:
batch_X, batch_y = next(iter(train_loader))
print({k: v.shape for k, v in batch_X.items()}, batch_y.shape)

{'ts': torch.Size([1024, 100, 5]), 'history': torch.Size([1024, 100, 1]), 'target_ts': torch.Size([1024, 5]), 'task': torch.Size([1024])} torch.Size([1024])


In [15]:
class ShuffleBuffer(IterableDataset):
    """
    Replays an IterableDataset through a rolling buffer of size `buf_sz`,
    yielding random items from it. On __iter__ it reshuffles.
    """
    def __init__(self, ds: IterableDataset, buf_sz: int):
        super().__init__()
        self.ds = ds
        self.buf_sz = buf_sz

    def __iter__(self):
        it = iter(self.ds)
        buf = []
        for item in it:
            buf.append(item)
            if len(buf) >= self.buf_sz:
                j = random.randrange(len(buf))
                yield buf.pop(j)
        # flush remainder
        while buf:
            j = random.randrange(len(buf))
            yield buf.pop(j)


In [16]:
train_shuffled = ShuffleBuffer(train_ds, buf_sz=5000)

train_loader = DataLoader(
    train_shuffled,
    batch_size=1024,
    num_workers=0,
)

In [None]:
from torch.utils.tensorboard import SummaryWriter
import datetime

# before training
run_id = datetime.datetime.now().strftime("%Y-%m-%d/%H:%M:%S")
writer = SummaryWriter(log_dir=os.path.join(DATA_DIR, "logs", run_id))
ckpt_dir = os.path.join(DATA_DIR, "models", run_id, "ckpts")
print(run_id)

2025-06-21/15:34:39


In [20]:
import numpy as np
import torch.nn as nn
import torch.optim as optim

from models.ForecastPFN import ForecastPFN

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

os.makedirs(ckpt_dir, exist_ok=True)

# Model, loss, optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ForecastPFN(scaler='robust').to(device)

loss_fn = nn.MSELoss()
opt = optim.Adam(model.parameters(), lr=1e-4)

# # Warm-up forward (avoids “first call” issues)
# batch_X, _ = next(iter(train_loader))
# _ = model(batch_X)

# Training loop
epochs = 700                                  # same as TF.fit(epochs=700)
steps_per_epoch = 10                          # same as TF.fit(steps_per_epoch=10)

for eps in range(1, epochs + 1):
    model.train()
    sum_loss = 0.0
    sum_mape = 0.0
    sum_smape= 0.0
    it = iter(train_loader)
    for step in range(steps_per_epoch):
        Xb, yb = next(it)
        Xb = {k: v.to(device) for k,v in Xb.items()}
        yb = yb.to(device)

        Xb["history"]   = Xb["history"].squeeze(-1)      # [B, T]
        Xb["target_ts"] = Xb["target_ts"].unsqueeze(1)   # [B, 1, 5]

        opt.zero_grad()
        out = model(Xb)
        pred = out['result'].view(-1)
        loss   = loss_fn(pred, yb)
        loss.backward()
        opt.step()

        sum_loss += loss.item()

        # MAPE
        mape_batch = (torch.abs((pred - yb) / (yb.clamp(min=eps)))).mean().item() * 100
        sum_mape += mape_batch

        # SMAPE
        smape_batch = (2 * torch.abs(pred - yb) /
                       (pred.abs() + yb.abs() + eps)
                      ).mean().item() * 100
        sum_smape += smape_batch

    avg_mse   = sum_loss   / steps_per_epoch
    avg_mape  = sum_mape   / steps_per_epoch
    avg_smape = sum_smape  / steps_per_epoch

    # Validation
    model.eval()
    v_loss = v_mape = v_smape = 0.0
    val_batches = 0
    with torch.no_grad():
        for Xv, yv in test_loader:
            val_batches += 1
            Xv = {k: v.to(device) for k,v in Xv.items()}
            Xv["history"]   = Xv["history"].squeeze(-1)
            Xv["target_ts"] = Xv["target_ts"].unsqueeze(1)
            yv = yv.to(device)

            out = model(Xv)
            pred = out['result'].view(-1)

            l     = loss_fn(pred, yv).item()
            v_loss  += l
            v_mape += (torch.abs((pred - yv) / (yv.clamp(min=eps)))
                       ).mean().item() * 100
            v_smape+= (2 * torch.abs(pred - yv) /
                       (pred.abs() + yv.abs() + eps)
                       ).mean().item() * 100

    avg_val_mse   = v_loss   / val_batches
    avg_val_mape  = v_mape   / val_batches
    avg_val_smape = v_smape  / val_batches

    # log & print
    writer.add_scalar("MSE/train",   avg_mse,   eps)
    writer.add_scalar("MAPE/train",  avg_mape,  eps)
    writer.add_scalar("SMAPE/train", avg_smape, eps)
    writer.add_scalar("MSE/val",     avg_val_mse,   eps)
    writer.add_scalar("MAPE/val",    avg_val_mape,  eps)
    writer.add_scalar("SMAPE/val",   avg_val_smape, eps)

    print(f"{datetime.datetime.now():%Y-%m-%d %H:%M:%S}"
          f"  Epoch {eps:3d}/{epochs:3d}: "
          f"train MSE={avg_mse:.4f}, MAPE={avg_mape:.2f}%, SMAPE={avg_smape:.2f}%  "
          f"val MSE={avg_val_mse:.4f}, MAPE={avg_val_mape:.2f}%, SMAPE={avg_val_smape:.2f}%"
    )


# Save final model
out_dir = os.path.join(DATA_DIR, "models")
os.makedirs(out_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(out_dir, "model2.pt"))
print(f"Model saved to {out_dir}")

2025-06-21 15:36:10  Epoch   1/700: train MSE=4739.5533, MAPE=142.81%, SMAPE=74.83%  val MSE=257188.1128, MAPE=456.25%, SMAPE=59.67%
2025-06-21 15:36:22  Epoch   2/700: train MSE=3274.4896, MAPE=85.24%, SMAPE=44.24%  val MSE=324205.4708, MAPE=362.83%, SMAPE=55.61%
2025-06-21 15:36:35  Epoch   3/700: train MSE=2890.0188, MAPE=51.72%, SMAPE=33.84%  val MSE=185391.2318, MAPE=847.45%, SMAPE=43.30%
2025-06-21 15:36:46  Epoch   4/700: train MSE=2374.8565, MAPE=63.76%, SMAPE=32.33%  val MSE=192137.8758, MAPE=727.08%, SMAPE=39.81%
2025-06-21 15:36:58  Epoch   5/700: train MSE=3649.1847, MAPE=37.86%, SMAPE=24.74%  val MSE=305979.6522, MAPE=621.37%, SMAPE=36.18%
2025-06-21 15:37:10  Epoch   6/700: train MSE=1897.1645, MAPE=37.26%, SMAPE=20.92%  val MSE=176226.7023, MAPE=425.02%, SMAPE=33.27%
2025-06-21 15:37:22  Epoch   7/700: train MSE=653.5235, MAPE=26.72%, SMAPE=20.48%  val MSE=170534.7986, MAPE=433.93%, SMAPE=33.23%
2025-06-21 15:37:35  Epoch   8/700: train MSE=404.5187, MAPE=27.50%, SMAPE=1