In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
# Federated Learning Implementation for Fake Image Detection
# Based on: "Fake Image Detection Using Deep Learning"

import os
import numpy as np
import pickle
import datetime
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K
import gc

# -------------------------------------------------------------
# CONFIG
# -------------------------------------------------------------
IMG_SIZE = (224, 224)
BATCH_SIZE = 128
LOCAL_EPOCHS = 3          # Epochs per client per round
FEDERATED_ROUNDS = 10     # Number of federated rounds
NUM_CLIENTS = 5           # Number of federated clients

BASE_PATH = "/kaggle/input/140k-real-and-fake-faces/real_vs_fake/real-vs-fake"
TRAIN_PATH = f"{BASE_PATH}/train"
VAL_PATH = f"{BASE_PATH}/valid"
TEST_PATH = f"{BASE_PATH}/test"

AUTOTUNE = tf.data.AUTOTUNE

# -------------------------------------------------------------
# GPU Setup
# -------------------------------------------------------------
def setup_gpu():
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            print(f"✓ GPU enabled: {len(gpus)} GPU(s)")
        except RuntimeError as e:
            print(f"GPU error: {e}")
    else:
        print("⚠ No GPU - using CPU")

# -------------------------------------------------------------
# Create Model (same architecture as paper)
# -------------------------------------------------------------
def create_model():
    """Create EfficientNetB0 model with custom top layers"""
    base_model = EfficientNetB0(
        weights="imagenet",
        include_top=False,
        input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)
    )
    base_model.trainable = False

    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(256, activation="relu"),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(2, activation="softmax", dtype='float32')
    ])

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss="categorical_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")]
    )

    return model

# -------------------------------------------------------------
# Load Full Dataset
# -------------------------------------------------------------
def load_full_dataset():
    """Load the complete training, validation, and test datasets"""
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        TRAIN_PATH,
        image_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        label_mode="categorical",
        shuffle=True,
        seed=42
    )

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        VAL_PATH,
        image_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        label_mode="categorical"
    )

    test_ds = tf.keras.preprocessing.image_dataset_from_directory(
        TEST_PATH,
        image_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        shuffle=False,
        label_mode="categorical"
    )

    # Preprocessing
    train_ds = train_ds.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
    val_ds = val_ds.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
    test_ds = test_ds.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)

    # Optimize
    train_ds = train_ds.prefetch(AUTOTUNE)
    val_ds = val_ds.prefetch(AUTOTUNE)
    test_ds = test_ds.prefetch(AUTOTUNE)

    return train_ds, val_ds, test_ds

# -------------------------------------------------------------
# Partition Data for Federated Clients (shard-based, no full materialization)
# -------------------------------------------------------------
def partition_data_for_clients(train_ds, num_clients):
    """
    Split training data among clients using tf.data.shard (no full RAM load)
    """
    # Estimate total samples in a streamed way (one pass)
    total_samples = 0
    for batch_x, _ in train_ds:
        total_samples += batch_x.shape[0]
    print(f"Total training samples: {total_samples}")

    client_datasets = []
    client_sample_counts = []

    # Create a sharded view for each client
    for client_id in range(num_clients):
        client_ds = train_ds.shard(num_shards=num_clients, index=client_id)
        # Approximate per-client sample count
        client_samples = total_samples // num_clients
        client_datasets.append(client_ds.prefetch(AUTOTUNE))
        client_sample_counts.append(client_samples)
        print(f"  Client {client_id + 1}: ~{client_samples} samples")

    return client_datasets, client_sample_counts

# -------------------------------------------------------------
# Federated Learning: Weight Aggregation
# -------------------------------------------------------------
def aggregate_weights(client_weights_list, client_sample_counts):
    """
    Federated Averaging (FedAvg) - WEIGHTED average based on client data size
    """
    total_samples = sum(client_sample_counts)

    # Initialize with zeros - use float32 for ALL weight arrays
    first_weights = client_weights_list[0]
    avg_weights = [np.zeros_like(w, dtype=np.float32) for w in first_weights]

    # Weighted sum of client weights
    for client_weights, num_samples in zip(client_weights_list, client_sample_counts):
        client_weight = num_samples / total_samples
        
        for i, w in enumerate(client_weights):
            # Ensure everything is float32
            avg_weights[i] += np.float32(client_weight) * w.astype(np.float32)

    return avg_weights
# -------------------------------------------------------------
# Client Training Function (with memory cleanup)
# -------------------------------------------------------------
def train_client(client_id, client_dataset, global_weights, local_epochs):
    """
    Train a single client on their local data
    """
    print(f"  Training Client {client_id}...")

    # Create fresh model with global weights
    K.clear_session()  # clear any previous model graph
    model = create_model()
    model.set_weights(global_weights)

    # Train on local data
    history = model.fit(
        client_dataset,
        epochs=local_epochs,
        verbose=0
    )

    # Get updated weights (this is what gets sent to server)
    updated_weights = model.get_weights()

    # Get training metrics
    final_loss = history.history['loss'][-1]
    final_acc = history.history['accuracy'][-1]

    print(f"    Client {client_id} - Loss: {final_loss:.4f}, Acc: {final_acc:.4f}")

    # Explicit cleanup to reduce memory usage on Kaggle
    del history
    gc.collect()
    K.clear_session()

    return updated_weights, final_loss, final_acc

