# **TabPFN Relational Benchmark**
This notebook benchmarks the performance of TabPFN models on datasets from RelBench in two scenarios:
1. **Single Table** – Using only the target entity table.
2. **Merged Table** – Using a naively denormalized table obtained by joining related tables.

It automates dataset loading, preprocessing (including date feature engineering), vectorization, model training, prediction, and evaluation for all compatible tasks within a chosen RelBench dataset. The results allow comparing model performance between single-table and merged-table configurations.


## Import Libraries

In [1]:
# --- Standard Library ---
import os
import time
import re
import inspect
from typing import Any, Dict, List, Optional

# --- Third-Party Libraries ---
import numpy as np
import pandas as pd

# --- Skrub / Sentence Transformers ---
from skrub import TableVectorizer

# --- RelBench ---
from relbench.datasets import get_dataset
from relbench.tasks import get_task, get_task_names
from relbench.base import TaskType
import relbench.metrics

# --- TabPFN ---
from tabpfn import TabPFNClassifier, TabPFNRegressor

# --- Featuretools ---
import featuretools as ft

In [2]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

### Set Global Configuration

In [3]:
# Device selection (CPU, CUDA, or MPS if available)
def get_device():
    # Uncomment for auto-detection
    # if torch.backends.mps.is_available():
    #     return "mps"
    # elif torch.cuda.is_available():
    #     return "cuda"
    # else:
    #     return "cpu"
    return "cpu"  # Default: CPU

DEVICE = get_device()
print(f"Using device: {DEVICE}")

# Dataset and experiment settings
DATASET        = globals().get("DATASET", "rel-f1")           # Default dataset
SEED           = globals().get("SEED", 42)                    # Random seed
N_ESTIMATORS   = globals().get("N_ESTIMATORS", 16)            # TabPFN estimators
TABPFN_MAX     = globals().get("TABPFN_MAX", 10000)            # Max TabPFN samples
GETML_PROJECT  = globals().get("GETML_PROJECT", "default_project")  # getML project name

# getML engine state
_ENGINE_STARTED = False

# Optional: quiet Featuretools logs
ft.config.log_print_threshold = 1000000

Using device: cpu


### Featuretools knobs (safe defaults)

In [4]:
# ---- Featuretools knobs (safe defaults) ----
FT_MAX_DEPTH = globals().get("FT_MAX_DEPTH", 2)   # bump to 2 to surface relational signal

# Primitive wishlists resolved at runtime (see resolver cell)
FT_AGG_PRIMITIVES_WISHLIST = [
    "mean", "sum", "count", "n_unique", "max", "min", "std", "mode"
]
FT_TRANS_PRIMITIVES_WISHLIST = ["day", "month", "year", "weekday", "hour"]

# PK/FK inference sensitivity (fraction of child FK values found in parent PK)
FK_COVERAGE_THRESHOLD = globals().get("FK_COVERAGE_THRESHOLD", 0.95)

# When building model inputs, drop raw join keys to avoid trivial leakage
DROP_JOIN_KEYS_FROM_X = globals().get("DROP_JOIN_KEYS_FROM_X", True)


## Notebook Configuration and Dataset Selection


Sets the dataset name (`DATASET`) and download flag (`DOWNLOAD`), then discovers all available tasks for the selected dataset using RelBench’s APIs. Filters tasks to only those compatible with TabPFN (classification and regression).


In [5]:
# Reuse existing config if present, otherwise set defaults
DATASET = globals().get("DATASET", "rel-f1")
DOWNLOAD = globals().get("DOWNLOAD", True)

# Discover tasks and keep only entity-level cls/reg tasks TabPFN can handle
def _is_tabpfn_friendly(task):
    return task.task_type in (
        TaskType.BINARY_CLASSIFICATION,
        TaskType.MULTICLASS_CLASSIFICATION,
        TaskType.MULTILABEL_CLASSIFICATION,
        TaskType.REGRESSION,
    )

_all = get_task_names(DATASET)  # shown in tutorials
TASKS = []
for tname in _all:
    try:
        t = get_task(DATASET, tname, download=DOWNLOAD)
        if _is_tabpfn_friendly(t):
            TASKS.append(tname)
    except Exception as e:
        print(f"[skip] {tname}: {e!s}")

