# **INSTALL AND IMPORT LIBRARIES**

In [19]:
!pip install -q tensorflow scikit-learn matplotlib
!pip install cryptography
!pip install -q flwr
!pip install -q ray



In [20]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.applications import MobileNetV2
from sklearn.model_selection import train_test_split
import flwr as fl

# **PREPROCESSING THE DATA**

In [21]:
# Load dataset
df = pd.read_csv("hmnist_64_64_L.csv")
print("Dataset shape:", df.shape)
print("Missing labels:", df['label'].isnull().sum())
print("Unique labels:", df['label'].unique())

# Preprocessing
labels = df['label'].astype(int).values - 1  # Convert 1–8 to 0–7
images = df.drop('label', axis=1).values
images = images.reshape(-1, 64, 64, 1).astype('float32') / 255.0  # Normalize

# Convert grayscale to RGB (MobileNetV2 expects 3 channels)
images_rgb = np.repeat(images, 3, axis=-1)  # Shape: (N, 64, 64, 3)


Dataset shape: (5000, 4097)
Missing labels: 0
Unique labels: [2 5 7 6 8 1 4 3]


In [22]:
# Distribute data across 3 clients
client_data = {0: {"x": [], "y": []}, 1: {"x": [], "y": []}, 2: {"x": [], "y": []}}

for class_label in np.unique(labels):
    indices = np.where(labels == class_label)[0]
    np.random.shuffle(indices)
    split_1 = int(len(indices) * 1/3)
    split_2 = int(len(indices) * 2/3)

    client_data[0]["x"].extend(images_rgb[indices[:split_1]])
    client_data[0]["y"].extend(labels[indices[:split_1]])
    client_data[1]["x"].extend(images_rgb[indices[split_1:split_2]])
    client_data[1]["y"].extend(labels[indices[split_1:split_2]])
    client_data[2]["x"].extend(images_rgb[indices[split_2:]])
    client_data[2]["y"].extend(labels[indices[split_2:]])

for i in range(3):
    client_data[i]["x"] = np.array(client_data[i]["x"])
    client_data[i]["y"] = np.array(client_data[i]["y"])
    print(f"Client {i} samples: {len(client_data[i]['x'])}")

Client 0 samples: 1664
Client 1 samples: 1664
Client 2 samples: 1672


In [29]:
# Focal Loss Function
import tensorflow.keras.backend as K
def focal_loss(gamma=2., alpha=0.25):
    def focal_loss_fixed(y_true, y_pred):
        y_true = tf.cast(y_true, tf.int32)
        y_true = tf.one_hot(y_true, depth=8)
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
        cross_entropy = -y_true * K.log(y_pred)
        weight = alpha * K.pow(1 - y_pred, gamma)
        loss = weight * cross_entropy
        return K.sum(loss, axis=1)
    return focal_loss_fixed

# **BUILDING THE MODEL - MOBILENETV2**

In [37]:
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.applications import MobileNetV2

# Momentum-based SGD optimizer
def create_mobilenet_model():
    base_model = MobileNetV2(
        input_shape=(64, 64, 3),
        include_top=False,
        weights='imagenet',
        pooling='avg'
    )
    base_model.trainable = False  # Freeze base layers

    optimizer = SGD(learning_rate=0.01, momentum=0.9, nesterov=True) #using momentum to avoid overfitting

    model = Sequential([
        base_model,
        Dense(64, activation='relu'),   # Reduced from 128 → 64
        Dropout(0.5),                   # Increased dropout for regularization
        Dense(8, activation='softmax')  # CH-MNIST has 8 classes
    ])

    model.compile(
        optimizer=optimizer,
        loss=focal_loss(gamma=2., alpha=0.25), # using focal loss right here
        metrics=['accuracy']
    )
    return model


# **TESTING THE ACCURACY BEFORE FEDERATED LEARNING**

