## Importing the libraries

In [None]:
import numpy as np
import nibabel as nib
import glob
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()


## Preprocessing

In [None]:
t2_list = sorted(glob.glob('/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*t2.nii'))
t1ce_list = sorted(glob.glob('/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*t1ce.nii'))
flair_list = sorted(glob.glob('/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*flair.nii'))
mask_list = sorted(glob.glob('/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*seg.nii'))

In [None]:
data_dir = '/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
save_dir = '/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/processed_data'


os.makedirs(f'{save_dir}/images', exist_ok=True)
os.makedirs(f'{save_dir}/masks', exist_ok=True)


cases = sorted(glob(f'{data_dir}/*'))
t2_list = [f'{c}/{os.path.basename(c)}_t2.nii' for c in cases]
t1ce_list = [f'{c}/{os.path.basename(c)}_t1ce.nii' for c in cases]
flair_list = [f'{c}/{os.path.basename(c)}_flair.nii' for c in cases]
mask_list = [f'{c}/{os.path.basename(c)}_seg.nii' for c in cases]






# First pass: Computing dataset statistics
print("Calculating dataset statistics...")
for i in range(len(t2_list)):
    t2_data = nib.load(t2_list[i]).get_fdata().reshape(-1, 1)

    # T1CE
    t1ce_data = nib.load(t1ce_list[i]).get_fdata().reshape(-1, 1)


    # FLAIR
    flair_data = nib.load(flair_list[i]).get_fdata().reshape(-1, 1)


# Second pass: Processing and save data
print("\nProcessing and saving data...")
for idx in range(len(t2_list)):
    print(f"Processing case {idx+1}/{len(t2_list)}")

    # Loading raw data
    t2 = nib.load(t2_list[idx]).get_fdata()
    t1ce = nib.load(t1ce_list[idx]).get_fdata()
    flair = nib.load(flair_list[idx]).get_fdata()
    mask = nib.load(mask_list[idx]).get_fdata()

    # Normalizing using dataset statistics
    t2 = scaler_t2.transform(t2.reshape(-1, 1)).reshape(t2.shape)
    t1ce = scaler_t1ce.transform(t1ce.reshape(-1, 1)).reshape(t1ce.shape)
    flair = scaler_flair.transform(flair.reshape(-1, 1)).reshape(flair.shape)


    mask = mask.astype(np.uint8)
    mask[mask == 4] = 3


    combined = np.stack([flair, t1ce, t2], axis=-1)[56:184, 56:184, 13:141]
    mask_cropped = mask[56:184, 56:184, 13:141]

    np.save(f'{save_dir}/images/case_{idx}.npy', combined)
    np.save(f'{save_dir}/masks/case_{idx}.npy', mask_cropped)

Calculating dataset statistics...

Processing and saving data...
Processing case 1/369
Processing case 2/369
Processing case 3/369
Processing case 4/369
Processing case 5/369
Processing case 6/369
Processing case 7/369
Processing case 8/369
Processing case 9/369
Processing case 10/369
Processing case 11/369
Processing case 12/369
Processing case 13/369
Processing case 14/369
Processing case 15/369
Processing case 16/369
Processing case 17/369
Processing case 18/369
Processing case 19/369
Processing case 20/369
Processing case 21/369
Processing case 22/369
Processing case 23/369
Processing case 24/369
Processing case 25/369
Processing case 26/369
Processing case 27/369
Processing case 28/369
Processing case 29/369
Processing case 30/369
Processing case 31/369
Processing case 32/369
Processing case 33/369
Processing case 34/369
Processing case 35/369
Processing case 36/369
Processing case 37/369
Processing case 38/369
Processing case 39/369
Processing case 40/369
Processing case 41/369
P

## Code for Simple U-Net

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_classes=4, init_channels=32):
        super().__init__()

        # Encoder
        self.enc1 = self._block(in_channels, init_channels)
        self.pool1 = nn.MaxPool3d(2, 2)
        self.enc2 = self._block(init_channels, init_channels*2)
        self.pool2 = nn.MaxPool3d(2, 2)
        self.enc3 = self._block(init_channels*2, init_channels*4)
        self.pool3 = nn.MaxPool3d(2, 2)

        # Bottleneck
        self.bottleneck = self._block(init_channels*4, init_channels*8)

        # Decoder
        self.up3 = nn.ConvTranspose3d(init_channels*8, init_channels*4, kernel_size=2, stride=2)
        self.dec3 = self._block(init_channels*8, init_channels*4)
        self.up2 = nn.ConvTranspose3d(init_channels*4, init_channels*2, kernel_size=2, stride=2)
        self.dec2 = self._block(init_channels*4, init_channels*2)
        self.up1 = nn.ConvTranspose3d(init_channels*2, init_channels, kernel_size=2, stride=2)
        self.dec1 = self._block(init_channels*2, init_channels)

        self.final_conv = nn.Conv3d(init_channels, out_classes, kernel_size=1)

    def _block(self, in_channels, features):
        return nn.Sequential(
            nn.Conv3d(in_channels, features, 3, padding=1),
            nn.GroupNorm(8, features),
            nn.ReLU(inplace=True),
            nn.Conv3d(features, features, 3, padding=1),
            nn.GroupNorm(8, features),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        pool1 = self.pool1(enc1)
        enc2 = self.enc2(pool1)
        pool2 = self.pool2(enc2)
        enc3 = self.enc3(pool2)
        pool3 = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(pool3)

        # Decoder
        up3 = self.up3(bottleneck)
        dec3 = self.dec3(torch.cat([up3, enc3], 1))
        up2 = self.up2(dec3)
        dec2 = self.dec2(torch.cat([up2, enc2], 1))
        up1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat([up1, enc1], 1))

        return self.final_conv(dec1)

