In [1]:
import os
import json
import torch
import numpy as np
from copy import deepcopy
from sklearn.metrics import roc_auc_score

from method.group_dro import run_group_dro
from method.group_test import run_group_test
from method.group_dro_focal import run_group_dro_focal
from models import node, tabnet, tabtrans

import numpy as np
import torch
import json
from utils.data_loader import load_data

from sklearn.metrics import (
    f1_score, recall_score, precision_score,
    roc_auc_score, brier_score_loss
)

MODEL_CLASSES = {
    'node': node.NODE,
    'tabnet': tabnet.TabNet,
    #'tabtrans' : tabtrans.TabTransformer
}

METHODS = {
    'group-dro' : run_group_dro,
    'group-test' : run_group_test,
    'group-dro-focal' : run_group_dro_focal
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def load_default_config(model_type, dataset, default_dir="experiments"):
    path = os.path.join(default_dir, f"{model_type}_{dataset}.json")
    with open(path, "r") as f:
        return json.load(f), path

def save_as_default(config, path):
    with open(path, "w") as f:
        json.dump(config, f, indent=2)

def collect_predictions(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    all_groups = []

    with torch.no_grad():
        for batch in dataloader:
            if model.__class__.__name__.lower() != "tabtransformer":
                x, y, g, *_ = batch
                x = x.to(device)
                outputs = model(x).squeeze()
            else:
                x_cat, x_num, y, g, *_ = batch
                x_cat = x_cat.to(device)
                x_num = x_num.to(device)
                outputs = model((x_cat, x_num)).squeeze()

            preds = torch.sigmoid(outputs).cpu().numpy()

            if preds.ndim == 0:
                all_preds.extend([preds])
            else:
                all_preds.extend(preds)

            all_labels.extend(y.numpy())
            all_groups.extend(g.numpy())

    return np.array(all_labels), np.array(all_preds), np.array(all_groups)

def generate_group_dro_optuna_runner(grid, model_type, dataset, load_data, MODEL_CLASSES, METHODS,
                                     default_dir="experiments", n_trials=30):
    import optuna

    config, default_path = load_default_config(model_type, dataset)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_class = MODEL_CLASSES[model_type]

    def to_optuna_space(grid):
        space = {}
        for k, v in grid.items():
            if isinstance(v, list):
                space[k] = {"type": "categorical", "values": v}
            elif isinstance(v, dict) and "low" in v and "high" in v:
                space[k] = v
        return space

    param_space = to_optuna_space(grid)

    def optuna_objective(trial):
        trial_config = deepcopy(config)
        trial_config["method"] = 'group-dro-focal'

        for param_name, param_def in param_space.items():
            if param_def["type"] == "categorical":
                suggested = trial.suggest_categorical(param_name, param_def["values"])
            elif param_def["type"] == "int":
                suggested = trial.suggest_int(param_name, param_def["low"], param_def["high"])
            elif param_def["type"] == "float":
                suggested = trial.suggest_float(param_name, param_def["low"], param_def["high"],
                                                 log=param_def.get("log", False))
            else:
                raise ValueError(f"Unknown param type: {param_def['type']}")

            if param_name in trial_config["model_params"]:
                trial_config["model_params"][param_name] = suggested
            elif param_name in trial_config["train_params"]:
                trial_config["train_params"][param_name] = suggested

        if "group_dro_eta" in param_space:
            trial_config["train_params"]["group_dro_eta"] = trial.suggest_float(
                "group_dro_eta", param_space["group_dro_eta"]["low"], param_space["group_dro_eta"]["high"], log=True
            )
        if "gamma" in param_space:
            trial_config["train_params"]["gamma"] = trial.suggest_float(
                "gamma", param_space["gamma"]["low"], param_space["gamma"]["high"]
            )

        train_loader, valid_loader, test_loader, train_df = load_data(trial_config)
        model = model_class(**trial_config["model_params"]).to(device)
        method_fn = METHODS[trial_config["method"]]
        try:
            model = method_fn(model, train_loader, valid_loader, test_loader, train_df,
                               trial_config["train_params"], device, dataset, trial_config["method"])
        except ValueError as e:
            print(f"Optuna Trial failed during method execution with ValueError: {e}")
            raise
        except TypeError as e:
            print(f"Optuna Trial failed during method execution with TypeError: {e}")
            raise

        y_true, y_prob, group_ids = collect_predictions(model, valid_loader, device)

        try:
            overall_auc = roc_auc_score(y_true, y_prob)
        except ValueError:
            overall_auc = 0.0

        group_aucs = []
        for g in np.unique(group_ids):
            mask = group_ids == g
            if np.sum(mask) > 1:
                try:
                    group_auc = roc_auc_score(y_true[mask], y_prob[mask])
                    group_aucs.append(group_auc)
                except ValueError:
                    pass
        worst_group_auc = min(group_aucs) if group_aucs else 0.0

        trial.set_user_attr("config", deepcopy(trial_config))
        trial.set_user_attr("overall_auc", overall_auc)
        trial.set_user_attr("worst_group_auc", worst_group_auc)

        return worst_group_auc

    study = optuna.create_study(direction="maximize")
    study.optimize(optuna_objective, n_trials=n_trials)

    best_trial = study.best_trial
    best_config = best_trial.user_attrs["config"]
    save_as_default(best_config, default_path.replace(".json", "_optuna_best_groupdrofocal.json"))

    train_loader, valid_loader, test_loader, train_df = load_data(best_config)
    model_class = MODEL_CLASSES[best_config["model_type"]]
    best_model = model_class(**best_config["model_params"]).to(device)
    method_fn = METHODS[best_config["method"]]
    best_group_dro_model = method_fn(best_model, train_loader, valid_loader, test_loader, train_df,
                                      best_config["train_params"], device, dataset, best_config["method"])

    return best_group_dro_model, best_config, test_loader, valid_loader

def get_probs_and_labels_from_loader(model, loader, device):
    model.eval()
    probs, labels = [], []

    is_tabtrans = model.__class__.__name__.lower() == "tabtransformer"

    with torch.no_grad():
        for batch in loader:
            if is_tabtrans:
                x_cat, x_num, y, *_ = batch
                x_cat = x_cat.to(device)
                x_num = x_num.to(device)
                x = (x_cat, x_num)
            else:
                x, y, *_ = batch
                x = x.to(device)

            output = model(x)
            
            prob = torch.sigmoid(output).flatten().cpu().numpy()
            
            if prob.ndim == 0:
                probs.extend([prob])
            else:
                probs.extend(prob)

            if y.ndim == 0:
                labels.extend([y.numpy()])
            else:
                labels.extend(y.numpy())

    return np.array(probs), np.array(labels)

def evaluate_group_metrics(model, test_loader, device, threshold):
    model.eval()
    model.to(device)

    all_logits = []
    all_preds = []
    all_labels = []
    all_groups = []

    is_tabtrans = model.__class__.__name__.lower() == "tabtransformer"

    with torch.no_grad():
        for batch in test_loader:
            if is_tabtrans:
                x_cat, x_num, y, g, *_ = batch
                x_cat = x_cat.to(device)
                x_num = x_num.to(device)
                x = (x_cat, x_num)
            else:
                x, y, g, *_ = batch
                x = x.to(device)

            outputs = model(x)  # raw logits
            probs = torch.sigmoid(outputs).flatten().cpu().numpy()
            preds = (probs > threshold).astype(int)

            all_logits.append(probs)
            all_preds.append(preds)
            all_labels.append(y.numpy())
            all_groups.append(g.numpy())

    y_prob = np.concatenate(all_logits)
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_labels)
    group_ids = np.concatenate(all_groups)
    total = len(y_true)

    # 전체 metric
    auc = roc_auc_score(y_true, y_prob)
    f1 = f1_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    brier = brier_score_loss(y_true, y_prob)

    print(f"전체 AUC:     {auc:.4f}")
    print(f"전체 F1-score: {f1:.4f}")
    print(f"전체 Recall:   {recall:.4f}")
    print(f"전체 Precision: {precision:.4f}")
    print(f"전체 Brier Score: {brier:.4f}")

    # 그룹별 요약
    print("\n그룹별 성능 요약:")
    print(f"{'Group':>6} | {'Ratio (%)':>9} | {'AUC':>6} | {'F1':>6} | {'Recall':>7} | {'Precision':>9}")
    print("-" * 60)

    for g in np.unique(group_ids):
        idx = group_ids == g
        group_size = np.sum(idx)
        ratio = group_size / total * 100

        y_true_g = y_true[idx]
        y_pred_g = y_pred[idx]
        y_prob_g = y_prob[idx]

        f1_g = f1_score(y_true_g, y_pred_g, zero_division=0)
        recall_g = recall_score(y_true_g, y_pred_g, zero_division=0)
        precision_g = precision_score(y_true_g, y_pred_g, zero_division=0)

        try:
            auc_g = roc_auc_score(y_true_g, y_prob_g)
        except ValueError:
            auc_g = float('nan')

        print(f"{g:>6} | {ratio:9.2f} | {auc_g:6.4f} | {f1_g:6.4f} | {recall_g:7.4f} | {precision_g:9.4f}")

def find_best_threshold_for_f1(y_prob, y_true, num_thresholds=100):
    thresholds = np.linspace(0.0, 1.0, num_thresholds)
    best_f1 = 0.0
    best_threshold = 0.5

    for t in thresholds:
        y_pred = (y_prob > t).astype(int)
        f1 = f1_score(y_true, y_pred)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = t

    return best_threshold, best_f1

def predict_with_threshold(model, x, threshold=0.5):
    probs = model.predict_proba(x)[:, 1]
    preds = (probs > threshold).astype(int)
    return preds

import logging
import os

def setup_logger(log_file="log.txt"):
    # 이전 핸들러 제거
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)

    logging.basicConfig(
        level=logging.INFO,
        format="%(message)s",
        handlers=[
            logging.FileHandler(log_file, mode='w'),
            logging.StreamHandler()  # 콘솔 출력도 유지
        ]
    )

