In [10]:
# Federated Learning MAE + ANN (TFF 없는 시뮬레이션 버전)
import openpyxl
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam, RMSprop
import time
import copy

# =================== 데이터 로딩 및 전처리 ===================
referrals = pd.read_excel(r"C:\Users\이희창\Downloads\opd.xlsx", engine='openpyxl')
df = referrals
print(referrals['transplanted'].value_counts())
print(referrals['transplanted'].unique())

# ⚠️ 중요: HospitalID를 유지 (Federated Learning을 위해)
outcome_columns = [col for col in df.columns if col.startswith('outcome_')]
columns_to_drop = ['PatientID'] + outcome_columns  # HospitalID 제거하지 않음!

print("제거할 변수들:", columns_to_drop)
df = df.drop(columns=columns_to_drop, axis=1)
print(f"변수 제거 후 원본 데이터프레임 크기: {df.shape}")

# =================== 범주형 변수 재분류 ===================
df_new = df.copy()

def total_values(df, col, list_features, label):
    for i in list_features:
        df[col].mask(df[col] == i, label, inplace=True)

# 기존 범주형 변수 재분류 코드 (동일)
infections = ['Sepsis', 'Septic Shock', 'Infectious Disease - Bacterial', 'Infectious Disease - Viral',
              'Infectious Disease - Other, specify', 'Pneumonia', 'HIV', 'Hepatitis', 'AIDS/HIV']
total_values(df_new, 'Cause_of_Death_OPO', infections, 'Infectious Disease')

cardio = ['CHF', 'CAR - CHF', 'AAA or thoracic AA', 'AAA - abdominal aortic aneurysm', 
          'CAR - cardiomegaly/cardiomyopathy/cardiovascular', 'Pulmonary embolism', 'PE--Pulmonary Embolism ',
          'Myocardial infarction', 'CAR - MI', 'CAR - probable MI', 'CAR - arrhythmia',
          'Arrhythmia', 'Cardiac - Other, specify']
total_values(df_new, 'Cause_of_Death_OPO', cardio, 'Circulatory Disease')

resp = ['Anoxia', 'COPD', 'RES - COPD', 'Respiratory - Other', 'Respiratory - Other, specify',
        'RES - other', 'RES - pneumonia', 'RES - lung disease', 'RES - asthma', 'RES - aspiration']
total_values(df_new, 'Cause_of_Death_OPO', resp, 'Respiratory Disease')

newborn = ['Fetal Demise', 'Prematurity', 'Sudden infant death syndrome', 'PED - abuse/shaken baby']
total_values(df_new, 'Cause_of_Death_OPO', newborn, 'Newborn Disease')

cancers = ['Leukemia / Lymphoma', 'Cancer', 'Cancer - Leukemia/Lymphoma', 'Cancer/Current or within five years']
total_values(df_new, 'Cause_of_Death_OPO', cancers, 'Cancer')

neuro = ['CVA/Stroke - Cerebro Accident', 'ICB / ICH', 'Cerebrovascular / Stroke', 'CNS Tumor', 'SAH',
         'Meningitis', 'Seizure/Seizure Disorder', 'Aneurysm']
total_values(df_new, 'Cause_of_Death_OPO', neuro, 'Nervous Disease')

digestive = ['GI - necrotic bowel', 'GI - bleed', 'GI - bowel perforation', 'GI - bowel obstruction']
total_values(df_new, 'Cause_of_Death_OPO', digestive, 'Digestive Disease')

liver = ['Liver Disease/Failure', 'ESLD']
total_values(df_new, 'Cause_of_Death_OPO', liver, 'Liver Disease')

kidney = ['ESRD', 'Kidney/Renal  Disease']
total_values(df_new, 'Cause_of_Death_OPO', kidney, 'Kidney Disease')

eye = ['PED - other', 'PED - premature']
total_values(df_new, 'Cause_of_Death_OPO', eye, 'Eye Disease')