print(f"{DATASET}: {len(TASKS)} TabPFN-friendly tasks -> {TASKS}")


rel-f1: 3 TabPFN-friendly tasks -> ['driver-position', 'driver-dnf', 'driver-top3']


### Patch RelBench Metrics (Optional)

In [6]:
# Patch relbench.metrics.skm.mean_squared_error to local mean_squared_error
from sklearn.metrics import mean_squared_error

relbench.metrics.skm.mean_squared_error = mean_squared_error

def patched_rmse(true, pred):
    if "squared" in inspect.signature(mean_squared_error).parameters:
        return mean_squared_error(true, pred, squared=False)
    else:
        return np.sqrt(mean_squared_error(true, pred))

relbench.metrics.rmse = patched_rmse

### Fetch Dataset Splits

Utility functions to load a task’s splits (`train`, `val`, `test`), convert them to pandas DataFrames, and extract features (`X`) and targets (`y`). Includes functions to:
* Load train/val/test splits for a task.
* Extract features/targets.
* Infer primary keys.
* Denormalize tables (one-hop join).
* Build data frames for both single-table and merged-table scenarios.


### Dataset loaders / frame builders

In [7]:
def fetch_splits(dataset_name: str, task_name: str, download: bool = True):
    task = get_task(dataset_name, task_name, download=download)
    # keep original columns (mask_input_cols=False so we see raw fields)
    splits = {
        split: task.get_table(split, mask_input_cols=False)
        for split in ("train", "val", "test")
    }
    return task, splits

def to_Xy(df: pd.DataFrame, target_col: str):
    y = df[target_col].to_numpy()
    X = df.drop(columns=[target_col])
    return X, y

def build_single_table_frames(task, splits):
    """
    Single-table mode: do NOT engineer features here.
    Just return raw base table X, y per split (target dropped from X).
    """
    frames = {}
    for split, table in splits.items():
        df = table.df.copy()
        X, y = to_Xy(df, task.target_col)
        frames[split] = (X, y, df)
    return frames


### Featuretools helpers: dtype cleanup, PK/FK inference, ES builders

In [8]:
def _clean_for_ft(df: pd.DataFrame) -> pd.DataFrame:
    """
    Make pandas dtypes friendly to Woodwork/Featuretools:
    - Keep nullable ints as IntegerNullable where possible,
    - Parse datetimes only for columns that strongly look like time fields.
    """
    x = df.copy()
    for c in x.columns:
        # Ensure datetimes are parsed only for proper time-like names
        if _looks_like_time_name(c) and not pd.api.types.is_datetime64_any_dtype(x[c]):
            try:
                parsed = pd.to_datetime(x[c], errors="coerce", utc=False)
                # keep as datetime if we have any valid values and more than 1 unique
                if parsed.notna().sum() > 0 and parsed.nunique(dropna=True) > 1:
                    x[c] = parsed
            except Exception:
                pass
    return x


def _candidate_pk(df: pd.DataFrame, table_name: str) -> Optional[str]:
    """
    Deterministic PK selection:
      1) exact 'id', 'Id', 'ID', f'{table}_id', f'{table}_ID', or any '*_id', '*Id', '*_ID', '*ID' that is unique and non-null
      2) otherwise any column that is fully unique and non-null
      3) else None  (we will make an index)
    """
    cols = list(df.columns)
    priorities = (
        ["id", "Id", "ID", f"{table_name}_id", f"{table_name}_ID"]
        + [c for c in cols if c.lower().endswith("_id")]
        + [c for c in cols if c.endswith("Id")]
        + [c for c in cols if c.endswith("ID")]
    )
    def is_good_key(s: pd.Series) -> bool:
        return s.notna().all() and s.is_unique
    for c in priorities:
        if c in df.columns and is_good_key(df[c]):
            return c
    for c in df.columns:
        s = df[c]
        if is_good_key(s):
            return c
    return None

