# **TabPFN Relational Benchmark**
This notebook benchmarks the performance of TabPFN models on datasets from RelBench in two scenarios:
1. **Single Table** – Using only the target entity table.
2. **Merged Table** – Using a naively denormalized table obtained by joining related tables.

It automates dataset loading, preprocessing (including date feature engineering), vectorization, model training, prediction, and evaluation for all compatible tasks within a chosen RelBench dataset. The results allow comparing model performance between single-table and merged-table configurations.


## Import Libraries

In [None]:
# --- Standard Library ---
import os
import time
import inspect
from typing import Any, Dict, List, Optional

# --- Third-Party Libraries ---
import numpy as np
import pandas as pd

# --- PyTorch / TorchFrame (already used) ---
import torch
from torch import Tensor
from torch_frame.config.text_embedder import TextEmbedderConfig

# --- Skrub / Sentence Transformers (keep as-is for your vectorization) ---
from skrub import TableVectorizer
from sentence_transformers import SentenceTransformer

# --- RelBench (keep as-is) ---
from relbench.datasets import get_dataset
from relbench.tasks import get_task, get_task_names
from relbench.base import TaskType
import relbench.metrics
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph

# --- TabPFN (keep as-is) ---
from tabpfn import TabPFNClassifier, TabPFNRegressor

# --- getML (new) ---
from getml import data as gdata
from getml import engine as geng
from getml import pipeline as gpipeline
from getml import feature_learning as gfl
from getml import preprocessors as gprep
from getml import predictors as gpred


### Set Global Configuration

In [None]:
# Device preference
#if torch.backends.mps.is_available():
#    DEVICE = "mps"
#elif torch.cuda.is_available():
#    DEVICE = "cuda"
#else:
#    DEVICE = "cpu"

DEVICE = "cpu"

print(f"Using device: {DEVICE}")

# Define global dataset variable (any available dataset from relbench.datasets)
# "rel-f1" is the default, but can be overridden by setting DATASET variable
DATASET = "rel-f1"

# Global configuration variables with defaults
SEED   = globals().get("SEED", 42)
N_ESTIMATORS = globals().get("N_ESTIMATORS", 16) # number of TabPFN estimators
TABPFN_MAX = globals().get("TABPFN_MAX", 1000)  # hard ceiling for TabPFN

# Define project for getML engine
GETML_PROJECT = globals().get("GETML_PROJECT", "default_project")

# Global for getML engine state
_ENGINE_STARTED = False

## Notebook Configuration and Dataset Selection


Sets the dataset name (`DATASET`) and download flag (`DOWNLOAD`), then discovers all available tasks for the selected dataset using RelBench’s APIs. Filters tasks to only those compatible with TabPFN (classification and regression).


In [None]:
# Reuse existing config if present, otherwise set defaults
DATASET = globals().get("DATASET", "rel-f1")
DOWNLOAD = globals().get("DOWNLOAD", True)

# Discover tasks and keep only entity-level cls/reg tasks TabPFN can handle
def _is_tabpfn_friendly(task):
    return task.task_type in (
        TaskType.BINARY_CLASSIFICATION,
        TaskType.MULTICLASS_CLASSIFICATION,
        TaskType.MULTILABEL_CLASSIFICATION,
        TaskType.REGRESSION,
    )

_all = get_task_names(DATASET)  # shown in tutorials
TASKS = []
for tname in _all:
    try:
        t = get_task(DATASET, tname, download=DOWNLOAD)
        if _is_tabpfn_friendly(t):
            TASKS.append(tname)
    except Exception as e:
        print(f"[skip] {tname}: {e!s}")

print(f"{DATASET}: {len(TASKS)} TabPFN-friendly tasks -> {TASKS}")


### Patch RelBench Metrics (Optional)

In [None]:
# Patch relbench.metrics.skm.mean_squared_error to local mean_squared_error
from sklearn.metrics import mean_squared_error

relbench.metrics.skm.mean_squared_error = mean_squared_error

