# **TabPFN Relational Benchmark**
This notebook benchmarks the performance of TabPFN models on datasets from RelBench in two scenarios:
1. **Single Table** – Using only the target entity table.
2. **Merged Table** – Using a naively denormalized table obtained by joining related tables.

It automates dataset loading, preprocessing (including date feature engineering), vectorization, model training, prediction, and evaluation for all compatible tasks within a chosen RelBench dataset. The results allow comparing model performance between single-table and merged-table configurations.


## Import Libraries

In [22]:
# --- Standard Library ---
import os
import time
import inspect

# --- Third-Party Libraries ---
import pandas as pd
import re
import numpy as np
from typing import Dict, Optional, Any, List, Tuple

# --- Skrub / Sentence Transformers ---
from skrub import TableVectorizer

# --- RelBench ---
from relbench.datasets import get_dataset
from relbench.tasks import get_task, get_task_names
from relbench.base import TaskType
import relbench.metrics

# --- TabPFN ---
from tabpfn import TabPFNClassifier, TabPFNRegressor

# --- Featuretools ---
import featuretools as ft

In [23]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

### Set Global Configuration

In [24]:
# Device selection (CPU, CUDA, or MPS if available)
def get_device():
    # Uncomment for auto-detection
    # if torch.backends.mps.is_available():
    #     return "mps"
    # elif torch.cuda.is_available():
    #     return "cuda"
    # else:
    #     return "cpu"
    return "cpu"  # Default: CPU

DEVICE = get_device()
print(f"Using device: {DEVICE}")

# Dataset and experiment settings
DATASET        = globals().get("DATASET", "rel-f1")           # Default dataset
MAIN_TABLE     = globals().get("MAIN_TABLE", "qualifying")    # Main table for single-table mode
SEED           = globals().get("SEED", 42)                    # Random seed
N_ESTIMATORS   = globals().get("N_ESTIMATORS", 16)            # TabPFN estimators
TABPFN_MAX     = globals().get("TABPFN_MAX", 10000)            # Max TabPFN samples

# getML engine state
_ENGINE_STARTED = False

# Optional: quiet Featuretools logs
ft.config.log_print_threshold = 1000000

Using device: cpu


### Featuretools knobs (safe defaults)

In [25]:
# ---- Featuretools knobs (stable across datasets) ----
FT_MAX_DEPTH = globals().get("FT_MAX_DEPTH", 2)

# Desired primitives (resolved at runtime to what your FT version supports)
FT_AGG_PRIMITIVES_WISHLIST = ["mean", "sum", "count", "n_unique", "max", "min", "std", "mode"]
FT_TRANS_PRIMITIVES_WISHLIST = ["day", "month", "year", "weekday", "hour"]

# Drop raw join keys from model inputs (to avoid trivial leakage)
DROP_JOIN_KEYS_FROM_X = globals().get("DROP_JOIN_KEYS_FROM_X", True)

# ---- Generic PK/FK inference thresholds ----
PK_MIN_UNIQUE_RATIO = globals().get("PK_MIN_UNIQUE_RATIO", 0.98)        # exclude constant-ish columns in PK
PK_MAX_NULL_RATIO   = globals().get("PK_MAX_NULL_RATIO", 0.02)          # exclude columns with too many nulls
FK_COVERAGE_THRESHOLD = globals().get("FK_COVERAGE_THRESHOLD", 0.90)    # minimum referential coverage for FKs
FK_MIN_UNIQUE_RATIO = globals().get("FK_MIN_UNIQUE_RATIO", 0.01)        # exclude constant-ish columns in FK
FK_MAX_UNIQUE_RATIO = globals().get("FK_MAX_UNIQUE_RATIO", 0.95)        # exclude near-unique -> likely not an FK
KEY_NAME_BONUS = globals().get("KEY_NAME_BONUS", 0.05)

# Name hint boost: helps columns that *look* like keys to win ties
KEY_NAME_BONUS = globals().get("KEY_NAME_BONUS", 0.05)


## Notebook Configuration and Dataset Selection


Sets the dataset name (`DATASET`) and download flag (`DOWNLOAD`), then discovers all available tasks for the selected dataset using RelBench’s APIs. Filters tasks to only those compatible with TabPFN (classification and regression).


In [26]:
# Reuse existing config if present, otherwise set defaults
DATASET = globals().get("DATASET", "rel-f1")
DOWNLOAD = globals().get("DOWNLOAD", True)

# Discover tasks and keep only entity-level cls/reg tasks TabPFN can handle
def _is_tabpfn_friendly(task):
    """
    Check if a task is compatible with TabPFN.
    """
    return task.task_type in (
        TaskType.BINARY_CLASSIFICATION,
        TaskType.MULTICLASS_CLASSIFICATION,
        TaskType.MULTILABEL_CLASSIFICATION,
        TaskType.REGRESSION,
    )

_all = get_task_names(DATASET)  # shown in tutorials
TASKS = []
for tname in _all:
    try:
        t = get_task(DATASET, tname, download=DOWNLOAD)
        if _is_tabpfn_friendly(t):
            TASKS.append(tname)
    except Exception as e:
        print(f"[skip] {tname}: {e!s}")

print(f"{DATASET}: {len(TASKS)} TabPFN-friendly tasks -> {TASKS}")


rel-f1: 3 TabPFN-friendly tasks -> ['driver-position', 'driver-dnf', 'driver-top3']


