In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
# get drug features using Deepchem library
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import argparse
import random, sys
import numpy as np
import csv
from scipy import stats
import time
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from sklearn import metrics
from sklearn.metrics import roc_auc_score
from sklearn import preprocessing
import pandas as pd

# Set CUDA_VISIBLE_DEVICES before importing TensorFlow and Keras
def parse_arguments():
    parser = argparse.ArgumentParser(description='Chemical_Genotoxicity_pre')
    parser.add_argument('-gpu_id', dest='gpu_id', type=str, default='0', help='GPU devices')
    parser.add_argument('-israndom', dest='israndom', type=bool, default=False, help='randomlize X and A')
    # Hyperparameters for GCN
    parser.add_argument('-unit_list', dest='unit_list', nargs='+', type=int, default=[256, 256, 256],
                        help='unit list for GCN')
    parser.add_argument('-use_bn', dest='use_bn', type=bool, default=True, help='use batch normalization for GCN')
    parser.add_argument('-use_relu', dest='use_relu', type=bool, default=True, help='use relu for GCN')
    parser.add_argument('-use_GMP', dest='use_GMP', type=bool, default=True, help='use GlobalMaxPooling for GCN')
    #fold수 설정 =30
    parser.add_argument('-n_splits', dest='n_splits', type=int, default=20, help='Number of cross-validation folds')
    parser.add_argument('-random_state', dest='random_state', type=int, default=1, help='Random state for reproducibility')
    
    # parse_known_args를 사용하여 알 수 없는 인수를 무시
    args, unknown = parser.parse_known_args()
    return args

args = parse_arguments()

# Set CUDA_VISIBLE_DEVICES to use GPU0
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

# Now import TensorFlow and Keras after setting the environment variable
import tensorflow.compat.v1 as tf
from keras import backend as K
from keras.models import Model, Sequential
from keras.models import load_model
from keras.layers import Input, InputLayer, Multiply, ZeroPadding2D
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Dense, Activation, Dropout, Flatten, Concatenate
from keras.layers import BatchNormalization
from keras.layers import Lambda
from keras import optimizers, utils
from keras.constraints import max_norm
from keras import regularizers
from keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, History, CSVLogger, ReduceLROnPlateau
from keras.utils import plot_model
from keras.optimizers import Adam, SGD
from keras.models import model_from_json
from sklearn.metrics import average_precision_score
from scipy.stats import pearsonr
from Car_model import KerasMultiSourceGCNModel
import hickle as hkl
import scipy.sparse as sp
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
import matplotlib.pyplot as plt

# Set seeds and TensorFlow configurations
random.seed(0)
tf.set_random_seed(0)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
np.random.seed(0)
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

israndom = args.israndom
GCN_deploy = '_'.join(map(str, args.unit_list)) + '_' + ('bn' if args.use_bn else 'no_bn') + '_' + ('relu' if args.use_relu else 'tanh') + '_' + ('GMP' if args.use_GMP else 'GAP')
model_suffix = GCN_deploy

####################################Constants Settings###########################
Drug_feature_file = r'/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/CarcGC_data/drug_graph_feat/'
Max_atoms = 100

def DataGenerate(Drug_feature_file):
    # load drug features
    drug_pubchem_id_set = []
    all_drug_feature = {}
    for each in os.listdir(Drug_feature_file):
        # 디렉토리가 아닌 파일만 처리
        if os.path.isfile(os.path.join(Drug_feature_file, each)):
            drug_pubchem_id_set.append(each.split('.')[0])
            feat_mat, adj_list, degree_list = hkl.load(os.path.join(Drug_feature_file, each))
            all_drug_feature[each.split('.')[0]] = [feat_mat, adj_list, degree_list]
    assert len(drug_pubchem_id_set) == len(all_drug_feature.values())
    return all_drug_feature