def patched_rmse(true, pred):
    if "squared" in inspect.signature(mean_squared_error).parameters:
        return mean_squared_error(true, pred, squared=False)
    else:
        return np.sqrt(mean_squared_error(true, pred))

relbench.metrics.rmse = patched_rmse

### Text Embedding Configuration

This part provides a callable class to embed text columns using GloVe embeddings.

In [None]:
# Define a text embedding class using GloVe embeddings from Sentence Transformers.
class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

### Fetch Dataset Splits

Utility functions to load a task’s splits (`train`, `val`, `test`), convert them to pandas DataFrames, and extract features (`X`) and targets (`y`). Includes functions to:
* Load train/val/test splits for a task.
* Extract features/targets.
* Infer primary keys.
* Denormalize tables (one-hop join).
* Build data frames for both single-table and merged-table scenarios.


In [None]:
# ---------------------------
# Dataset loaders / frame builders
# ---------------------------

def fetch_splits(dataset_name: str, task_name: str, download: bool = True):
    task = get_task(dataset_name, task_name, download=download)
    # keep original columns (mask_input_cols=False so we see raw fields)
    splits = {
        split: task.get_table(split, mask_input_cols=False)
        for split in ("train", "val", "test")
    }
    return task, splits

def to_Xy(df: pd.DataFrame, target_col: str):
    y = df[target_col].to_numpy()
    X = df.drop(columns=[target_col])
    return X, y

def build_single_table_frames(task, splits):
    """
    Single-table mode: do NOT engineer features here.
    Just return raw base table X, y per split (target dropped from X).
    """
    frames = {}
    for split, table in splits.items():
        df = table.df.copy()
        X, y = to_Xy(df, task.target_col)
        frames[split] = (X, y, df)
    return frames


In [None]:
def _ensure_getml_engine():
    """Idempotently launch getML engine and set project."""
    global _ENGINE_STARTED, GETML_PROJECT
    if _ENGINE_STARTED:
        return
    try:
        geng.launch(
            allow_push_notifications=False,
            allow_remote_ips=False,
            in_memory=True,
            launch_browser=False,
            log=False,
            quiet=True,
        )
        geng.set_project(GETML_PROJECT)
        _ENGINE_STARTED = True
    except Exception as e:
        raise RuntimeError("Failed to launch the getML engine.") from e

def _shutdown_getml_engine():
    """Shutdown getML engine if started."""
    global _ENGINE_STARTED
    if not _ENGINE_STARTED:
        return
    try:
        geng.shutdown()
    except Exception:
        pass
    _ENGINE_STARTED = False

def _first_datetime_col(df: pd.DataFrame) -> Optional[str]:
    # Prefer already-typed datetime64 columns
    for col in df.columns:
        if pd.api.types.is_datetime64_any_dtype(df[col]):
            return col
    # Try to parse common timestamp strings without mutating df
    for col in df.columns:
        if df[col].dtype == object:
            try:
                pd.to_datetime(df[col], errors="raise")
                return col
            except Exception:
                pass
    return None


def _guess_primary_key(table_name: str, df: pd.DataFrame) -> Optional[str]:
    candidates = [
        f"{table_name}_id",
        f"{table_name[:-1]}_id" if table_name.endswith("s") else None,
        "id",
        "ID",
        f"{table_name}Id",
        f"{table_name}_pk",
    ]
    candidates = [c for c in candidates if c and c in df.columns]
    # prefer unique columns
    for c in candidates:
        if df[c].is_unique:
            return c
    return candidates[0] if candidates else None


def _guess_foreign_key(pop_df: pd.DataFrame, per_name: str, per_pk: str) -> Optional[str]:
    if per_pk in pop_df.columns:
        return per_pk
    candidates = [
        f"{per_name}_id",
        f"{per_name[:-1]}_id" if per_name.endswith("s") else None,
        f"{per_name}{per_pk.capitalize()}" if per_pk != "id" else None,
    ]
    for c in [c for c in candidates if c]:
        if c in pop_df.columns:
            return c
    return None