injury = ['GSW', 'TR - GSW', 'Drowning', 'Head Trauma', 'Trauma', 'Overdose',
          'Drug Overdose/Probable Drug Abuse', 'An - other', 'An - asphyixiation',
          'An - smoke inhalation', 'An -  hanging', 'An - drowning', 'TR - MVA', 'TR - other',
          'TR - CHI - Closed Head Injury', 'TR - burns', 'TR - stabbing', 'TR - electrocution',
          'Poisoning', 'Intracranial Hemorrhage', 'Exsanguination']
total_values(df_new, 'Cause_of_Death_OPO', injury, 'Injury_External Causes')

multi = ['Multi-system failure', 'MultiSystem Failure']
total_values(df_new, 'Cause_of_Death_OPO', multi, 'Multi-system failure')

other = ['Other', 'Other, specify']
total_values(df_new, 'Cause_of_Death_OPO', other, 'Other')

# UNOS 분류 (기존과 동일)
infections = ['Sepsis', 'Infectious Disease - Bacterial', 'Infectious Disease - Viral',
              'Infectious Disease - Other, specify', 'Pneumonia', 'HIV', 'Hepatitis']
total_values(df_new, 'Cause_of_Death_UNOS', infections, 'Infectious Disease')

cardio = ['CHF', 'AAA or thoracic AA', 'Pulmonary embolism', 'Myocardial infarction', 'Arrhythmia', 'Cardiac - Other, specify']
total_values(df_new, 'Cause_of_Death_UNOS', cardio, 'Circulatory Disease')

resp = ['Anoxia', 'COPD', 'Respiratory - Other', 'Respiratory - Other, specify']
total_values(df_new, 'Cause_of_Death_UNOS', resp, 'Respiratory Disease')

newborn = ['Fetal Demise', 'Prematurity', 'Sudden infant death syndrome']
total_values(df_new, 'Cause_of_Death_UNOS', newborn, 'Newborn Disease')

cancers = ['Leukemia / Lymphoma', 'Cancer']
total_values(df_new, 'Cause_of_Death_UNOS', cancers, 'Cancer')

neuro = ['CVA/Stroke', 'ICB / ICH', 'Cerebrovascular / Stroke', 'CNS Tumor', 'SAH']
total_values(df_new, 'Cause_of_Death_UNOS', neuro, 'Nervous Disease')

injury = ['GSW', 'Drowning', 'Head Trauma', 'Trauma', 'Overdose', 'Exsanguination']
total_values(df_new, 'Cause_of_Death_UNOS', injury, 'Injury_External Causes')

other = ['Other', 'Other, specify']
total_values(df_new, 'Cause_of_Death_UNOS', other, 'Other')

df_new['Cause_of_Death_UNOS'].replace('ESRD', 'Kidney Disease', inplace=True)
df_new['Cause_of_Death_UNOS'].replace('ESLD', 'Liver Disease', inplace=True)

# 죽음의 메커니즘 재분류
natural_causes = ['Natural Causes', 'Death from Natural Causes']
total_values(df_new, 'Mechanism_of_Death', natural_causes, 'Natural Causes')

injury_external = ['Blunt Injury', 'Drug Intoxication', 'Gun Shot Wound', 'Asphyxiation', 'Drug / Intoxication',
                   'Drowning', 'Gunshot Wound', 'Stab', 'Electrical']
total_values(df_new, 'Mechanism_of_Death', injury_external, 'Injury_External Causes')

nervous_diseases = ['ICH/Stroke', 'Intracranial Hemmorrhage / Stroke', 'Seizure']
total_values(df_new, 'Mechanism_of_Death', nervous_diseases, 'Nervous Disease')

nofa = ['None of the Above', 'None of the above']
total_values(df_new, 'Mechanism_of_Death', nofa, 'Other')

# 죽음의 환경 재분류
natural_causes = ['Natural Causes', 'Death from Natural Causes']
total_values(df_new, 'Circumstances_of_Death', natural_causes, 'Natural Causes')

mva = ['Motor Vehicle Accident', 'MVA']
total_values(df_new, 'Circumstances_of_Death', mva, 'Motor Accident')

