# Import libraries

In [2]:
# General libraries
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
from sklearn.preprocessing import StandardScaler
import torch
from tqdm.notebook import tqdm
import warnings
import yaml

# TiRex
from chronos import ChronosPipeline
from tabpfn import TabPFNClassifier
from tirex import load_model, ForecastModel, TiRexZero
from tirex_util import load_tirex_from_checkpoint

warnings.filterwarnings("ignore")

In [3]:
# === Base class ===
class BaseForecastModel:
    def forecast(self, context, horizon):
        """Returns quantiles[1,9,horizon], mean[1,horizon]"""
        raise NotImplementedError


# === TiRex ===
class TiRexModel(BaseForecastModel):
    def __init__(self, checkpoint=None):
        if checkpoint and os.path.isfile(checkpoint):
            self.model = load_tirex_from_checkpoint(checkpoint_path=checkpoint, model_id="TiRex")
        else:
            self.model = load_model("NX-AI/TiRex")

    def forecast(self, context, horizon):
        return self.model.forecast(context, prediction_length=horizon)


# === Chronos ===
class ChronosModel(BaseForecastModel):
    def __init__(self):
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"
        self.model = ChronosPipeline.from_pretrained("amazon/chronos-t5-base", device_map=device)

    def forecast(self, context, horizon):
        ts = torch.tensor(context, dtype=torch.float32)
        if ts.ndim == 1:
            ts = ts.unsqueeze(0)

        preds = self.model.predict(ts, prediction_length=horizon, num_samples=20)
        if isinstance(preds, np.ndarray):
            preds = torch.tensor(preds)
        if preds.ndim == 3:
            preds = preds.squeeze(1)

        mean = preds.mean(dim=0, keepdim=True)        # shape (1, horizon)
        quantiles = torch.quantile(
            preds, torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), dim=0
        ).unsqueeze(0)                               # shape (1, 9, horizon)

        return quantiles, mean

# === TabPFN  ===
class TabPFNModel(BaseForecastModel):
    """
    TabPFN-based zero-shot forecaster.
    Adapts TabPFNClassifier for time-series contexts with length and memory limits.
    """

    def __init__(self, device="cpu", max_samples=5000):
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"
        self.model = TabPFNClassifier(device=device, ignore_pretraining_limits=True)
        self.max_samples = max_samples

    def forecast(self, context, horizon):
        window = min(24, len(context))  # small receptive field
        X = np.array([context[i - window:i] for i in range(window, len(context))])
        y = np.array(context[window:])
        if len(X) > self.max_samples:
            idx = np.linspace(0, len(X) - 1, self.max_samples).astype(int)
            X = X[idx]
            y = y[idx]
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)

        n_bins = min(9, max(2, int(len(y) / 5)))
        bins = np.linspace(y.min(), y.max(), n_bins + 1)
        y_disc = np.digitize(y, bins) - 1

        self.model.fit(X_scaled, y_disc)

        last_window = scaler.transform([context[-window:]])
        preds = []
        for _ in range(horizon):
            proba = self.model.predict_proba(last_window).mean(axis=0)
            classes = np.linspace(y.min(), y.max(), len(proba))
            pred = float(np.dot(proba, classes))
            preds.append(pred)
            new_context = np.append(context, pred)[-window:]
            last_window = scaler.transform([new_context])

        preds = torch.tensor(preds).unsqueeze(0)  # (1, horizon)
        quantiles = torch.quantile(
            preds, torch.linspace(0.1, 0.9, 9), dim=0
        ).unsqueeze(0)  # (1, 9, horizon)
        return quantiles, preds