def _infer_pk_fk_graph(tables: Dict[str, pd.DataFrame]) -> Dict[str, Dict[str, Any]]:
    pkeys: Dict[str, Optional[str]] = {}
    for tname, df in tables.items():
        pkeys[tname] = _candidate_pk(df, tname)

    parent_values = {}
    for parent, pk in pkeys.items():
        if pk is None:
            continue
        vals = pd.Series(tables[parent][pk]).dropna().astype(str).unique()
        parent_values[parent] = set(vals.tolist())

    fkeys: Dict[str, List[Dict[str, Any]]] = {parent: [] for parent in tables.keys()}
    reasons: List[str] = []

    for parent, pk in pkeys.items():
        if pk is None:
            continue
        pset = parent_values.get(parent, set())
        if not pset:
            continue
        sample_parent = pd.Series(list(pset))
        for child, cdf in tables.items():
            if child == parent:
                continue
            # quick prefilter: columns that look like keys
            for col in cdf.columns:
                if col.lower().endswith("_id") or col == pk or col.lower() == pk.lower():
                    s = cdf[col].dropna()
                    if len(s) == 0:
                        continue
                    # compatible cardinalities
                    if s.nunique(dropna=True) == 1:
                        continue
                    coverage = (s.astype(str).isin(sample_parent)).mean()
                    if coverage >= FK_COVERAGE_THRESHOLD:
                        fkeys[parent].append({"child": child, "fk": col, "coverage": float(coverage)})
                        reasons.append(f"{parent}.{pk} <- {child}.{col} (coverage={coverage:.3f})")
    return {"pkeys": pkeys, "fkeys": fkeys, "debug": {"reasons": reasons}}


_TIME_NAME_PATTERN = re.compile(
    r"(^|[_])("
    r"timestamp|datetime|event_time|eventtime|time|date|dt|"
    r"created_at|updated_at|inserted_at|occurred_at|recorded_at"
    r")([_]|$)",
    flags=re.IGNORECASE,
)

def _looks_like_time_name(colname: str) -> bool:
    """
    Strictly decide if a column name is time-like:
    - Matches whole-word tokens (e.g., 'order_date', 'event_time', 'created_at', 'datetime')
    - Accepts exact 'ts' or 'ts_*' / '*_ts' specifically, but not substrings inside other words.
    """
    name = str(colname)
    if name.lower() in {"ts"}:
        return True
    if name.lower().startswith("ts_") or name.lower().endswith("_ts"):
        return True
    return _TIME_NAME_PATTERN.search(name) is not None

def _detect_time_col(df: pd.DataFrame, pk: Optional[str] = None) -> Optional[str]:
    """
    Pick a valid datetime column that is NOT the PK.
    - Must be datetime64 after parsing attempt,
    - Must have >1 unique non-null values (so aggregates can vary),
    - Name must look like a time field using _looks_like_time_name.
    """
    for c in df.columns:
        if c == pk:
            continue
        if not _looks_like_time_name(c):
            continue
        s = df[c]
        if not pd.api.types.is_datetime64_any_dtype(s):
            try:
                s = pd.to_datetime(s, errors="coerce", utc=False)
            except Exception:
                continue
        if s.notna().sum() > 0 and s.nunique(dropna=True) > 1:
            return c
    return None