### Patch RelBench Metrics (Optional)

In [27]:
# Patch relbench.metrics.skm.mean_squared_error to local mean_squared_error
from sklearn.metrics import mean_squared_error

relbench.metrics.skm.mean_squared_error = mean_squared_error

def patched_rmse(true, pred):
    """
    Compute RMSE using sklearn's mean_squared_error.
    """
    if "squared" in inspect.signature(mean_squared_error).parameters:
        return mean_squared_error(true, pred, squared=False)
    else:
        return np.sqrt(mean_squared_error(true, pred))

relbench.metrics.rmse = patched_rmse

### Fetch Dataset Splits

Utility functions to load a task’s splits (`train`, `val`, `test`), convert them to pandas DataFrames, and extract features (`X`) and targets (`y`). Includes functions to:
* Load train/val/test splits for a task.
* Extract features/targets.
* Infer primary keys.
* Denormalize tables (one-hop join).
* Build data frames for both single-table and merged-table scenarios.


### Dataset loaders / frame builders

In [28]:
def fetch_splits(dataset_name: str, task_name: str, download: bool = True):
    """
    Fetch train/val/test splits for a given dataset and task.
    """
    task = get_task(dataset_name, task_name, download=download)
    # keep original columns (mask_input_cols=False so we see raw fields)
    splits = {
        split: task.get_table(split, mask_input_cols=False)
        for split in ("train", "val", "test")
    }
    return task, splits

def to_Xy(df: pd.DataFrame, target_col: str):
    y = df[target_col].to_numpy()
    X = df.drop(columns=[target_col])
    return X, y

def build_single_table_frames(task, splits):
    """
    Single-table mode: do NOT engineer features here.
    Just return raw base table X, y per split (target dropped from X).
    """
    frames = {}

    for split, table in splits.items():
        df = table.df.copy()

        if task.target_col not in df.columns:
            raise ValueError(f"Target column '{task.target_col}' not found in table '{df.name}'")

        X, y = to_Xy(df, task.target_col)
        frames[split] = (X, y, df)
    return frames


### Featuretools helpers: dtype cleanup, PK/FK inference, ES builders

In [29]:
def _normalize_name(s: str) -> str:
    """
    Normalize a string to a lowercase alphanumeric representation.
    """
    return re.sub(r'[^a-z0-9]', '', str(s).lower())

def _name_matches_pk(child_col: str, parent_pk: str) -> bool:
    """
    Require that child column name matches parent PK name (normalized).
    """
    return _normalize_name(child_col) == _normalize_name(parent_pk)

def _dtype_category(dt) -> str:
    """
    Determine the category of a pandas dtype.
    """
    try:
        import pandas.api.types as pat
        if pat.is_datetime64_any_dtype(dt): return "dt"
        if pat.is_integer_dtype(dt) or pat.is_bool_dtype(dt): return "int"
        if pat.is_float_dtype(dt): return "float"
        if pat.is_string_dtype(dt) or dt == object: return "str"
    except Exception:
        pass
    return "other"

def _dtype_compatible(dt_parent, dt_child) -> bool:
    """
    Determine if two dtypes are compatible for PK/FK relationships.
    """
    a, b = _dtype_category(dt_parent), _dtype_category(dt_child)
    # keys should be int↔int or str↔str; allow bool as int
    if a == "int" and b == "int": return True
    if a == "str" and b == "str": return True
    return False

### Featuretools helpers: robust, dataset-agnostic

In [30]:
def _clean_for_ft(df: pd.DataFrame) -> pd.DataFrame:
    """
    dtype cleanup for Featuretools:
      - convert object/string columns to categorical,
      - parse datetime-like columns,
      - ensure no mixed dtypes in columns.
    """
    x = df.copy()
    # parse only strong time-like names
    for c in x.columns:
        if _looks_like_time_name(c) and not pd.api.types.is_datetime64_any_dtype(x[c]):
            try:
                parsed = pd.to_datetime(x[c], errors="coerce")
                if parsed.notna().sum() > 0 and parsed.nunique(dropna=True) > 1:
                    x[c] = parsed
            except Exception:
                pass
    return x

def _is_key_like(colname: str) -> bool:
    """
    key-like name detection:
      - ends with _id, id, _key, key (case-insensitive),
      - ends with ID or Id (case-sensitive).
    """
    s = str(colname)
    sl = s.lower()
    return (
        sl.endswith("_id") or sl.endswith("id") or
        sl.endswith("_key") or sl.endswith("key") or
        s.endswith("ID") or s.endswith("Id")
    )

def _score_pk(series: pd.Series) -> float:
    """
    PK score based on uniqueness and null ratio
    """
    n = len(series)
    if n == 0:
        return 0.0
    nunq = series.nunique(dropna=True)
    null_ratio = 1.0 - series.notna().mean()
    unique_ratio = nunq / max(1, n)
    score = unique_ratio - null_ratio  # prefer unique, penalize nulls
    if _is_key_like(series.name):
        score += KEY_NAME_BONUS
    return float(score)