non_mva = ['Non-Motor Vehicle Accident', 'Accident, Non-MVA']
total_values(df_new, 'Circumstances_of_Death', non_mva, 'Non-motor Accident')

suicide = ['Suicide', 'Alleged Suicide']
total_values(df_new, 'Circumstances_of_Death', suicide, 'Suicide')

homicide = ['Homicide', 'Alleged Homicide']
total_values(df_new, 'Circumstances_of_Death', homicide, 'Homicide')

child_abuse = ['Child Abuse', 'Alleged Child Abuse']
total_values(df_new, 'Circumstances_of_Death', child_abuse, 'Homicide')

other = ['Other', 'None of the Above']
total_values(df_new, 'Circumstances_of_Death', other, 'Other')

# =================== 시간 관련 특성 엔지니어링 ===================
def get_duration_between_dates(then, now, interval="default"):
    duration = now - then
    duration_in_s = duration.total_seconds()
    
    def years():
        return divmod(duration_in_s, 31536000)
    def days(seconds=None):
        return divmod(seconds if seconds != None else duration_in_s, 86400)
    def hours(seconds=None):
        return divmod(seconds if seconds != None else duration_in_s, 3600)
    def minutes(seconds=None):
        return divmod(seconds if seconds != None else duration_in_s, 60)
    def seconds(seconds=None):
        if seconds != None:
            return divmod(seconds, 1)
        return duration_in_s
    
    return {
        'years': float(years()[0]),
        'days': float(days()[0]),
        'hours': float(hours()[0]),
        'minutes': float(minutes()[0]),
        'seconds': float(seconds()),
    }

def create_time_column(df, col1, col2, new_col_name):
    def convert_datetime(str1, str2):
        return [pd.to_datetime(str1), pd.to_datetime(str2)]
    
    time_category = []
    for row in zip(df[col1], df[col2]):
        if pd.isnull(row[0]) == False and pd.isnull(row[1]) == False:
            date_row = convert_datetime(row[0], row[1])
            time_elapsed = abs(get_duration_between_dates(date_row[0], date_row[1])['hours'])
            
            if time_elapsed <= 24:
                time_category.append('Within 24 hours')
            else:
                time_category.append('Over 24 hours')
        else:
            time_category.append('Milestone not reached')
    
    df[new_col_name] = time_category
    return df

# 시간 변수들 생성
time_vars = ['time_asystole', 'time_brain_death', 'time_referred', 'time_approached', 'time_authorized', 'time_procured']

df_new = create_time_column(df_new, time_vars[0], time_vars[2], 'time_asystole_to_referred')
df_new = create_time_column(df_new, time_vars[1], time_vars[2], 'time_brain_death_to_referred')
df_new = create_time_column(df_new, time_vars[2], time_vars[3], 'time_referred_to_approached')
df_new = create_time_column(df_new, time_vars[3], time_vars[4], 'time_approached_to_authorized')
df_new = create_time_column(df_new, time_vars[4], time_vars[5], 'time_authorized_to_procured')

# =================== 결측치 처리 및 변수 제거 ===================
def get_missing_data(data):
    missing_data_prop = {}
    for x, y in enumerate(list(data.isnull().sum())):
        missing_data_prop[data.columns[x]] = (float(y / data.shape[0]) * 100)
    missing_data = pd.DataFrame(missing_data_prop.items(), columns=['column', 'percent_missing'])
    return missing_data

missing_data = get_missing_data(df_new)
cols_large_missing = list(missing_data[missing_data['percent_missing'] > 50]['column'])

# 중요한 변수들 유지
if 'Cause_of_Death_OPO' in cols_large_missing:
    cols_large_missing.remove('Cause_of_Death_OPO')
if 'time_brain_death' in cols_large_missing:
    cols_large_missing.remove('time_brain_death')
if 'time_approached' in cols_large_missing:
    cols_large_missing.remove('time_approached')
if 'time_authorized' in cols_large_missing:
    cols_large_missing.remove('time_authorized')