In [4]:
def calculate_metrics(df, folder_path):
    """Compute metrics (MAPE, RMSE), plot residuals and save results."""
    os.makedirs(folder_path, exist_ok=True)
    
    # Ensure proper typing
    df = df.copy()
    df["date"] = pd.to_datetime(df["date"])
    df["resid"] = df["obs"] - df["median_pred"]
    
    # Metrics
    mape_mean = np.mean(np.abs((df["obs"] - df["median_pred"]) / df["obs"])) * 100
    rmse_mean = np.sqrt(np.mean((df["obs"] - df["median_pred"])**2))
        
    # Plot residuals
    plt.figure(figsize=(10, 5))
    sns.lineplot(data=df, x="date", y="resid", color="steelblue", linewidth=1.5)
    plt.axhline(0, color="gray", linestyle="--", linewidth=1)
    plt.title(f"Residuals over Time\nMAPE: {mape_mean:.2f}%  |  RMSE: {rmse_mean:.0f} MW", fontsize=13)
    plt.xlabel("Date")
    plt.ylabel("Residual (Observed - Predicted)")
    plt.tight_layout()
    
    plt.savefig(os.path.join(folder_path, "residuals.png"), dpi=200)
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(df["date"], df["obs"], label="Observed", color="black", linewidth=1.2)
    plt.plot(df["date"], df["median_pred"], label="Predicted (Median)", color="royalblue", linewidth=1.2)
    plt.fill_between(df["date"], df["q10_pred"], df["q90_pred"], color="lightblue", alpha=0.4)
    plt.legend()
    plt.title("Observed vs Predicted Consumption")
    plt.xlabel("Date")
    plt.ylabel("Consumption [MW]")
    plt.tight_layout()
    plt.savefig(os.path.join(folder_path, "forecast_comparison.png"), dpi=200)
    plt.close()
    print('MAPE: {:.2f}%, RMSE: {:.2f} MW'.format(mape_mean, rmse_mean))
    return mape_mean, rmse_mean

def visualize_existing_results(config_files, results_dir="results_tirex", show_n=0, max_workers=4):
    """
    Fast visualization of existing results.
    - Generates forecast and residual plots for all experiments.
    - Parallelized to speed up large batches.
    
    Args:
        config_files (list[str]): list of YAML config paths.
        results_dir (str): root folder where results are stored.
        show_n (int): number of examples to display inline (0 = none, just save).
        max_workers (int): number of threads for parallel plotting.
    """

    def process_config(cfg):
        try:
            with open(cfg, "r") as file:
                config = yaml.safe_load(file)
            result_path = os.path.join(results_dir, config["expe_name"], "sequence.csv")
            out_dir = os.path.dirname(result_path)

            if not os.path.exists(result_path):
                return f"No results for {config['expe_name']}"

            df = pd.read_csv(result_path)
            required = {"date", "obs", "median_pred", "q10_pred", "q90_pred"}
            if df.empty or not required.issubset(df.columns):
                return f"Invalid file for {config['expe_name']}"

            # Fast conversions
            df["date"] = pd.to_datetime(df["date"], errors="coerce")
            obs, pred = df["obs"].to_numpy(), df["median_pred"].to_numpy()
            resid = obs - pred

            # Metrics
            nonzero_mask = obs != 0
            mape = np.mean(np.abs((obs[nonzero_mask] - pred[nonzero_mask]) / obs[nonzero_mask])) * 100
            rmse = np.sqrt(np.mean((obs - pred) ** 2))

            # === Forecast plot ===
            fig, ax = plt.subplots(figsize=(10, 4))
            ax.plot(df["date"], obs, label="Observation", color="black", linewidth=1)
            ax.plot(df["date"], pred, label="Median forecast", color="royalblue", linewidth=1)
            ax.fill_between(df["date"], df["q10_pred"], df["q90_pred"], color="lightblue", alpha=0.4)
            ax.legend()
            ax.set_title(f"{config['expe_name']} — MAPE: {mape:.2f}% | RMSE: {rmse:.0f} MW", fontsize=11)
            ax.set_xlabel("Date")
            ax.set_ylabel("Consumption [MW]")
            fig.tight_layout()
            fig.savefig(os.path.join(out_dir, "forecast_comparison.png"), dpi=150)
            if show_n > 0:
                plt.show()
            plt.close(fig)

            # === Residual plot ===
            fig, ax = plt.subplots(figsize=(10, 3))
            ax.plot(df["date"], resid, color="steelblue", linewidth=1)
            ax.axhline(0, color="gray", linestyle="--", linewidth=1)
            ax.set_title(f"Residuals — {config['expe_name']}")
            ax.set_xlabel("Date")
            ax.set_ylabel("Residual (Obs - Pred)")
            fig.tight_layout()
            fig.savefig(os.path.join(out_dir, "residuals.png"), dpi=150)
            plt.close(fig)

            return f"{config['expe_name']} done  (MAPE={mape:.2f}%, RMSE={rmse:.0f}MW)"

        except Exception as e:
            return f"{cfg}: {e}"

    # === Parallel execution ===
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_config, cfg): cfg for cfg in config_files}
        for future in as_completed(futures):
            results.append(future.result())

    print("\n".join(results))

