Dataset : https://www.kaggle.com/datasets/ssarkar445/covid-19-xray-and-ct-scan-image-dataset

In [4]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, random_split, ConcatDataset
import copy
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
import torchvision.datasets


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


dataset_paths = {
    "CT": "/home/ealam/MDCL/EXP1/archive (2)/COVID-19 Dataset/CT",
    "X-ray": "/home/ealam/MDCL/EXP1/archive (2)/COVID-19 Dataset/X-ray"
}


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])
])


datasets = {}
for dataset_type, dataset_path in dataset_paths.items():
    datasets[dataset_type] = torchvision.datasets.ImageFolder(root=dataset_path, transform=transform)


ct_dataset = datasets["CT"]
ct_class0_indices = [idx for idx, (img, label) in enumerate(ct_dataset.imgs) if label == 0]
ct_class1_indices = [idx for idx, (img, label) in enumerate(ct_dataset.imgs) if label == 1]
selected_ct_indices = ct_class0_indices[:399] + ct_class1_indices[:146]


ct_train_size = int(0.8 * len(selected_ct_indices))
ct_test_size = len(selected_ct_indices) - ct_train_size
ct_train_dataset, ct_test_dataset = random_split(Subset(ct_dataset, selected_ct_indices), [ct_train_size, ct_test_size])


xray_dataset = datasets["X-ray"]
xray_class0_indices = [idx for idx, (img, label) in enumerate(xray_dataset.imgs) if label == 0]
xray_class1_indices = [idx for idx, (img, label) in enumerate(xray_dataset.imgs) if label == 1]
selected_xray_indices = xray_class0_indices[:223] + xray_class1_indices[:1341]


xray_train_size = int(0.8 * len(selected_xray_indices))
xray_test_size = len(selected_xray_indices) - xray_train_size
xray_train_dataset, xray_test_dataset = random_split(Subset(xray_dataset, selected_xray_indices), [xray_train_size, xray_test_size])


client_clusters_ct = {0: list(range(len(ct_train_dataset))), 1: list(range(len(ct_train_dataset)))}

client_clusters_xray = {2: list(range(len(xray_train_dataset))), 3: list(range(len(xray_train_dataset)))}


class OriginalVGG16(nn.Module):
    def __init__(self, num_classes=2):
        super(OriginalVGG16, self).__init__()
        self.features = models.vgg16(pretrained=True).features
        self.classifier = models.vgg16(pretrained=True).classifier


        self.classifier._modules['6'] = nn.Linear(4096, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=True, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduction = reduction

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        else:
            BCE_loss = nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss


def local_train(model, train_loader, epochs, criterion, optimizer):
    model.train()
    for epoch in range(epochs):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)


            target_onehot = torch.zeros_like(output)
            target_onehot.scatter_(1, target.view(-1, 1), 1)

            loss = criterion(output, target_onehot)
            loss.backward()
            optimizer.step()


def evaluate_model_per_class(model, test_loader):
    model.eval()
    y_true = []
    y_pred = []
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, predicted = torch.max(output, 1)
        y_true.extend(target.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())
    precision = precision_score(y_true, y_pred, average=None)
    recall = recall_score(y_true, y_pred, average=None)
    f1 = f1_score(y_true, y_pred, average=None)
    return precision, recall, f1


lr = 0.0001
local_epochs = 17
global_epochs = 25
batch_size = 32


combined_train_dataset = ConcatDataset([Subset(ct_train_dataset, indices) for indices in client_clusters_ct.values()] +
                                       [Subset(xray_train_dataset, indices) for indices in client_clusters_xray.values()])


combined_client_clusters = {key: value for key, value in client_clusters_ct.items()}
for key, value in client_clusters_xray.items():
    combined_client_clusters[key + len(client_clusters_ct)] = value


shared_model_combined = OriginalVGG16(num_classes=2).to(device)


criterion_combined = FocalLoss(alpha=1, gamma=2)
optimizer_combined = torch.optim.Adam(shared_model_combined.parameters(), lr=lr)


for global_epoch in range(global_epochs):
    for cluster_id, cluster_indices in combined_client_clusters.items():
        local_model_combined = copy.deepcopy(shared_model_combined)
        train_loader_combined = DataLoader(Subset(combined_train_dataset, cluster_indices), batch_size=batch_size, shuffle=True)
        local_train(local_model_combined, train_loader_combined, local_epochs, criterion_combined, optimizer_combined)
        shared_model_combined.load_state_dict(local_model_combined.state_dict())


        for shared_param, local_param in zip(shared_model_combined.parameters(), local_model_combined.parameters()):
            shared_param.data.copy_((shared_param.data + local_param.data) / 2.0)


test_loader_ct = DataLoader(ct_test_dataset, batch_size=batch_size, shuffle=False)
precision_ct, recall_ct, f1_ct = evaluate_model_per_class(shared_model_combined, test_loader_ct)


print("Evaluation Metrics for CT Dataset:")
print(f'Class 0:')
print(f'Precision: {precision_ct[0]:.2f}')
print(f'Recall: {recall_ct[0]:.2f}')
print(f'F1 Score: {f1_ct[0]:.2f}')
print(f'Class 1:')
print(f'Precision: {precision_ct[1]:.2f}')
print(f'Recall: {recall_ct[1]:.2f}')
print(f'F1 Score: {f1_ct[1]:.2f}')


test_loader_xray = DataLoader(xray_test_dataset, batch_size=batch_size, shuffle=False)
precision_xray, recall_xray, f1_xray = evaluate_model_per_class(shared_model_combined, test_loader_xray)


print("\nEvaluation Metrics for X-ray Dataset:")
print(f'Class 0:')
print(f'Precision: {precision_xray[0]:.2f}')
print(f'Recall: {recall_xray[0]:.2f}')
print(f'F1 Score: {f1_xray[0]:.2f}')
print(f'Class 1:')
print(f'Precision: {precision_xray[1]:.2f}')
print(f'Recall: {recall_xray[1]:.2f}')
print(f'F1 Score: {f1_xray[1]:.2f}')


Evaluation Metrics for CT Dataset:
Class 0:
Precision: 0.96
Recall: 0.92
F1 Score: 0.95
Class 1:
Precision: 0.97
Recall: 0.77
F1 Score: 0.86

Evaluation Metrics for X-ray Dataset:
Class 0:
Precision: 1.0
Recall: 0.93
F1 Score: 0.96
Class 1:
Precision: 0.93
Recall: 0.89
F1 Score: 0.91