In [None]:
setup_logger("log.txt")  # 로그 파일 경로는 필요시 변경 가능


grid = {
    "lr": {"type": "float", "low": 5e-5, "high": 5e-2, "log": True},
    "hidden_dim": {"type": "categorical", "values": [32, 64, 128, 256, 512]},
    "batch_size": {"type": "categorical", "values": [64, 128, 256, 512]},
    "dropout": {"type": "float", "low": 0.0, "high": 0.6},
    "num_trees": {"type": "categorical", "values": [5, 10, 20, 50]},
    "depth": {"type": "categorical", "values": [2, 3, 4, 5, 6]},
    "group_dro_eta": {"type": "float", "low": 1e-4, "high": 1.0, "log": True},
    "gamma": {"type": "float", "low": 0.5, "high": 5.0}  # focal loss용 gamma
}

model_type = "node"
dataset = "compas"

best_group_dro_model, best_config_groupdro, test_loader, valid_loader = generate_group_dro_optuna_runner(
    grid, model_type, dataset, load_data, MODEL_CLASSES, METHODS, default_dir="experiments", n_trials=30
)

print("\n[Optuna Best Group-DRO-FOCAL 모델 설정]")
print(best_config_groupdro)

print("\n[Optuna Best Group-DRO-FOCAL 모델] 성능 요약")
v_prob, y_valid = get_probs_and_labels_from_loader(best_group_dro_model, valid_loader, device)
threshold, _ = find_best_threshold_for_f1(v_prob, y_valid)
print(f"threshold 균형점 : {threshold:.4f}")

