<p align="center">
  <b>Divide and Learn: The Future of Distributed AI</b>
</p>



# 🩻 Federated Learning with Flower (FL) on Mammography Images


```
MyDrive/mammo_data/
 ├── Benign/
 │    ├── image_001.png
 │    └── ...
 └── Malignant/
      ├── image_101.png
      └── ...
```



# 1) Install dependencies



In [1]:
#!pip install flwr
#!pip install "flwr>=1.7,<2.0"
# Instala versiones compatibles para Colab
!pip -q install "flwr>=1.7,<2.0" "ray[default]>=2.9,<3.0"

import ray
# Inicialización básica; baja recursos para Colab si quieres
ray.shutdown()
ray.init(ignore_reinit_error=True, include_dashboard=False, num_cpus=2)
print("Ray OK:", ray.is_initialized())

!pip install -U "flwr[simulation]"


2025-09-30 11:30:11,806	INFO worker.py:1771 -- Started a local Ray instance.


Ray OK: True


## Cell 2 — Configuration and dataset utilities


In [2]:
import os, random, shutil, math
from pathlib import Path
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
from PIL import Image, ImageOps, ImageFilter

import kagglehub


In [3]:
# ==== MAIN CONFIGURATION ====
USE_GOOGLE_DRIVE = False          # Change to False if you don’t want to mount Drive
ROOT_RELATIVE = "mammo_data"     # Root dataset folder (inside Drive or local)
NUM_CLIENTS = 3                  # Number of federated clients
IMG_SIZE = (256, 256)            # Resize images to this size
BATCH_SIZE = 16
EPOCHS_LOCAL = 1                 # Local epochs per round
ROUNDS = 3                       # Federated rounds
VAL_SPLIT = 0.2                  # Validation split ratio per client
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


# Download latest version
path = kagglehub.dataset_download("awsaf49/cbis-ddsm-breast-cancer-image-dataset")

print("Path to dataset files:", path)


if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    DATA_ROOT = Path("/content/drive/MyDrive") / ROOT_RELATIVE
else:
    DATA_ROOT = Path("/content") / ROOT_RELATIVE


print("DATA_ROOT:", DATA_ROOT)

Using Colab cache for faster access to the 'cbis-ddsm-breast-cancer-image-dataset' dataset.
Path to dataset files: /kaggle/input/cbis-ddsm-breast-cancer-image-dataset
DATA_ROOT: /content/mammo_data


In [4]:

# ==== TRANSFORMS ====
transform_train = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(IMG_SIZE),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_val = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# ==== DEMO DATASET CREATOR ====
def create_demo_mammo_dataset(root: Path, n_per_class=60):
    """
    Create a minimal synthetic dataset with 2 classes: Benign / Malignant.
    Only for testing Flower end-to-end without big downloads.
    """
    classes = ["Benign", "Malignant"]
    for cls in classes:
        (root/cls).mkdir(parents=True, exist_ok=True)

    W, H = IMG_SIZE
    for cls in classes:
        for i in range(n_per_class):
            img = Image.new("L", (W, H), color=0)
            if cls == "Benign":
                # Smooth circular patterns
                for _ in range(6):
                    cx, cy = random.randint(0,W-1), random.randint(0,H-1)
                    r = random.randint(10, 35)
                    for y in range(max(0,cy-r), min(H, cy+r)):
                        for x in range(max(0,cx-r), min(W, cx+r)):
                            if (x-cx)**2 + (y-cy)**2 <= r*r:
                                img.putpixel((x,y), min(255, img.getpixel((x,y)) + random.randint(15,25)))
                img = img.filter(ImageFilter.GaussianBlur(1.5))
            else:
                # Hard edges / bright masses
                for _ in range(5):
                    x0, y0 = random.randint(0,W-30), random.randint(0,H-30)
                    w, h = random.randint(15,45), random.randint(15,45)
                    for y in range(y0, min(H, y0+h)):
                        for x in range(x0, min(W, x0+w)):
                            img.putpixel((x,y), min(255, img.getpixel((x,y)) + random.randint(25,35)))
                img = ImageOps.autocontrast(img)
            img.save(root/cls/f"{cls}_{i:03d}.png")