df_new = df_new.drop(cols_large_missing, axis=1)

# 공선성 변수 제거
cols_collinear = ['brain_death', 'time_referred', 'time_asystole', 'authorized', 'procured',
                  'time_approached_to_authorized', 'time_authorized_to_procured']
existing_cols = [col for col in cols_collinear if col in df_new.columns]
df_new = df_new.drop(existing_cols, axis=1)

# =================== 전처리 파이프라인 ===================
def categorize_columns(df):
    categorical_cols = []
    numerical_cols = []
    datetime_cols = []
    binary_cols = []
    
    for col in df.columns:
        if col in ['transplanted', 'HospitalID', 'OPO_Group']:  # OPO_Group도 제외
            continue
        
        if pd.api.types.is_datetime64_any_dtype(df[col]):
            datetime_cols.append(col)
        elif pd.api.types.is_numeric_dtype(df[col]):
            if df[col].nunique() <= 2:
                binary_cols.append(col)
            else:
                numerical_cols.append(col)
        elif pd.api.types.is_object_dtype(df[col]) or pd.api.types.is_categorical_dtype(df[col]):
            categorical_cols.append(col)
    
    return categorical_cols, numerical_cols, datetime_cols, binary_cols

df_processed = df_new.copy()

# 불리언 변수를 정수형으로 변환
for col in df_processed.columns:
    if df_processed[col].dtype == 'bool':
        df_processed[col] = df_processed[col].astype(int)

# 시간 변수 처리
for col in df_processed.columns:
    if pd.api.types.is_datetime64_any_dtype(df_processed[col]):
        if not df_processed[col].isna().all():
            reference_date = df_processed[col].min()
            df_processed[col] = (df_processed[col] - reference_date).dt.total_seconds() / (24 * 3600)
            df_processed[col] = df_processed[col].fillna(0)