def MetadataGenerate(all_drug_feature, train_csv_path):
    label = pd.read_csv(train_csv_path, index_col=None, header=0)
    drugnames = label['pert_id'].tolist()
    label['Carcinogenicity_label'] = (label['Carcinogenicity'] == '+').astype(int)
    data_idx = list(zip(label['pert_id'], label['Carcinogenicity_label']))
    nb_drugs = len(set([item[0] for item in data_idx]))
    # Ensure that all drugnames exist in all_drug_feature
    valid_drugnames = [drug for drug in drugnames if drug in all_drug_feature]
    if len(valid_drugnames) < len(drugnames):
        missing_drugs = set(drugnames) - set(valid_drugnames)
        print(f"Warning: The following pert_id are missing in drug_graph_feat and will be skipped: {missing_drugs}")
    drug_feature = {key: all_drug_feature[key] for key in valid_drugnames}
    data_idx = [item for item in data_idx if item[0] in valid_drugnames]
    return drug_feature, data_idx

def ValidationGenerate(all_drug_feature, val_csv_path):
    label = pd.read_csv(val_csv_path, index_col=None, header=0)
    drugnames = label['pert_id'].tolist()
    label['Carcinogenicity_label'] = (label['Carcinogenicity'] == '+').astype(int)
    data_idx = list(zip(label['pert_id'], label['Carcinogenicity_label']))
    nb_drugs = len(set([item[0] for item in data_idx]))
    # Ensure that all drugnames exist in all_drug_feature
    valid_drugnames = [drug for drug in drugnames if drug in all_drug_feature]
    if len(valid_drugnames) < len(drugnames):
        missing_drugs = set(drugnames) - set(valid_drugnames)
        print(f"Warning: The following pert_id are missing in drug_graph_feat and will be skipped: {missing_drugs}")
    drug_feature = {key: all_drug_feature[key] for key in valid_drugnames}
    data_idx = [item for item in data_idx if item[0] in valid_drugnames]
    return drug_feature, data_idx

def ExtraGenerate(all_drug_feature, test_csv_path):
    label = pd.read_csv(test_csv_path, index_col=None, header=0)
    drugnames = label['pert_id'].tolist()
    label['Carcinogenicity_label'] = (label['Carcinogenicity'] == '+').astype(int)
    data_idx = list(zip(label['pert_id'], label['Carcinogenicity_label']))
    nb_drugs = len(set([item[0] for item in data_idx]))
    # Ensure that all drugnames exist in all_drug_feature
    valid_drugnames = [drug for drug in drugnames if drug in all_drug_feature]
    if len(valid_drugnames) < len(drugnames):
        missing_drugs = set(drugnames) - set(valid_drugnames)
        print(f"Warning: The following pert_id are missing in drug_graph_feat and will be skipped: {missing_drugs}")
    drug_feature = {key: all_drug_feature[key] for key in valid_drugnames}
    data_idx = [item for item in data_idx if item[0] in valid_drugnames]
    return drug_feature, data_idx

def NormalizeAdj(adj):
    adj = adj + np.eye(adj.shape[0])
    d = sp.diags(np.power(np.array(adj.sum(1)), -0.5).flatten(), 0).toarray()
    a_norm = adj.dot(d).transpose().dot(d)
    return a_norm

def random_adjacency_matrix(n):
    matrix = [[random.randint(0, 1) for i in range(n)] for j in range(n)]
    # No vertex connects to itself
    for i in range(n):
        matrix[i][i] = 0
    # If i is connected to j, j is connected to i
    for i in range(n):
        for j in range(n):
            matrix[j][i] = matrix[i][j]
    return matrix

def CalculateGraphFeat(feat_mat, adj_list):
    # Truncate feat_mat and adj_list if they exceed Max_atoms
    if feat_mat.shape[0] > Max_atoms:
#         print(f"Warning: Truncating feat_mat from {feat_mat.shape[0]} to {Max_atoms}.")
        feat_mat = feat_mat[:Max_atoms]
        adj_list = adj_list[:Max_atoms]
    
    # Ensure the number of nodes in feat_mat matches the length of adj_list
    assert feat_mat.shape[0] == len(adj_list)
    
    # Initialize feature and adjacency matrices
    feat = np.zeros((Max_atoms, feat_mat.shape[-1]), dtype='float32')
    adj_mat = np.zeros((Max_atoms, Max_atoms), dtype='float32')
    
    # If israndom is set, generate random feature and adjacency matrices
    if israndom:
        feat = np.random.rand(Max_atoms, feat_mat.shape[-1])
        adj_mat[feat_mat.shape[0]:, feat_mat.shape[0]:] = random_adjacency_matrix(Max_atoms - feat_mat.shape[0])
    
    # Fill the feature matrix with feat_mat
    feat[:feat_mat.shape[0], :] = feat_mat
    
    # Construct the adjacency matrix
    for i in range(len(adj_list)):
        nodes = adj_list[i]
        for each in nodes:
            if each < Max_atoms:  # Ensure node index is within bounds
                adj_mat[i, int(each)] = 1
    # Ensure the adjacency matrix is symmetric
    assert np.allclose(adj_mat, adj_mat.T)
    adj_ = adj_mat[:len(adj_list), :len(adj_list)]
    adj_2 = adj_mat[len(adj_list):, len(adj_list):]
    norm_adj_ = NormalizeAdj(adj_)
    norm_adj_2 = NormalizeAdj(adj_2)
    adj_mat[:len(adj_list), :len(adj_list)] = norm_adj_
    adj_mat[len(adj_list):, len(adj_list):] = norm_adj_2
    return [feat, adj_mat]

def FeatureExtract(data_idx, drug_feature):
    nb_instance = len(data_idx)
    drug_data = [[] for _ in range(nb_instance)]
    target = np.zeros(nb_instance, dtype='int16')
    for idx in range(nb_instance):
        drugname, clabel = data_idx[idx]
        # modify
        feat_mat, adj_list, _ = drug_feature[str(drugname)]
        # fill drug data, padding to the same size with zeros
        drug_data[idx] = CalculateGraphFeat(feat_mat, adj_list)
        # randomlize X A
        target[idx] = clabel
    # return drug_data, target
    return drug_data, target

