In [1]:
# !pip install -q flwr
# !pip install -U ipywidgets
# !pip install -U flwr["simulation"]

In [2]:
import medmnist
from medmnist import INFO, Evaluator
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda
import torch.utils.data as data
from torchvision import models
import torch
import tqdm
from collections import OrderedDict
# from typing import List, Tuple
from typing import Dict, List, Optional, Tuple
import numpy as np 
from torch.utils.data import DataLoader, random_split
import flwr as fl
from torch import nn
import torch.nn.functional as F
from flwr.common import Metrics


#DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
DEVICE= torch.device("cpu")
print(DEVICE)

cpu


## Dataset Loading

In [3]:


def load_datasets(num_clients: int, batch_size: int):
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5])
    ])
 
    data_flag = 'chestmnist'
    info = INFO[data_flag]
    DataClass = getattr(medmnist, info['python_class'])

    trainset = DataClass(split='train', transform=transform, download=True)
    testset = DataClass(split='test', transform=transform, download=True)
    valset = DataClass(split='val', transform=transform, download=True)

    #  Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    if partition_size * num_clients == len(trainset):
        lengths = [partition_size] * num_clients 
    else: 
         lengths = [partition_size] * (num_clients-1)  + [partition_size+ (len(trainset) % num_clients)]
    train_datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))
    
    partition_size = len(valset) // num_clients
    if partition_size * num_clients == len(valset):
        lengths = [partition_size] * num_clients 
    else: 
         lengths = [partition_size] * (num_clients-1)  + [partition_size+ (len(valset) % num_clients)]
    # print(len(valset), lengths)
    val_datasets = random_split(valset, lengths, torch.Generator().manual_seed(42))
    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []

    for t_ds in train_datasets:
        trainloaders.append(DataLoader(t_ds, batch_size=batch_size, shuffle=True))


    for v_ds in val_datasets:
        valloaders.append(DataLoader(v_ds, batch_size=batch_size))
        
    testloader = DataLoader(testset, batch_size=batch_size)
    return trainloaders, valloaders, testloader



## Define Train and Test

In [4]:
from torchmetrics.classification import F1Score
f1 = F1Score(task="multiclass", num_classes=14)

def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""

    net.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    # optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for i, batch in enumerate(trainloader):
            images, labels = batch
            optimizer.zero_grad()
            # images = images.expand(-1, 3, -1, -1)
            outputs = net(images)
            labels_ = labels.to(torch.float32)
            loss = criterion(outputs, labels_)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            predicted_classes = torch.argmax(outputs, dim=1)
            true_classes = torch.argmax(labels_, dim=1)
            correct += torch.sum(predicted_classes == true_classes).item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        
        # if verbose:
        # print(f"Epoch {epoch}: train loss {epoch_loss}, accuracy {epoch_acc}")
    return epoch_loss, epoch_acc * 100 
    # return np.mean(losses), np.mean(accuracies)

def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    # accuracies = []
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch
            # images = images.expand(-1, 3, -1, -1)
            # labels = labels.long()
            labels = labels.to(torch.float32)
            outputs = net(images)

            loss += criterion(outputs, labels).item()

            total += labels.size(0)

            predicted_classes = torch.argmax(outputs, dim=1)
            true_classes = torch.argmax(labels, dim=1)

            correct += torch.sum(predicted_classes == true_classes).item()



    loss /= len(testloader.dataset)
    accuracy = correct / total
    # accuracy = np.mean(accuracies)
    return loss, accuracy * 100          



## Poisoning Attacks 

In [5]:
def poison_labels(poisoned_dataset, poison_rate=0.25):
    # Calculate the number of samples to poison
    num_samples = len(poisoned_dataset)
    num_to_poison = int(num_samples * poison_rate)

    # Randomly select samples to poison
    indices_to_poison = np.random.choice(num_samples, num_to_poison, replace=False)
    
    for idx in indices_to_poison:
    # Check if the sample has no disease
        if torch.sum(poisoned_dataset[idx]) == 0:
            # Add a disease by setting a random position to 1
            disease_to_add = torch.randint(0, poisoned_dataset.shape[1], (1,))
            poisoned_dataset[idx][disease_to_add] = 1
        else:
            # Decide randomly whether to add or remove a disease
            if np.random.rand() > 0.5:
                # Add a disease by setting a random zero position to 1
                zero_indices = np.where(poisoned_dataset[idx] == 0)[0]
                if zero_indices.size > 0:
                    disease_to_add = np.random.choice(zero_indices)
                    poisoned_dataset[idx][disease_to_add] = 1
            else:
                # Remove a disease by setting a random one position to 0
                one_indices = np.where(poisoned_dataset[idx] == 1)[0]
                if one_indices.size > 0:
                    disease_to_remove = np.random.choice(one_indices)
                    poisoned_dataset[idx][disease_to_remove] = 0