def _roles_from_df(
    df: pd.DataFrame,
    target_col: Optional[str],
    join_keys: List[str],
    time_col: Optional[str],
):
    num_cols, cat_cols = [], []
    skip = set(join_keys)
    if time_col:
        skip.add(time_col)
    if target_col:
        skip.add(target_col)

    for c in df.columns:
        if c in skip:
            continue
        if pd.api.types.is_numeric_dtype(df[c]):
            num_cols.append(c)
        else:
            cat_cols.append(c)

    return {
        "target": target_col,
        "join_key": join_keys,
        "time_stamp": time_col,
        "numerical": num_cols,
        "categorical": cat_cols,
    }


def _df_set_roles(gdf: gdata.DataFrame, roles: Dict[str, Any]):
    # Assign roles on a getML DataFrame
    if roles.get("target"):
        gdf.set_role(roles["target"], gdata.roles.target)
    jk = roles.get("join_key") or []
    if jk:
        gdf.set_role(jk, gdata.roles.join_key)
    if roles.get("time_stamp"):
        gdf.set_role(roles["time_stamp"], gdata.roles.time_stamp)
    if roles.get("numerical"):
        gdf.set_role(roles["numerical"], gdata.roles.numerical)
    if roles.get("categorical"):
        gdf.set_role(roles["categorical"], gdata.roles.categorical)


def _harmonize_join_keys(pop_by_split: Dict[str, pd.DataFrame], per_spec: List[Dict[str, str]]) -> List[Dict[str, str]]:
    """
    Ensure the population has columns named like each peripheral's PK.
    If population FK name != peripheral PK name, mirror FK into a new column
    with the PK name (required because getML matches join_keys by name).
    Returns the filtered per_spec with only peripherals we can actually join.
    """
    valid_specs = []
    for spec in per_spec:
        pk = spec["pk"]
        fk = spec["fk"]
        can_use = True
        for split, df in pop_by_split.items():
            if pk in df.columns:
                continue  # already aligned
            if fk in df.columns:
                df[pk] = df[fk]
            else:
                can_use = False
                break
        if can_use:
            valid_specs.append(spec)
    return valid_specs