## Code For Attention U-Net

In [None]:
class AttentionUNet(UNet):
    class AttentionGate(nn.Module):
        def __init__(self, in_channels, inter_channels):
            super().__init__()
            self.W_g = nn.Conv3d(in_channels, 1)
            self.W_x = nn.Conv3d(in_channels, 1)
            self.sigmoid = nn.Sigmoid()
            self.relu = nn.ReLU(inplace=True)

        def forward(self, g, x):
            g1 = self.W_g(g)
            x1 = self.W_x(x)
            psi = self.relu(g1 + x1)
            psi = self.sigmoid(self.psi(psi))
            return x * psi

    def __init__(self, in_channels=3, out_classes=4, init_channels=32):
        super().__init__(in_channels, out_classes, init_channels)


        self.up3 = nn.ConvTranspose3d(init_channels*8, init_channels*4, 2, 2)
        self.att3 = self.AttentionGate(init_channels*4, init_channels*4//2)
        self.up2 = nn.ConvTranspose3d(init_channels*4, init_channels*2, 2, 2)
        self.att2 = self.AttentionGate(init_channels*2, init_channels*2//2)
        self.up1 = nn.ConvTranspose3d(init_channels*2, init_channels, 2, 2)
        self.att1 = self.AttentionGate(init_channels, init_channels//2)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        pool1 = self.pool1(enc1)
        enc2 = self.enc2(pool1)
        pool2 = self.pool2(enc2)
        enc3 = self.enc3(pool2)
        pool3 = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(pool3)

        # Decoder with attention
        up3 = self.up3(bottleneck)
        att3 = self.att3(up3, enc3)
        dec3 = self.dec3(torch.cat([up3, att3], 1))

        up2 = self.up2(dec3)
        att2 = self.att2(up2, enc2)
        dec2 = self.dec2(torch.cat([up2, att2], 1))

        up1 = self.up1(dec2)
        att1 = self.att1(up1, enc1)
        dec1 = self.dec1(torch.cat([up1, att1], 1))

        return self.final_conv(dec1)

In [None]:
class BraTSDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_paths = sorted(glob(os.path.join(image_dir, "*.npy")))
        self.mask_paths = sorted(glob(os.path.join(mask_dir, "*.npy")))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.load(self.image_paths[idx])
        mask = np.load(self.mask_paths[idx])

        # Converting to PyTorch format and permute dimensions because PyTorch requires data in this format
        image = torch.tensor(image).float().permute(3, 2, 0, 1)
        mask = torch.tensor(mask).long()
        return image, mask

## Centralised Learning

In [None]:
def train_centralized_model(num_epochs=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    base_dir = "/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/fed"

    # Creating combined dataset from all hospitals
    combined_dataset = ConcatDataset([
        BraTSDataset(
            os.path.join(base_dir, f"hosp{i}/images"),
            os.path.join(base_dir, f"hosp{i}/masks")
        ) for i in range(1, 6)
    ])

    train_loader = DataLoader(combined_dataset, batch_size=1, shuffle=True)


    model = UNet().to(device)


    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Training loop
    model.train()
    for epoch in range(num_epochs):
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()



    return model.state_dict()


## Federated Learning

In [None]:
def train_client_model(client_id, global_weights, num_epochs=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Loading client data
    base_dir = "/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/fed"
    train_dataset = BraTSDataset(
        os.path.join(base_dir, f"hosp{client_id}/images"),
        os.path.join(base_dir, f"hosp{client_id}/masks")
    )
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

    # Initialize model
    model = UNet().to(device)
    model.load_state_dict(global_weights)



    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Training loop
    model.train()
    for epoch in range(num_epochs):
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Cleanup
    del images, masks, outputs
    torch.cuda.empty_cache()

    return model.state_dict()

def federated_average(client_weights_list):
    avg_weights = {}
    for key in client_weights_list[0].keys():
        avg_weights[key] = torch.mean(
            torch.stack([client_weights[key] for client_weights in client_weights_list]),
            dim=0
        )
    return avg_weights

In [None]:
def dice_coeff(pred, target, smooth=1e-6):

    if pred.dim() == target.dim() + 1:
        pred = torch.argmax(pred, dim=1)


    num_classes = pred.max().item() + 1
    target_one_hot = torch.nn.functional.one_hot(target, num_classes).permute(0, 4, 1, 2, 3).float()
    pred_one_hot = torch.nn.functional.one_hot(pred, num_classes).permute(0, 4, 1, 2, 3).float()

    # Calculating intersection and union
    intersection = (pred_one_hot * target_one_hot).sum(dim=(0, 1, 2, 3))
    union = pred_one_hot.sum(dim=(0, 1, 2, 3)) + target_one_hot.sum(dim=(0, 1, 2, 3))

    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.mean().item()

def jaccard_similarity(pred, target, smooth=1e-6):

    if pred.dim() == target.dim() + 1:
        pred = torch.argmax(pred, dim=1)

    num_classes = pred.max().item() + 1
    target_one_hot = torch.nn.functional.one_hot(target, num_classes).permute(0, 4, 1, 2, 3).float()
    pred_one_hot = torch.nn.functional.one_hot(pred, num_classes).permute(0, 4, 1, 2, 3).float()

    intersection = (pred_one_hot * target_one_hot).sum(dim=(0, 1, 2, 3))
    union = pred_one_hot.sum(dim=(0, 2, 3, 4)) + target_one_hot.sum(dim=(0, 1, 2, 3)) - intersection

    jaccard = (intersection + smooth) / (union + smooth)
    return jaccard.mean().item()


In [None]:
def evaluate_global_model(model, val_dir):
    val_dataset = BraTSDataset(
        os.path.join(val_dir, "images"),
        os.path.join(val_dir, "masks")
    )
    val_loader = DataLoader(val_dataset, batch_size=1)

    model.eval()
    total_dice = 0.0
    total_jaccard = 0.0
    class_dice = {1: 0.0, 2: 0.0, 3: 0.0}
    class_counts = {1: 1e-6, 2: 1e-6, 3: 1e-6}

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.cuda(), masks.cuda()
            outputs = model(images)

            # Convert logits to class predictions
            preds = torch.argmax(outputs, dim=1)


            for cls in [1, 2, 3]:
                pred_cls = (preds == cls)
                mask_cls = (masks == cls)

                if mask_cls.sum() > 0:
                    dice = dice_coefficient(pred_cls.float(), mask_cls.float())
                    jaccard = jaccard_similarity(pred_cls.float(), mask_cls.float())

                    class_dice[cls] += dice
                    class_counts[cls] += 1

            # Calculating combined regions
            wt_pred = (preds >= 1)  # Whole Tumor: 1+2+3
            wt_mask = (masks >= 1)
            total_dice += dice_coefficient(wt_pred.float(), wt_mask.float())

            tc_pred = (preds == 1) | (preds == 3)  # Tumor Core: 1+3
            tc_mask = (masks == 1) | (masks == 3)
            total_jaccard += jaccard_similarity(tc_pred.float(), tc_mask.float())

    # Average metrics
    metrics = {
        'mean_dice': total_dice / len(val_loader),
        'mean_jaccard': total_jaccard / len(val_loader),
        'class_dice': {cls: class_dice[cls]/class_counts[cls] for cls in [1, 2, 3]}
    }

    return metrics


In [None]:
global_model = UNet().cuda()
global_weights = global_model.state_dict()

    # Federated parameters
num_clients = 5
num_rounds = 2
val_dir = "/content/drive/MyDrive/BraTS2020_Extracted/BraTS2020_TrainingData/val"

    # Training loop
for round in range(num_rounds):
        print(f"\n--- Round {round+1}/{num_rounds} ---")

        # Client training
        client_weights = []
        for client_id in range(1, num_clients+1):
            print(f"Training client {client_id}...")
            client_weights.append(train_client_model(client_id, global_weights))

        # Aggregating weights
        global_weights = federated_average(client_weights)
        global_model.load_state_dict(global_weights)

        # Validation
        val_dice = evaluate_global_model(global_model, val_dir)
        print(f"Validation Dice: {val_dice:.4f}")

        # Saving model
        torch.save(global_weights, f"global_model_round_{round}.pth")


=== Round 1/2 ===
Training Client 1...
Training Client 2...
Training Client 3...
Training Client 4...
Training Client 5...

=== Round 2/2 ===
Training Client 1...
Training Client 2...
Training Client 3...
Training Client 4...
Training Client 5...


In [None]:
metrics = evaluate_global_model(model, val_loader)
print(f"Mean Dice: {metrics['mean_dice']:.}")
print(f"Mean Jaccard: {metrics['mean_jaccard']:.4f}")
print(f"Class-wise Dice:")
print(f"  Whole Tumor (1+2+3): {metrics['class_dice'][1]:.4f}")
print(f"  Tumor Core (1+3): {metrics['class_dice'][2]:.4f}")
print(f"  Enhancing Tumor (3): {metrics['class_dice'][3]:.4f}")

Mean Dice: 0.3407
Mean Jaccard: 0.2698
Class-wise Dice:
  Whole Tumor (1+2+3): 0.4622
  Tumor Core (1+3): 0.3345
  Enhancing Tumor (3): 0.2254