def single_pixel_perturbations(poisoned_dataset, x, y, new_value, poison_rate=0.25):
    
    num_samples = len(poisoned_dataset)
    num_to_poison = int(num_samples * poison_rate)

    # Randomly select samples to poison
    indices_to_poison = np.random.choice(num_samples, num_to_poison, replace=False)
        
    for idx in indices_to_poison:
        poisoned_dataset[idx][0, y, x] = new_value

## Model 

In [6]:

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 14)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        # print(x.shape)
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = F.softmax(x, dim=1)
        return x


# class Net2(nn.Module):
#     def __init__(self) -> None:
#         super(Net2, self).__init__()
#         self.conv1 = nn.Conv2d(1, 6, 3)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.conv2 = nn.Conv2d(6, 16, 5)
#         self.fc1 = nn.Linear(6 * 13 * 13, 500)
#         self.fc2 = nn.Linear(500, 250)
#         self.fc3 = nn.Linear(250, 120)
#         self.fc4 = nn.Linear(120, 84)
#         # self.fc5 = nn.Linear(120, 84)
#         self.fc5 = nn.Linear(84, 2)


#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         x = self.pool(F.relu(self.conv1(x)))
#         # x = self.pool(F.relu(self.conv2(x)))

#         # print(x.shape)
#         x = x.view(-1, 6 * 13 * 13)
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = F.relu(self.fc3(x))
#         x = F.relu(self.fc4(x))
#         x = self.fc5(x)
#         # x = F.softmax(x)
#         return x
# # Load model and data
# # net = Net().to(DEVICE)

In [7]:
class Net(nn.Module):
    def __init__(self, in_channels, num_classes)->None:
        super(Net, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU())

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer3 = nn.Sequential(
            nn.Conv2d(16, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU())
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.layer5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.fc = nn.Sequential(
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



In [8]:

    
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict(
        {
            k: torch.Tensor(v) if v.shape != torch.Size([]) else torch.Tensor([0])
            for k, v in params_dict
        }
    )
    net.load_state_dict(state_dict, strict=True)


class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {"loss": float(loss), "accuracy": float(accuracy)}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"loss": float(loss), "accuracy": float(accuracy)}

def client_fn(cid) -> FlowerClient:
    # net = Net().to(DEVICE)
    net = Net(in_channels=1, num_classes=14).to(DEVICE)

    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

Using downloaded and verified file: C:\Users\tugca\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\tugca\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\tugca\.medmnist\chestmnist.npz


In [9]:

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    # print(metrics)
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [10]:
# The `evaluate` function will be by Flower called after every round
def evaluate(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = Net(in_channels=1, num_classes=14).to(DEVICE)
    
    # net = Net().to(DEVICE)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

In [None]:
# Create an instance of the model and get the parameters
NUM_CLIENTS = 3
BATCH_SIZE = 32
NUM_ROUNDS = 3
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS, batch_size=BATCH_SIZE)

CLIENT = 1  #0, 1, 2
Poison_type = poison_labels # poison_labels / single_pixel_perturbations
Poison_rate = 1

def modify_traindata(trainloaders, poison_type, poison_rate=1):
    if(poison_type == single_pixel_perturbations):

        for batch in (trainloaders):
            single_pixel_perturbations(batch[0], 10, 10, torch.max(batch[0]), poison_rate=poison_rate)
     
    else:
        for batch in (trainloaders):
            poison_labels(batch[1], poison_rate)
        
#modify_traindata(trainloaders=trainloaders[CLIENT], poison_type=Poison_type, poison_rate=Poison_rate)

server_config = fl.server.ServerConfig(num_rounds=NUM_ROUNDS)
# Pass parameters to the Strategy for server-side parameter initialization
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    # initial_parameters=fl.common.ndarrays_to_parameters(params),
    evaluate_metrics_aggregation_fn=weighted_average, 
    fit_metrics_aggregation_fn=weighted_average,
    evaluate_fn=evaluate
    
)

# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 1}

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=server_config,  
    strategy=strategy,
    client_resources=client_resources,
)