In [None]:
def build_merged_table_frames(dataset, task, splits):
    """
    Build merged features with getML:
    - Automatically infers PK/FK candidates,
    - Harmonizes population FK names to match peripheral PK names,
    - Generates relational aggregates via FastProp,
    - Returns (X, y, engineered_df) per split (same contract as the rest of the notebook).
    """
    _ensure_getml_engine()

    # --- Load all DB tables into pandas ---
    db = dataset.get_db()
    all_tables: Dict[str, pd.DataFrame] = {
        name: (tbl.df.copy() if hasattr(tbl, "df") else tbl.to_pandas())
        for name, tbl in db.table_dict.items()
    }

    # --- Base (population) tables by split (the task table) ---
    pop: Dict[str, pd.DataFrame] = {split: tbl.df.copy() for split, tbl in splits.items()}
    base_name = getattr(splits["train"], "name", "population")

    # --- Infer candidate peripherals and PK/FK mapping from TRAIN schema ---
    peripheral_names = [t for t in all_tables if t != base_name]
    pop_train = pop["train"]
    time_col = _first_datetime_col(pop_train)

    per_spec = []
    for per_name in peripheral_names:
        per_df = all_tables[per_name]
        per_pk = _guess_primary_key(per_name, per_df)
        if not per_pk:
            continue
        pop_fk = _guess_foreign_key(pop_train, per_name, per_pk)
        if not pop_fk:
            continue
        per_spec.append({"name": per_name, "pk": per_pk, "fk": pop_fk})

    # Align population join key names to peripheral PKs (required by getML)
    per_spec = _harmonize_join_keys(pop, per_spec)

    # If nothing to join, fall back to single-table frames without changing downstream code paths
    if not per_spec:
        frames = {}
        for split, df in pop.items():
            X, y = to_Xy(df, task.target_col)
            frames[split] = (X, y, df)
        return frames

    # --- Build getML Containers per split using from_pandas (fixes zero-column error) ---
    def _make_container(split: str) -> gdata.Container:
        # Population
        g_pop = gdata.DataFrame.from_pandas(pop[split], name=f"{base_name}_{split}")
        roles_pop = _roles_from_df(
            pop[split],
            target_col=task.target_col,                          # target ONLY on population
            join_keys=[spec["pk"] for spec in per_spec],         # (now) aligned names
            time_col=time_col,
        )
        _df_set_roles(g_pop, roles_pop)

        # Peripherals
        g_per = []
        for spec in per_spec:
            per_df = all_tables[spec["name"]]
            gdf = gdata.DataFrame.from_pandas(per_df, name=f"{spec['name']}_{split}")
            roles_per = _roles_from_df(
                per_df,
                target_col=None,                                  # NO target on peripherals
                join_keys=[spec["pk"]],
                time_col=None,
            )
            _df_set_roles(gdf, roles_per)
            g_per.append(gdf)
        return gdata.Container(population=g_pop, peripherals=g_per)

    train_c = _make_container("train")
    val_c   = _make_container("val")
    test_c  = _make_container("test")

    # --- Build and run the pipeline ---
    predictor = (
        gpred.LinearRegression()
        if task.task_type == TaskType.REGRESSION
        else gpred.LogisticRegression()
    )

    pipe = gpipeline.Pipeline(
        feature_learners=[gfl.FastProp()],
        preprocessors=[gprep.Imputation()],
        predictors=[predictor],
    )

    # IMPORTANT: fit on a single Container (no subset kwargs),
    # then transform other splits separately to avoid the 'mutually exclusive' error.
    pipe.fit(train_c)

    fe_train = pipe.transform(train_c).to_pandas()
    fe_val   = pipe.transform(val_c).to_pandas()
    fe_test  = pipe.transform(test_c).to_pandas()

    # --- Extract X, y keeping the notebook's original contract ---
    # If the target is not present in engineered features (common), fall back to original split y.
    def _to_Xy_safe(fe_df: pd.DataFrame, raw_df: pd.DataFrame):
        if task.target_col in fe_df.columns:
            return to_Xy(fe_df, task.target_col)
        else:
            y = raw_df[task.target_col].to_numpy()
            X = fe_df.copy()
            return X, y

    Xtr, ytr = _to_Xy_safe(fe_train, pop["train"])
    Xva, yva = _to_Xy_safe(fe_val,   pop["val"])
    Xte, yte = _to_Xy_safe(fe_test,  pop["test"])

    return {
        "train": (Xtr, ytr, fe_train),
        "val":   (Xva, yva, fe_val),
        "test":  (Xte, yte, fe_test),
    }


## Vectorization Wrapper (Version-Safe)

Initializes a `TableVectorizer` with only supported arguments for the installed `skrub` or `dirty_cat` version, ensuring compatibility. Transforms `train`, `val`, and `test` splits into numerical feature matrices, converting them to dense format if necessary.


### Helper Functions for Vectorization

In [None]:
# This function creates a `TableVectorizer` instance with version-specific arguments.
def _make_table_vectorizer():
    sig = inspect.signature(TableVectorizer.__init__)
    allowed = set(sig.parameters.keys()) - {"self"}

    tv_kwargs = {}

    # Only set kwargs that actually exist in the installed version
    if "cardinality_threshold" in allowed:
        tv_kwargs["cardinality_threshold"] = globals().get("CARDINALITY_THRESHOLD", 1000)

    # Some versions expose this; others don't, guard it
    if "high_cardinality_transformer" in allowed:
        tv_kwargs["high_cardinality_transformer"] = globals().get("HIGH_CARD_TRANSFORMER", "hashing")

    # Optional knobs if you define them globally and the version supports them
    if "text_separator" in allowed and "TEXT_SEPARATOR" in globals():
        tv_kwargs["text_separator"] = globals()["TEXT_SEPARATOR"]
    if "numerical_transformer" in allowed and "NUMERICAL_TRANSFORMER" in globals():
        tv_kwargs["numerical_transformer"] = globals()["NUMERICAL_TRANSFORMER"]
    if "categorical_transformer" in allowed and "CATEGORICAL_TRANSFORMER" in globals():
        tv_kwargs["categorical_transformer"] = globals()["CATEGORICAL_TRANSFORMER"]

    return TableVectorizer(**tv_kwargs)