def run_experiment(config_path, model_name="tirex", checkpoint=None):
    """
    Run forecasting for any supported foundation model (TiRex, Chronos, TabPFN).
    Compatible with zero-shot evaluation across models.
    """
    print(f"\nRunning {model_name.upper()} on config: {config_path}")
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)

    data = pd.read_csv(config["data_path"])
    data["Date_local"] = pd.to_datetime(data["Date_local"]).dt.tz_localize(None)
    begin_train = pd.to_datetime(config["begin_train"])
    begin_test = pd.to_datetime(config["begin_test"])
    end_test = pd.to_datetime(config["end_test"])
    horizon = config["horizon"]

    folder_path = os.path.join(f"results_{model_name}", config["expe_name"])
    os.makedirs(folder_path, exist_ok=True)
    print(f"Saving results to: {folder_path}")

    if model_name == "tirex":
        model = TiRexModel(checkpoint)
    elif model_name == "chronos":
        model = ChronosModel()
    elif model_name == "tabpfn":
        model = TabPFNModel(device='mps', max_samples=500)
    else:
        raise ValueError(f"Unknown model: {model_name}")

    # --- Slice data ---
    X = data[(data["Date_local"] >= begin_train) & (data["Date_local"] < end_test)].reset_index(drop=True)
    y_test = X[(X["Date_local"] >= begin_test) & (X["Date_local"] < end_test)]["Consommation"].values
    hist = X[X["Date_local"] <= begin_test]["Consommation"].values
    fut = X[X["Date_local"] > begin_test]["Consommation"].values

    # --- Forecast loop ---
    quantiles_full, mean_full = [], []
    for i in tqdm(range(0, len(y_test), horizon), desc=f"{model_name.upper()} forecasting"):
        ctx = np.concatenate((hist, fut[:i]))
        quantiles, mean = model.forecast(ctx, horizon)
        quantiles_full.append(quantiles[0])
        mean_full.append(mean)

    m = torch.cat(mean_full, dim=1)
    q = torch.cat(quantiles_full, dim=0)

    min_len = min(m.shape[-1], len(y_test))
    m = m[:, :min_len]
    y_test = torch.tensor(y_test[:min_len])
    dates_test = (
        X.loc[(X["Date_local"] >= begin_test) & (X["Date_local"] < end_test), "Date_local"]
        .iloc[:min_len]
        .reset_index(drop=True)
    )

    if q.ndim == 2:
        mean = m.clone()
        q = torch.stack(
            [mean * 0.9, mean * 0.95, mean * 0.97, mean * 0.99, mean,
             mean * 1.01, mean * 1.03, mean * 1.05, mean * 1.1],
            dim=1
        )  # shape (1, 9, horizon)
    else:
        q = q[:, :, :min_len]

    q10_pred = q[0, 0, :].detach().cpu()
    median_pred = q[0, 4, :].detach().cpu()
    q90_pred = q[0, 8, :].detach().cpu()

    final = pd.DataFrame({
        "date": dates_test,
        "obs": y_test.numpy(),
        "q10_pred": q10_pred.numpy(),
        "q90_pred": q90_pred.numpy(),
        "median_pred": median_pred.numpy(),
    })

    lengths = [len(final[c]) for c in final.columns]
    assert len(set(lengths)) == 1, f"Length mismatch detected: {lengths}"

    output_file = os.path.join(folder_path, "sequence.csv")
    final.to_csv(output_file, index=False)
    print(f"Saved results: {output_file}")
    return final


# Import data

In [5]:
train_weave   = pd.read_csv('data/weave/train_weave.csv')
train_rfrance = pd.read_csv('data/rfrance/train2.csv')
test_weave    = pd.read_csv('data/weave/test_weave.csv')
test_rfrance  = pd.read_csv('data/rfrance/test2.csv')

df_weave   = pd.concat([train_weave, test_weave], axis=0).reset_index(drop=True)
df_rfrance = pd.concat([train_rfrance, test_rfrance], axis=0).reset_index(drop=True)