# Create demo data if missing
if not DATA_ROOT.exists() or not any((DATA_ROOT/"Benign").glob("*")) or not any((DATA_ROOT/"Malignant").glob("*")):
    print("No real dataset found → creating synthetic DEMO at:", DATA_ROOT)
    create_demo_mammo_dataset(DATA_ROOT, n_per_class=60)
else:
    print("✅ Found dataset folders:", [p.name for p in DATA_ROOT.iterdir()])

✅ Found dataset folders: ['Malignant', 'Benign']


## Cell 3 — Stratified partition per client and DataLoaders


In [5]:
# Load the whole dataset as ImageFolder
full_dataset = datasets.ImageFolder(str(DATA_ROOT), transform=transform_train)
class_to_idx = full_dataset.class_to_idx
print("Classes:", class_to_idx, "Total images:", len(full_dataset))

# Indices per class
from collections import defaultdict
indices_by_class = defaultdict(list)
for idx, (_, target) in enumerate(full_dataset.samples):
    indices_by_class[target].append(idx)

# Stratified partition across NUM_CLIENTS
def stratified_partition(indices_by_class, num_clients):
    parts = [[] for _ in range(num_clients)]
    for c, idxs in indices_by_class.items():
        random.shuffle(idxs)
        chunk = math.ceil(len(idxs)/num_clients)
        for i in range(num_clients):
            parts[i].extend(idxs[i*chunk:(i+1)*chunk])
    return parts

client_indices = stratified_partition(indices_by_class, NUM_CLIENTS)

# Build loaders per client (train/val)
def make_client_loaders(indices, val_split=VAL_SPLIT, batch_size=BATCH_SIZE):
    subset = Subset(full_dataset, indices)
    n = len(subset)
    n_val = int(n*val_split)
    n_train = n - n_val
    train_subset, val_subset = random_split(
        subset, [n_train, n_val],
        generator=torch.Generator().manual_seed(SEED)
    )
    # Different transforms for train/val
    train_subset.dataset.transform = transform_train
    val_subset.dataset.transform = transform_val

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader

client_loaders = [make_client_loaders(idxs) for idxs in client_indices]
for i,(tr,va) in enumerate(client_loaders):
    print(f"Client {i}: train={len(tr.dataset)} | val={len(va.dataset)}")


Classes: {'Benign': 0, 'Malignant': 1} Total images: 120
Client 0: train=32 | val=8
Client 1: train=32 | val=8
Client 2: train=32 | val=8


## Cell 4 — CNN model



In [6]:
import torch.nn as nn
import torch.nn.functional as F

class MammoCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # 1-channel input (grayscale)
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(32)
        self.pool  = nn.MaxPool2d(2,2)
        self.drop  = nn.Dropout(0.3)
        # IMG_SIZE 256x256 → after 2 pools → 64x64 with 32 channels
        self.fc1   = nn.Linear(32*64*64, 128)
        self.fc2   = nn.Linear(128, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(x.size(0), -1)
        x = self.drop(F.relu(self.fc1(x)))
        return self.fc2(x)


## Cell 5 — Flower client and train/eval functions


In [7]:
import flwr as fl
import torch
import torch.optim as optim
from copy import deepcopy

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

def train_one_epoch(model, loader, optimizer, loss_fn):
    model.train()
    total_loss = 0.0
    for images, labels in loader:
        images, labels = images.to(DEVICE, non_blocking=True), labels.to(DEVICE, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss/len(loader) if len(loader) > 0 else 0.0

@torch.no_grad()
def evaluate(model, loader, loss_fn):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(DEVICE, non_blocking=True), labels.to(DEVICE, non_blocking=True)
        logits = model(images)
        loss = loss_fn(logits, labels)
        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
    acc = correct/total if total > 0 else 0.0
    return total_loss/len(loader) if len(loader) > 0 else 0.0, acc

# Registry of loaders for each client id
CLIENT_REGISTRY = {}
for cid,(tr,va) in enumerate(client_loaders):
    CLIENT_REGISTRY[str(cid)] = {"train": tr, "val": va}

# Flower client implementation
class MammoClient(fl.client.NumPyClient):
    def __init__(self, cid: str):
        self.cid = cid
        self.model = MammoCNN().to(DEVICE)
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        self.train_loader = CLIENT_REGISTRY[cid]["train"]
        self.val_loader   = CLIENT_REGISTRY[cid]["val"]

    def get_parameters(self, config):
        return [t.detach().cpu().numpy() for _, t in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        state_dict = deepcopy(self.model.state_dict())
        for k, name in enumerate(state_dict.keys()):
            state_dict[name] = torch.tensor(parameters[k])
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        last_loss = 0.0
        for _ in range(EPOCHS_LOCAL):
            last_loss = train_one_epoch(self.model, self.train_loader, self.optimizer, self.loss_fn)
        return self.get_parameters({}), len(self.train_loader.dataset), {"train_loss": float(last_loss)}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        val_loss, val_acc = evaluate(self.model, self.val_loader, self.loss_fn)
        # You can report multiple metrics; we’ll aggregate them
        return float(val_loss), len(self.val_loader.dataset), {"val_accuracy": float(val_acc), "val_loss": float(val_loss)}

def client_fn(cid: str):
    # cid: "0" .. "NUM_CLIENTS-1"
    return MammoClient(cid)


DEVICE: cpu


  return datetime.utcnow().replace(tzinfo=utc)


## Cell 6 — Initialize Ray, FedAvg strategy (with aggregation), and run simulation


In [9]:
import ray
import flwr as fl
import torch

# Init Ray for Flower Simulation (Colab friendly)
ray.shutdown()
ray.init(ignore_reinit_error=True, include_dashboard=False, num_cpus=2)
print("Ray OK:", ray.is_initialized())

# Weighted average aggregator for client metrics
def weighted_average(metrics: list[tuple[int, dict[str, float]]]) -> dict[str, float]:
    # metrics: [(num_examples, {"val_accuracy": ..., "val_loss": ...}), ...]
    if not metrics:
        return {}
    total_examples = sum(n for n, _ in metrics)
    keys = set().union(*(m.keys() for _, m in metrics))
    out: dict[str, float] = {}
    for k in keys:
        out[k] = sum(m.get(k, 0.0) * n for n, m in metrics) / max(total_examples, 1)
    return out

# Define a single, modern FedAvg strategy (NO legacy EvaluateMetricsAggregator)
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    evaluate_metrics_aggregation_fn=weighted_average,  # ✅ modern hook
)

# Run Flower simulation
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=ROUNDS),
    strategy=strategy,
    client_resources={
        "num_cpus": 1,
        "num_gpus": 1.0 if torch.cuda.is_available() else 0.0
    },
)

print("Done. Check `history` for results.")


2025-09-30 11:31:04,103	INFO worker.py:1771 -- Started a local Ray instance.
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=3, no round_timeout


Ray OK: True


  return datetime.utcnow().replace(tzinfo=utc)
2025-09-30 11:31:11,918	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 8.0, 'node:__internal_head__': 1.0, 'node:172.28.0.12': 1.0, 'object_store_memory': 16319712460.0, 'memory': 32639424923.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 8 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=11863)[0m 2025-09-30 11:31:19.585756: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=11863)[0m E0000 00:00:175

Done. Check `history` for results.


  return datetime.utcnow().replace(tzinfo=utc)


## Cell 7 — Summarize metrics per round

In [10]:
try:
    # Depending on Flower version, these fields may vary
    if hasattr(history, "metrics_distributed"):
        print("Distributed metrics:", history.metrics_distributed)
    if hasattr(history, "metrics_centralized"):
        print("Centralized metrics:", history.metrics_centralized)
    if hasattr(history, "losses_distributed"):
        print("Distributed losses:", history.losses_distributed)
except Exception as e:
    print("Could not summarize history:", repr(e))


Distributed metrics: {'val_loss': [(1, 1.4473384221394856), (2, 5.1335428555806475), (3, 1.0319629907608032)], 'val_accuracy': [(1, 0.75), (2, 0.25), (3, 0.75)]}
Centralized metrics: {}
Distributed losses: [(1, 1.4473384221394856), (2, 5.1335428555806475), (3, 1.0319629907608032)]