print("[BEST OVERALL F1 MODEL]")
evaluate_group_metrics(best_group_dro_model, test_loader, device, threshold)

save_path = os.path.join("experiments", f"{dataset}_{model_type}_best_groupdro.pt")
torch.save(best_group_dro_model.state_dict(), save_path)
print(f"[모델 저장 완료] 경로: {save_path}")

  from .autonotebook import tqdm as notebook_tqdm
[I 2025-04-23 04:03:44,187] A new study created in memory with name: no-name-65dd7a49-e18f-465e-a3df-65ce80e3d8fe


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:08:26,857] Trial 0 finished with value: 0.75 and parameters: {'lr': 0.0002023193058420545, 'dropout': 0.1335399072929882, 'group_dro_eta': 0.0022980339693153425, 'gamma': 3.9357577390211262}. Best is trial 0 with value: 0.75.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:13:04,960] Trial 1 finished with value: 0.7184677848940203 and parameters: {'lr': 5.471743571196684e-05, 'dropout': 0.11030067016034806, 'group_dro_eta': 0.0013669896815474303, 'gamma': 0.8112643026771709}. Best is trial 0 with value: 0.75.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:17:43,705] Trial 2 finished with value: 0.75 and parameters: {'lr': 0.003493948350749193, 'dropout': 0.5644180721629695, 'group_dro_eta': 0.007691746179081383, 'gamma': 2.253885459688668}. Best is trial 0 with value: 0.75.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:22:13,428] Trial 3 finished with value: 0.7752039897757352 and parameters: {'lr': 0.0005360314164563458, 'dropout': 0.5914260451253025, 'group_dro_eta': 0.0005576478149393788, 'gamma': 3.7253418030064576}. Best is trial 3 with value: 0.7752039897757352.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:26:44,180] Trial 4 finished with value: 0.7716733975033964 and parameters: {'lr': 0.0002015361230452119, 'dropout': 0.3802126138373412, 'group_dro_eta': 0.0004276352612141488, 'gamma': 4.128186547028964}. Best is trial 3 with value: 0.7752039897757352.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:31:26,648] Trial 5 finished with value: 0.7142857142857143 and parameters: {'lr': 0.001107083894437871, 'dropout': 0.1762367812563273, 'group_dro_eta': 0.519150690230916, 'gamma': 2.337118180779862}. Best is trial 3 with value: 0.7752039897757352.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:36:16,694] Trial 6 finished with value: 0.6071428571428571 and parameters: {'lr': 0.0021611281529741663, 'dropout': 0.33122156155376153, 'group_dro_eta': 0.0003924459898764599, 'gamma': 1.280218624630705}. Best is trial 3 with value: 0.7752039897757352.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:41:02,273] Trial 7 finished with value: 0.6785714285714286 and parameters: {'lr': 0.0005109405221234593, 'dropout': 0.41556773400407787, 'group_dro_eta': 0.00021409938406979285, 'gamma': 3.8194885640468756}. Best is trial 3 with value: 0.7752039897757352.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:45:33,127] Trial 8 finished with value: 0.7142857142857143 and parameters: {'lr': 0.00013534320882105602, 'dropout': 0.4129238226747064, 'group_dro_eta': 0.008391672817975475, 'gamma': 0.9404861687539267}. Best is trial 3 with value: 0.7752039897757352.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:50:23,360] Trial 9 finished with value: 0.7766461644554009 and parameters: {'lr': 0.00046151077843629886, 'dropout': 0.18790338025921602, 'group_dro_eta': 0.030896889158341035, 'gamma': 1.752928352264792}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 04:55:23,110] Trial 10 finished with value: 0.75 and parameters: {'lr': 0.02959430418275461, 'dropout': 0.0035077203602027762, 'group_dro_eta': 0.18138627708277455, 'gamma': 1.9032705854307748}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:00:18,528] Trial 11 finished with value: 0.769567430183054 and parameters: {'lr': 0.006524770803657266, 'dropout': 0.57573683979985, 'group_dro_eta': 0.049232151610464195, 'gamma': 4.990763663690373}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:05:11,606] Trial 12 finished with value: 0.7747745376162928 and parameters: {'lr': 0.000728466871569426, 'dropout': 0.25009278062716206, 'group_dro_eta': 0.03551089963198225, 'gamma': 3.0523930497119047}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:10:05,774] Trial 13 finished with value: 0.7739522812581296 and parameters: {'lr': 0.0004919062281324476, 'dropout': 0.4835397599051927, 'group_dro_eta': 0.03328284942058121, 'gamma': 3.2799288859650773}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:15:01,155] Trial 14 finished with value: 0.6785714285714286 and parameters: {'lr': 0.006607468890422558, 'dropout': 0.2490091389300607, 'group_dro_eta': 0.002083045677526339, 'gamma': 1.550970416815159}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:19:51,295] Trial 15 finished with value: 0.712882842147426 and parameters: {'lr': 5.173998823245207e-05, 'dropout': 0.022328493954502232, 'group_dro_eta': 0.12711703187007933, 'gamma': 4.928563202073606}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:24:42,878] Trial 16 finished with value: 0.7752039897757352 and parameters: {'lr': 0.00030133745866941174, 'dropout': 0.4956743662113922, 'group_dro_eta': 0.014140321730536812, 'gamma': 2.707666757840456}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:29:30,069] Trial 17 finished with value: 0.7142857142857143 and parameters: {'lr': 0.0018025602748350967, 'dropout': 0.2261793993060413, 'group_dro_eta': 0.6959477966151384, 'gamma': 3.519768496834458}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:34:24,553] Trial 18 finished with value: 0.760686751813816 and parameters: {'lr': 0.03827826427568225, 'dropout': 0.31859173244938704, 'group_dro_eta': 0.000907505651889065, 'gamma': 2.660356335371092}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:39:17,606] Trial 19 finished with value: 0.7142857142857143 and parameters: {'lr': 0.00012180935934072672, 'dropout': 0.1145005221778721, 'group_dro_eta': 0.00010855229066248647, 'gamma': 4.350575177234067}. Best is trial 9 with value: 0.7766461644554009.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:44:06,816] Trial 20 finished with value: 0.7790318950815339 and parameters: {'lr': 0.001009607067434045, 'dropout': 0.485756501781085, 'group_dro_eta': 0.014520102298759453, 'gamma': 1.8654728024300191}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:48:44,349] Trial 21 finished with value: 0.7679363378467103 and parameters: {'lr': 0.0007749916682089353, 'dropout': 0.5070993039686049, 'group_dro_eta': 0.004998043115358829, 'gamma': 1.7501904397342378}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:53:31,232] Trial 22 finished with value: 0.5 and parameters: {'lr': 0.0003500780755183821, 'dropout': 0.5882363171849518, 'group_dro_eta': 0.01815062869615154, 'gamma': 2.101148724023526}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 05:58:19,476] Trial 23 finished with value: 0.75 and parameters: {'lr': 0.0011931282863755534, 'dropout': 0.5158308062205791, 'group_dro_eta': 0.0937403296862684, 'gamma': 0.51310947919051}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 06:03:08,182] Trial 24 finished with value: 0.75 and parameters: {'lr': 0.001817873783750137, 'dropout': 0.47027594037725995, 'group_dro_eta': 0.003927009638107821, 'gamma': 1.4396892210576837}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 06:07:55,675] Trial 25 finished with value: 0.75 and parameters: {'lr': 0.004918802256821143, 'dropout': 0.5412458276067147, 'group_dro_eta': 0.019845295579257224, 'gamma': 2.49693880282903}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 06:12:41,424] Trial 26 finished with value: 0.7593498755001673 and parameters: {'lr': 0.012497909214371767, 'dropout': 0.5969473104060409, 'group_dro_eta': 0.32413536016164846, 'gamma': 2.9571098547277903}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 06:17:28,774] Trial 27 finished with value: 0.7142857142857143 and parameters: {'lr': 0.0006816136832197541, 'dropout': 0.45544825466078803, 'group_dro_eta': 0.0716246412209951, 'gamma': 1.1545968718368682}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 06:22:17,643] Trial 28 finished with value: 0.6071428571428571 and parameters: {'lr': 0.0027015717150735198, 'dropout': 0.3808107597178971, 'group_dro_eta': 0.005256050098417118, 'gamma': 1.9013534075848488}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...