def _make_es_for_split(base_name: str,
                       pop_df: pd.DataFrame,
                       all_tables: Dict[str, pd.DataFrame],
                       schema: Dict[str, Dict[str, Any]]) -> ft.EntitySet:
    es = ft.EntitySet(id=f"rb_es_{base_name}")

    # population
    pop_df = _clean_for_ft(pop_df)
    pop_pk = schema["pkeys"].get(base_name)
    pop_time = _detect_time_col(pop_df, pk=pop_pk)

    def _add_df_safe(es_obj, name, df, pk_candidate):
        # choose index
        if pk_candidate is None or pk_candidate not in df.columns:
            idx = f"{name}__ft_index"
            make_idx = True
        else:
            idx = pk_candidate
            make_idx = False
        # choose time_index (must be present, datetime64, != index)
        tcol = _detect_time_col(df, pk=idx)
        if tcol is not None and tcol == idx:
            tcol = None
        # pass only if truly datetime
        if tcol is not None and not pd.api.types.is_datetime64_any_dtype(df[tcol]):
            tcol = None

        # add
        if make_idx:
            return es_obj.add_dataframe(
                dataframe_name=name, dataframe=df, index=idx, make_index=True, time_index=tcol
            ), idx
        else:
            return es_obj.add_dataframe(
                dataframe_name=name, dataframe=df, index=idx, time_index=tcol
            ), idx

    es, pop_idx = _add_df_safe(es, base_name, pop_df, pop_pk)

    # peripherals
    for tname, df in all_tables.items():
        if tname == base_name:
            continue
        df2 = _clean_for_ft(df)
        pk = schema["pkeys"].get(tname)
        es, _ = _add_df_safe(es, tname, df2, pk)

    # relationships
    for parent, rels in schema["fkeys"].items():
        if parent not in es.dataframe_dict:
            continue
        parent_pk = schema["pkeys"].get(parent) or f"{parent}__ft_index"
        for r in rels:
            child = r["child"]; fk = r["fk"]
            if child in es.dataframe_dict and fk in es[child].ww.columns:
                es = es.add_relationship(parent_dataframe_name=parent,
                                         parent_column_name=parent_pk,
                                         child_dataframe_name=child,
                                         child_column_name=fk)
    # minimal debug
    #print(f"[ES] Added dataframes={list(es.dataframe_dict.keys())} | relationships={len(es.relationships)}")
    return es


def _dfs_feature_matrices(es_train: ft.EntitySet,
                          es_val: ft.EntitySet,
                          es_test: ft.EntitySet,
                          target_df: str):
    agg_prims, trans_prims = _resolve_ft_primitives(
        FT_AGG_PRIMITIVES_WISHLIST, FT_TRANS_PRIMITIVES_WISHLIST
    )

    # simple sanity logs
    #print(f"[DFS] agg={agg_prims} | trans={trans_prims}")
    #print(f"[DFS] relationships(train)={len(es_train.relationships)}")

    fm_train, fdefs = ft.dfs(
        entityset=es_train,
        target_dataframe_name=target_df,
        agg_primitives=agg_prims,
        trans_primitives=trans_prims,
        max_depth=FT_MAX_DEPTH,
        features_only=False,
        verbose=False
    )
    fm_val = ft.calculate_feature_matrix(features=fdefs, entityset=es_val, verbose=False)
    fm_test = ft.calculate_feature_matrix(features=fdefs, entityset=es_test, verbose=False)

    fm_val = fm_val.reindex(columns=fm_train.columns, fill_value=np.nan)
    fm_test = fm_test.reindex(columns=fm_train.columns, fill_value=np.nan)
    return fm_train, fm_val, fm_test, fdefs


### Resolve FT primitive names to what's actually available

In [9]:
def _resolve_ft_primitives(
    agg_wishlist: list,
    trans_wishlist: list
):
    """
    Inspects installed Featuretools primitives and returns (agg_list, trans_list)
    with only supported names, mapping common aliases where possible.
    """
    prim_df = ft.primitives.list_primitives()
    have_agg = set(prim_df[prim_df.type == "aggregation"].name.str.lower())
    have_trans = set(prim_df[prim_df.type == "transform"].name.str.lower())

    # Common alias sets for version differences
    alias_groups_agg = [
        {"n_unique", "nunique", "num_unique", "count_unique"},
        {"mode", "mode_agg"},  # sometimes exposed with a suffix
        {"std", "standard_deviation"},
        {"count"}  # count is stable but keep set logic uniform
    ]
    alias_groups_trans = [
        {"weekday", "week_day"},
        {"year"},
        {"month"},
        {"day"},
        {"hour"},
    ]

    def pick_available(wishlist, have_set, alias_groups):
        out = []
        lower_wish = [w.lower() for w in wishlist]
        for w in lower_wish:
            # direct
            if w in have_set:
                out.append(w)
                continue
            # alias
            matched = False
            for group in alias_groups:
                if w in group:
                    # pick first available from group
                    for cand in group:
                        if cand in have_set:
                            out.append(cand)
                            matched = True
                            break
                if matched:
                    break
            # if still not matched, skip silently
        # Deduplicate while preserving order
        seen = set()
        uniq = []
        for p in out:
            if p not in seen:
                seen.add(p)
                uniq.append(p)
        return uniq

    agg = pick_available(agg_wishlist, have_agg, alias_groups_agg)
    trans = pick_available(trans_wishlist, have_trans, alias_groups_trans)

    if len(agg) == 0:
        # minimally guarantee some aggregations
        for fallback in ["mean", "sum", "count", "max", "min"]:
            if fallback in have_agg:
                agg.append(fallback)
        agg = list(dict.fromkeys(agg))  # dedupe

    return agg, trans