In [6]:
# Export Weave by site 
os.makedirs('data/weave', exist_ok=True)
nb_ids = df_weave['id_unique'].nunique()
pad = max(3, len(str(nb_ids)))

weave_out = (
    df_weave[['date', 'id_unique', 'consumption']]
    .rename(columns={'date': 'Date_local', 'consumption': 'Consommation'})
    .assign(Date_local=lambda d: pd.to_datetime(d['Date_local']))
)

for i, (_, g) in enumerate(weave_out.groupby('id_unique', sort=False), start=1):
    g = g.sort_values('Date_local')
    g[['Date_local', 'id_unique', 'Consommation']].to_csv(
        f"data/weave/{i:0{pad}d}.csv", index=False
    )

uk = df_weave[['date','consumption']].groupby('date').sum().reset_index().sort_values('date')
uk.columns = ['Date_local','Consommation']
uk.to_csv('data/weave/uk.csv')

# Export RFrance by region
os.makedirs('data/rfrance', exist_ok=True)

rfr_out = (
    df_rfrance[['date', 'Region', 'load']]
    .rename(columns={'date': 'Date_local', 'load': 'Consommation'})
    .assign(Date_local=lambda d: pd.to_datetime(d['Date_local']))
)

for region, g in rfr_out.groupby('Region', sort=False):
    g = g.sort_values('Date_local')
    g[['Date_local', 'Region', 'Consommation']].to_csv(
        f"data/rfrance/{region[:5]}.csv", index=False
    )

france = df_rfrance[['date','load']].groupby('date').sum().reset_index().sort_values('date')
france.columns = ['Date_local','Consommation']
france.to_csv('data/rfrance/france.csv')

# Generate configs

In [7]:
# Dataset rfrance
data_dir = "data/rfrance"
config_template = {
    "begin_train": "2015-01-01",
    "begin_test": "2019-01-01",
    "end_test": "2020-01-01",
    "horizon": 48
}

output_dir = "configs/rfrance"
os.makedirs(output_dir, exist_ok=True)

csv_files = [
    f for f in os.listdir(data_dir)
    if f.endswith(".csv") and "train" not in f and "test" not in f
]

for csv_file in csv_files:
    region_name = os.path.splitext(csv_file)[0]  # e.g. "Auver"
    config = config_template.copy()
    config["expe_name"] = f"fm_region_{region_name}"
    config["data_path"] = os.path.join(data_dir, csv_file)
    output_path = os.path.join(output_dir, f"config_rfrance_{region_name}.yaml")

    with open(output_path, "w") as f:
        yaml.dump(config, f, sort_keys=False)

# Dataset weave
data_dir = "data/weave"
config_template = {
    "begin_train": "2024-02-13",
    "begin_test": "2024-02-23",
    "end_test": "2024-02-26",
    "horizon": 48
}

output_dir = "configs/weave"
os.makedirs(output_dir, exist_ok=True)

csv_files = [
    f for f in os.listdir(data_dir)
    if f.endswith(".csv") and "train" not in f and "test" not in f
]

for idx, csv_file in enumerate(sorted(csv_files), start=1):
    region_name = os.path.splitext(csv_file)[0]  # e.g. "001"
    config = config_template.copy()
    config["expe_name"] = f"fm_uk_{idx}"
    config["data_path"] = os.path.join("data/weave", csv_file)
    output_path = os.path.join(output_dir, f"config_uk_{region_name}.yaml")
    with open(output_path, "w") as f:
        yaml.dump(config, f, sort_keys=False)


# Foundation models

In [8]:
DATASET = "weave"  # ou "weave"
CONFIG_DIR = f"configs/{DATASET}"

config_files = sorted(
    [os.path.join(CONFIG_DIR, f) for f in os.listdir(CONFIG_DIR) if f.endswith(".yaml")]
)
print(f"{len(config_files)} config files found in {CONFIG_DIR}")


30 config files found in configs/weave


In [10]:
foundation_models = [
    {"name": "tirex",   "checkpoint": None, "model_id": "TiRex"},
    # {"name": "chronos", "checkpoint": None, "model_id": "Chronos"},
    # {"name": "tabpfn",  "checkpoint": None, "model_id": "TabPFN"},
]