# Converts a sparse matrix to a dense NumPy array, handling cases where the input is already dense or does not support `.toarray()`.
def _to_dense(X):
    try:
        # scipy sparse matrices have .toarray()
        return X.toarray() if hasattr(X, "toarray") else X
    except Exception:
        return X

# Vectorizes the training, validation, and test splits using a `TableVectorizer`. It initializes the vectorizer, fits it on the training data, and transforms all splits into dense NumPy arrays. Returns the vectorizer and the transformed matrices.
def vectorize_splits(X_train, X_val, X_test):
    # Fit only on training data to prevent data leakage
    tv = _make_table_vectorizer()
    Xt = _to_dense(tv.fit_transform(X_train))
    Xv = _to_dense(tv.transform(X_val))
    Xs = _to_dense(tv.transform(X_test))
    return tv, Xt, Xv, Xs


### Training and Prediction Helpers

In [None]:
# This function subsamples the training data to a maximum size defined by `TABPFN_MAX`. If the dataset is smaller than this cap, it returns the full dataset; otherwise, it randomly samples without replacement.
def _subsample(X, y, cap=TABPFN_MAX, seed=SEED):
    if len(X) <= cap:
        return X, y, np.arange(len(X))
    idx = np.random.RandomState(seed).choice(len(X), size=cap, replace=False)
    if hasattr(X, "iloc"):
        Xs = X.iloc[idx]
    else:
        Xs = X[idx]
    ys = y[idx]
    return Xs, ys, idx

# Fits a TabPFN model (either classifier or regressor) based on the task type. It initializes the model with the specified device and number of estimators, then fits it to the provided training data.
def _fit_tabpfn(task, Xt, yt):
    if task.task_type == TaskType.REGRESSION and TabPFNRegressor is not None:
        model = TabPFNRegressor(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
            ignore_pretraining_limits=True,
        )
    else:
        model = TabPFNClassifier(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
            ignore_pretraining_limits=True,
        )
    model.fit(Xt, yt)
    return model

# Helper function to make predictions for a given task using the fitted model. It handles different task types (regression, binary classification, multiclass/multilabel) and returns the appropriate prediction format.
def _predict_for_task(task, model, X):
    # align with RelBench evaluators: AUROC expects probabilities for the positive class
    if task.task_type == TaskType.REGRESSION:
        return model.predict(X)
    proba = model.predict_proba(X)
    if task.task_type == TaskType.BINARY_CLASSIFICATION:
        return proba[:, 1]
    else:
        # multiclass/multilabel: pass full probability matrix
        return proba

### Run TabPFN on Selected Tasks

Runs TabPFN on a specified dataset and task, handling both single-table and merged-table modes. It vectorizes the data, fits the model, makes predictions, and evaluates performance using RelBench’s evaluators. Returns a dictionary with results.

In [None]:
def run_tabpfn_on_task(dataset_name: str, task_name: str, mode: str = "single") -> Dict[str, Any]:
    # Load dataset and task splits
    dataset = get_dataset(dataset_name, download=DOWNLOAD)
    task, splits = fetch_splits(dataset_name, task_name, download=DOWNLOAD)

    # Ensure the task is compatible with TabPFN
    if mode == "single":
        frames = build_single_table_frames(task, splits)
    elif mode == "merged":
        try:
            frames = build_merged_table_frames(dataset, task, splits)
        except Exception as e:
            print(f"merged mode failed: {e}, falling back to single-table mode.")
            frames = build_single_table_frames(task, splits)
    else:
        raise ValueError("mode must be 'single' or 'merged'")

    # Extract features and targets for each split
    (Xtr, ytr, _dftr) = frames["train"]
    (Xva, yva, dfva)  = frames["val"]
    (Xte, yte, dfte)  = frames["test"]

    # Vectorize
    tv, Xt, Xv, Xs = vectorize_splits(Xtr, Xva, Xte)

    # Respect TabPFN's sample cap
    Xt_cap, yt_cap, _ = _subsample(Xt, ytr, cap=TABPFN_MAX, seed=SEED)

    # Fit
    model = _fit_tabpfn(task, Xt_cap, yt_cap)

    # Predict & Evaluate with RelBench evaluators
    val_pred  = _predict_for_task(task, model, Xv)
    test_pred = _predict_for_task(task, model, Xs)

    # Align predictions with original DataFrame indices for evaluation
    val_metrics  = task.evaluate(val_pred,  splits["val"])
    test_metrics = task.evaluate(test_pred, splits["test"])

    # Convert metrics to a dictionary, ensuring all values are floats
    out = {
        "dataset": dataset_name,
        "task": task_name,
        "mode": mode,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "n_train_used": len(Xt_cap),
        "n_train_total": len(Xt),
        "n_val": len(Xv),
        "n_test": len(Xs),
    }
    return out