### Build Merged-Table Frames

In [10]:
def build_merged_table_frames(dataset, task, splits):
    """
    Build merged features with Featuretools:
    - Detect PK/FK pairs from the raw tables (deterministic rules, no guessing),
    - Build an EntitySet per split with the same relationships,
    - Generate relational aggregates via DFS,
    - Drop raw join keys (optional) and target from inputs,
    - Return (X, y, engineered_df) per split to match the rest of the notebook.
    """
    # --- Load DB tables (pandas) ---
    db = dataset.get_db()
    all_tables: Dict[str, pd.DataFrame] = {
        name: (tbl.df.copy() if hasattr(tbl, "df") else tbl.to_pandas())
        for name, tbl in db.table_dict.items()
    }

    # Identify the population table name used by RelBench task tables
    base_name = getattr(splits["train"], "name", "population")

    # Clean tables for FT typing
    all_tables = {k: _clean_for_ft(v) for k, v in all_tables.items()}

    # --- Infer PK/FK graph using Woodwork typing + coverage checks ---
    schema = _infer_pk_fk_graph(all_tables)

    # If FT didn't find a PK for the population, ensure we produce one later
    if schema["pkeys"].get(base_name) is None and base_name in all_tables:
        # let _make_es_for_split create an artificial index
        pass

    # --- Build EntitySets per split ---
    pop = {split: tbl.df.copy() for split, tbl in splits.items()}  # keep your original behavior
    es_train = _make_es_for_split(base_name, pop["train"], all_tables, schema)
    es_val   = _make_es_for_split(base_name, pop["val"],   all_tables, schema)
    es_test  = _make_es_for_split(base_name, pop["test"],  all_tables, schema)

    # --- DFS -> aligned feature matrices ---
    fe_train, fe_val, fe_test, feature_defs = _dfs_feature_matrices(es_train, es_val, es_test, target_df=base_name)

    # --- Prepare (X, y), optionally drop raw join keys from model inputs ---
    def _drop_keys(dfX: pd.DataFrame) -> pd.DataFrame:
        if not DROP_JOIN_KEYS_FROM_X:
            return dfX
        drop_cols = set()

        # population PK
        pop_pk = schema["pkeys"].get(base_name)
        if pop_pk and pop_pk in dfX.columns:
            drop_cols.add(pop_pk)

        # any FKs that appear on population table columns
        for parent, rels in schema["fkeys"].items():
            for r in rels:
                fk = r["fk"]
                if fk in dfX.columns:
                    drop_cols.add(fk)

        return dfX.drop(columns=list(drop_cols), errors="ignore")

    def _to_Xy_safe(fe_df: pd.DataFrame, raw_df: pd.DataFrame):
        # Ensure target is not part of X
        X = fe_df.copy()
        if task.target_col in X.columns:
            X = X.drop(columns=[task.target_col], errors="ignore")
        X = _drop_keys(X)
        # y from the raw population split
        y = raw_df[task.target_col].to_numpy()
        return X, y

    Xtr, ytr = _to_Xy_safe(fe_train, pop["train"])
    Xva, yva = _to_Xy_safe(fe_val,   pop["val"])
    Xte, yte = _to_Xy_safe(fe_test,  pop["test"])

    return {
        "train": (Xtr, ytr, fe_train),
        "val":   (Xva, yva, fe_val),
        "test":  (Xte, yte, fe_test),
    }


# Merged-mode diagnostics

