In [2]:
import pandas as pd
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from tabpfn import TabPFNClassifier

def load_adni_data(csv_path: str):
    """
    从 CSV 文件中加载 ADNI 数据：
      - 将 "Group" 列当作标签 y（假设它已经是数值型或 0/1 二分类）。
      - 将以下列当作特征 X：
          ['AGE', 'PTGENDER', 'PTEDUCAT', 'PTETHCAT', 'PTRACCAT', 
           'PTMARRY', 'APOE4', 'CDRSB', 'ADAS11', 'ADAS13', 'ADASQ4', 'MMSE']
    返回：
      X: numpy 数组，形状为 (n_samples, 12)
      y: numpy 数组，形状为 (n_samples,)
    """
    df = pd.read_csv(csv_path)

    # 确保我们需要的列都存在
    feature_cols = [
        "AGE", "PTGENDER", "PTEDUCAT", "PTETHCAT", "PTRACCAT",
        "PTMARRY", "APOE4", "CDRSB", "ADAS11", "ADAS13", "ADASQ4", "MMSE"
    ]
    for col in feature_cols + ["Group"]:
        if col not in df.columns:
            raise ValueError(f"在 CSV 中找不到列：'{col}'，请检查列名是否拼写正确。")

    # 如果某些列存在空值（NaN），可以选择在这里进行填充或删除：
    # df = df.dropna(subset=feature_cols + ["Group"])        # 方法一：删除含 NaN 的行
    # df[feature_cols] = df[feature_cols].fillna(df[feature_cols].median())  # 方法二：用中位数填充

    # 提取标签
    y = df["Group"].values

    # 提取特征
    X = df[feature_cols].values

    return X, y

# ---------------------------
# 主流程示例
# ---------------------------

# 1. 读取 ADNI 数据
csv_file = "ADNI_Tabel.csv"   # 请根据实际路径修改
X, y = load_adni_data(csv_file)

# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, random_state=42
)

# 3. 初始化并训练 TabPFNClassifier
clf = TabPFNClassifier()
clf.fit(X_train, y_train)

# 4. 预测概率并计算 ROC AUC
prediction_probabilities = clf.predict_proba(X_test)
print("ROC AUC:", roc_auc_score(y_test, prediction_probabilities[:, 1]))

# 5. 预测标签并计算 Accuracy
predictions = clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, predictions))


  model, _, config_ = load_model_criterion_config(
  from .autonotebook import tqdm as notebook_tqdm
Consider using a GPU or the tabpfn-client API: https://github.com/PriorLabs/tabpfn-client


ValueError: multi_class must be in ('ovo', 'ovr')