def _candidate_pk(df: pd.DataFrame) -> Optional[str]:
    """
    Candidate primary key detection
    """
    best_col, best_score = None, -1.0
    for c in df.columns:
        s = df[c]
        # quick skip for obvious non-keys
        if pd.api.types.is_float_dtype(s) and not pd.api.types.is_integer_dtype(s):
            # allow floats only if they look like ints after dropna
            if not np.allclose(s.dropna() % 1, 0):
                continue
        sc = _score_pk(s)
        if sc > best_score:
            best_col, best_score = c, sc
    # require thresholds
    if best_col is None:
        return None
    s = df[best_col]
    nunq = s.nunique(dropna=True)
    n = len(s)
    null_ratio = 1.0 - s.notna().mean()
    if (nunq / max(1, n) >= PK_MIN_UNIQUE_RATIO) and (null_ratio <= PK_MAX_NULL_RATIO):
        return best_col
    return None

# ---------- Time detection (strict; never equals PK) ----------
_TIME_NAME_PATTERN = re.compile(
    r"(^|[_])("
    r"timestamp|datetime|event_time|eventtime|time|date|dt|"
    r"created_at|updated_at|inserted_at|occurred_at|recorded_at"
    r")([_]|$)", re.IGNORECASE,
)

def _looks_like_time_name(colname: str) -> bool:
    """
    Time-like name detection
    """
    name = str(colname)
    if name.lower() == "ts" or name.lower().startswith("ts_") or name.lower().endswith("_ts"):
        return True
    return _TIME_NAME_PATTERN.search(name) is not None

def _detect_time_col(df: pd.DataFrame, pk: Optional[str] = None) -> Optional[str]:
    """
    Try to detect a time-like column in the DataFrame.
    """
    for c in df.columns:
        if c == pk:
            continue
        if not _looks_like_time_name(c):
            continue
        s = df[c]
        if not pd.api.types.is_datetime64_any_dtype(s):
            try:
                s = pd.to_datetime(s, errors="coerce")
            except Exception:
                continue
        if s.notna().sum() > 0 and s.nunique(dropna=True) > 1:
            return c
    return None

# ---------- PK/FK inference across all tables (incl. MAIN_TABLE) ----------
def _infer_pk_fk_graph_auto(all_tables: Dict[str, pd.DataFrame]) -> Dict[str, Dict[str, Any]]:
    """
    Schema-agnostic PK/FK inference with strict name+dtype gating:
      - pick single-column PK by uniqueness/nulls (+name hint),
      - propose FKs only when child column NAME == parent PK NAME (normalized),
      - require dtype compatibility, not-near-unique child, and high referential coverage,
      - forbid child's own PK as FK; forbid self-relationships.
    """
    # 1) PKs
    pkeys: Dict[str, Optional[str]] = {}
    for tname, df in all_tables.items():
        pkeys[tname] = _candidate_pk(df)

    # 2) parent PK value sets + dtypes
    parent_values: Dict[str, set] = {}
    parent_pk_dtype: Dict[str, Any] = {}
    for parent, pk in pkeys.items():
        if pk and pk in all_tables[parent].columns:
            parent_pk_dtype[parent] = all_tables[parent][pk].dtype
            vals = pd.Series(all_tables[parent][pk]).dropna().astype(str).unique()
            if len(vals) > 0:
                parent_values[parent] = set(vals.tolist())

    # 3) scan child tables for key-like columns that *name-match* a parent PK
    fkeys: Dict[str, List[Dict[str, Any]]] = {parent: [] for parent in all_tables.keys()}
    reasons: List[str] = []

    for child, cdf in all_tables.items():
        child_pk = pkeys.get(child)
        for parent, pk in pkeys.items():
            if not pk:
                continue
            if parent == child:
                continue
            if parent not in parent_values:
                continue

            # candidate child columns whose NAME matches the parent's PK
            for col in cdf.columns:
                if col == child_pk:
                    continue
                if not _name_matches_pk(col, pk):
                    continue
                # dtype compatibility
                if not _dtype_compatible(parent_pk_dtype.get(parent, None), cdf[col].dtype):
                    continue
                # child column cardinality constraints (many-to-one)
                s = cdf[col].dropna()
                if s.empty:
                    continue
                nunq = s.nunique(dropna=True)
                uniq_ratio = nunq / max(1, len(s))
                if uniq_ratio < FK_MIN_UNIQUE_RATIO or uniq_ratio > FK_MAX_UNIQUE_RATIO:
                    continue
                # referential coverage check
                coverage = (s.astype(str).isin(pd.Series(list(parent_values[parent])))).mean()
                if coverage >= FK_COVERAGE_THRESHOLD:
                    lst = fkeys.setdefault(parent, [])
                    if not any(r["child"] == child and r["fk"] == col for r in lst):
                        lst.append({"child": child, "fk": col, "coverage": float(coverage)})
                        reasons.append(f"{parent}.{pk} <- {child}.{col} (coverage={coverage:.3f})")

    return {"pkeys": pkeys, "fkeys": fkeys, "debug": {"reasons": reasons}}


