In [None]:
import os
import pandas as pd
import tensorflow as tf
import pickle
import json
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle, resample


model_storage = "/content/i_models"
data_dir = "/content/preprocessed_data"


iterations = 2
epochs_per_iteration = 2
batch_size = 32
global_model_type = "gru"

global_features_file = os.path.join(data_dir, "global_features.json")
scaler_and_encoders_file = os.path.join(data_dir, "scalers_and_encoders.pkl")

with open(global_features_file, "r") as f:
    global_features = json.load(f)

with open(scaler_and_encoders_file, "rb") as f:
    preprocessors = pickle.load(f)

global_scaler = preprocessors["scaler"]
label_encoders = preprocessors["encoders"]

def build_model(model_type, input_shape):
    if model_type == "lstm":
        return tf.keras.Sequential([
            tf.keras.layers.LSTM(64, input_shape=input_shape, return_sequences=True),
            tf.keras.layers.Dropout(0.1),
            tf.keras.layers.LSTM(32, return_sequences=False),
            tf.keras.layers.Dropout(0.1),
            tf.keras.layers.Dense(1, activation="sigmoid")
        ])
    elif model_type == "gru":
        return tf.keras.Sequential([
            tf.keras.layers.GRU(64, input_shape=input_shape, return_sequences=True),
            tf.keras.layers.Dropout(0.1),
            tf.keras.layers.GRU(32, return_sequences=False),
            tf.keras.layers.Dropout(0.1),
            tf.keras.layers.Dense(1, activation="sigmoid")
        ])
    else:
        raise ValueError("Unsupported model type")

def augment_data(X, y):
    noise = 0.01 * np.random.normal(size=X.shape)
    X_augmented = X + noise
    return np.vstack([X, X_augmented]), np.hstack([y, y])

def resample_data(X, y):
    X_reshaped = X.reshape(X.shape[0], -1)
    data = pd.concat([pd.DataFrame(X_reshaped), pd.Series(y, name="label")], axis=1)

    minority = data[data["label"] == 1]
    majority = data[data["label"] == 0]
    minority_upsampled = resample(minority, replace=True, n_samples=len(majority), random_state=42)
    data_resampled = pd.concat([majority, minority_upsampled])
    X_resampled = data_resampled.iloc[:, :-1].values
    y_resampled = data_resampled["label"].values
    return X_resampled, y_resampled

def lr_schedule(epoch, lr):
    if epoch < 2:
        return lr
    return lr * 0.8

def train_local_model(local_model, X_train, y_train, X_val, y_val, model_type):
    learning_rate = 0.001
    local_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                        loss="binary_crossentropy",
                        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")])

    lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule, verbose=1)
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=3, restore_best_weights=True, verbose=1
    )

    local_model.fit(X_train, y_train,
                    epochs=epochs_per_iteration,
                    batch_size=batch_size,
                    validation_data=(X_val, y_val),
                    callbacks=[lr_scheduler, early_stopping],
                    verbose=1)

    return local_model

def federated_averaging(models, sample_counts):
    new_weights = []
    for weights_list_tuple in zip(*[model.get_weights() for model in models]):
        new_weights.append(np.average(weights_list_tuple, axis=0, weights=sample_counts))
    return new_weights


def federated_training(data_dir):

    input_shape = (len(global_features), 1)
    global_model = build_model(global_model_type, input_shape)


    global_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")],
    )


    local_data = []
    local_models = []
    for filename in os.listdir(data_dir):
        if filename.startswith("preprocessed_") and filename.endswith(".csv"):
            data_path = os.path.join(data_dir, filename)


            data = pd.read_csv(data_path)


            for col in global_features:
                if col not in data.columns:
                    data[col] = 0.0

            X = data[global_features]
            y = data["label"]

            X_scaled = global_scaler.transform(X)

            X_scaled = X_scaled.reshape((X_scaled.shape[0], X_scaled.shape[1], 1))

            X_shuffled, y_shuffled = shuffle(X_scaled, y, random_state=42)
            X_train, X_val, y_train, y_val = train_test_split(X_shuffled, y_shuffled, test_size=0.2, random_state=42)

            local_data.append((X_train, X_val, y_train, y_val))

            local_model = build_model(global_model_type, X_train.shape[1:])
            local_models.append(local_model)

    num_rounds = 2
    global_results = []
    for round_num in range(num_rounds):
        print(f"\n--- Federated Learning Round {round_num + 1}/{num_rounds} ---")
        local_weights = []
        sample_counts = []

        for client, (X_train, X_val, y_train, y_val) in enumerate(local_data):
            print(f"\nTraining for client {client + 1}/{len(local_data)}")

            X_train_aug, y_train_aug = augment_data(X_train, y_train)
            X_train_res, y_train_res = resample_data(X_train_aug, y_train_aug)

            X_train_res = X_train_res.reshape(X_train_res.shape[0], X_train_res.shape[1], 1)

            local_model = local_models[client]
            local_model = train_local_model(local_model, X_train_res, y_train_res, X_val, y_val, global_model_type)

            local_weights.append(local_model.get_weights())
            sample_counts.append(len(y_train_res))

        aggregated_weights = federated_averaging(local_models, sample_counts)
        global_model.set_weights(aggregated_weights)

        X_val_all = np.vstack([X_val for _, X_val, _, _ in local_data])
        y_val_all = np.hstack([y_val for _, X_val, _, y_val in local_data])

        global_loss, global_accuracy, global_auc = global_model.evaluate(X_val_all, y_val_all, verbose=0)
        print(f"Global Validation - Round {round_num + 1}: Loss = {global_loss:.4f}, "
              f"Accuracy = {global_accuracy:.4f}, AUC = {global_auc:.4f}")
        global_results.append((global_loss, global_accuracy, global_auc))

    print("\nFederated Learning Training Completed!")
    print("\n--- Global Model Results Over Rounds ---")
    for round_num, (loss, accuracy, auc) in enumerate(global_results, start=1):
        print(f"Round {round_num}: Loss = {loss:.4f}, Accuracy = {accuracy:.4f}, AUC = {auc:.4f}")

    return global_results




if __name__ == "__main__":
    federated_training(data_dir)