for model_info in foundation_models:
    model_name = model_info["name"]
    checkpoint = model_info["checkpoint"]

    print(f"\nBenchmarking model: {model_name.upper()}")
    results_dir = f"results_{model_name}"

    for cfg in tqdm(config_files, desc=f"{model_name.upper()}"):
        try:
            with open(cfg, "r") as f:
                conf = yaml.safe_load(f)

            result_path = os.path.join(results_dir, conf["expe_name"], "sequence.csv")

            if os.path.exists(result_path):
                print(f"{model_name.upper()} already has results for {conf['expe_name']} → skip.")
                continue

            run_experiment(cfg, model_name=model_name, checkpoint=checkpoint)

        except Exception as e:
            print(f"Error for {model_name.upper()} on {os.path.basename(cfg)}: {e}")



Benchmarking model: TIREX


TIREX:   0%|          | 0/30 [00:00<?, ?it/s]


Running TIREX on config: configs/weave/config_uk_000.yaml
Saving results to: results_tirex/fm_uk_0


TIREX forecasting:   0%|          | 0/3 [00:00<?, ?it/s]

Saved results: results_tirex/fm_uk_0/sequence.csv
TIREX already has results for fm_uk_1 → skip.
TIREX already has results for fm_uk_2 → skip.
TIREX already has results for fm_uk_3 → skip.
TIREX already has results for fm_uk_4 → skip.
TIREX already has results for fm_uk_5 → skip.
TIREX already has results for fm_uk_6 → skip.
TIREX already has results for fm_uk_7 → skip.
TIREX already has results for fm_uk_8 → skip.
TIREX already has results for fm_uk_9 → skip.
TIREX already has results for fm_uk_10 → skip.
TIREX already has results for fm_uk_11 → skip.
TIREX already has results for fm_uk_12 → skip.
TIREX already has results for fm_uk_13 → skip.
TIREX already has results for fm_uk_14 → skip.
TIREX already has results for fm_uk_15 → skip.
TIREX already has results for fm_uk_16 → skip.
TIREX already has results for fm_uk_17 → skip.
TIREX already has results for fm_uk_18 → skip.
TIREX already has results for fm_uk_19 → skip.
TIREX already has results for fm_uk_20 → skip.
TIREX already has r

# Vizualisation and results

In [9]:
visualize_existing_results(config_files, results_dir="results_chronos", show_n=0, max_workers=6)

fm_uk_3 done  (MAPE=100.83%, RMSE=596MW)
fm_uk_4 done  (MAPE=13.98%, RMSE=1935MW)
fm_uk_0 done  (MAPE=8.04%, RMSE=14062MW)
fm_uk_2 done  (MAPE=40.39%, RMSE=1703MW)
fm_uk_5 done  (MAPE=11.66%, RMSE=1212MW)
fm_uk_1 done  (MAPE=15.21%, RMSE=1032MW)
fm_uk_6 done  (MAPE=25.65%, RMSE=1516MW)
fm_uk_7 done  (MAPE=28.98%, RMSE=496MW)
fm_uk_8 done  (MAPE=32.73%, RMSE=1575MW)
fm_uk_11 done  (MAPE=19.20%, RMSE=1433MW)
fm_uk_9 done  (MAPE=24.94%, RMSE=1184MW)
fm_uk_10 done  (MAPE=24.46%, RMSE=754MW)
fm_uk_12 done  (MAPE=15.13%, RMSE=2021MW)
fm_uk_13 done  (MAPE=18.15%, RMSE=2179MW)
fm_uk_14 done  (MAPE=16.52%, RMSE=1467MW)
fm_uk_15 done  (MAPE=28.93%, RMSE=1098MW)
fm_uk_16 done  (MAPE=23.30%, RMSE=2435MW)
fm_uk_17 done  (MAPE=19.40%, RMSE=1170MW)
fm_uk_18 done  (MAPE=13.78%, RMSE=2101MW)
fm_uk_19 done  (MAPE=27.77%, RMSE=1865MW)
fm_uk_20 done  (MAPE=39.51%, RMSE=2230MW)
fm_uk_21 done  (MAPE=39.56%, RMSE=1761MW)
fm_uk_22 done  (MAPE=20.45%, RMSE=622MW)
fm_uk_23 done  (MAPE=27.22%, RMSE=1252MW)
No re