def precision_metric(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def recall_metric(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def f1_score_metric(y_true, y_pred):
    prec = precision_metric(y_true, y_pred)
    recal = recall_metric(y_true, y_pred)
    return 2.0 * prec * recal / (prec + recal + K.epsilon())

def average_precision_metric(y_true, y_pred):
    return tf.py_function(average_precision_score, (y_true, y_pred), tf.double) 

class MyCallback(Callback):
    def __init__(self, training_data, validation_data, patience, fold):
        self.x_train = training_data[0]
        self.y_train = training_data[1]
        self.x_val = validation_data[0]
        self.y_val = validation_data[1]
        self.best_weight = None
        self.patience = patience
        self.fold = fold  # To differentiate between folds

    def on_train_begin(self, logs=None):
        self.wait = 0
        self.stopped_epoch = 0
        self.best = -np.Inf
        self.losses = {'batch': [], 'epoch': []}
        self.auct = {'batch': [], 'epoch': []}
        self.val_loss = {'batch': [], 'epoch': []}
        self.aucl = {'batch': [], 'epoch': []}
        self.H = {}
        return

    def on_batch_end(self, batch, logs=None):
        self.losses['batch'].append(logs.get('loss'))
        self.val_loss['batch'].append(logs.get('val_loss'))
        return

    def on_train_end(self, logs=None):
        if self.best_weight is not None:
            self.model.set_weights(self.best_weight)
            self.model.save(f'/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/CarcGC_data/bestmodel/BestCarcGC_Cartoxicity_classify_fold{self.fold}_{model_suffix}.h5')
#         if self.stopped_epoch > 0:
#             print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
        return

    def on_epoch_begin(self, epoch, logs=None):
        return

    def on_epoch_end(self, epoch, logs=None):
        y_pred_val = self.model.predict(self.x_val)
        roc_val = roc_auc_score(self.y_val, y_pred_val)
        y_pred_train = self.model.predict(self.x_train)
        roc_train = roc_auc_score(self.y_train, y_pred_train)
        precision, recall, _ = metrics.precision_recall_curve(self.y_val, y_pred_val)
        pr_val = metrics.average_precision_score(self.y_val, y_pred_val)
        self.losses['epoch'].append(logs.get('loss'))
        self.val_loss['epoch'].append(logs.get('val_loss'))
        self.auct['epoch'].append(roc_train)
        self.aucl['epoch'].append(roc_val)
#         print(f'Fold {self.fold} - Epoch {epoch+1}: roc-val: {roc_val:.4f}, pr-val: {pr_val:.4f}')
        if roc_val > self.best:
            self.best = roc_val
            self.wait = 0
            self.best_weight = self.model.get_weights()
            self.model.save(f'/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/CarcGC_data/bestmodel/BestCarcGC_Cartoxicity_highestAUCROC_fold{self.fold}_{model_suffix}.h5')
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
        return

    def savedata(self, lr, batchsize, wd, dd, fold):
        iters = range(len(self.val_loss['epoch']))
        eponb = float(len(self.losses['batch'])) / float(len(self.val_loss['epoch']))
        dflist = []
        for ii in iters:
            ystart = int(ii * eponb)
            yend = int((ii + 1) * eponb)
            yloss = self.losses['epoch'][ii]
            valloss = self.val_loss['epoch'][ii]
            aucroct = self.auct['epoch'][ii]
            aucroc = self.aucl['epoch'][ii]
            dflist.append([ii + 1, aucroct, aucroc, yloss, valloss])
        df = pd.DataFrame(dflist)
        df.columns = ['epoch', 'auc_train', 'auc_val', 'train_loss', 'validation_loss']
        df.to_csv(f'/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/CarcGC_data/gridsearch_loss/lr{lr}_batch{batchsize}_dropout{dd}_{wd}_fold{fold}_loss.csv', index=False, header=True)
        return

def ModelTraining(model, lr, batchsize, wd, dd, X_drug_data_train, Y_train, validation_data, patience, fold, nb_epoch=500):
    optimizer = Adam(lr=lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy', precision_metric, recall_metric, f1_score_metric, average_precision_metric])
    
    # Prepare training data
    X_drug_feat_data_train = np.array([item[0] for item in X_drug_data_train])  # nb_instance * Max_atoms * feat_dim
    X_drug_adj_data_train = np.array([item[1] for item in X_drug_data_train])   # nb_instance * Max_atoms * Max_atoms
    
    # Unpack validation data correctly
    X_val_data, Y_val = validation_data
    X_val_feat, X_val_adj = zip(*X_val_data)
    X_val_feat = np.array(X_val_feat)
    X_val_adj = np.array(X_val_adj)
    
    history = MyCallback(training_data=[[X_drug_feat_data_train, X_drug_adj_data_train], Y_train],
                         validation_data=([X_val_feat, X_val_adj], Y_val),
                         patience=patience,
                         fold=fold)
    callbacks = [
        ModelCheckpoint(filepath=f'/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/CarcGC_data/checkpoint_weight/Cartoxicity_weights_fold{fold}_{{epoch:04d}}.h5', verbose=0),
        history
    ]
    model.fit(x=[X_drug_feat_data_train, X_drug_adj_data_train],
              validation_data=([X_val_feat, X_val_adj], Y_val),
              y=Y_train,
              batch_size=batchsize,
              epochs=nb_epoch,
              callbacks=callbacks,verbose=0)
    history.savedata(lr, batchsize, wd, dd, fold)
    return model

def ModelEvaluate(model, X_drug_data_test, Y_test, log_file_path):
    X_drug_feat_data_test = np.array([item[0] for item in X_drug_data_test])  # nb_instance * Max_atoms * feat_dim
    X_drug_adj_data_test = np.array([item[1] for item in X_drug_data_test])   # nb_instance * Max_atoms * Max_atoms    
    Y_pred = model.predict([X_drug_feat_data_test, X_drug_adj_data_test])
    auROC_all = metrics.roc_auc_score(Y_test, Y_pred)
    fpr, tpr, _ = metrics.roc_curve(Y_test, Y_pred)
    precision, recall, _ = metrics.precision_recall_curve(Y_test, Y_pred)
    auPR_all = metrics.average_precision_score(Y_test, Y_pred)
#     print("The overall AUC and auPR is %.4f and %.4f." % (auROC_all, auPR_all))
    
    # Optionally, log the results
    with open(log_file_path, 'w') as f:
        f.write(f"AUC-ROC: {auROC_all:.4f}\n")
        f.write(f"AUC-PR: {auPR_all:.4f}\n")
    
    return auROC_all, auPR_all, Y_pred

def main():
    # Define hyperparameter grid
    lr_list = [0.001, 0.0001]
    batchsize_list = [8]
    dd_list = [0.2]
    wd_list = [1e-06, 1e-05, 1e-04]
    
    # Initialize a list to store grid search results
    grid_results = []
    
    # Total number of combinations for progress tracking (optional)
    total_combinations = len(lr_list) * len(batchsize_list) * len(dd_list) * len(wd_list)
    current_combination = 1
    
    for lr in lr_list:
        for batchsize in batchsize_list:
            for dd in dd_list:
                for wd in wd_list:
                    print(f"\n===== Grid Search Combination {current_combination}/{total_combinations} =====")
                    print(f"Learning Rate: {lr}, Batch Size: {batchsize}, Dropout: {dd}, Weight Decay: {wd}")
                    
                    # Set seeds for reproducibility
                    random.seed(1)
                    tf.set_random_seed(2)
                    os.environ['TF_DETERMINISTIC_OPS'] = '1'
                    np.random.seed(2)
                    
                    # Load drug features
                    all_drug_feature = DataGenerate(Drug_feature_file)
                    
                    # Load train and test data
                    train_csv_path = r'/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/Dataset/mgl_train_data.csv'
                    val_csv_path = r'/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/Dataset/mgl_val_data.csv'
                    test_csv_path = r'/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/Dataset/mgl_test_data.csv'
                    drug_feature_train, data_idx_train = MetadataGenerate(all_drug_feature, train_csv_path)
                    drug_feature_val, data_idx_val = ValidationGenerate(all_drug_feature, val_csv_path)    
                    drug_feature_test, data_idx_test = ExtraGenerate(all_drug_feature, test_csv_path)
                    
                    # Combine train and test data
                    combined_drug_feature = {**drug_feature_train, **drug_feature_val, **drug_feature_test}
                    combined_data_idx = data_idx_train + data_idx_val + data_idx_test
                    
                    # Extract features
                    X_drug_data, Y = FeatureExtract(combined_data_idx, combined_drug_feature)
                    
                    # Convert to numpy arrays
                    X_feat = np.array([item[0] for item in X_drug_data])  # nb_instance * Max_atoms * feat_dim
                    X_adj = np.array([item[1] for item in X_drug_data])   # nb_instance * Max_atoms * Max_atoms
                    Y = np.array(Y)
                    
                    # Determine drug_dim
                    drug_dim = X_feat.shape[-1]
                    
                    # Define Stratified K-Fold
                    n_splits = args.n_splits
                    random_state = args.random_state
                    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
                    
                    # Initialize metrics lists for this hyperparameter combination
                    auROC_folds = []
                    auPR_folds = []
                    
                    fold_num = 1
                    for outer_train_idx, outer_test_idx in skf.split(X_feat, Y):
                        print(f"\n===== Fold {fold_num} / {n_splits} for Hyperparameters (lr={lr}, batchsize={batchsize}, dd={dd}, wd={wd}) =====")
                        
                        # Split into outer train and outer test
                        X_outer_train_feat, X_outer_test_feat = X_feat[outer_train_idx], X_feat[outer_test_idx]
                        X_outer_train_adj, X_outer_test_adj = X_adj[outer_train_idx], X_adj[outer_test_idx]
                        Y_outer_train, Y_outer_test = Y[outer_train_idx], Y[outer_test_idx]
                        
                        # From outer_train, create inner train and validation using Stratified Shuffle Split
                        sss = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=random_state)
                        for inner_train_idx, inner_val_idx in sss.split(X_outer_train_feat, Y_outer_train):
                            X_inner_train_feat, X_val_feat = X_outer_train_feat[inner_train_idx], X_outer_train_feat[inner_val_idx]
                            X_inner_train_adj, X_val_adj = X_outer_train_adj[inner_train_idx], X_outer_train_adj[inner_val_idx]
                            Y_inner_train, Y_val = Y_outer_train[inner_train_idx], Y_outer_train[inner_val_idx]
                        
                        # Prepare training and validation data
                        X_inner_train_data = list(zip(X_inner_train_feat, X_inner_train_adj))
                        X_val_data = list(zip(X_val_feat, X_val_adj))
                        
                        # Create model
                        try:
                            model = KerasMultiSourceGCNModel(regr=False).createMaster(
                                drug_dim=drug_dim,
                                units_list=args.unit_list,
                                wd=wd,
                                dd=dd,
                                use_relu=args.use_relu,
                                use_bn=args.use_bn,
                                use_GMP=args.use_GMP
                            )
                        except TypeError as e:
                            print(f"Error creating model in fold {fold_num}: {e}")
                            print("Please check the 'createMaster' method in 'Car_model.py' and ensure that the arguments match.")
                            sys.exit(1)
                        
                        # Train model
#                         print('Begin training...')
                        model = ModelTraining(
                            model=model,
                            lr=lr,
                            batchsize=batchsize,
                            wd=wd,
                            dd=dd,
                            X_drug_data_train=X_inner_train_data,
                            Y_train=Y_inner_train,
                            validation_data=(X_val_data, Y_val),
                            patience=10,
                            fold=fold_num,
                            nb_epoch=500
                        )
                        
                        # Evaluate on outer test
#                         print("Evaluating on outer test set...")
                        X_test_data = list(zip(X_outer_test_feat, X_outer_test_adj))
                        log_file_path = f'CarcGC_data/CarcGC_test_fold{fold_num}_lr{lr}_bs{batchsize}_dd{dd}_wd{wd}.log'
                        auROC_test, auPR_test, Y_pred_test = ModelEvaluate(
                            model=model,
                            X_drug_data_test=X_test_data,
                            Y_test=Y_outer_test,
                            log_file_path=log_file_path
                        )
                        auROC_folds.append(auROC_test)
                        auPR_folds.append(auPR_test)
                        
                        fold_num += 1
                    
                    # After all folds for this hyperparameter combination
                    mean_auROC = np.mean(auROC_folds)
                    std_auROC = np.std(auROC_folds)
                    mean_auPR = np.mean(auPR_folds)
                    std_auPR = np.std(auPR_folds)
#                     print(f"\n===== Results for Hyperparameters (lr={lr}, batchsize={batchsize}, dd={dd}, wd={wd}) =====")
#                     print(f"Mean AUC-ROC: {mean_auROC:.4f} ± {std_auROC:.4f}")
#                     print(f"Mean AUC-PR: {mean_auPR:.4f} ± {std_auPR:.4f}")
                    
                    # Append results to grid_results
                    grid_results.append([lr, batchsize, dd, wd, mean_auROC, std_auROC, mean_auPR, std_auPR])
                    
                    # Increment combination counter
                    current_combination += 1
    
    # After all hyperparameter combinations
    # Convert grid_results to DataFrame
    grid_df = pd.DataFrame(grid_results, columns=[
        'Learning_Rate', 'Batch_Size', 'Dropout', 'Weight_Decay',
        'Mean_AUC_ROC', 'Std_AUC_ROC', 'Mean_AUC_PR', 'Std_AUC_PR'
    ])
    
    # Save grid search results to CSV
    grid_df.to_csv('/data/home/dbswn0814/2025JCM/model/CarcGC/mgl/gridsearch_results_mgl.csv', index=False)
    print("\n===== Grid Search Completed =====")
    print(grid_df)

if __name__ == '__main__':
    main()



===== Grid Search Combination 1/6 =====
Learning Rate: 0.001, Batch Size: 8, Dropout: 0.2, Weight Decay: 1e-06

===== Fold 1 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 2 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 3 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 4 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 5 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 6 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 7 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 8 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 9 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 10 / 20 for Hyperparameters (lr=0.001, batchsize=8, dd=0.2, wd=1e-06) =====

===== Fold 11 / 


===== Fold 8 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 9 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 10 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 11 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 12 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 13 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 14 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 15 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 16 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 17 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 18 / 20 for Hyperparameters (lr=0.0001, batchsize=8, dd=0.2, wd=1e-05) =====

===== Fold 19 / 20 for