## Orchestrator for Benchmark Runs

Iterates over all discovered tasks and runs TabPFN in both **single** and **merged** modes. Collects performance metrics for validation and test splits into a results table, handling failures gracefully. Sorts results for easier comparison.


In [None]:
MODES = globals().get("MODES", ["single", "merged"])

records = []
failures = []

# Run TabPFN on all tasks in both modes and collect results
for task_name in TASKS:
    for mode in MODES:
        try:
            res = run_tabpfn_on_task(DATASET, task_name, mode=mode)
            # Flatten metrics for val and test, one metric per row
            for split in ["val", "test"]:
                metrics = res.get(f"{split}_metrics") or {}
                for metric_name, metric_value in metrics.items():
                    # Only add rows with non-empty metric_value
                    if metric_value is not None and not (isinstance(metric_value, float) and np.isnan(metric_value)):
                        records.append({
                            "dataset": res.get("dataset", DATASET),
                            "task": res.get("task", task_name),
                            "split": split,
                            "mode": res.get("mode", mode),
                            "method": "TabPFN_experimental_v1.0",
                            "metric": metric_name,
                            "score": metric_value,
                        })
        except Exception as e:
            msg = f"[{DATASET} | {task_name} | {mode}] failed: {e!s}"
            print(msg)
            failures.append(msg)

# Convert collected records into a DataFrame
results_df = pd.DataFrame.from_records(records)

# If no successful runs were recorded, display a message and create an empty DataFrame
if results_df.empty:
    print("No successful runs were recorded. Check the failure messages above.")
    results_df = pd.DataFrame(
        columns=["dataset", "task", "split", "mode", "method", "metric", "score"]
    )
else:
    # Ensure required sort keys exist even if some rows missed them
    for col in ["task", "mode"]:
        if col not in results_df.columns:
            results_df[col] = pd.NA
    # Sort only by the columns that exist to avoid KeyError
    sort_keys = [c for c in ["task", "mode", "metric"] if c in results_df.columns]
    if sort_keys:
        results_df = results_df.sort_values(sort_keys)

display(results_df)


### Save Results to CSV

In [None]:
# Specify output directory
out_dir = globals().get("OUT_DIR", "outputs")
os.makedirs(out_dir, exist_ok=True)

# Change timestamp format to "dd.mm.yyyy-hh:mm"
timestamp = time.strftime("%d.%m.%Y-%H:%M")
csv_name = f"tabpfn_{DATASET}_{timestamp}.csv"
csv_path = os.path.join(out_dir, csv_name)

# Round all numerical results to 4 decimal places before saving
if "score" in results_df.columns:
    if pd.api.types.is_numeric_dtype(results_df["score"]):
        results_df["score"] = results_df["score"].round(4)
    else:
        # Optionally, try to convert to numeric first
        results_df["score"] = pd.to_numeric(results_df["score"], errors="coerce").round(4)

# Filter out rows with empty score values before saving
if "score" in results_df.columns:
    results_df = results_df[results_df["score"].notnull()]

# Save the results DataFrame to a CSV file
results_df.to_csv(csv_path, index=False)
print(f"Saved results to: {csv_path}")