def _prune_fk_cycles(schema: Dict[str, Dict[str, Any]], base_name: str) -> Dict[str, Dict[str, Any]]:
    """
    Remove edges that introduce cycles. Preference:
      - keep edges incident to the MAIN_TABLE (base_name),
      - keep edges with higher coverage,
      - drop the weakest edge per detected cycle.
    """
    pkeys = schema["pkeys"]
    fkeys = {k: list(v) for k, v in schema["fkeys"].items()}

    def edges():
        for parent, rels in fkeys.items():
            for r in rels:
                yield (parent, r["child"], r)

    # build adjacency
    def build_adj():
        adj = {}
        for u, v, r in edges():
            adj.setdefault(u, []).append((v, r))
        return adj

    # cycle detection via DFS
    def find_cycle():
        adj = build_adj()
        visited, stack = {}, []
        def dfs(u):
            visited[u] = 1
            stack.append(u)
            for v, r in adj.get(u, []):
                if visited.get(v, 0) == 0:
                    cyc = dfs(v)
                    if cyc: return cyc
                elif visited.get(v, 0) == 1:
                    # cycle found: collect path u->...->v
                    if v in stack:
                        i = stack.index(v)
                        cyc_nodes = stack[i:] + [v]
                        cyc_edges = []
                        # collect edges on this cycle
                        for a, b in zip(cyc_nodes, cyc_nodes[1:]):
                            # find the relationship object
                            rel = None
                            for nb, r2 in adj.get(a, []):
                                if nb == b:
                                    rel = r2; break
                            if rel: cyc_edges.append((a, b, rel))
                        return cyc_edges
            stack.pop()
            visited[u] = 2
            return None

        for u in set(list(pkeys.keys()) + [base_name]):
            if visited.get(u, 0) == 0:
                cyc = dfs(u)
                if cyc: return cyc
        return None

    # iteratively prune one weakest edge per cycle
    removed = 0
    while True:
        cyc = find_cycle()
        if not cyc:
            break
        # pick edge to drop: avoid dropping edges touching base_name if possible; lowest coverage wins
        cand = []
        for a, b, r in cyc:
            score = r.get("coverage", 0.0)
            touches_base = int(a == base_name or b == base_name)
            cand.append((touches_base, score, a, b, r))
        # sort: prefer removing edges NOT touching base_name, and with lowest coverage
        cand.sort(key=lambda x: (x[0], x[1]))
        _, _, a, b, r = cand[0]
        # remove
        fkeys[a] = [x for x in fkeys.get(a, []) if not (x["child"] == b and x["fk"] == r["fk"])]
        removed += 1
    if removed:
        print(f"[Schema] Pruned {removed} cyclic relationship(s)")
    return {"pkeys": pkeys, "fkeys": fkeys, "debug": schema.get("debug", {})}



# ---------- EntitySet builder (safe: never pk==time_index) ----------
def _make_es_for_split(base_name: str,
                       pop_df: pd.DataFrame,
                       all_tables: Dict[str, pd.DataFrame],
                       schema: Dict[str, Dict[str, Any]]) -> ft.EntitySet:
    """
    Extract a Featuretools EntitySet from the given base table and all related tables.
    """
    es = ft.EntitySet(id=f"rb_es_{base_name}")

    def _add_df_safe(es_obj, name, df, pk_candidate):
        df = _clean_for_ft(df)
        # index
        if pk_candidate is None or pk_candidate not in df.columns:
            idx = f"{name}__ft_index"; make_idx = True
        else:
            idx = pk_candidate; make_idx = False
        # time index
        tcol = _detect_time_col(df, pk=idx)
        if tcol is not None and tcol == idx:
            tcol = None
        if tcol is not None and not pd.api.types.is_datetime64_any_dtype(df[tcol]):
            tcol = None

        if make_idx:
            return es_obj.add_dataframe(dataframe_name=name, dataframe=df, index=idx, make_index=True, time_index=tcol), idx
        else:
            return es_obj.add_dataframe(dataframe_name=name, dataframe=df, index=idx, time_index=tcol), idx

    # add MAIN_TABLE
    pop_pk = schema["pkeys"].get(base_name)
    es, pop_idx = _add_df_safe(es, base_name, pop_df, pop_pk)

    # add other tables
    for tname, df in all_tables.items():
        if tname == base_name:
            continue
        pk = schema["pkeys"].get(tname)
        es, _ = _add_df_safe(es, tname, df, pk)

    # relationships (parent: one, child: many; FK lives on child)
    rel_added, rel_skipped = 0, 0
    for parent, rels in schema["fkeys"].items():
        if parent not in es.dataframe_dict:
            continue
        parent_pk = schema["pkeys"].get(parent) or f"{parent}__ft_index"
        for r in rels:
            child = r["child"]; fk = r["fk"]
            if child not in es.dataframe_dict:
                rel_skipped += 1; continue
            child_idx = es[child].ww.index
            # Skip if FK column is the child's index (illegal in FT) or missing
            if fk == child_idx or fk not in es[child].ww.columns:
                rel_skipped += 1; continue
            es = es.add_relationship(parent_dataframe_name=parent,
                                     parent_column_name=parent_pk,
                                     child_dataframe_name=child,
                                     child_column_name=fk)
            rel_added += 1

    return es