In [11]:
# -------- Diagnostics: single vs merged feature spaces --------
def compare_single_vs_merged(di_single, di_merged, target_name: str, label="train"):
    Xs, ys, df_s = di_single[label]
    Xm, ym, df_m = di_merged[label]

    print(f"[Diag:{label}] X_single: {Xs.shape}, X_merged: {Xm.shape}")
    print(f"[Diag:{label}] y match: {np.allclose(ys, ym)}")

    # columns present only in merged
    only_m = [c for c in Xm.columns if c not in Xs.columns]
    only_s = [c for c in Xs.columns if c not in Xm.columns]
    print(f"[Diag:{label}] new cols in merged: {len(only_m)} ; dropped vs merged: {len(only_s)}")

    # low-variance check on merged-only columns
    if only_m:
        lv = Xm[only_m].nunique(dropna=True)
        print(f"[Diag:{label}] merged-only columns nunique (head):")
        print(lv.sort_values().head(10))

    try:
        from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
        y = ys
        # pick MI type by y dtype
        if pd.api.types.is_integer_dtype(y) or pd.api.types.is_bool_dtype(y):
            mi_fun = mutual_info_classif
        else:
            mi_fun = mutual_info_regression

        # compute MI for up to 200 merged-only cols to keep it fast
        subset = only_m[:200]
        if subset:
            mi = mi_fun(pd.DataFrame(Xm[subset]).fillna(0), y, discrete_features='auto', random_state=0)
            mi_s = pd.Series(mi, index=subset).sort_values(ascending=False)
            print(f"[Diag:{label}] top merged-only MI features:")
            print(mi_s.head(10))
        else:
            print(f"[Diag:{label}] no merged-only columns to test MI.")
    except Exception as e:
        print(f"[Diag:{label}] MI calc skipped: {e}")


## Vectorization Wrapper (Version-Safe)

Initializes a `TableVectorizer` with only supported arguments for the installed `skrub` or `dirty_cat` version, ensuring compatibility. Transforms `train`, `val`, and `test` splits into numerical feature matrices, converting them to dense format if necessary.


### Helper Functions for Vectorization

In [12]:
# This function creates a `TableVectorizer` instance with version-specific arguments.
def _make_table_vectorizer():
    sig = inspect.signature(TableVectorizer.__init__)
    allowed = set(sig.parameters.keys()) - {"self"}

    tv_kwargs = {}

    # Only set kwargs that actually exist in the installed version
    if "cardinality_threshold" in allowed:
        tv_kwargs["cardinality_threshold"] = globals().get("CARDINALITY_THRESHOLD", 1000)

    # Some versions expose this; others don't, guard it
    if "high_cardinality_transformer" in allowed:
        tv_kwargs["high_cardinality_transformer"] = globals().get("HIGH_CARD_TRANSFORMER", "hashing")

    # Optional knobs if you define them globally and the version supports them
    if "text_separator" in allowed and "TEXT_SEPARATOR" in globals():
        tv_kwargs["text_separator"] = globals()["TEXT_SEPARATOR"]
    if "numerical_transformer" in allowed and "NUMERICAL_TRANSFORMER" in globals():
        tv_kwargs["numerical_transformer"] = globals()["NUMERICAL_TRANSFORMER"]
    if "categorical_transformer" in allowed and "CATEGORICAL_TRANSFORMER" in globals():
        tv_kwargs["categorical_transformer"] = globals()["CATEGORICAL_TRANSFORMER"]

    return TableVectorizer(**tv_kwargs)

# Converts a sparse matrix to a dense NumPy array, handling cases where the input is already dense or does not support `.toarray()`.
def _to_dense(X):
    try:
        # scipy sparse matrices have .toarray()
        return X.toarray() if hasattr(X, "toarray") else X
    except Exception:
        return X

# Vectorizes the training, validation, and test splits using a `TableVectorizer`. It initializes the vectorizer, fits it on the training data, and transforms all splits into dense NumPy arrays. Returns the vectorizer and the transformed matrices.
def vectorize_splits(X_train, X_val, X_test):
    # Fit only on training data to prevent data leakage
    tv = _make_table_vectorizer()
    Xt = _to_dense(tv.fit_transform(X_train))
    Xv = _to_dense(tv.transform(X_val))
    Xs = _to_dense(tv.transform(X_test))
    return tv, Xt, Xv, Xs


### Training and Prediction Helpers