# -------------------------------------------------------------
# Federated Learning Main Loop
# -------------------------------------------------------------
def federated_learning(client_datasets, client_sample_counts, val_ds, test_ds):
    """
    Main federated learning loop
    """
    print("\n" + "="*60)
    print("FEDERATED LEARNING - FAKE IMAGE DETECTION")
    print("="*60)
    print(f"Clients: {NUM_CLIENTS}")
    print(f"Federated Rounds: {FEDERATED_ROUNDS}")
    print(f"Local Epochs per Round: {LOCAL_EPOCHS}")
    print(f"\nClient Data Distribution:")
    total = sum(client_sample_counts)
    for i, count in enumerate(client_sample_counts):
        percentage = (count / total) * 100
        print(f"  Client {i+1}: {count} samples (~{percentage:.1f}%)")
    print("="*60 + "\n")

    # Initialize global model
    global_model = create_model()
    global_weights = global_model.get_weights()

    # Track metrics
    history = {
        'round': [],
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'val_auc': []
    }

    # Federated training rounds
    for round_num in range(1, FEDERATED_ROUNDS + 1):
        print(f"\n{'='*60}")
        print(f"FEDERATED ROUND {round_num}/{FEDERATED_ROUNDS}")
        print(f"{'='*60}")

        # Store weights from all clients
        client_weights_list = []
        round_losses = []
        round_accs = []

        # Train each client
        for client_id in range(1, NUM_CLIENTS + 1):
            client_weights, loss, acc = train_client(
                client_id=client_id,
                client_dataset=client_datasets[client_id - 1],
                global_weights=global_weights,
                local_epochs=LOCAL_EPOCHS
            )
            client_weights_list.append(client_weights)
            round_losses.append(loss)
            round_accs.append(acc)

        # Aggregate weights (WEIGHTED Federated Averaging)
        print(f"\n  Aggregating weights (weighted by dataset size)...")
        global_weights = aggregate_weights(client_weights_list, client_sample_counts)

        # Update global model
        global_model.set_weights(global_weights)

        # Evaluate on validation set
        print(f"  Evaluating global model on validation set...")
        val_results = global_model.evaluate(val_ds, verbose=0)
        val_loss, val_acc, val_auc = val_results

        # Average client metrics
        avg_train_loss = np.mean(round_losses)
        avg_train_acc = np.mean(round_accs)

        # Store metrics
        history['round'].append(round_num)
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(avg_train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_auc'].append(val_auc)

        print(f"\n  Round {round_num} Summary:")
        print(f"    Avg Client Train Loss: {avg_train_loss:.4f}")
        print(f"    Avg Client Train Acc:  {avg_train_acc:.4f}")
        print(f"    Global Val Loss:       {val_loss:.4f}")
        print(f"    Global Val Acc:        {val_acc:.4f}")
        print(f"    Global Val AUC:        {val_auc:.4f}")

        # Save checkpoint
        if round_num % 2 == 0:  # Save every 2 rounds
            global_model.save(f"federated_model_round_{round_num}.keras")
            print(f"    ✓ Checkpoint saved")

    return global_model, history

# -------------------------------------------------------------
# Main Function
# -------------------------------------------------------------
def main():
    setup_gpu()

    print("\n=== Loading Dataset ===")
    train_ds, val_ds, test_ds = load_full_dataset()

    print("\n=== Partitioning Data for Federated Clients ===")
    client_datasets, client_sample_counts = partition_data_for_clients(train_ds, NUM_CLIENTS)

    print("\n=== Starting Federated Learning ===")
    global_model, history = federated_learning(client_datasets, client_sample_counts, val_ds, test_ds)

    print("\n" + "="*60)
    print("FINAL EVALUATION ON TEST SET")
    print("="*60)
    test_results = global_model.evaluate(test_ds, verbose=1)
    test_loss, test_acc, test_auc = test_results

    print(f"\nFinal Test Results:")
    print(f"  Test Loss:     {test_loss:.4f}")
    print(f"  Test Accuracy: {test_acc*100:.2f}%")
    print(f"  Test AUC:      {test_auc:.4f}")

    # Save final model and history
    global_model.save("federated_final_model.h5")
    with open("federated_history.pkl", "wb") as f:
        pickle.dump(history, f)

    print("\n✓ Training Complete!")
    print("✓ Model saved as 'federated_final_model.h5'")

if __name__ == "__main__":
    main()


✓ GPU enabled: 1 GPU(s)

=== Loading Dataset ===
Found 100000 files belonging to 2 classes.
Found 20000 files belonging to 2 classes.
Found 20000 files belonging to 2 classes.

=== Partitioning Data for Federated Clients ===
Total training samples: 100000
  Client 1: ~20000 samples
  Client 2: ~20000 samples
  Client 3: ~20000 samples
  Client 4: ~20000 samples
  Client 5: ~20000 samples

=== Starting Federated Learning ===

FEDERATED LEARNING - FAKE IMAGE DETECTION
Clients: 5
Federated Rounds: 10
Local Epochs per Round: 3

Client Data Distribution:
  Client 1: 20000 samples (~20.0%)
  Client 2: 20000 samples (~20.0%)
  Client 3: 20000 samples (~20.0%)
  Client 4: 20000 samples (~20.0%)
  Client 5: 20000 samples (~20.0%)


FEDERATED ROUND 1/10
  Training Client 1...
    Client 1 - Loss: 0.3790, Acc: 0.8286
  Training Client 2...
    Client 2 - Loss: 0.3789, Acc: 0.8306
  Training Client 3...
    Client 3 - Loss: 0.3769, Acc: 0.8303
  Training Client 4...
    Client 4 - Loss: 0.3796, Ac




Final Test Results:
  Test Loss:     0.2287
  Test Accuracy: 90.32%
  Test AUC:      0.9686

✓ Training Complete!
✓ Model saved as 'federated_final_model.h5'