# ---------- Primitive resolver (version-agnostic) ----------
def _resolve_ft_primitives(agg_wishlist: list, trans_wishlist: list) -> Tuple[List[str], List[str]]:
    """
    Resolve Featuretools primitives based on wishlist and available primitives.
    """
    prim_df = ft.primitives.list_primitives()
    have_agg = set(prim_df[prim_df.type == "aggregation"].name.str.lower())
    have_trans = set(prim_df[prim_df.type == "transform"].name.str.lower())

    alias_groups_agg = [
        {"n_unique", "nunique", "num_unique", "count_unique"},
        {"mode", "mode_agg"},
        {"std", "standard_deviation"},
        {"count"}
    ]
    alias_groups_trans = [
        {"weekday", "week_day"},
        {"year"}, {"month"}, {"day"}, {"hour"}
    ]
    def pick(wish, have, groups):
        out = []
        for w in [w.lower() for w in wish]:
            if w in have:
                out.append(w); continue
            matched = False
            for g in groups:
                if w in g:
                    for cand in g:
                        if cand in have:
                            out.append(cand); matched = True; break
                if matched: break
        # ensure some basics
        if not out and have:
            for fb in ["mean", "sum", "count", "max", "min"]:
                if fb in have: out.append(fb)
        # dedupe
        seen, uniq = set(), []
        for p in out:
            if p not in seen: seen.add(p); uniq.append(p)
        return uniq
    return pick(agg_wishlist, have_agg, alias_groups_agg), pick(trans_wishlist, have_trans, alias_groups_trans)

# ---------- DFS (train -> reuse on val/test) ----------
def _dfs_feature_matrices(es_train: ft.EntitySet,
                          es_val: ft.EntitySet,
                          es_test: ft.EntitySet,
                          target_df: str):
    """
    Run Featuretools DFS on the training EntitySet and reuse the feature definitions on validation and test sets.
    """
    agg_prims, trans_prims = _resolve_ft_primitives(FT_AGG_PRIMITIVES_WISHLIST, FT_TRANS_PRIMITIVES_WISHLIST)

    fm_train, fdefs = ft.dfs(
        entityset=es_train,
        target_dataframe_name=target_df,
        agg_primitives=agg_prims,
        trans_primitives=trans_prims,
        max_depth=FT_MAX_DEPTH,
        features_only=False,
        verbose=False
    )
    fm_val = ft.calculate_feature_matrix(features=fdefs, entityset=es_val, verbose=False)
    fm_test = ft.calculate_feature_matrix(features=fdefs, entityset=es_test, verbose=False)

    # align columns
    fm_val = fm_val.reindex(columns=fm_train.columns, fill_value=np.nan)
    fm_test = fm_test.reindex(columns=fm_train.columns, fill_value=np.nan)

    return fm_train, fm_val, fm_test, fdefs


### Build Merged-Table Frames

In [31]:
def build_merged_table_frames(dataset, task, splits):
    """
    Generic merged features with Featuretools:
    - load ALL tables as pandas,
    - include TRAIN MAIN_TABLE table in PK/FK inference,
    - infer PK/FK graph automatically (unique ratio + coverage),
    - build EntitySet per split,
    - DFS on train, reuse features on val/test,
    - drop join keys from X and never include target in X,
    - return dict with (X, y, engineered_df) per split.
    """
    # --- Load DB tables (pandas) ---
    db = dataset.get_db()
    all_tables: Dict[str, pd.DataFrame] = {
        name: (tbl.df.copy() if hasattr(tbl, "df") else tbl.to_pandas())
        for name, tbl in db.table_dict.items()
    }
    all_tables = {k: _clean_for_ft(v) for k, v in all_tables.items()}

    base_name = getattr(splits["train"], "name", MAIN_TABLE)
    # include MAIN_TABLE (TRAIN) into inference so relations can touch the target df
    pop = {split: tbl.df.copy() for split, tbl in splits.items()}
    combined_for_schema = dict(all_tables)
    combined_for_schema[base_name] = _clean_for_ft(pop["train"].copy())

    # --- Automatic PK/FK inference (no manual schema) ---
    schema = _infer_pk_fk_graph_auto(combined_for_schema)
    reasons = schema.get("debug", {}).get("reasons", [])
    print(f"[Schema] inferred={len(reasons)} examples:", reasons[:8])

    # --- Prune cycles to prevent DFS recursion ---
    schema = _prune_fk_cycles(schema, base_name=base_name)

    if schema.get("debug", {}).get("reasons"):
        print("[Schema] inferred relationships (sample):", schema["debug"]["reasons"][:10])
    else:
        print("[Schema] WARNING: no relationships inferred. Consider lowering FK_COVERAGE_THRESHOLD.")

    # --- Build ES per split with the same schema ---
    es_train = _make_es_for_split(base_name, pop["train"], all_tables, schema)
    es_val   = _make_es_for_split(base_name, pop["val"],   all_tables, schema)
    es_test  = _make_es_for_split(base_name, pop["test"],  all_tables, schema)

    # --- DFS and aligned feature matrices ---
    fe_train, fe_val, fe_test, feature_defs = _dfs_feature_matrices(es_train, es_val, es_test, target_df=base_name)

    # --- Prepare (X, y) and drop keys/target from X ---
    def _drop_keys(dfX: pd.DataFrame) -> pd.DataFrame:
        if not DROP_JOIN_KEYS_FROM_X:
            return dfX
        drop_cols = set()
        pop_pk = schema["pkeys"].get(base_name)
        if pop_pk and pop_pk in dfX.columns:
            drop_cols.add(pop_pk)
        for parent, rels in schema["fkeys"].items():
            for r in rels:
                fk = r["fk"]
                if fk in dfX.columns:
                    drop_cols.add(fk)
        return dfX.drop(columns=list(drop_cols), errors="ignore")

    def _to_Xy(fe_df: pd.DataFrame, raw_df: pd.DataFrame):
        X = fe_df.copy()
        if task.target_col in X.columns:
            X = X.drop(columns=[task.target_col], errors="ignore")
        X = _drop_keys(X)
        y = raw_df[task.target_col].to_numpy()
        return X, y

    Xtr, ytr = _to_Xy(fe_train, pop["train"])
    Xva, yva = _to_Xy(fe_val,   pop["val"])
    Xte, yte = _to_Xy(fe_test,  pop["test"])

    return {
        "train": (Xtr, ytr, fe_train),
        "val":   (Xva, yva, fe_val),
        "test":  (Xte, yte, fe_test),
    }