# =================== Federated Learning 시뮬레이션 ===================
class FederatedMAEANN:
    """TFF 없이 Federated Learning 시뮬레이션"""
    
    def __init__(self):
        self.global_mae_weights = None
        self.global_ann_weights = None
        self.clients_data = []
        self.preprocessor = None
        
    def prepare_federated_data(self, df):
        """병원별로 데이터를 분할하고 전처리"""
        
        # HospitalID에서 OPO 번호 추출하여 그룹핑
        def extract_opo_group(hospital_id):
            """HospitalID에서 OPO 번호만 추출 (예: OPO1, OPO2, ...)"""
            if pd.isna(hospital_id):
                return 'Unknown'
            hospital_str = str(hospital_id).upper()
            # OPO 뒤의 숫자 찾기
            import re
            match = re.search(r'OPO\s*(\d+)', hospital_str)
            if match:
                return f"OPO{match.group(1)}"
            else:
                # OPO 패턴이 없으면 첫 번째 문자 몇 개 사용
                return hospital_str[:5] if len(hospital_str) >= 5 else hospital_str
        
        # OPO 그룹별로 데이터 재분할
        df['OPO_Group'] = df['HospitalID'].apply(extract_opo_group)
        
        hospital_data = {}
        unique_opo_groups = df['OPO_Group'].unique()
        
        print(f"총 {len(unique_opo_groups)}개 OPO 그룹 발견")
        for opo_group in unique_opo_groups:
            opo_df = df[df['OPO_Group'] == opo_group].copy()
            sample_count = len(opo_df)
            print(f"{opo_group}: {sample_count} 샘플")
            
            # 최소 샘플 수 체크 (train/test split을 위해 최소 10개 필요)
            if sample_count >= 10:
                hospital_data[opo_group] = opo_df
            else:
                print(f"  ⚠️ {opo_group}는 샘플 수가 너무 적어서 제외됩니다.")
        
        print(f"\n실제 사용할 OPO 그룹: {len(hospital_data)}개")
        
        
        # 전체 데이터로 전처리 파이프라인 학습 (글로벌 통계량 사용)
        X_df = df.drop(['transplanted', 'HospitalID', 'OPO_Group'], axis=1)
        y = df['transplanted']
        
        categorical_cols, numerical_cols, datetime_cols, binary_cols = categorize_columns(df)
        numerical_cols.extend(datetime_cols)
        
        # 전처리 파이프라인 구성
        transformers = []
        if numerical_cols:
            transformers.append(('num', Pipeline([
                ('imputer', SimpleImputer(strategy='median')),
                ('scaler', StandardScaler())
            ]), numerical_cols))
        
        if categorical_cols:
            transformers.append(('cat', Pipeline([
                ('imputer', SimpleImputer(strategy='most_frequent')),
                ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False))
            ]), categorical_cols))
        
        if binary_cols:
            transformers.append(('bin', Pipeline([
                ('imputer', SimpleImputer(strategy='most_frequent')),
                ('scaler', StandardScaler())
            ]), binary_cols))
        
        self.preprocessor = ColumnTransformer(transformers=transformers, remainder='passthrough')
        
        # 전체 데이터로 전처리 파이프라인 학습
        self.preprocessor.fit(X_df)
        
        # 각 OPO 그룹별로 전처리 적용
        self.clients_data = []
        for opo_group, opo_df in hospital_data.items():
            X_opo = opo_df.drop(['transplanted', 'HospitalID', 'OPO_Group'], axis=1)
            y_opo = opo_df['transplanted'].values
            
            # 동일한 전처리 파이프라인 적용
            X_opo_processed = self.preprocessor.transform(X_opo)
            
            # sparse matrix를 dense로 변환
            if hasattr(X_opo_processed, 'toarray'):
                X_opo_processed = X_opo_processed.toarray()
            
            # train/test 분할 (이제 충분한 샘플이 있음)
            try:
                # stratify를 시도하되, 실패하면 일반 분할 사용
                X_train, X_test, y_train, y_test = train_test_split(
                    X_opo_processed, y_opo, test_size=0.2, random_state=42, stratify=y_opo
                )
            except ValueError:
                # stratify 실패 시 (클래스 불균형이 심한 경우)
                print(f"  ⚠️ {opo_group}: stratify 불가, 일반 분할 사용")
                X_train, X_test, y_train, y_test = train_test_split(
                    X_opo_processed, y_opo, test_size=0.2, random_state=42
                )
            
            self.clients_data.append({
                'hospital_id': opo_group,  # OPO 그룹명 사용
                'x_train': X_train.astype(np.float32),
                'y_train': y_train.astype(np.float32),
                'x_test': X_test.astype(np.float32),
                'y_test': y_test.astype(np.float32)
            })
            
            print(f"✅ {opo_group}: Train={len(X_train)}, Test={len(X_test)}")
        
        return self.clients_data
    
    def create_mae_model(self, input_shape):
        """MAE 모델 생성"""
        model = tf.keras.Sequential([
            tf.keras.layers.Input(shape=(input_shape,)),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(64, activation='relu', name='latent'),  # latent space
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(input_shape, activation=None)  # reconstruction
        ])
        
        model.compile(optimizer=RMSprop(learning_rate=0.001), loss='mse', metrics=['mae'])
        return model
    
    def create_encoder_model(self, input_shape):
        """Encoder만 분리한 모델"""
        encoder = tf.keras.Sequential([
            tf.keras.layers.Input(shape=(input_shape,)),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(64, activation='relu')  # latent space
        ])
        return encoder
    
    def create_ann_model(self, input_shape, latent_dim=64):
        """ANN 모델 생성 (원본 + latent 특성)"""
        model = tf.keras.Sequential([
            tf.keras.layers.Input(shape=(input_shape + latent_dim,)),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dropout(0.15),
            tf.keras.layers.Dense(32, activation='relu'),
            tf.keras.layers.Dropout(0.07),
            tf.keras.layers.Dense(16, activation='relu'),
            tf.keras.layers.Dense(1, activation='sigmoid')
        ])
        
        model.compile(optimizer=Adam(learning_rate=0.0007), 
                     loss='binary_crossentropy', 
                     metrics=['accuracy'])
        return model
    
    def federated_averaging(self, models):
        """Federated Averaging - 모델 가중치들의 평균 계산"""
        if not models:
            return None
            
        # 첫 번째 모델의 가중치를 기준으로 시작
        averaged_weights = [np.zeros_like(w) for w in models[0].get_weights()]
        
        # 모든 모델의 가중치를 더함
        for model in models:
            model_weights = model.get_weights()
            for i, w in enumerate(model_weights):
                averaged_weights[i] += w
        
        # 평균 계산
        num_models = len(models)
        averaged_weights = [w / num_models for w in averaged_weights]
        
        return averaged_weights
    
    def apply_masking(self, X, mask_ratio=0.5):
        """마스킹 적용"""
        mask = np.random.uniform(0, 1, X.shape) > mask_ratio
        return np.where(mask, X, 0)
    
    def train_federated_mae(self, rounds=50, mask_ratio=0.5):
        """Federated MAE 학습"""
        print("="*50)
        print("Federated MAE 학습 시작")
        print("="*50)
        
        input_shape = self.clients_data[0]['x_train'].shape[1]
        
        # 각 클라이언트별 MAE 모델 초기화
        client_mae_models = []
        for _ in self.clients_data:
            model = self.create_mae_model(input_shape)
            client_mae_models.append(model)
        
        # Federated 학습
        for round_num in range(rounds):
            print(f"\nRound {round_num + 1}/{rounds}")
            
            # 각 클라이언트에서 로컬 학습
            for i, (client_data, model) in enumerate(zip(self.clients_data, client_mae_models)):
                # 마스킹 적용
                X_masked = self.apply_masking(client_data['x_train'], mask_ratio)
                
                # 로컬 학습 (1 epoch)
                model.fit(X_masked, client_data['x_train'], 
                         epochs=1, batch_size=32, verbose=0)
            
            # Federated Averaging
            self.global_mae_weights = self.federated_averaging(client_mae_models)
            
            # 글로벌 가중치를 모든 클라이언트에 배포
            for model in client_mae_models:
                model.set_weights(self.global_mae_weights)
            
            # 성능 평가 (10 라운드마다)
            if (round_num + 1) % 10 == 0:
                total_loss = 0
                for i, (client_data, model) in enumerate(zip(self.clients_data, client_mae_models)):
                    loss = model.evaluate(client_data['x_train'], client_data['x_train'], verbose=0)[0]
                    total_loss += loss
                avg_loss = total_loss / len(self.clients_data)
                print(f"  평균 MAE Loss: {avg_loss:.4f}")
        
        print("Federated MAE 학습 완료!")
        return client_mae_models[0]  # 글로벌 모델 반환
    
    def train_federated_ann(self, mae_model, rounds=100):
        """Federated ANN 학습"""
        print("="*50)
        print("Federated ANN 학습 시작")
        print("="*50)
        
        input_shape = self.clients_data[0]['x_train'].shape[1]
        
        # MAE encoder 생성
        encoder = self.create_encoder_model(input_shape)
        # MAE에서 학습된 encoder 가중치 복사
        encoder_weights = mae_model.get_weights()[:8]  # encoder 부분만 (Dense 4개 * 2개씩)
        encoder.set_weights(encoder_weights)
        
        # 각 클라이언트의 데이터에 latent 특성 추가
        for client_data in self.clients_data:
            # Latent 특성 생성
            latent_train = encoder.predict(client_data['x_train'], verbose=0)
            latent_test = encoder.predict(client_data['x_test'], verbose=0)
            
            # 원본 + latent 특성 결합
            client_data['x_train_combined'] = np.concatenate([client_data['x_train'], latent_train], axis=1)
            client_data['x_test_combined'] = np.concatenate([client_data['x_test'], latent_test], axis=1)
        
        # 각 클라이언트별 ANN 모델 초기화
        combined_input_shape = self.clients_data[0]['x_train_combined'].shape[1]
        client_ann_models = []
        for _ in self.clients_data:
            model = self.create_ann_model(input_shape, latent_dim=64)
            client_ann_models.append(model)
        
        # Federated 학습
        for round_num in range(rounds):
            if (round_num + 1) % 20 == 0:
                print(f"\nRound {round_num + 1}/{rounds}")
            
            # 각 클라이언트에서 로컬 학습
            for i, (client_data, model) in enumerate(zip(self.clients_data, client_ann_models)):
                # 로컬 학습 (1 epoch)
                model.fit(client_data['x_train_combined'], client_data['y_train'], 
                         epochs=1, batch_size=32, verbose=0)
            
            # Federated Averaging
            self.global_ann_weights = self.federated_averaging(client_ann_models)
            
            # 글로벌 가중치를 모든 클라이언트에 배포
            for model in client_ann_models:
                model.set_weights(self.global_ann_weights)
            
            # 성능 평가 (20 라운드마다)
            if (round_num + 1) % 20 == 0:
                total_loss = 0
                total_acc = 0
                for i, (client_data, model) in enumerate(zip(self.clients_data, client_ann_models)):
                    metrics = model.evaluate(client_data['x_train_combined'], client_data['y_train'], verbose=0)
                    total_loss += metrics[0]
                    total_acc += metrics[1]
                avg_loss = total_loss / len(self.clients_data)
                avg_acc = total_acc / len(self.clients_data)
                print(f"  평균 Loss: {avg_loss:.4f}, 평균 Accuracy: {avg_acc:.4f}")
        
        print("Federated ANN 학습 완료!")
        return client_ann_models[0], encoder  # 글로벌 모델 반환
    
    def evaluate_model(self, ann_model, encoder):
        """모델 성능 평가"""
        print("="*50)
        print("Federated Learning 성능 평가")
        print("="*50)
        
        all_y_true = []
        all_y_pred = []
        all_y_prob = []
        
        for client_data in self.clients_data:
            # 예측
            y_prob = ann_model.predict(client_data['x_test_combined'], verbose=0)
            y_pred = (y_prob >= 0.5).astype(int)
            
            all_y_true.extend(client_data['y_test'])
            all_y_pred.extend(y_pred.flatten())
            all_y_prob.extend(y_prob.flatten())
        
        # 성능 지표 계산
        test_accuracy = accuracy_score(all_y_true, all_y_pred)
        test_precision = precision_score(all_y_true, all_y_pred)
        test_recall = recall_score(all_y_true, all_y_pred)
        test_f1 = f1_score(all_y_true, all_y_pred)
        test_auc = roc_auc_score(all_y_true, all_y_prob)
        
        print(f"\n📊 Federated Learning 최종 성능:")
        print(f"정확도: {test_accuracy:.4f}")
        print(f"정밀도: {test_precision:.4f}")
        print(f"민감도(재현율): {test_recall:.4f}")
        print(f"F1 점수: {test_f1:.4f}")
        print(f"AUC: {test_auc:.4f}")
        
        # 병원별 성능 분석
        print(f"\n🏥 OPO 그룹별 성능 분석:")
        for i, client_data in enumerate(self.clients_data):
            opo_group = client_data['hospital_id']  # 이제 OPO 그룹명
            y_prob = ann_model.predict(client_data['x_test_combined'], verbose=0)
            y_pred = (y_prob >= 0.5).astype(int)
            
            if len(np.unique(client_data['y_test'])) > 1:  # 클래스가 2개 이상일 때만 계산
                accuracy = accuracy_score(client_data['y_test'], y_pred)
                precision = precision_score(client_data['y_test'], y_pred, zero_division=0)
                recall = recall_score(client_data['y_test'], y_pred, zero_division=0)
                f1 = f1_score(client_data['y_test'], y_pred, zero_division=0)
                
                print(f"{opo_group}: Accuracy={accuracy:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")
            else:
                print(f"{opo_group}: 테스트 데이터에 단일 클래스만 존재")
        
        return {
            'accuracy': test_accuracy,
            'precision': test_precision,
            'recall': test_recall,
            'f1': test_f1,
            'auc': test_auc
        }