[I 2025-04-23 06:27:13,209] Trial 29 finished with value: 0.7705367945525645 and parameters: {'lr': 0.0003202904411882953, 'dropout': 0.17390656098556764, 'group_dro_eta': 0.0028554217530128716, 'gamma': 3.611350575475785}. Best is trial 20 with value: 0.7790318950815339.


🔥 Running GroupDRO Method on cuda...

[Optuna Best Group-DRO-FOCAL 모델 설정]
{'dataset': 'compas', 'model_type': 'node', 'method': 'group-dro-focal', 'model_params': {'input_dim': 34, 'hidden_dim': 64, 'num_trees': 7, 'depth': 2, 'num_classes': 1}, 'train_params': {'epochs': 30, 'lr': 0.001009607067434045, 'group_dro_eta': 0.014520102298759453, 'gamma': 1.8654728024300191}}

[Optuna Best Group-DRO-FOCAL 모델] 성능 요약
threshold 균형점 : 0.4545
[BEST OVERALL F1 MODEL]
전체 AUC:     0.7807
전체 F1-score: 0.7046
전체 Recall:   0.7952
전체 Precision: 0.6324
전체 Brier Score: 0.1991

그룹별 성능 요약:
 Group | Ratio (%) |    AUC |     F1 |  Recall | Precision