# Merged-mode diagnostics

In [32]:
# -------- Diagnostics: single vs merged feature spaces --------
def compare_single_vs_merged(di_single, di_merged, target_name: str, label="train"):
    Xs, ys, df_s = di_single[label]
    Xm, ym, df_m = di_merged[label]

    print(f"[Diag:{label}] X_single: {Xs.shape}, X_merged: {Xm.shape}")
    print(f"[Diag:{label}] y match: {np.allclose(ys, ym)}")

    # columns present only in merged
    only_m = [c for c in Xm.columns if c not in Xs.columns]
    only_s = [c for c in Xs.columns if c not in Xm.columns]
    print(f"[Diag:{label}] new cols in merged: {len(only_m)} ; dropped vs merged: {len(only_s)}")

    # low-variance check on merged-only columns
    if only_m:
        lv = Xm[only_m].nunique(dropna=True)
        print(f"[Diag:{label}] merged-only columns nunique (head):")
        print(lv.sort_values().head(10))

    try:
        from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
        y = ys
        # pick MI type by y dtype
        if pd.api.types.is_integer_dtype(y) or pd.api.types.is_bool_dtype(y):
            mi_fun = mutual_info_classif
        else:
            mi_fun = mutual_info_regression

        # compute MI for up to 200 merged-only cols to keep it fast
        subset = only_m[:200]
        if subset:
            mi = mi_fun(pd.DataFrame(Xm[subset]).fillna(0), y, discrete_features='auto', random_state=0)
            mi_s = pd.Series(mi, index=subset).sort_values(ascending=False)
            print(f"[Diag:{label}] top merged-only MI features:")
            print(mi_s.head(10))
        else:
            print(f"[Diag:{label}] no merged-only columns to test MI.")
    except Exception as e:
        print(f"[Diag:{label}] MI calc skipped: {e}")


## Vectorization Wrapper (Version-Safe)

Initializes a `TableVectorizer` with only supported arguments for the installed `skrub` or `dirty_cat` version, ensuring compatibility. Transforms `train`, `val`, and `test` splits into numerical feature matrices, converting them to dense format if necessary.


### Helper Functions for Vectorization

In [33]:
def _make_table_vectorizer():
    """
    Create a TableVectorizer with version-safe arguments.
    """
    sig = inspect.signature(TableVectorizer.__init__)
    allowed = set(sig.parameters.keys()) - {"self"}

    tv_kwargs = {}

    # Only set kwargs that actually exist in the installed version
    if "cardinality_threshold" in allowed:
        tv_kwargs["cardinality_threshold"] = globals().get("CARDINALITY_THRESHOLD", 1000)

    # Some versions expose this; others don't, guard it
    if "high_cardinality_transformer" in allowed:
        tv_kwargs["high_cardinality_transformer"] = globals().get("HIGH_CARD_TRANSFORMER", "hashing")

    # Optional knobs if you define them globally and the version supports them
    if "text_separator" in allowed and "TEXT_SEPARATOR" in globals():
        tv_kwargs["text_separator"] = globals()["TEXT_SEPARATOR"]
    if "numerical_transformer" in allowed and "NUMERICAL_TRANSFORMER" in globals():
        tv_kwargs["numerical_transformer"] = globals()["NUMERICAL_TRANSFORMER"]
    if "categorical_transformer" in allowed and "CATEGORICAL_TRANSFORMER" in globals():
        tv_kwargs["categorical_transformer"] = globals()["CATEGORICAL_TRANSFORMER"]

    return TableVectorizer(**tv_kwargs)

def _to_dense(X):
    """
    Convert a sparse matrix to a dense NumPy array.
    """
    try:
        # scipy sparse matrices have .toarray()
        return X.toarray() if hasattr(X, "toarray") else X
    except Exception:
        return X

def vectorize_splits(X_train, X_val, X_test):
    """
    Vectorize train, val, and test splits using a TableVectorizer.
    """
    # Fit only on training data to prevent data leakage
    tv = _make_table_vectorizer()
    Xt = _to_dense(tv.fit_transform(X_train))
    Xv = _to_dense(tv.transform(X_val))
    Xs = _to_dense(tv.transform(X_test))
    return tv, Xt, Xv, Xs


### Training and Prediction Helpers

In [34]:
def _subsample(X, y, cap=TABPFN_MAX, seed=SEED):
    """
    Subsample the dataset to a maximum size defined by `cap`.
    """
    if len(X) <= cap:
        return X, y, np.arange(len(X))
    idx = np.random.RandomState(seed).choice(len(X), size=cap, replace=False)
    if hasattr(X, "iloc"):
        Xs = X.iloc[idx]
    else:
        Xs = X[idx]
    ys = y[idx]
    return Xs, ys, idx