# =================== 실행 ===================
print("HospitalID 및 OPO 그룹 확인:")
if 'HospitalID' in df_processed.columns:
    # OPO 그룹 미리 생성해서 확인
    def extract_opo_group(hospital_id):
        if pd.isna(hospital_id):
            return 'Unknown'
        hospital_str = str(hospital_id).upper()
        import re
        match = re.search(r'OPO\s*(\d+)', hospital_str)
        if match:
            return f"OPO{match.group(1)}"
        else:
            return hospital_str[:5] if len(hospital_str) >= 5 else hospital_str
    
    df_processed['OPO_Group'] = df_processed['HospitalID'].apply(extract_opo_group)
    
    print(f"OPO 그룹별 데이터 개수:")
    opo_counts = df_processed['OPO_Group'].value_counts()
    for opo_group, count in opo_counts.items():
        print(f"  {opo_group}: {count} 샘플")
    


# Federated Learning 실행
if 'HospitalID' in df_processed.columns:
    fl_system = FederatedMAEANN()
    
    # 데이터 준비
    clients_data = fl_system.prepare_federated_data(df_processed)
    
    if len(clients_data) > 0:  # 유효한 클라이언트가 있는 경우만 실행
        # MAE 학습
        global_mae_model = fl_system.train_federated_mae(rounds=50, mask_ratio=0.5)
        
        # ANN 학습
        global_ann_model, global_encoder = fl_system.train_federated_ann(global_mae_model, rounds=100)
        
        # 성능 평가
        results = fl_system.evaluate_model(global_ann_model, global_encoder)
        
        print("\n✅ Federated Learning MAE + ANN 학습 및 평가 완료!")

    
    # 기존 코드 실행
    # ... (기존 MAE + ANN 코드)