In [13]:
# This function subsamples the training data to a maximum size defined by `TABPFN_MAX`. If the dataset is smaller than this cap, it returns the full dataset; otherwise, it randomly samples without replacement.
def _subsample(X, y, cap=TABPFN_MAX, seed=SEED):
    if len(X) <= cap:
        return X, y, np.arange(len(X))
    idx = np.random.RandomState(seed).choice(len(X), size=cap, replace=False)
    if hasattr(X, "iloc"):
        Xs = X.iloc[idx]
    else:
        Xs = X[idx]
    ys = y[idx]
    return Xs, ys, idx

# Fits a TabPFN model (either classifier or regressor) based on the task type. It initializes the model with the specified device and number of estimators, then fits it to the provided training data.
def _fit_tabpfn(task, Xt, yt):
    if task.task_type == TaskType.REGRESSION and TabPFNRegressor is not None:
        model = TabPFNRegressor(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
            ignore_pretraining_limits=True,
        )
    else:
        model = TabPFNClassifier(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
            ignore_pretraining_limits=True,
        )
    model.fit(Xt, yt)
    return model

# Helper function to make predictions for a given task using the fitted model. It handles different task types (regression, binary classification, multiclass/multilabel) and returns the appropriate prediction format.
def _predict_for_task(task, model, X):
    # align with RelBench evaluators: AUROC expects probabilities for the positive class
    if task.task_type == TaskType.REGRESSION:
        return model.predict(X)
    proba = model.predict_proba(X)
    if task.task_type == TaskType.BINARY_CLASSIFICATION:
        return proba[:, 1]
    else:
        # multiclass/multilabel: pass full probability matrix
        return proba

### Run TabPFN on Selected Tasks

Runs TabPFN on a specified dataset and task, handling both single-table and merged-table modes. It vectorizes the data, fits the model, makes predictions, and evaluates performance using RelBench’s evaluators. Returns a dictionary with results.

In [14]:
dataset = get_dataset(DATASET, download=DOWNLOAD)

for task_name in TASKS:
    task, splits = fetch_splits(DATASET, task_name, download=DOWNLOAD)

    frames_single = build_single_table_frames(task, splits)
    frames_merged = build_merged_table_frames(dataset, task, splits)

    # Display head of train, val and test of frames_merged
    # Show all column names and count for merged train split
    merged_train_df = frames_merged["train"][2]
    print("Merged columns:", list(merged_train_df.columns))
    print("Number of columns:", len(merged_train_df.columns))


    #compare_single_vs_merged(frames_single, frames_merged, target_name=task.target_col, label="train")


Loading Database object from /Users/michaelflppv/Library/Caches/relbench/rel-f1/db...
Done in 0.01 seconds.
Merged columns: ['driverId', 'position', 'DAY(date)', 'HOUR(date)', 'MONTH(date)', 'WEEKDAY(date)', 'YEAR(date)']
Number of columns: 7
Merged columns: ['driverId', 'did_not_finish', 'DAY(date)', 'HOUR(date)', 'MONTH(date)', 'WEEKDAY(date)', 'YEAR(date)']
Number of columns: 7
Merged columns: ['driverId', 'qualifying', 'DAY(date)', 'HOUR(date)', 'MONTH(date)', 'WEEKDAY(date)', 'YEAR(date)']
Number of columns: 7


In [15]:
def run_tabpfn_on_task(dataset_name: str, task_name: str, mode: str = "single") -> Dict[str, Any]:
    # Load dataset and task splits
    dataset = get_dataset(dataset_name, download=DOWNLOAD)
    task, splits = fetch_splits(dataset_name, task_name, download=DOWNLOAD)

    # Ensure the task is compatible with TabPFN
    if mode == "single":
        frames = build_single_table_frames(task, splits)
    elif mode == "merged":
        try:
            frames = build_merged_table_frames(dataset, task, splits)
        except Exception as e:
            print(f"merged mode failed: {e}, falling back to single-table mode.")
            raise
    else:
        raise ValueError("mode must be 'single' or 'merged'")

    # Extract features and targets for each split
    (Xtr, ytr, _dftr) = frames["train"]
    (Xva, yva, dfva)  = frames["val"]
    (Xte, yte, dfte)  = frames["test"]

    # Vectorize
    tv, Xt, Xv, Xs = vectorize_splits(Xtr, Xva, Xte)

    # Respect TabPFN's sample cap
    Xt_cap, yt_cap, _ = _subsample(Xt, ytr, cap=TABPFN_MAX, seed=SEED)

    # Fit
    model = _fit_tabpfn(task, Xt_cap, yt_cap)

    # Predict & Evaluate with RelBench evaluators
    val_pred  = _predict_for_task(task, model, Xv)
    test_pred = _predict_for_task(task, model, Xs)

    # Align predictions with original DataFrame indices for evaluation
    val_metrics  = task.evaluate(val_pred,  splits["val"])
    test_metrics = task.evaluate(test_pred, splits["test"])

    # Convert metrics to a dictionary, ensuring all values are floats
    out = {
        "dataset": dataset_name,
        "task": task_name,
        "mode": mode,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "n_train_used": len(Xt_cap),
        "n_train_total": len(Xt),
        "n_val": len(Xv),
        "n_test": len(Xs),
    }
    return out