------------------------------------------------------------
     0 |     53.53 | 0.7624 | 0.7329 |  0.8460 |    0.6465
     1 |      0.44 | 0.9750 | 0.7143 |  0.6250 |    0.8333
     2 |     33.23 | 0.7739 | 0.6683 |  0.7263 |    0.6188
     3 |      8.00 | 0.7790 | 0.6026 |  0.6741 |    0.5449
     4 |      0.21 | 0.9000 | 0.7500 |  1.0000 |    0.6000
     5 | 

In [None]:
def find_groupwise_thresholds_from_loader(model, valid_loader, device):
    model.eval()
    model.to(device)

    all_probs = []
    all_labels = []
    all_groups = []

    is_tabtrans = model.__class__.__name__.lower() == "tabtransformer"

    with torch.no_grad():
        for batch in valid_loader:
            if is_tabtrans:
                x_cat, x_num, y, g, *_ = batch
                x = (x_cat.to(device), x_num.to(device))
            else:
                x, y, g, *_ = batch
                x = x.to(device)

            logits = model(x)
            probs = torch.sigmoid(logits).flatten().cpu().numpy()
            all_probs.append(probs)
            all_labels.append(y.numpy())
            all_groups.append(g.numpy())

    # concat all data
    y_prob = np.concatenate(all_probs)
    y_true = np.concatenate(all_labels)
    group_ids = np.concatenate(all_groups)

    # calculate per-group threshold
    thresholds_by_group = {}
    for g in np.unique(group_ids):
        idx = group_ids == g
        best_t, _ = find_best_threshold_for_f1(y_prob[idx], y_true[idx])
        thresholds_by_group[g] = best_t

    return thresholds_by_group

