In [12]:
# %%
import os
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    f1_score,
    mean_absolute_error,
    mean_squared_error,
    average_precision_score
)

import torch

# RelBench
from relbench.datasets import get_dataset
from relbench.tasks import get_task

# TabPFN
from tabpfn import TabPFNClassifier, TabPFNRegressor

# Device selection with MPS preference
if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

print(f"Using device: {DEVICE}")

Using device: mps


In [13]:
# %%
@contextmanager
def elapsed_timer():
    start = time.perf_counter()
    yield lambda: time.perf_counter() - start


def classification_metrics(y_true, y_pred, y_prob=None) -> Dict[str, float]:
    out = {
        "accuracy": accuracy_score(y_true, y_pred),
        "f1_macro": f1_score(y_true, y_pred, average="macro"),
    }
    if y_prob is not None:
        try:
            out["roc_auc"] = roc_auc_score(y_true, y_prob)
        except Exception:
            out["roc_auc"] = np.nan
    else:
        out["roc_auc"] = np.nan
    return out


def regression_metrics(y_true, y_pred) -> Dict[str, float]:
    return {
        "mae": mean_absolute_error(y_true, y_pred),
        "mse": mean_squared_error(y_true, y_pred),
    }


def to_pandas(table) -> pd.DataFrame:
    if hasattr(table, "to_pandas"):
        return table.to_pandas()
    if hasattr(table, "df"):
        return table.df if isinstance(table.df, pd.DataFrame) else pd.DataFrame(table.df)
    return pd.DataFrame(table)


def coerce_datetime(df: pd.DataFrame, col: str) -> pd.DataFrame:
    if col in df.columns:
        if not np.issubdtype(df[col].dtype, np.datetime64):
            df[col] = pd.to_datetime(df[col], errors="coerce")
    return df


