In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!mkdir -p /content/drive/MyDrive/EECE608

In [None]:
train_data = "/content/drive/MyDrive/EECE608/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"
valid_data = "/content/drive/MyDrive/EECE608/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/"

In [None]:
! pip install opacus



In [None]:
! pip install nilearn
! pip install keras_unet_collection
!pip install torch torchvision
!pip install -q monai



In [None]:
!pip uninstall -y numpy
!pip install numpy==1.24.4


Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Successfully uninstalled numpy-1.26.4
Collecting numpy==1.24.4
  Downloading numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Downloading numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m99.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jaxlib 0.5.1 requires numpy>=1.25, but you have numpy 1.24.4 which is incompatible.
jax 0.5.2 requires numpy>=1.25, but you have numpy 1.24.4 which is incompatible.
tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 1.24.4 which is incompatible.
pymc 5.21.1 requires numpy>=1.25.0, but you have numpy 1.24.4 w

In [None]:
import os, cv2, gc, copy
import nibabel as nib
import numpy as np
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from opacus import PrivacyEngine

In [None]:
# ε = 8 (epsilon)
IMG_SIZE = 64
VOLUME_SLICES = 3
VOLUME_START_AT = 22
BATCH_SIZE = 1
EPOCHS = 1
ROUNDS = 1
LR = 5e-4
NOISE_MULTIPLIER = 1.3
MAX_GRAD_NORM = 1.0

train_data = "/content/drive/MyDrive/EECE608/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BraTSDataset(Dataset):
    def __init__(self, list_ids, modality):
        self.list_ids = list_ids
        self.modality = modality

    def __len__(self):
        return len(self.list_ids) * VOLUME_SLICES

    def __getitem__(self, idx):
        case_idx = idx // VOLUME_SLICES
        slice_idx = VOLUME_START_AT + (idx % VOLUME_SLICES)
        case_id = self.list_ids[case_idx]
        case_path = os.path.join(train_data, case_id)

        img = np.asarray(nib.load(os.path.join(case_path, f"{case_id}_{self.modality}.nii")).dataobj[:, :, slice_idx], dtype=np.float32)
        seg = np.asarray(nib.load(os.path.join(case_path, f"{case_id}_seg.nii")).dataobj[:, :, slice_idx], dtype=np.uint8)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        seg = cv2.resize(seg, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

        img = img / (np.max(img) + 1e-8)
        seg[seg == 4] = 3

        return torch.tensor(img, dtype=torch.float32).unsqueeze(0), torch.tensor(seg, dtype=torch.long)

class ResNetLite(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.GroupNorm(4, 16), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.GroupNorm(4, 32), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.GroupNorm(4, 64), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 32, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, 4, 1)
        )

    def forward(self, x):
        return self.net(x)


def dice_score(preds, targets, smooth=1e-6):
    preds = torch.argmax(preds, dim=1).view(-1)
    targets = targets.view(-1)
    intersection = (preds == targets).float().sum()
    return (2. * intersection + smooth) / (preds.numel() + targets.numel() + smooth)

def evaluate_model(model, dataloader, device):
    model.eval()
    score = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            if out.shape[-2:] != y.shape[-2:]:
                out = F.interpolate(out, size=y.shape[-2:], mode='bilinear', align_corners=False)
            score += dice_score(out, y).item()
    return score / len(dataloader)