In [38]:
# Test locally on 1 client
model = create_mobilenet_model()
history = model.fit(
    client_data[0]["x"], client_data[0]["y"],
    epochs=10,
    batch_size=16,
    validation_split=0.1
)


  base_model = MobileNetV2(


Epoch 1/10
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 91ms/step - accuracy: 0.3042 - loss: 0.4326 - val_accuracy: 0.8982 - val_loss: 0.1171
Epoch 2/10
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 53ms/step - accuracy: 0.5404 - loss: 0.2020 - val_accuracy: 0.9222 - val_loss: 0.0596
Epoch 3/10
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 70ms/step - accuracy: 0.6329 - loss: 0.1649 - val_accuracy: 0.9401 - val_loss: 0.0362
Epoch 4/10
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 52ms/step - accuracy: 0.6461 - loss: 0.1503 - val_accuracy: 0.9401 - val_loss: 0.0432
Epoch 5/10
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 71ms/step - accuracy: 0.6574 - loss: 0.1296 - val_accuracy: 0.9581 - val_loss: 0.0299
Epoch 6/10
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 54ms/step - accuracy: 0.7221 - loss: 0.1092 - val_accuracy: 0.9222 - val_loss: 0.0452
Epoch 7/10
[1m94/94[0m [32m━━━

# **FEDERATED LEARNING PLAN**

In [39]:
# Federated Client Class
class HistologyClient(fl.client.NumPyClient):
    def __init__(self, x_train, y_train, x_test, y_test):
        self.model = create_mobilenet_model()
        self.x_train = x_train
        self.y_train = y_train
        self.x_test = x_test
        self.y_test = y_test

    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        early_stop = tf.keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True, monitor="val_loss")
        reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(patience=2, factor=0.5, min_lr=1e-6)

        self.model.fit(
            self.x_train, self.y_train,
            validation_data=(self.x_test, self.y_test),
            epochs=10,
            batch_size=16,
            verbose=0,
            callbacks=[early_stop, reduce_lr]
        )
        return self.model.get_weights(), len(self.x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
        return loss, len(self.x_test), {"accuracy": accuracy}

# --- Train/Test Split ---
client_train_test = {}
for i in range(3):
    x = client_data[i]["x"]
    y = client_data[i]["y"]
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, stratify=y, random_state=42
    )
    client_train_test[i] = (x_train, y_train, x_test, y_test)

In [40]:
# Global Evaluation
def get_evaluate_fn():
    all_x = np.concatenate([client_train_test[i][2] for i in range(3)])
    all_y = np.concatenate([client_train_test[i][3] for i in range(3)])
    def evaluate(server_round, parameters, config):
        model = create_mobilenet_model()
        model.set_weights(parameters)
        loss, accuracy = model.evaluate(all_x, all_y, verbose=0)
        print(f"[Round {server_round}] Global Accuracy: {accuracy:.4f}")
        return loss, {"accuracy": accuracy}
    return evaluate

# Client Function Factory
def client_fn(cid: str):
    cid = int(cid)
    x_train, y_train, x_test, y_test = client_train_test[cid]
    return HistologyClient(x_train, y_train, x_test, y_test)

# Start Simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=3,
    config=fl.server.ServerConfig(num_rounds=10),
    strategy=fl.server.strategy.FedAvg(evaluate_fn=get_evaluate_fn()),
)

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2025-04-22 21:42:56,070	INFO worker.py:1852 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 2.0, 'memory': 9238184346.0, 'node:172.28.0.12': 1.0, 'object_store_memory': 3959221862.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower V

[Round 0] Global Accuracy: 0.1429


[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m 2025-04-22 21:43:22.043148: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40925)[0m             entirely in future vers

[Round 1] Global Accuracy: 0.7413


[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=4

[Round 2] Global Accuracy: 0.7602


[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=4

[Round 3] Global Accuracy: 0.7582


[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=4

[Round 4] Global Accuracy: 0.7762


[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=4

[Round 5] Global Accuracy: 0.7712


[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=4

[Round 6] Global Accuracy: 0.7702


[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=4

[Round 7] Global Accuracy: 0.7742


[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=4

[Round 8] Global Accuracy: 0.7762


[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=4

[Round 9] Global Accuracy: 0.7742


[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=

[Round 10] Global Accuracy: 0.7832


[36m(ClientAppActor pid=40926)[0m 
[36m(ClientAppActor pid=40926)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=40926)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=40926)[0m         
[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40925)[0m 
[36m(ClientAppActor pid=40925)[0m         
[36m(ClientAppActor pid=40925)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=40925)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 1162.44s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.10993292352387479
[92mINFO [0m:      

History (loss, distributed):
	round 1: 0.10993292352387479
	round 2: 0.08998673939174943
	round 3: 0.08946926496573142
	round 4: 0.08464046662950611
	round 5: 0.08405106542231915
	round 6: 0.08273015281447878
	round 7: 0.08185380185906822
	round 8: 0.0810640930504232
	round 9: 0.07986407390573284
	round 10: 0.08001448740283926
History (loss, centralized):
	round 0: 0.6864148378372192
	round 1: 0.10993294417858124
	round 2: 0.08998674899339676
	round 3: 0.08946926146745682
	round 4: 0.08464048057794571
	round 5: 0.08405108004808426
	round 6: 0.08273015916347504
	round 7: 0.08185380697250366
	round 8: 0.08106409013271332
	round 9: 0.07986409962177277
	round 10: 0.08001448959112167
History (metrics, centralized):
{'accuracy': [(0, 0.1428571492433548),
              (1, 0.7412587404251099),
              (2, 0.7602397799491882),
              (3, 0.7582417726516724),
              (4, 0.7762237787246704),
              (5, 0.7712287902832031),
              (6, 0.7702297568321228),
       