import logging

def evaluate_group_metrics_per_threshold(model, test_loader, device, thresholds_by_group):
    model.eval()
    model.to(device)

    all_logits, all_labels, all_groups = [], [], []
    is_tabtrans = model.__class__.__name__.lower() == "tabtransformer"

    with torch.no_grad():
        for batch in test_loader:
            if is_tabtrans:
                x_cat, x_num, y, g, *_ = batch
                x_cat = x_cat.to(device)
                x_num = x_num.to(device)
                x = (x_cat, x_num)
            else:
                x, y, g, *_ = batch
                x = x.to(device)

            probs = torch.sigmoid(model(x)).flatten().cpu().numpy()
            all_logits.append(probs)
            all_labels.append(y.numpy())
            all_groups.append(g.numpy())

    y_prob = np.concatenate(all_logits)
    y_true = np.concatenate(all_labels)
    group_ids = np.concatenate(all_groups)

    logging.info("\n그룹별 성능 (개별 threshold):")
    logging.info(f"{'Group':>6} | {'Thresh.':>7} | {'AUC':>6} | {'F1':>6} | {'Recall':>7} | {'Precision':>9}")
    logging.info("-" * 60)

    for g in np.unique(group_ids):
        idx = group_ids == g
        t = thresholds_by_group[g]
        y_pred = (y_prob[idx] > t).astype(int)

        try:
            auc_g = roc_auc_score(y_true[idx], y_prob[idx])
        except ValueError:
            auc_g = float('nan')

        f1_g = f1_score(y_true[idx], y_pred, zero_division=0)
        recall_g = recall_score(y_true[idx], y_pred, zero_division=0)
        precision_g = precision_score(y_true[idx], y_pred, zero_division=0)

        logging.info(f"{g:>6} | {t:7.4f} | {auc_g:6.4f} | {f1_g:6.4f} | {recall_g:7.4f} | {precision_g:9.4f}")


In [5]:
threshold = find_groupwise_thresholds_from_loader(best_group_dro_model, valid_loader,device)

print("[Optuna Best Group-DRO-FOCAL 모델] 성능 요약")
evaluate_group_metrics_per_threshold(best_group_dro_model, test_loader, device, threshold)

[Optuna Best Group-DRO-FOCAL 모델] 성능 요약

그룹별 성능 (개별 threshold):
 Group | Thresh. |    AUC |     F1 |  Recall | Precision
------------------------------------------------------------
     0 |  0.4444 | 0.7624 | 0.7454 |  0.8834 |    0.6447
     1 |  0.3838 | 0.9750 | 0.7273 |  1.0000 |    0.5714
     2 |  0.4646 | 0.7739 | 0.6662 |  0.6974 |    0.6378
     3 |  0.5152 | 0.7790 | 0.6215 |  0.5778 |    0.6724
     4 |  0.4242 | 0.9000 | 0.7500 |  1.0000 |    0.6000
     5 |  0.4242 | 0.7888 | 0.6306 |  0.7778 |    0.5303
