In [1]:
import pandas as pd
import xgboost as xgb
import pickle
import json
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import LabelEncoder

# -----------------------------
# 1. Load Dataset
# -----------------------------
data = pd.read_csv(r"E:\semm 8\federated_learning\federated_3\data_preprosses\preprocessed_diabetic_data.csv") 

features = [
    'race', 'gender', 'age', 'time_in_hospital', 'num_lab_procedures',
    'num_procedures', 'num_medications', 'number_outpatient',
    'number_emergency', 'number_inpatient', 'diag_1', 'diag_2', 'diag_3',
    'number_diagnoses', 'metformin', 'repaglinide', 'nateglinide',
    'chlorpropamide', 'glimepiride', 'acetohexamide', 'glipizide',
    'glyburide', 'tolbutamide', 'pioglitazone', 'rosiglitazone', 'acarbose',
    'miglitol', 'troglitazone', 'tolazamide', 'examide', 'citoglipton',
    'insulin', 'glyburide-metformin', 'glipizide-metformin',
    'glimepiride-pioglitazone', 'metformin-rosiglitazone',
    'metformin-pioglitazone', 'change', 'diabetesMed', 'A1Cresult', 
    'max_glu_serum', 'weight'
]
target_col = 'readmitted'

data = data[features + [target_col]]

# -----------------------------
# 2. Encoding & Preprocessing
# -----------------------------

gender_map = {'Male': 0, 'Female': 1, 'Unknown/Invalid': 2}
age_map = {
    '[0-10)': 5, '[10-20)': 15, '[20-30)': 25, '[30-40)': 35, 
    '[40-50)': 45, '[50-60)': 55, '[60-70)': 65, '[70-80)': 75, 
    '[80-90)': 85, '[90-100)': 95
}

data['gender'] = data['gender'].map(gender_map).fillna(2)
data['age'] = data['age'].map(age_map).fillna(45)

for col in features:
    if data[col].dtype == 'object':
        le = LabelEncoder()
        data[col] = le.fit_transform(data[col].astype(str))

# -----------------------------
# 3. Multiclass Target Transformation
# -----------------------------
# NO  -> 0
# >30 -> 1
# <30 -> 2
target_map = {"NO": 0, ">30": 1, "<30": 2}
data["Readmitted_Multiclass"] = data[target_col].map(target_map)

X = data[features]
y = data["Readmitted_Multiclass"]

# -----------------------------
# 4. Train Multiclass Model
# -----------------------------
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Training Multiclass Model on {X_train.shape[1]} features...")

# Changes: multi:softprob objective and num_class=3
model = xgb.XGBClassifier(
    objective="multi:softprob",
    num_class=3,
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    eval_metric="mlogloss",      # Multiclass logloss
    tree_method="hist",          # CPU Stability Fix
    device="cpu"                 # CPU Stability Fix
)

model.fit(X_train, y_train)

# Evaluation
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Initial Multiclass Model Accuracy: {accuracy*100:.2f}%")
print("\nDetailed Report:\n", classification_report(y_test, y_pred, target_names=["NO", ">30", "<30"]))

# -----------------------------
# 5. Save for Federated Learning
# -----------------------------
ensemble_data = {
    "models": [model],
    "weights": [1.0]
}

os.makedirs("server", exist_ok=True)

with open("server/global_model.pkl", "wb") as f:
    pickle.dump(ensemble_data, f)

history = {
    "round": 0,
    "accuracy": round(accuracy * 100, 2),
    "logs": [{"round": 0, "clients": 1, "accuracy": round(accuracy * 100, 2)}]
}

with open("server/history.json", "w") as f:
    json.dump(history, f, indent=4)

print("✅ Multiclass global_model.pkl and history.json initialized.")

Training Multiclass Model on 42 features...
Initial Multiclass Model Accuracy: 58.20%

Detailed Report:
               precision    recall  f1-score   support

          NO       0.61      0.85      0.71     10952
         >30       0.50      0.34      0.41      7117
         <30       0.41      0.02      0.03      2285

    accuracy                           0.58     20354
   macro avg       0.51      0.41      0.38     20354
weighted avg       0.55      0.58      0.53     20354

✅ Multiclass global_model.pkl and history.json initialized.