def _fit_tabpfn(task, Xt, yt):
    """
    Fit a TabPFN model for the given task type.
    """
    if task.task_type == TaskType.REGRESSION and TabPFNRegressor is not None:
        model = TabPFNRegressor(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
            ignore_pretraining_limits=True,
        )
    else:
        model = TabPFNClassifier(
            device=DEVICE,
            #n_estimators=int(N_ESTIMATORS),
            ignore_pretraining_limits=True,
        )
    model.fit(Xt, yt)
    return model

def _predict_for_task(task, model, X):
    """
    Make predictions using the fitted model for the given task type.
    """
    # align with RelBench evaluators: AUROC expects probabilities for the positive class
    if task.task_type == TaskType.REGRESSION:
        return model.predict(X)
    proba = model.predict_proba(X)
    if task.task_type == TaskType.BINARY_CLASSIFICATION:
        return proba[:, 1]
    else:
        # multiclass/multilabel: pass full probability matrix
        return proba

### Run TabPFN on Selected Tasks

Runs TabPFN on a specified dataset and task, handling both single-table and merged-table modes. It vectorizes the data, fits the model, makes predictions, and evaluates performance using RelBench’s evaluators. Returns a dictionary with results.

In [35]:
dataset = get_dataset(DATASET, download=DOWNLOAD)

for task_name in TASKS:
    task, splits = fetch_splits(DATASET, task_name, download=DOWNLOAD)

    frames_single = build_single_table_frames(task, splits)
    frames_merged = build_merged_table_frames(dataset, task, splits)

    # Display head of train, val and test of frames_merged
    # Show all column names and count for merged train split
    merged_train_df = frames_merged["train"][2]
    print("Single-table (train) head:")
    print(frames_single["train"][2].head(5))

    print("\nMerged-table (train) head:")
    print(frames_merged["train"][2].head(5))


    #compare_single_vs_merged(frames_single, frames_merged, target_name=task.target_col, label="train")

    # ---- Diagnostics: which columns are dropped vs kept? ----

    def _cols_dropped_between(engineered_df: pd.DataFrame, X_after_drop: pd.DataFrame):
        eng_cols = list(engineered_df.columns)
        X_cols = list(X_after_drop.columns)
        dropped = [c for c in eng_cols if c not in X_cols]
        return dropped

    def show_drop_report(frames_single, frames_merged):
        print("=== SINGLE (train) ===")
        Xs, ys, df_s = frames_single["train"]
        dropped_s = _cols_dropped_between(df_s, Xs)
        print(f"Kept: {Xs.shape[1]} | Dropped: {len(dropped_s)}")
        if dropped_s:
            print("Dropped columns (single):", dropped_s[:50])  # cap for readability
        else:
            print("Dropped columns (single): []")

        print("\n=== MERGED (train) ===")
        Xm, ym, df_m = frames_merged["train"]
        dropped_m = _cols_dropped_between(df_m, Xm)
        print(f"Kept: {Xm.shape[1]} | Dropped: {len(dropped_m)}")
        if dropped_m:
            print("Dropped columns (merged):", dropped_m[:50])
        else:
            print("Dropped columns (merged): []")

    # Call it once after both frames are built:
    show_drop_report(frames_single, frames_merged)