## Orchestrator for Benchmark Runs

Iterates over all discovered tasks and runs TabPFN in both **single** and **merged** modes. Collects performance metrics for validation and test splits into a results table, handling failures gracefully. Sorts results for easier comparison.


In [16]:
MODES = globals().get("MODES", ["single", "merged"])

records = []
failures = []

# Run TabPFN on all tasks in both modes and collect results
for task_name in TASKS:
    for mode in MODES:
        try:
            res = run_tabpfn_on_task(DATASET, task_name, mode=mode)
            # Flatten metrics for val and test, one metric per row
            for split in ["val", "test"]:
                metrics = res.get(f"{split}_metrics") or {}
                for metric_name, metric_value in metrics.items():
                    # Only add rows with non-empty metric_value
                    if metric_value is not None and not (isinstance(metric_value, float) and np.isnan(metric_value)):
                        records.append({
                            "dataset": res.get("dataset", DATASET),
                            "task": res.get("task", task_name),
                            "split": split,
                            "mode": res.get("mode", mode),
                            "method": "TabPFN_experimental_v1.0",
                            "metric": metric_name,
                            "score": metric_value,
                        })
        except Exception as e:
            msg = f"[{DATASET} | {task_name} | {mode}] failed: {e!s}"
            print(msg)
            failures.append(msg)

# Convert collected records into a DataFrame
results_df = pd.DataFrame.from_records(records)

# If no successful runs were recorded, display a message and create an empty DataFrame
if results_df.empty:
    print("No successful runs were recorded. Check the failure messages above.")
    results_df = pd.DataFrame(
        columns=["dataset", "task", "split", "mode", "method", "metric", "score"]
    )
else:
    # Ensure required sort keys exist even if some rows missed them
    for col in ["task", "mode"]:
        if col not in results_df.columns:
            results_df[col] = pd.NA
    # Sort only by the columns that exist to avoid KeyError
    sort_keys = [c for c in ["task", "mode", "metric"] if c in results_df.columns]
    if sort_keys:
        results_df = results_df.sort_values(sort_keys)

results_df

KeyboardInterrupt: 

### Save Results to CSV

In [117]:
# Specify output directory
out_dir = globals().get("OUT_DIR", "outputs")
os.makedirs(out_dir, exist_ok=True)

# Change timestamp format to "dd.mm.yyyy-hh:mm"
timestamp = time.strftime("%d.%m.%Y-%H:%M")
csv_name = f"tabpfn_{DATASET}_{timestamp}.csv"
csv_path = os.path.join(out_dir, csv_name)

# Round all numerical results to 4 decimal places before saving
if "score" in results_df.columns:
    if pd.api.types.is_numeric_dtype(results_df["score"]):
        results_df["score"] = results_df["score"].round(4)
    else:
        # Optionally, try to convert to numeric first
        results_df["score"] = pd.to_numeric(results_df["score"], errors="coerce").round(4)

# Filter out rows with empty score values before saving
if "score" in results_df.columns:
    results_df = results_df[results_df["score"].notnull()]

# Save the results DataFrame to a CSV file
results_df.to_csv(csv_path, index=False)
print(f"Saved results to: {csv_path}")


Saved results to: outputs/tabpfn_rel-f1_21.08.2025-00:21.csv