transplanted
False    124129
True       8972
Name: count, dtype: int64
[False  True]
제거할 변수들: ['PatientID', 'outcome_heart', 'outcome_liver', 'outcome_kidney_left', 'outcome_kidney_right', 'outcome_lung_left', 'outcome_lung_right', 'outcome_intestine', 'outcome_pancreas']
변수 제거 후 원본 데이터프레임 크기: (133101, 25)


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df[col].mask(df[col] == i, label, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df[col].mask(df[col] == i, label, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting valu

HospitalID 및 OPO 그룹 확인:
OPO 그룹별 데이터 개수:
  OPO4: 33641 샘플
  OPO1: 32148 샘플
  OPO6: 22915 샘플
  OPO2: 16145 샘플
  OPO5: 15738 샘플
  OPO3: 12514 샘플
총 6개 OPO 그룹 발견
OPO1: 32148 샘플
OPO2: 16145 샘플
OPO3: 12514 샘플
OPO4: 33641 샘플
OPO5: 15738 샘플
OPO6: 22915 샘플

실제 사용할 OPO 그룹: 6개
✅ OPO1: Train=25718, Test=6430
✅ OPO2: Train=12916, Test=3229
✅ OPO3: Train=10011, Test=2503
✅ OPO4: Train=26912, Test=6729
✅ OPO5: Train=12590, Test=3148
✅ OPO6: Train=18332, Test=4583
Federated MAE 학습 시작

Round 1/50

Round 2/50

Round 3/50

Round 4/50

Round 5/50

Round 6/50

Round 7/50

Round 8/50

Round 9/50

Round 10/50
  평균 MAE Loss: 0.0957

Round 11/50

Round 12/50

Round 13/50

Round 14/50

Round 15/50

Round 16/50

Round 17/50

Round 18/50

Round 19/50

Round 20/50
  평균 MAE Loss: 0.0768

Round 21/50

Round 22/50

Round 23/50

Round 24/50

Round 25/50

Round 26/50

Round 27/50

Round 28/50

Round 29/50

Round 30/50
  평균 MAE Loss: 0.0715

Round 31/50

Round 32/50

Round 33/50

Round 34/50

Round 35/50

Round 36/50

Ro