def first_existing(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
    for c in candidates:
        if c in df.columns:
            return c
    return None


def merge_asof_by_group(left_idx: pd.DataFrame,
                        right_feat: pd.DataFrame,
                        key_col: str,
                        date_col: str = "date") -> pd.DataFrame:
    """
    A leak-safe asof join: for each group by key_col, join the last known
    feature snapshot at or before 'date' (shift should be applied in the feature builder).
    """
    if right_feat.empty:
        return pd.DataFrame(index=left_idx.index)

    parts = []
    for k, g_left in left_idx.groupby(key_col, sort=False):
        g_right = right_feat[right_feat[key_col] == k]
        if g_right.empty:
            parts.append(pd.DataFrame(index=g_left.index))
            continue
        merged = pd.merge_asof(
            g_left.sort_values(date_col),
            g_right.sort_values(date_col),
            on=date_col, direction="backward"
        )
        merged.index = g_left.sort_values(date_col).index  # restore original row order
        # drop keys/date from the right side (they duplicate)
        drop_cols = [c for c in [key_col, date_col] if c in merged.columns]
        merged = merged.drop(columns=[c for c in drop_cols if c in merged.columns], errors="ignore")
        parts.append(merged.sort_index())

    return pd.concat(parts, axis=0).sort_index()


In [14]:
# %%
dataset = get_dataset("rel-avito")
db = dataset.get_db()

def load_all_tables(db) -> Dict[str, pd.DataFrame]:
    names = db.table_dict
    out = {}
    for n in names:
        out[n] = to_pandas(db.table_dict[n])
        # Limit each table to 1000 samples
        if len(out[n]) > 1000:
            out[n] = out[n].sample(n=1000, random_state=42).reset_index(drop=True)
    return out

tables = load_all_tables(db)
list(tables.keys())

['Location',
 'PhoneRequestsStream',
 'SearchInfo',
 'VisitStream',
 'Category',
 'SearchStream',
 'UserInfo',
 'AdsInfo']

In [15]:
# %%
# Normalize date columns for relevant tables (Avito schema)
# We will standardize to 'date' column names inside our engineered frames.
date_mappings = {
    "SearchInfo": ["SearchDate"],
    "SearchStream": ["SearchDate"],
    "VisitStream": ["ViewDate"],
    "PhoneRequestsStream": ["PhoneRequestDate"]
}

for tname, df in tables.items():
    # coerce known date columns if exist
    if tname in date_mappings:
        for c in date_mappings[tname]:
            if c in df.columns:
                df[c] = pd.to_datetime(df[c], errors="coerce")
    # also coerce any column containing 'date' substring (robustness)
    for c in df.columns:
        if "date" in c.lower():
            df[c] = pd.to_datetime(df[c], errors="coerce")


In [16]:
# %%
TASK_SPECS = {
    "user-visits":    {"kind": "clf",  "join": "user"},
    "user-clicks":    {"kind": "clf",  "join": "user"},
    "ad-ctr":         {"kind": "reg",  "join": "ad"},
    "user-ad-visit":  {"kind": "link", "join": "user-ad"}  # probabilities → MAP@k via task.evaluate
}

TASKS = list(TASK_SPECS.keys())
TASKS


['user-visits', 'user-clicks', 'ad-ctr', 'user-ad-visit']

In [17]:
# %%
def run_single_table_experiment(task_name: str):
    spec = TASK_SPECS[task_name]
    task = get_task("rel-avito", task_name)

    # Create a training, validation, and test splits
    train_table = task.get_table("train")
    val_table = task.get_table("val")
    test_table = task.get_table("test", mask_input_cols=False)

    # Input features need to be numeric, otherwise encoded
    df = train_table.df
    if len(df) > 1000:
        df = df.sample(n=1000, random_state=42).reset_index(drop=True)
    X_train = df.drop(columns=[task.target_col]).select_dtypes(include=[np.number])
    y_train = df[task.target_col]

    df = val_table.df
    if len(df) > 1000:
        df = df.sample(n=1000, random_state=42).reset_index(drop=True)
    X_val = df.drop(columns=[task.target_col]).select_dtypes(include=[np.number])
    y_val = df[task.target_col]

    df = test_table.df
    if len(df) > 1000:
        df = df.sample(n=1000, random_state=42).reset_index(drop=True)
    X_test = df.drop(columns=[task.target_col]).select_dtypes(include=[np.number])
    y_test = df[task.target_col]

    # Choose model
    if spec["kind"] == "reg":
        model = TabPFNRegressor(device=DEVICE)
        is_reg = True
    else:
        model = TabPFNClassifier(device=DEVICE)
        is_reg = False

    # Fit
    with elapsed_timer() as t:
        model.fit(X_train, y_train)
    fit_time = t()

    # Predict val/test
    with elapsed_timer() as t:
        y_val_pred = model.predict(X_val)
    pred_time_val = t()

    with elapsed_timer() as t:
        y_test_pred = model.predict(X_test)
    pred_time_test = t()

    # Probabilities for clf/link
    prob_val = prob_test = None
    if not is_reg:
        try:
            proba_val  = model.predict_proba(X_val)
            proba_test = model.predict_proba(X_test)
            # binary -> keep positive class prob
            if proba_val.ndim == 2 and proba_val.shape[1] == 2:
                prob_val = proba_val[:, 1]
                prob_test = proba_test[:, 1]
            else:
                # multiclass or already 1D
                prob_val = proba_val
                prob_test = proba_test
        except Exception:
            prob_val = prob_test = None

    # Secondary metrics
    if is_reg:
        val_sec  = {**regression_metrics(y_val, y_val_pred)}
        test_sec = {**regression_metrics(y_test, y_test_pred)}
    else:
        val_sec  = {**classification_metrics(y_val, y_val_pred, prob_val)}
        test_sec = {**classification_metrics(y_test, y_test_pred, prob_test)}

    # Primary metric via RelBench
    if spec["kind"] in ["clf", "link"]:
        # RelBench expects probabilities/scores
        val_primary  = task.evaluate(prob_val if prob_val is not None else y_val_pred, "val")
        test_primary = task.evaluate(prob_test if prob_test is not None else y_test_pred)
    else:
        # regression expects numeric predictions
        val_primary  = task.evaluate(y_val_pred, "val")
        test_primary = task.evaluate(y_test_pred)

    res = {
        "val":  {**val_sec,  "fit_time": fit_time, "predict_time": pred_time_val,  "primary_metric_relbench": val_primary},
        "test": {**test_sec, "fit_time": fit_time, "predict_time": pred_time_test, "primary_metric_relbench": test_primary},
    }
    return res


In [18]:
# %%
def build_user_timeseries(tables: Dict[str, pd.DataFrame]) -> pd.DataFrame:
    """
    Returns per-user historical features keyed by ['UserID','date'].
    Features are cumulative/expanding and SHIFTED by 1 to avoid look-ahead leakage.
    """
    # --- Source 1: VisitStream (views)
    vs = tables.get("VisitStream", pd.DataFrame()).copy()
    if not vs.empty:
        vs = vs.dropna(subset=["UserID"])
        date_col = first_existing(vs, ["ViewDate", "date", "Date", "viewDate"])
        if date_col is None:
            vs = pd.DataFrame(columns=["UserID", "AdID", "date"])
        else:
            vs = coerce_datetime(vs, date_col)
            vs = vs.rename(columns={date_col: "date"})
            required_vs_cols = ["UserID", "AdID", "date"]
            existing_vs_cols = [c for c in required_vs_cols if c in vs.columns]
            vs = vs[existing_vs_cols].dropna(subset=["date"]) if "date" in existing_vs_cols else pd.DataFrame(columns=required_vs_cols)
    else:
        vs = pd.DataFrame(columns=["UserID", "AdID", "date"])

    # --- Source 2: SearchStream (clicks)
    ss = tables.get("SearchStream", pd.DataFrame()).copy()
    if not ss.empty:
        if "UserID" in ss.columns:
            ss = ss.dropna(subset=["UserID"])
        date_col = first_existing(ss, ["SearchDate", "date", "Date"])
        if date_col is None:
            ss = pd.DataFrame(columns=["UserID", "AdID", "IsClick", "date"])
        else:
            ss = coerce_datetime(ss, date_col)
            ss = ss.rename(columns={date_col: "date"})
            if "IsClick" not in ss.columns:
                ss["IsClick"] = np.nan
            required_ss_cols = ["UserID", "AdID", "IsClick", "date"]
            existing_ss_cols = [c for c in required_ss_cols if c in ss.columns]
            ss = ss[existing_ss_cols].dropna(subset=["date"]) if "date" in existing_ss_cols else pd.DataFrame(columns=required_ss_cols)
    else:
        ss = pd.DataFrame(columns=["UserID", "AdID", "IsClick", "date"])

    # --- Source 3: PhoneRequestsStream (optional activity proxy)
    prs = tables.get("PhoneRequestsStream", pd.DataFrame()).copy()
    if not prs.empty and "UserID" in prs.columns:
        date_col = first_existing(prs, ["PhoneRequestDate", "date", "Date"])
        if date_col is not None:
            prs = coerce_datetime(prs, date_col)
            prs = prs.rename(columns={date_col: "date"})
            required_prs_cols = ["UserID", "AdID", "date"]
            existing_prs_cols = [c for c in required_prs_cols if c in prs.columns]
            prs = prs[existing_prs_cols].dropna(subset=["date"]) if "date" in existing_prs_cols else pd.DataFrame(columns=required_prs_cols)
        else:
            prs = pd.DataFrame(columns=["UserID", "AdID", "date"])
    else:
        prs = pd.DataFrame(columns=["UserID", "AdID", "date"])

    # Build a combined event log per user (per-day granularity)
    # Views
    if all(c in vs.columns for c in ["UserID", "date", "AdID"]):
        vs["view_cnt"] = 1
        vs_agg = vs.groupby(["UserID", "date"]).agg(view_cnt=("view_cnt", "sum"),
                                                    ad_visited_n=("AdID", "nunique")).reset_index()
    else:
        vs_agg = pd.DataFrame(columns=["UserID", "date", "view_cnt", "ad_visited_n"])
    # Clicks
    if all(c in ss.columns for c in ["UserID", "date", "AdID"]):
        if "IsClick" in ss.columns:
            ss["click_cnt"] = (ss["IsClick"] == 1).astype(int)
        else:
            ss["click_cnt"] = 0
        ss_agg = ss.groupby(["UserID", "date"]).agg(click_cnt=("click_cnt", "sum"),
                                                    ad_clicked_n=("AdID", "nunique")).reset_index()
    else:
        ss_agg = pd.DataFrame(columns=["UserID", "date", "click_cnt", "ad_clicked_n"])
    # Phone reqs
    if all(c in prs.columns for c in ["UserID", "date", "AdID"]):
        prs["phone_cnt"] = 1
        prs_agg = prs.groupby(["UserID", "date"]).agg(phone_cnt=("phone_cnt", "sum"),
                                                      ad_phone_n=("AdID", "nunique")).reset_index()
    else:
        prs_agg = pd.DataFrame(columns=["UserID", "date", "phone_cnt", "ad_phone_n"])

    # Merge daily aggregates
    daily = (vs_agg.merge(ss_agg, on=["UserID", "date"], how="outer")
                  .merge(prs_agg, on=["UserID", "date"], how="outer"))
    daily = daily.sort_values(["UserID", "date"]).fillna(0)

    # Expanding historical features, shifted by 1
    def expanding_shifted(g: pd.DataFrame) -> pd.DataFrame:
        out = pd.DataFrame(index=g.index)
        for c in ["view_cnt", "ad_visited_n", "click_cnt", "ad_clicked_n", "phone_cnt", "ad_phone_n"]:
            if c in g:
                out[f"usr_cum_{c}"] = g[c].cumsum().shift(1)
                out[f"usr_avg_{c}"] = g[c].expanding().mean().shift(1)
        # Activity rates
        if "click_cnt" in g and "view_cnt" in g:
            rate = (g["click_cnt"].replace(0, np.nan) / g["view_cnt"].replace(0, np.nan))
            out["usr_rate_click_per_view"] = rate.expanding().mean().shift(1)
        return out

    feats = daily.groupby("UserID", group_keys=False).apply(expanding_shifted)
    feats = pd.concat([daily[["UserID", "date"]], feats], axis=1)
    return feats


def build_ad_timeseries(tables: Dict[str, pd.DataFrame]) -> pd.DataFrame:
    """
    Returns per-ad historical features keyed by ['AdID','date'].
    """
    # Views
    vs = tables.get("VisitStream", pd.DataFrame()).copy()
    if not vs.empty and "AdID" in vs.columns:
        date_col = first_existing(vs, ["ViewDate", "date", "Date"])
        if date_col is not None:
            vs = coerce_datetime(vs, date_col).rename(columns={date_col: "date"})
            required_vs_cols = ["AdID", "UserID", "date"]
            existing_vs_cols = [c for c in required_vs_cols if c in vs.columns]
            vs = vs[existing_vs_cols].dropna(subset=["date"]) if "date" in existing_vs_cols else pd.DataFrame(columns=required_vs_cols)
        else:
            vs = pd.DataFrame(columns=["AdID", "UserID", "date"])
    else:
        vs = pd.DataFrame(columns=["AdID", "UserID", "date"])

    # Clicks (SearchStream)
    ss = tables.get("SearchStream", pd.DataFrame()).copy()
    if not ss.empty and "AdID" in ss.columns:
        date_col = first_existing(ss, ["SearchDate", "date", "Date"])
        if date_col is not None:
            ss = coerce_datetime(ss, date_col).rename(columns={date_col: "date"})
            if "IsClick" not in ss.columns:
                ss["IsClick"] = 0
            required_ss_cols = ["AdID", "UserID", "IsClick", "date"]
            existing_ss_cols = [c for c in required_ss_cols if c in ss.columns]
            ss = ss[existing_ss_cols].dropna(subset=["date"]) if "date" in existing_ss_cols else pd.DataFrame(columns=required_ss_cols)
        else:
            ss = pd.DataFrame(columns=["AdID", "UserID", "IsClick", "date"])
    else:
        ss = pd.DataFrame(columns=["AdID", "UserID", "IsClick", "date"])

    # Aggregate per day per Ad
    if all(c in vs.columns for c in ["AdID", "date", "UserID"]):
        vs["view_cnt"] = 1
        vs_agg = vs.groupby(["AdID", "date"]).agg(view_cnt=("view_cnt", "sum"),
                                                  user_view_n=("UserID", "nunique")).reset_index()
    else:
        vs_agg = pd.DataFrame(columns=["AdID", "date", "view_cnt", "user_view_n"])

    if all(c in ss.columns for c in ["AdID", "date", "UserID", "IsClick"]):
        ss["click_cnt"] = (ss["IsClick"] == 1).astype(int)
        ss_agg = ss.groupby(["AdID", "date"]).agg(click_cnt=("click_cnt", "sum"),
                                                  user_click_n=("UserID", "nunique")).reset_index()
    else:
        ss_agg = pd.DataFrame(columns=["AdID", "date", "click_cnt", "user_click_n"])

    daily = vs_agg.merge(ss_agg, on=["AdID", "date"], how="outer").sort_values(["AdID", "date"]).fillna(0)

    # Expanding and CTR
    def expanding_shifted(g: pd.DataFrame) -> pd.DataFrame:
        out = pd.DataFrame(index=g.index)
        for c in ["view_cnt", "user_view_n", "click_cnt", "user_click_n"]:
            if c in g:
                out[f"ad_cum_{c}"] = g[c].cumsum().shift(1)
                out[f"ad_avg_{c}"] = g[c].expanding().mean().shift(1)
        # Historical CTR
        if "click_cnt" in g and "view_cnt" in g:
            cum_click = g["click_cnt"].cumsum().shift(1)
            cum_view  = g["view_cnt"].cumsum().shift(1)
            out["ad_hist_ctr"] = (cum_click / cum_view).replace([np.inf, -np.inf], np.nan)
        return out

    feats = daily.groupby("AdID", group_keys=False).apply(expanding_shifted)
    feats = pd.concat([daily[["AdID", "date"]], feats], axis=1)

    # Attach static ad attributes (no time leakage)
    ads = tables.get("AdsInfo", pd.DataFrame()).copy()
    if not ads.empty and "AdID" in ads.columns:
        ads_static = ads.drop_duplicates("AdID")
        # Optional: enrich with Category / Location
        cat = tables.get("Category", pd.DataFrame()).copy()
        loc = tables.get("Location", pd.DataFrame()).copy()
        if not cat.empty and "CategoryID" in ads_static.columns:
            ads_static = ads_static.merge(cat.add_prefix("cat_"), left_on="CategoryID", right_on="cat_CategoryID", how="left")
        if not loc.empty and "LocationID" in ads_static.columns:
            ads_static = ads_static.merge(loc.add_prefix("loc_"), left_on="LocationID", right_on="loc_LocationID", how="left")
        feats = feats.merge(ads_static, on="AdID", how="left")
    return feats


In [19]:
# Compute cumulative user and ad features and shift them to avoid leakage.
# This will be used to enrich task indices later.
user_feats = build_user_timeseries(tables)
ad_feats   = build_ad_timeseries(tables)

user_feats.head(3), ad_feats.head(3)


  daily = daily.sort_values(["UserID", "date"]).fillna(0)
  feats = daily.groupby("UserID", group_keys=False).apply(expanding_shifted)
  daily = vs_agg.merge(ss_agg, on=["AdID", "date"], how="outer").sort_values(["AdID", "date"]).fillna(0)
  feats = daily.groupby("AdID", group_keys=False).apply(expanding_shifted)


(   UserID                date  usr_cum_view_cnt  usr_avg_view_cnt  \
 0      24 2015-05-13 22:53:35               NaN               NaN   
 1      31 2015-05-02 21:13:01               NaN               NaN   
 2      50 2015-04-29 14:17:19               NaN               NaN   
 
    usr_cum_ad_visited_n  usr_avg_ad_visited_n  usr_cum_click_cnt  \
 0                   NaN                   NaN                NaN   
 1                   NaN                   NaN                NaN   
 2                   NaN                   NaN                NaN   
 
    usr_avg_click_cnt  usr_cum_ad_clicked_n  usr_avg_ad_clicked_n  \
 0                NaN                   NaN                   NaN   
 1                NaN                   NaN                   NaN   
 2                NaN                   NaN                   NaN   
 
    usr_cum_phone_cnt  usr_avg_phone_cnt  usr_cum_ad_phone_n  \
 0                NaN                NaN                 NaN   
 1                NaN             

In [20]:
# %%
def index_to_df(idx: np.ndarray, mode: str) -> pd.DataFrame:
    """
    Convert task index array to a DataFrame with standardized column names.
    mode:
      - 'user'     -> expects (UserID, date)
      - 'ad'       -> expects (AdID, date)
      - 'user-ad'  -> expects (UserID, AdID, date)
    """
    arr = np.array(idx)
    if arr.ndim == 1:
        arr = arr.reshape(-1, 1)
    if mode == "user":
        assert arr.shape[1] == 2, f"Expected 2 columns in index for 'user', got {arr.shape[1]}"
        df = pd.DataFrame(arr, columns=["UserID", "date"])
    elif mode == "ad":
        assert arr.shape[1] == 2, f"Expected 2 columns in index for 'ad', got {arr.shape[1]}"
        df = pd.DataFrame(arr, columns=["AdID", "date"])
    elif mode == "user-ad":
        assert arr.shape[1] == 3, f"Expected 3 columns in index for 'user-ad', got {arr.shape[1]}"
        df = pd.DataFrame(arr, columns=["UserID", "AdID", "date"])
    else:
        raise ValueError(mode)
    df["date"] = pd.to_datetime(df["date"], errors="coerce")
    return df


def enrich_split(X: pd.DataFrame, idx: np.ndarray, join_mode: str) -> pd.DataFrame:
    """
    Given original task features X and its index, attach engineered features.
    """
    idx_df = index_to_df(idx, join_mode)
    base = X.reset_index(drop=True).copy()

    if join_mode == "user":
        # join user features by (UserID, date) asof within user
        left = idx_df[["UserID", "date"]].copy()
        uf = merge_asof_by_group(left, user_feats, key_col="UserID", date_col="date")
        out = pd.concat([base, uf.reset_index(drop=True)], axis=1)

    elif join_mode == "ad":
        left = idx_df[["AdID", "date"]].copy()
        af = merge_asof_by_group(left, ad_feats, key_col="AdID", date_col="date")
        out = pd.concat([base, af.reset_index(drop=True)], axis=1)

    elif join_mode == "user-ad":
        # Merge user and ad streams separately, then concat
        left_u = idx_df[["UserID", "date"]].copy()
        left_a = idx_df[["AdID", "date"]].copy()
        uf = merge_asof_by_group(left_u, user_feats, key_col="UserID", date_col="date")
        af = merge_asof_by_group(left_a, ad_feats,   key_col="AdID",   date_col="date")
        # Prefix to avoid collisions
        uf = uf.add_prefix("u_")
        af = af.add_prefix("a_")
        out = pd.concat([base, uf.reset_index(drop=True), af.reset_index(drop=True)], axis=1)
    else:
        raise ValueError(join_mode)

    return out


In [21]:
# %%
def run_merged_table_experiment(task_name: str):
    spec = TASK_SPECS[task_name]
    task = get_task("rel-avito", task_name)

    train_table = task.get_table("train")
    val_table = task.get_table("val")
    test_table = task.get_table("test", mask_input_cols=False)

    df_train = train_table.df
    if len(df_train) > 1000:
        df_train = df_train.sample(n=1000, random_state=42).reset_index(drop=True)
    df_val = val_table.df
    if len(df_val) > 1000:
        df_val = df_val.sample(n=1000, random_state=42).reset_index(drop=True)
    df_test = test_table.df
    if len(df_test) > 1000:
        df_test = df_test.sample(n=1000, random_state=42).reset_index(drop=True)

    # Determine index columns based on join mode
    if spec["join"] == "user":
        idx_train = list(zip(df_train["UserID"], df_train["date"]))
        idx_val   = list(zip(df_val["UserID"], df_val["date"]))
        idx_test  = list(zip(df_test["UserID"], df_test["date"]))
    elif spec["join"] == "ad":
        idx_train = list(zip(df_train["AdID"], df_train["date"]))
        idx_val   = list(zip(df_val["AdID"], df_val["date"]))
        idx_test  = list(zip(df_test["AdID"], df_test["date"]))
    elif spec["join"] == "user-ad":
        idx_train = list(zip(df_train["UserID"], df_train["AdID"], df_train["date"]))
        idx_val   = list(zip(df_val["UserID"], df_val["AdID"], df_val["date"]))
        idx_test  = list(zip(df_test["UserID"], df_test["AdID"], df_test["date"]))
    else:
        raise ValueError(f"Unknown join mode: {spec['join']}")

    X_train = df_train.drop(columns=[task.target_col]).select_dtypes(include=[np.number])
    y_train = df_train[task.target_col]

    X_val = df_val.drop(columns=[task.target_col]).select_dtypes(include=[np.number])
    y_val = df_val[task.target_col]

    X_test = df_test.drop(columns=[task.target_col]).select_dtypes(include=[np.number])
    y_test = df_test[task.target_col]

    # Choose model
    if spec["kind"] == "reg":
        model = TabPFNRegressor(device=DEVICE)
        is_reg = True
    else:
        model = TabPFNClassifier(device=DEVICE)
        is_reg = False

    # Train
    with elapsed_timer() as t:
        model.fit(X_train, y_train)
    fit_time = t()

    # Predict val/test
    with elapsed_timer() as t:
        y_val_pred = model.predict(X_val)
    pred_time_val = t()

    with elapsed_timer() as t:
        y_test_pred = model.predict(X_test)
    pred_time_test = t()

    # Probabilities for clf/link
    prob_val = prob_test = None
    if not is_reg:
        try:
            proba_val  = model.predict_proba(X_val)
            proba_test = model.predict_proba(X_test)
            if proba_val.ndim == 2 and proba_val.shape[1] == 2:
                prob_val  = proba_val[:, 1]
                prob_test = proba_test[:, 1]
            else:
                prob_val, prob_test = proba_val, proba_test
        except Exception:
            prob_val = prob_test = None

    # Secondary metrics
    if is_reg:
        val_sec  = {**regression_metrics(y_val, y_val_pred)}
        test_sec = {**regression_metrics(y_test, y_test_pred)}
    else:
        val_sec  = {**classification_metrics(y_val, y_val_pred, prob_val)}
        test_sec = {**classification_metrics(y_test, y_test_pred, prob_test)}

    # Primary via RelBench
    if spec["kind"] in ["clf", "link"]:
        val_primary  = task.evaluate(prob_val if prob_val is not None else y_val_pred, "val")
        test_primary = task.evaluate(prob_test if prob_test is not None else y_test_pred)
    else:
        val_primary  = task.evaluate(y_val_pred, "val")
        test_primary = task.evaluate(y_test_pred)

    res = {
        "val":  {**val_sec,  "fit_time": fit_time, "predict_time": pred_time_val,  "primary_metric_relbench": val_primary},
        "test": {**test_sec, "fit_time": fit_time, "predict_time": pred_time_test, "primary_metric_relbench": test_primary},
    }
    return res


In [22]:
# %%
# Directory to save tables and results
OUTPUT_DIR = "avito_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

all_rows = []

for task_name in TASKS:
    print(f"\n=== {task_name} | Single-Table ===")
    single = run_single_table_experiment(task_name)
    for split, metrics in single.items():
        # Print and save single-table split
        print(f"Task: {task_name}, Setting: single, Split: {split}")
        # Save the corresponding table
        table = get_task("rel-avito", task_name).get_table(split).df
        table_path = os.path.join(OUTPUT_DIR, f"{task_name}_single_{split}.csv")
        print(table.head())
        table.to_csv(table_path, index=False)
        all_rows.append({"task": task_name, "setting": "single", "split": split, **metrics})

    print(f"=== {task_name} | Merged-Table ===")
    merged = run_merged_table_experiment(task_name)
    for split, metrics in merged.items():
        # Print and save merged-table split
        print(f"Task: {task_name}, Setting: merged, Split: {split}")
        # Save the corresponding table
        table = get_task("rel-avito", task_name).get_table(split).df
        table_path = os.path.join(OUTPUT_DIR, f"{task_name}_merged_{split}.csv")
        print(table.head())
        table.to_csv(table_path, index=False)
        all_rows.append({"task": task_name, "setting": "merged", "split": split, **metrics})

results_df = pd.DataFrame(all_rows).sort_values(["task", "setting", "split"]).reset_index(drop=True)
results_csv_path = os.path.join(OUTPUT_DIR, "all_results.csv")
results_df.to_csv(results_csv_path, index=False)
results_df



=== user-visits | Single-Table ===


AttributeError: 'str' object has no attribute 'df'

In [None]:
# %%
def plot_metric(metric: str, title: Optional[str] = None):
    sub = results_df[(results_df["split"] == "test") & results_df[metric].notna()]
    if sub.empty:
        print(f"No data to plot for {metric}")
        return
    pivot = sub.pivot(index="task", columns="setting", values=metric)
    ax = pivot.plot(kind="bar", figsize=(9, 4))
    ax.set_ylabel(metric)
    ax.set_title(title or metric)
    ax.grid(True, axis="y")
    plt.tight_layout()
    plt.show()

# Classification/Link relevant (AUROC approximated via classification_metrics; primary in 'primary_metric_relbench')
plot_metric("roc_auc", "Test AUROC (classification/link)")
plot_metric("f1_macro", "Test F1 Macro (classification/link)")

# Regression
plot_metric("mae", "Test MAE (regression)")
plot_metric("mse", "Test MSE (regression)")

# Timing
plot_metric("fit_time", "Fit Time (s)")
plot_metric("predict_time", "Predict Time (s)")
