In [1]:
!pip install tab-transformer-pytorch

Collecting tab-transformer-pytorch
  Downloading tab_transformer_pytorch-0.4.2-py3-none-any.whl.metadata (914 bytes)
Collecting einops>=0.8 (from tab-transformer-pytorch)
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting hyper-connections>=0.1.15 (from tab-transformer-pytorch)
  Downloading hyper_connections-0.1.15-py3-none-any.whl.metadata (5.2 kB)
Downloading tab_transformer_pytorch-0.4.2-py3-none-any.whl (7.2 kB)
Downloading einops-0.8.1-py3-none-any.whl (64 kB)
Downloading hyper_connections-0.1.15-py3-none-any.whl (15 kB)
Installing collected packages: einops, hyper-connections, tab-transformer-pytorch
Successfully installed einops-0.8.1 hyper-connections-0.1.15 tab-transformer-pytorch-0.4.2
[0m

In [4]:
import torch
from utils.data_loader_trans import load_data_trans

config = {
  "dataset": "compas",
  "train_params": {
  }
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, valid_loader, test_loader, train_df = load_data_trans(config)

def dataloader_to_numpy_trans(dataloader):
    all_x_categ, all_x_numer, all_y, all_groups = [], [], [], []

    for x_categ, x_numer, y, group, *_ in dataloader:
        all_x_categ.append(x_categ)
        all_x_numer.append(x_numer)
        all_y.append(y)
        all_groups.append(group)

    x_categ = torch.cat(all_x_categ).numpy()
    x_numer = torch.cat(all_x_numer).numpy()
    y = torch.cat(all_y).numpy()
    groups = torch.cat(all_groups).numpy()

    return x_categ, x_numer, y, groups

x_categ, x_cont, y, groups = dataloader_to_numpy_trans(train_loader)

x_categ_tensor = torch.tensor(x_categ, dtype=torch.long)
x_cont_tensor = torch.tensor(x_cont, dtype=torch.float)
y_tensor = torch.tensor(y, dtype=torch.float).unsqueeze(1)

In [8]:
print(x_categ_tensor.shape)

torch.Size([8572, 7])


In [9]:
import torch
import torch.nn as nn
from tab_transformer_pytorch import TabTransformer

model = TabTransformer(
    categories = (3, 13, 10, 1, 4, 1, 4),     # config["model_params"]["categories"]
    num_continuous = 7,                      # config["model_params"]["num_continuous"]
    dim = 64,                                # config["model_params"]["dim"]
    dim_out = 1,                             # config["model_params"]["num_classes"]
    depth = 4,                               # config["model_params"]["depth"]
    heads = 8,                               # config["model_params"]["heads"]
    dim_head = 16,                           # config["model_params"]["dim_head"]
    attn_dropout = 0.1,
    ff_dropout = 0.1536,                     # config["model_params"]["dropout"]
    mlp_hidden_mults = (4, 2),
    mlp_act = nn.ReLU(),
    continuous_mean_std = torch.tensor([[0.0, 1.0]] * 7)  # 연속형 변수 7개에 대해 평균 0, 표준편차 1 (정규화 안하면 None)
)

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = nn.BCEWithLogitsLoss()

model.train()
for epoch in range(30):
    optimizer.zero_grad()
    output = model(x_categ_tensor, x_cont_tensor)
    loss = criterion(output, y_tensor)
    loss.backward()
    optimizer.step()

In [19]:
from sklearn.metrics import (
    f1_score, recall_score, precision_score,
    roc_auc_score, brier_score_loss
)
import numpy as np
import torch
import torch.nn.functional as F

def evaluate_group_metrics(model, test_loader, device, threshold=0.5):
    model.eval()
    model.to(device)

    all_probs = []
    all_preds = []
    all_labels = []
    all_groups = []

    is_tabtrans = model.__class__.__name__.lower() == "tabtransformer"

    with torch.no_grad():
        for batch in test_loader:
            if is_tabtrans:
                x_cat, x_num, y, g, *_ = batch
                x_cat, x_num = x_cat.to(device), x_num.to(device)
                output = model(x_cat, x_num)
            else:
                x, y, g, *_ = batch
                x = x.to(device)
                output = model(x)

            prob = torch.sigmoid(output).squeeze(-1).cpu().numpy()
            pred = (prob > threshold).astype(int)

            all_probs.append(prob)
            all_preds.append(pred)
            all_labels.append(y.numpy())
            all_groups.append(g.numpy())

    y_prob = np.concatenate(all_probs)
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_labels)
    group_ids = np.concatenate(all_groups)
    total = len(y_true)

    # 전체 성능
    auc = roc_auc_score(y_true, y_prob)
    f1 = f1_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    brier = brier_score_loss(y_true, y_prob)

    print(f"전체 AUC:       {auc:.4f}")
    print(f"전체 F1-score:  {f1:.4f}")
    print(f"전체 Recall:    {recall:.4f}")
    print(f"전체 Precision: {precision:.4f}")
    print(f"전체 Brier Score: {brier:.4f}")

    # 그룹별 요약
    print("\n그룹별 성능 요약:")
    print(f"{'Group':>6} | {'Ratio (%)':>9} | {'AUC':>6} | {'F1':>6} | {'Recall':>7} | {'Precision':>9}")
    print("-" * 60)

    for g in np.unique(group_ids):
        idx = group_ids == g
        ratio = np.mean(idx) * 100

        y_true_g = y_true[idx]
        y_pred_g = y_pred[idx]
        y_prob_g = y_prob[idx]

        f1_g = f1_score(y_true_g, y_pred_g, zero_division=0)
        recall_g = recall_score(y_true_g, y_pred_g, zero_division=0)
        precision_g = precision_score(y_true_g, y_pred_g, zero_division=0)

        try:
            auc_g = roc_auc_score(y_true_g, y_prob_g)
        except ValueError:
            auc_g = float('nan')

        print(f"{g:>6} | {ratio:9.2f} | {auc_g:6.4f} | {f1_g:6.4f} | {recall_g:7.4f} | {precision_g:9.4f}")


In [16]:
def get_probs_and_labels_from_loader(model, loader, device):
    model.eval()
    model.to(device)
    probs, labels = [], []

    is_tabtrans = model.__class__.__name__.lower() == "tabtransformer"

    with torch.no_grad():
        for batch in loader:
            if is_tabtrans:
                x_cat, x_num, y, *_ = batch
                x_cat = x_cat.to(device)
                x_num = x_num.to(device)
                output = model(x_cat, x_num)
            else:
                x, y, *_ = batch
                x = x.to(device)
                output = model(x)

            prob = torch.sigmoid(output).squeeze().cpu().numpy()
            probs.append(prob)
            labels.append(y.numpy())

    return np.concatenate(probs), np.concatenate(labels)

from sklearn.metrics import f1_score
import numpy as np

def find_best_threshold_for_f1(y_prob, y_true, num_thresholds=100):
    thresholds = np.linspace(0.0, 1.0, num_thresholds)
    best_f1 = 0.0
    best_threshold = 0.5

    for t in thresholds:
        y_pred = (y_prob > t).astype(int)
        f1 = f1_score(y_true, y_pred)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = t

    return best_threshold, best_f1

def predict_with_threshold(model, loader, device, threshold=0.5):
    model.eval()
    preds = []

    is_tabtrans = model.__class__.__name__.lower() == "tabtransformer"

    with torch.no_grad():
        for batch in loader:
            if is_tabtrans:
                x_cat, x_num, *_ = batch
                x_cat = x_cat.to(device)
                x_num = x_num.to(device)
                output = model(x_cat, x_num)
            else:
                x, *_ = batch
                x = x.to(device)
                output = model(x)

            prob = torch.sigmoid(output).squeeze().cpu().numpy()
            pred = (prob > threshold).astype(int)
            preds.append(pred)

    return np.concatenate(preds)

In [17]:
v_prob, y_valid = get_probs_and_labels_from_loader(model, valid_loader, device)
threshold, _ = find_best_threshold_for_f1(v_prob, y_valid)

In [20]:
evaluate_group_metrics(model, test_loader, device, threshold)

전체 AUC:       0.7263
전체 F1-score:  0.6853
전체 Recall:    0.7968
전체 Precision: 0.6011
전체 Brier Score: 0.2089

그룹별 성능 요약:
 Group | Ratio (%) |    AUC |     F1 |  Recall | Precision
------------------------------------------------------------
     0 |     53.53 | 0.6943 | 0.7149 |  0.8585 |    0.6124
     1 |      0.44 | 0.8750 | 0.6667 |  0.7500 |    0.6000
     2 |     33.23 | 0.7380 | 0.6519 |  0.7158 |    0.5985
     3 |      8.00 | 0.7459 | 0.6095 |  0.7111 |    0.5333
     4 |      0.21 | 0.9667 | 0.8000 |  1.0000 |    0.6667
     5 |      4.59 | 0.6726 | 0.5263 |  0.5556 |    0.5000


In [21]:
!pip install tabpfn

Collecting tabpfn
  Downloading tabpfn-2.0.8-py3-none-any.whl.metadata (25 kB)
Collecting huggingface-hub<1,>=0.0.1 (from tabpfn)
  Using cached huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Downloading tabpfn-2.0.8-py3-none-any.whl (128 kB)
Using cached huggingface_hub-0.30.2-py3-none-any.whl (481 kB)
Installing collected packages: huggingface-hub, tabpfn
Successfully installed huggingface-hub-0.30.2 tabpfn-2.0.8
[0m

In [27]:
def dataloader_to_numpy(dataloader):
    all_x, all_y, all_groups = [], [], []

    for batch in dataloader:
        x, y, group, *_ = batch 
        all_x.append(x)
        all_y.append(y)
        all_groups.append(group)

    X = torch.cat(all_x).numpy()
    y = torch.cat(all_y).numpy()
    groups = torch.cat(all_groups).numpy()
    
    return X, y, groups

In [28]:
import pandas as pd
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from tabpfn import TabPFNClassifier

# CSV 경로 설정
from utils.data_loader import load_data

config = {
  "dataset": "compas",
  "train_params": {
  }
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader, valid_loader, test_loader, train_df = load_data(config)

x_train, y_train, group_train = dataloader_to_numpy(train_loader)
x_valid, y_valid, group_vlaid = dataloader_to_numpy(valid_loader)
x_test, y_test, group_test = dataloader_to_numpy(test_loader)

# TabPFN 모델 초기화 및 학습
clf = TabPFNClassifier(device="cuda" if torch.cuda.is_available() else "cpu")
clf.fit(x_train, y_train)

# 예측 및 평가
pred_probs = clf.predict_proba(x_test)
preds = clf.predict(x_test)

print("ROC AUC:", roc_auc_score(y_test, pred_probs[:, 1]))
print("Accuracy:", accuracy_score(y_test, preds))

ROC AUC: 0.8419099024439889
Accuracy: 0.7637645265764907


In [29]:
def get_probs_and_labels_tabpfn(model, X_valid, y_valid):
    prob = model.predict_proba(X_valid)[:, 1]
    return prob, y_valid

def evaluate_group_metrics_tabpfn(model, X_test, y_test, group_ids, threshold=0.5):
    probs = model.predict_proba(X_test)[:, 1]
    preds = (probs > threshold).astype(int)

    auc = roc_auc_score(y_test, probs)
    f1 = f1_score(y_test, preds)
    recall = recall_score(y_test, preds)
    precision = precision_score(y_test, preds)
    brier = brier_score_loss(y_test, probs)

    print(f"전체 AUC:       {auc:.4f}")
    print(f"전체 F1-score:  {f1:.4f}")
    print(f"전체 Recall:    {recall:.4f}")
    print(f"전체 Precision: {precision:.4f}")
    print(f"전체 Brier Score: {brier:.4f}")

    # 그룹별 성능
    print("\n그룹별 성능 요약:")
    print(f"{'Group':>6} | {'Ratio (%)':>9} | {'AUC':>6} | {'F1':>6} | {'Recall':>7} | {'Precision':>9}")
    print("-" * 60)

    total = len(y_test)
    for g in np.unique(group_ids):
        idx = group_ids == g
        y_true_g = y_test[idx]
        y_pred_g = preds[idx]
        y_prob_g = probs[idx]
        ratio = 100 * len(y_true_g) / total

        f1_g = f1_score(y_true_g, y_pred_g, zero_division=0)
        recall_g = recall_score(y_true_g, y_pred_g, zero_division=0)
        precision_g = precision_score(y_true_g, y_pred_g, zero_division=0)
        try:
            auc_g = roc_auc_score(y_true_g, y_prob_g)
        except ValueError:
            auc_g = float('nan')

        print(f"{g:>6} | {ratio:9.2f} | {auc_g:6.4f} | {f1_g:6.4f} | {recall_g:7.4f} | {precision_g:9.4f}")


In [30]:

v_prob, y_valid = get_probs_and_labels_tabpfn(clf, x_valid, y_valid)
threshold, _ = find_best_threshold_for_f1(v_prob, y_valid)


# 4. 평가
evaluate_group_metrics_tabpfn(clf, x_test, y_test, group_test, threshold)

전체 AUC:       0.8419
전체 F1-score:  0.7458
전체 Recall:    0.7576
전체 Precision: 0.7344
전체 Brier Score: 0.1602

그룹별 성능 요약:
 Group | Ratio (%) |    AUC |     F1 |  Recall | Precision
------------------------------------------------------------
     0 |     53.53 | 0.8415 | 0.7769 |  0.8283 |    0.7315
     1 |      0.44 | 0.9500 | 0.7692 |  0.6250 |    1.0000
     2 |     33.23 | 0.8342 | 0.7079 |  0.6632 |    0.7590
     3 |      8.00 | 0.8211 | 0.6466 |  0.6370 |    0.6565
     4 |      0.21 | 1.0000 | 0.8571 |  1.0000 |    0.7500
     5 |      4.59 | 0.8052 | 0.6038 |  0.5333 |    0.6957