def train_dp(model, dataloader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        if out.shape[-2:] != y.shape[-2:]:
            out = F.interpolate(out, size=y.shape[-2:], mode='bilinear', align_corners=False)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
def create_clients(ids, modalities=['flair', 't1ce', 't2']):
    size = len(ids) // len(modalities)
    shards = [ids[i * size:(i + 1) * size] for i in range(len(modalities))]
    return {f"client_{i+1}": (shards[i], modalities[i]) for i in range(len(modalities))}
print("\n Begin FL + DP using ResNetLite")


dirs = [f.path for f in os.scandir(train_data) if f.is_dir()]
dirs.remove(train_data + 'BraTS20_Training_355')
ids = [os.path.basename(p) for p in dirs]
train_ids, val_ids = train_test_split(ids, test_size=0.2, random_state=42)

clients = create_clients(train_ids[:6])
val_ds = BraTSDataset(val_ids[:3], 'flair')
val_loader = DataLoader(val_ds, batch_size=1)

global_model = ResNetLite().to(device)
global_weights = copy.deepcopy(global_model.state_dict())

for round_num in range(1, ROUNDS + 1):
    print(f"\n Round {round_num}")
    weights = []

    for name, (ids_subset, modality) in clients.items():
        print(f" {name} training on {modality}")

        model = ResNetLite().to(device)
        model.load_state_dict(global_weights)
        optimizer = optim.Adam(model.parameters(), lr=LR)

        ds = BraTSDataset(ids_subset, modality)
        dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
        privacy_engine = PrivacyEngine()
        model, optimizer, dl = privacy_engine.make_private(
            module=model,
            optimizer=optimizer,
            data_loader=dl,
            noise_multiplier=NOISE_MULTIPLIER,
            max_grad_norm=MAX_GRAD_NORM
        )

        train_dp(model, dl, optimizer, device)

        clean_weights = {k.replace("_module.", ""): v for k, v in model.state_dict().items()}
        weights.append(clean_weights)

        del model, optimizer
        torch.cuda.empty_cache()
        gc.collect()

    new_weights = weights[0]
    for k in new_weights:
        for i in range(1, len(weights)):
            new_weights[k] += weights[i][k]
        new_weights[k] /= len(weights)
    global_model.load_state_dict(new_weights)

    dice = evaluate_model(global_model, val_loader, device)
    print(f" Dice Score: {dice:.4f}")

print("\n FL + DP ResNetLite Training Done!")



🚀 Begin FL + DP using ResNetLite

 Round 1
 client_1 training on flair


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


 client_2 training on t1ce
 client_3 training on t2
 Dice Score: 0.6056

 FL + DP ResNetLite Training Done!


In [None]:
# ε = 8 (epsilon)
IMG_SIZE = 64
VOLUME_SLICES = 3
VOLUME_START_AT = 22
BATCH_SIZE = 1
EPOCHS = 1
ROUNDS = 1
LR = 5e-4
NOISE_MULTIPLIER = 1.3
MAX_GRAD_NORM = 1.0

train_data = "/content/drive/MyDrive/EECE608/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BraTSDataset(Dataset):
    def __init__(self, list_ids, modality):
        self.list_ids = list_ids
        self.modality = modality

    def __len__(self):
        return len(self.list_ids) * VOLUME_SLICES

    def __getitem__(self, idx):
        case_idx = idx // VOLUME_SLICES
        slice_idx = VOLUME_START_AT + (idx % VOLUME_SLICES)
        case_id = self.list_ids[case_idx]
        case_path = os.path.join(train_data, case_id)

        img = np.asarray(nib.load(os.path.join(case_path, f"{case_id}_{self.modality}.nii")).dataobj[:, :, slice_idx], dtype=np.float32)
        seg = np.asarray(nib.load(os.path.join(case_path, f"{case_id}_seg.nii")).dataobj[:, :, slice_idx], dtype=np.uint8)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        seg = cv2.resize(seg, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

        img = img / (np.max(img) + 1e-8)
        seg[seg == 4] = 3

        return torch.tensor(img, dtype=torch.float32).unsqueeze(0), torch.tensor(seg, dtype=torch.long)

class ResNetLite(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.GroupNorm(4, 16), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.GroupNorm(4, 32), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.GroupNorm(4, 64), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 32, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, 4, 1)
        )

    def forward(self, x):
        return self.net(x)


def dice_score(preds, targets, smooth=1e-6):
    preds = torch.argmax(preds, dim=1).view(-1)
    targets = targets.view(-1)
    intersection = (preds == targets).float().sum()
    return (2. * intersection + smooth) / (preds.numel() + targets.numel() + smooth)

def evaluate_model(model, dataloader, device):
    model.eval()
    score = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            if out.shape[-2:] != y.shape[-2:]:
                out = F.interpolate(out, size=y.shape[-2:], mode='bilinear', align_corners=False)
            score += dice_score(out, y).item()
    return score / len(dataloader)
def train_dp(model, dataloader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        if out.shape[-2:] != y.shape[-2:]:
            out = F.interpolate(out, size=y.shape[-2:], mode='bilinear', align_corners=False)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
def create_clients(ids, modalities=['flair', 't1ce', 't2']):
    size = len(ids) // len(modalities)
    shards = [ids[i * size:(i + 1) * size] for i in range(len(modalities))]
    return {f"client_{i+1}": (shards[i], modalities[i]) for i in range(len(modalities))}
print("\n Begin FL + DP using ResNetLite")


dirs = [f.path for f in os.scandir(train_data) if f.is_dir()]
dirs.remove(train_data + 'BraTS20_Training_355')
ids = [os.path.basename(p) for p in dirs]
train_ids, val_ids = train_test_split(ids, test_size=0.2, random_state=42)

clients = create_clients(train_ids[:6])
val_ds = BraTSDataset(val_ids[:3], 'flair')
val_loader = DataLoader(val_ds, batch_size=1)

global_model = ResNetLite().to(device)
global_weights = copy.deepcopy(global_model.state_dict())

for round_num in range(1, ROUNDS + 1):
    print(f"\n Round {round_num}")
    weights = []

    for name, (ids_subset, modality) in clients.items():
        print(f" {name} training on {modality}")

        model = ResNetLite().to(device)
        model.load_state_dict(global_weights)
        optimizer = optim.Adam(model.parameters(), lr=LR)

        ds = BraTSDataset(ids_subset, modality)
        dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
        privacy_engine = PrivacyEngine()
        model, optimizer, dl = privacy_engine.make_private(
            module=model,
            optimizer=optimizer,
            data_loader=dl,
            noise_multiplier=NOISE_MULTIPLIER,
            max_grad_norm=MAX_GRAD_NORM
        )

        train_dp(model, dl, optimizer, device)

        clean_weights = {k.replace("_module.", ""): v for k, v in model.state_dict().items()}
        weights.append(clean_weights)

        del model, optimizer
        torch.cuda.empty_cache()
        gc.collect()

    new_weights = weights[0]
    for k in new_weights:
        for i in range(1, len(weights)):
            new_weights[k] += weights[i][k]
        new_weights[k] /= len(weights)
    global_model.load_state_dict(new_weights)

    dice = evaluate_model(global_model, val_loader, device)
    print(f" Dice Score: {dice:.4f}")

print("\n FL + DP ResNetLite Training Done!")



 Begin FL + DP using ResNetLite

 Round 1
 client_1 training on flair
 client_2 training on t1ce
 client_3 training on t2
 Dice Score: 0.0497

 FL + DP ResNetLite Training Done!


In [None]:
import os, gc, copy, cv2
import numpy as np
import nibabel as nib
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models

from opacus import PrivacyEngine
from opacus.validators import ModuleValidator

IMG_SIZE = 128
VOLUME_SLICES = 3
VOLUME_START_AT = 22
BATCH_SIZE = 1
EPOCHS = 1
ROUNDS = 1
LR = 5e-4
NOISE_MULTIPLIER = 1.3
MAX_GRAD_NORM = 1.0
NUM_CLASSES = 4
train_data = "/content/drive/MyDrive/EECE608/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BraTSDataset(Dataset):
    def __init__(self, list_ids, modality):
        self.list_ids = list_ids
        self.modality = modality

    def __len__(self):
        return len(self.list_ids) * VOLUME_SLICES

    def __getitem__(self, idx):
        case_idx = idx // VOLUME_SLICES
        slice_idx = VOLUME_START_AT + (idx % VOLUME_SLICES)
        case_id = self.list_ids[case_idx]
        case_path = os.path.join(train_data, case_id)

        img = np.asarray(nib.load(os.path.join(case_path, f"{case_id}_{self.modality}.nii")).dataobj[:, :, slice_idx], dtype=np.float32)
        seg = np.asarray(nib.load(os.path.join(case_path, f"{case_id}_seg.nii")).dataobj[:, :, slice_idx], dtype=np.uint8)

        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        seg = cv2.resize(seg, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

        img = img / (np.max(img) + 1e-8)
        seg[seg == 4] = 3

        return torch.tensor(img, dtype=torch.float32).unsqueeze(0), torch.tensor(seg, dtype=torch.long)

def convert_batchnorm_to_groupnorm(model):
    model = ModuleValidator.fix(model)
    return model

def get_deeplab():
    model = models.segmentation.deeplabv3_resnet50(weights=None, num_classes=NUM_CLASSES)
    # Modify first conv layer: 3 ➜ 1 input channel
    old_conv = model.backbone.conv1
    model.backbone.conv1 = nn.Conv2d(
        in_channels=1,
        out_channels=old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=old_conv.bias is not None
    )
    model = convert_batchnorm_to_groupnorm(model)
    return model

def dice_score(preds, targets, smooth=1e-6):
    preds = torch.argmax(preds, dim=1).view(-1)
    targets = targets.view(-1)
    intersection = (preds == targets).float().sum()
    return (2. * intersection + smooth) / (preds.numel() + targets.numel() + smooth)


def evaluate_model(model, dataloader, device):
    model.eval()
    score = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)['out']
            if out.shape[-2:] != y.shape[-2:]:
                out = F.interpolate(out, size=y.shape[-2:], mode='bilinear', align_corners=False)
            score += dice_score(out, y).item()
    return score / len(dataloader)

def train_dp(model, dataloader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)['out']
        if out.shape[-2:] != y.shape[-2:]:
            out = F.interpolate(out, size=y.shape[-2:], mode='bilinear', align_corners=False)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()


def create_clients(ids, modalities=['flair', 't1ce', 't2']):
    size = len(ids) // len(modalities)
    shards = [ids[i * size:(i + 1) * size] for i in range(len(modalities))]
    return {f"client_{i+1}": (shards[i], modalities[i]) for i in range(len(modalities))}

print("\n FL + DP with DeepLabV3")

dirs = [f.path for f in os.scandir(train_data) if f.is_dir()]
dirs.remove(train_data + 'BraTS20_Training_355')
ids = [os.path.basename(p) for p in dirs]
train_ids, val_ids = train_test_split(ids, test_size=0.2, random_state=42)

clients = create_clients(train_ids[:6])
val_ds = BraTSDataset(val_ids[:3], 'flair')
val_loader = DataLoader(val_ds, batch_size=1)

global_model = get_deeplab().to(device)
global_weights = copy.deepcopy(global_model.state_dict())

for round_num in range(1, ROUNDS + 1):
    print(f"\n Round {round_num}")
    weights = []

    for name, (ids_subset, modality) in clients.items():
        print(f" {name} training on {modality}")
        model = get_deeplab().to(device)
        model.load_state_dict(global_weights)
        optimizer = optim.Adam(model.parameters(), lr=LR)

        ds = BraTSDataset(ids_subset, modality)
        dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)


        privacy_engine = PrivacyEngine()
        model, optimizer, dl = privacy_engine.make_private(
            module=model,
            optimizer=optimizer,
            data_loader=dl,
            noise_multiplier=NOISE_MULTIPLIER,
            max_grad_norm=MAX_GRAD_NORM
        )

        train_dp(model, dl, optimizer, device)

        clean_weights = {k.replace("_module.", ""): v for k, v in model.state_dict().items()}
        weights.append(clean_weights)

        del model, optimizer
        torch.cuda.empty_cache()
        gc.collect()

    new_weights = weights[0]
    for k in new_weights:
        for i in range(1, len(weights)):
            new_weights[k] += weights[i][k]
        new_weights[k] /= len(weights)
    global_model.load_state_dict(new_weights)

    dice = evaluate_model(global_model, val_loader, device)
    print(f" Dice Score: {dice:.4f}")

print("\n Training Complete — DeepLabV3 + DP + FL")


 FL + DP with DeepLabV3

 Round 1
 client_1 training on flair
 client_2 training on t1ce
 client_3 training on t2
 Dice Score: 0.4710

 Training Complete — DeepLabV3 + DP + FL


In [None]:
# IMG_SIZE = 128
# VOLUME_SLICES = 3
# VOLUME_START_AT = 22
# BATCH_SIZE = 1
# EPOCHS = 1
# ROUNDS = 1
# LR = 5e-4
# NOISE_MULTIPLIER = 1.3
# MAX_GRAD_NORM = 1.0
# NUM_CLASSES = 4
# train_data = "/content/drive/MyDrive/EECE608/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# class BraTSDataset(Dataset):
#     def __init__(self, list_ids, modality):
#         self.list_ids = list_ids
#         self.modality = modality

#     def __len__(self):
#         return len(self.list_ids) * VOLUME_SLICES

#     def __getitem__(self, idx):
#         case_idx = idx // VOLUME_SLICES
#         slice_idx = VOLUME_START_AT + (idx % VOLUME_SLICES)
#         case_id = self.list_ids[case_idx]
#         case_path = os.path.join(train_data, case_id)

#         img = np.asarray(nib.load(os.path.join(case_path, f"{case_id}_{self.modality}.nii")).dataobj[:, :, slice_idx], dtype=np.float32)
#         seg = np.asarray(nib.load(os.path.join(case_path, f"{case_id}_seg.nii")).dataobj[:, :, slice_idx], dtype=np.uint8)

#         img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
#         seg = cv2.resize(seg, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

#         img = img / (np.max(img) + 1e-8)
#         seg[seg == 4] = 3

#         return torch.tensor(img, dtype=torch.float32).unsqueeze(0), torch.tensor(seg, dtype=torch.long)

# def convert_batchnorm_to_groupnorm(model):
#     model = ModuleValidator.fix(model)
#     return model

# def get_deeplab():
#     model = models.segmentation.deeplabv3_resnet50(weights=None, num_classes=NUM_CLASSES)
#     # Modify first conv layer: 3 ➜ 1 input channel
#     old_conv = model.backbone.conv1
#     model.backbone.conv1 = nn.Conv2d(
#         in_channels=1,
#         out_channels=old_conv.out_channels,
#         kernel_size=old_conv.kernel_size,
#         stride=old_conv.stride,
#         padding=old_conv.padding,
#         bias=old_conv.bias is not None
#     )
#     model = convert_batchnorm_to_groupnorm(model)
#     return model

# def dice_score(preds, targets, smooth=1e-6):
#     preds = torch.argmax(preds, dim=1).view(-1)
#     targets = targets.view(-1)
#     intersection = (preds == targets).float().sum()
#     return (2. * intersection + smooth) / (preds.numel() + targets.numel() + smooth)


# def evaluate_model(model, dataloader, device):
#     model.eval()
#     score = 0
#     with torch.no_grad():
#         for x, y in dataloader:
#             x, y = x.to(device), y.to(device)
#             out = model(x)['out']
#             if out.shape[-2:] != y.shape[-2:]:
#                 out = F.interpolate(out, size=y.shape[-2:], mode='bilinear', align_corners=False)
#             score += dice_score(out, y).item()
#     return score / len(dataloader)

# def train_dp(model, dataloader, optimizer, device):
#     model.train()
#     criterion = nn.CrossEntropyLoss()
#     for x, y in dataloader:
#         x, y = x.to(device), y.to(device)
#         optimizer.zero_grad()
#         out = model(x)['out']
#         if out.shape[-2:] != y.shape[-2:]:
#             out = F.interpolate(out, size=y.shape[-2:], mode='bilinear', align_corners=False)
#         loss = criterion(out, y)
#         loss.backward()
#         optimizer.step()


# def create_clients(ids, modalities=['flair', 't1ce', 't2']):
#     size = len(ids) // len(modalities)
#     shards = [ids[i * size:(i + 1) * size] for i in range(len(modalities))]
#     return {f"client_{i+1}": (shards[i], modalities[i]) for i in range(len(modalities))}

# print("\n FL + DP with DeepLabV3")

# dirs = [f.path for f in os.scandir(train_data) if f.is_dir()]
# dirs.remove(train_data + 'BraTS20_Training_355')
# ids = [os.path.basename(p) for p in dirs]
# train_ids, val_ids = train_test_split(ids, test_size=0.2, random_state=42)

# clients = create_clients(train_ids[:6])
# val_ds = BraTSDataset(val_ids[:3], 'flair')
# val_loader = DataLoader(val_ds, batch_size=1)

# global_model = get_deeplab().to(device)
# global_weights = copy.deepcopy(global_model.state_dict())

# for round_num in range(1, ROUNDS + 1):
#     print(f"\n Round {round_num}")
#     weights = []

#     for name, (ids_subset, modality) in clients.items():
#         print(f" {name} training on {modality}")
#         model = get_deeplab().to(device)
#         model.load_state_dict(global_weights)
#         optimizer = optim.Adam(model.parameters(), lr=LR)

#         ds = BraTSDataset(ids_subset, modality)
#         dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)


#         privacy_engine = PrivacyEngine()
#         model, optimizer, dl = privacy_engine.make_private(
#             module=model,
#             optimizer=optimizer,
#             data_loader=dl,
#             noise_multiplier=NOISE_MULTIPLIER,
#             max_grad_norm=MAX_GRAD_NORM
#         )

#         train_dp(model, dl, optimizer, device)

#         clean_weights = {k.replace("_module.", ""): v for k, v in model.state_dict().items()}
#         weights.append(clean_weights)

#         del model, optimizer
#         torch.cuda.empty_cache()
#         gc.collect()

#     new_weights = weights[0]
#     for k in new_weights:
#         for i in range(1, len(weights)):
#             new_weights[k] += weights[i][k]
#         new_weights[k] /= len(weights)
#     global_model.load_state_dict(new_weights)

#     dice = evaluate_model(global_model, val_loader, device)
#     print(f" Dice Score: {dice:.4f}")

# print("\n Training Complete — DeepLabV3 + DP + FL")