Loading Database object from /Users/michaelflppv/Library/Caches/relbench/rel-f1/db...
Done in 0.01 seconds.
[Schema] inferred=10 examples: ['drivers.driverId <- qualifying.driverId (coverage=1.000)', 'drivers.driverId <- results.driverId (coverage=1.000)', 'races.raceId <- results.raceId (coverage=1.000)', 'drivers.driverId <- standings.driverId (coverage=1.000)', 'races.raceId <- standings.raceId (coverage=1.000)', 'circuits.circuitId <- races.circuitId (coverage=1.000)', 'races.raceId <- constructor_results.raceId (coverage=1.000)', 'constructors.constructorId <- constructor_results.constructorId (coverage=1.000)']
[Schema] inferred relationships (sample): ['drivers.driverId <- qualifying.driverId (coverage=1.000)', 'drivers.driverId <- results.driverId (coverage=1.000)', 'races.raceId <- results.raceId (coverage=1.000)', 'drivers.driverId <- standings.driverId (coverage=1.000)', 'races.raceId <- standings.raceId (coverage=1.000)', 'circuits.circuitId <- races.circuitId (coverage=1.0

In [36]:
def run_tabpfn_on_task(dataset_name: str, task_name: str, mode: str = "single") -> Dict[str, Any]:
    """
    Run TabPFN on a specified dataset and task, handling both single-table and merged-table modes.
    """
    # Load dataset and task splits
    dataset = get_dataset(dataset_name, download=DOWNLOAD)
    task, splits = fetch_splits(dataset_name, task_name, download=DOWNLOAD)

    # Ensure the task is compatible with TabPFN
    if mode == "single":
        frames = build_single_table_frames(task, splits)
    elif mode == "merged":
        try:
            frames = build_merged_table_frames(dataset, task, splits)
        except Exception as e:
            print(f"merged mode failed: {e}, falling back to single-table mode.")
            raise
    else:
        raise ValueError("mode must be 'single' or 'merged'")

    # Extract features and targets for each split
    (Xtr, ytr, _dftr) = frames["train"]
    (Xva, yva, dfva)  = frames["val"]
    (Xte, yte, dfte)  = frames["test"]

    # Vectorize
    tv, Xt, Xv, Xs = vectorize_splits(Xtr, Xva, Xte)

    # Respect TabPFN's sample cap
    Xt_cap, yt_cap, _ = _subsample(Xt, ytr, cap=TABPFN_MAX, seed=SEED)

    # Fit
    model = _fit_tabpfn(task, Xt_cap, yt_cap)

    # Predict & Evaluate with RelBench evaluators
    val_pred  = _predict_for_task(task, model, Xv)
    test_pred = _predict_for_task(task, model, Xs)

    # Align predictions with original DataFrame indices for evaluation
    val_metrics  = task.evaluate(val_pred,  splits["val"])
    test_metrics = task.evaluate(test_pred, splits["test"])

    # Convert metrics to a dictionary, ensuring all values are floats
    out = {
        "dataset": dataset_name,
        "task": task_name,
        "mode": mode,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "n_train_used": len(Xt_cap),
        "n_train_total": len(Xt),
        "n_val": len(Xv),
        "n_test": len(Xs),
    }
    return out

## Orchestrator for Benchmark Runs

Iterates over all discovered tasks and runs TabPFN in both **single** and **merged** modes. Collects performance metrics for validation and test splits into a results table, handling failures gracefully. Sorts results for easier comparison.


In [None]:
MODES = globals().get("MODES", ["single", "merged"])

records = []
failures = []

# Run TabPFN on all tasks in both modes and collect results
for task_name in TASKS:
    for mode in MODES:
        try:
            res = run_tabpfn_on_task(DATASET, task_name, mode=mode)
            # Flatten metrics for val and test, one metric per row
            for split in ["val", "test"]:
                metrics = res.get(f"{split}_metrics") or {}
                for metric_name, metric_value in metrics.items():
                    # Only add rows with non-empty metric_value
                    if metric_value is not None and not (isinstance(metric_value, float) and np.isnan(metric_value)):
                        records.append({
                            "dataset": res.get("dataset", DATASET),
                            "task": res.get("task", task_name),
                            "split": split,
                            "mode": res.get("mode", mode),
                            "method": "TabPFN_experimental_v1.0",
                            "metric": metric_name,
                            "score": metric_value,
                        })
        except Exception as e:
            msg = f"[{DATASET} | {task_name} | {mode}] failed: {e!s}"
            print(msg)
            failures.append(msg)

# Convert collected records into a DataFrame
results_df = pd.DataFrame.from_records(records)

# If no successful runs were recorded, display a message and create an empty DataFrame
if results_df.empty:
    print("No successful runs were recorded. Check the failure messages above.")
    results_df = pd.DataFrame(
        columns=["dataset", "task", "split", "mode", "method", "metric", "score"]
    )
else:
    # Ensure required sort keys exist even if some rows missed them
    for col in ["task", "mode"]:
        if col not in results_df.columns:
            results_df[col] = pd.NA
    # Sort only by the columns that exist to avoid KeyError
    sort_keys = [c for c in ["task", "mode", "metric"] if c in results_df.columns]
    if sort_keys:
        results_df = results_df.sort_values(sort_keys)

results_df

[Schema] inferred=10 examples: ['drivers.driverId <- qualifying.driverId (coverage=1.000)', 'drivers.driverId <- results.driverId (coverage=1.000)', 'races.raceId <- results.raceId (coverage=1.000)', 'drivers.driverId <- standings.driverId (coverage=1.000)', 'races.raceId <- standings.raceId (coverage=1.000)', 'circuits.circuitId <- races.circuitId (coverage=1.000)', 'races.raceId <- constructor_results.raceId (coverage=1.000)', 'constructors.constructorId <- constructor_results.constructorId (coverage=1.000)']
[Schema] inferred relationships (sample): ['drivers.driverId <- qualifying.driverId (coverage=1.000)', 'drivers.driverId <- results.driverId (coverage=1.000)', 'races.raceId <- results.raceId (coverage=1.000)', 'drivers.driverId <- standings.driverId (coverage=1.000)', 'races.raceId <- standings.raceId (coverage=1.000)', 'circuits.circuitId <- races.circuitId (coverage=1.000)', 'races.raceId <- constructor_results.raceId (coverage=1.000)', 'constructors.constructorId <- construc

### Save Results to CSV

In [None]:
# Specify output directory
out_dir = globals().get("OUT_DIR", "outputs")
os.makedirs(out_dir, exist_ok=True)

# Change timestamp format to "dd.mm.yyyy-hh:mm"
timestamp = time.strftime("%d.%m.%Y-%H:%M")
csv_name = f"tabpfn_{DATASET}_{timestamp}.csv"
csv_path = os.path.join(out_dir, csv_name)

# Round all numerical results to 4 decimal places before saving
if "score" in results_df.columns:
    if pd.api.types.is_numeric_dtype(results_df["score"]):
        results_df["score"] = results_df["score"].round(4)
    else:
        # Optionally, try to convert to numeric first
        results_df["score"] = pd.to_numeric(results_df["score"], errors="coerce").round(4)

# Filter out rows with empty score values before saving
if "score" in results_df.columns:
    results_df = results_df[results_df["score"].notnull()]

# Save the results DataFrame to a CSV file
results_df.to_csv(csv_path, index=False)
print(f"Saved results to: {csv_path}")
