In [2]:
import torch
import torchvision
import numpy as np
from torch.utils.data import DataLoader
from fedlab.utils.dataset.partition import CIFAR10Partitioner

## Downloading the CIFAR dataset

Download the CIFAR dataset from torchvision. It will be downloaded to the folder specified by ``root``.

In [4]:
train_data = torchvision.datasets.CIFAR10(root="cifar_data", train=True, download=True)
test_data = torchvision.datasets.CIFAR10(root="cifar_data", train=False, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
X_train = (train_data.data / 255).astype(np.float32).transpose(0, 3, 1, 2)
y_train = np.array(train_data.targets, dtype=np.int64)

X_test = (test_data.data / 255).astype(np.float32).transpose(0, 3, 1, 2)
y_test = np.array(test_data.targets, dtype=np.int64)

## Using the CIFAR dataset

The partitioner from fedlab (https://fedlab.readthedocs.io/en/master/tutorials/cifar10_tutorial.html) is used to create an imbalanced, non-iid partitioned version of the CIFAR dataset. The data is partitioned into 750 clients, so that the number of samples on each client ranges from about 10 to 250.

The test data however has a balanced distribution, so that the models can be fairly tested.

In [6]:
partitioned_train_data = CIFAR10Partitioner(train_data.targets,
                                              750,
                                              balance=None,
                                              partition="dirichlet",
                                              dir_alpha=0.3,
                                              seed=42)

In [7]:
partitioned_test_data = CIFAR10Partitioner(test_data.targets,
                                           750,
                                           balance=True,
                                           partition="iid",
                                           seed=42)

Putting the data into dataloaders:

In [8]:
all_client_trainloaders = []
all_client_testloaders = []

for client in range (750):
    
    client_X_train = X_train[partitioned_train_data[client], :, :, :]
    client_y_train = y_train[partitioned_train_data[client]]
    torch.manual_seed(47)
    train_loader = DataLoader(dataset=list(zip(client_X_train, client_y_train)),
                              batch_size=32,
                              shuffle=True,
                              pin_memory=True)
    
    client_X_test = X_test[partitioned_test_data[client], :, :, :]
    client_y_test = y_test[partitioned_test_data[client]]
    torch.manual_seed(47)
    test_loader = DataLoader(dataset=list(zip(client_X_test, client_y_test)),
                              batch_size=32,
                              shuffle=True,
                              pin_memory=True)
    
    all_client_trainloaders.append(train_loader)
    all_client_testloaders.append(test_loader)

These dataloaders can then be imported into each client using code along the lines of:

```python
def client_fn(cid):  
    model = resnet18().to(DEVICE)
    trainloader = all_client_trainloaders[int(cid)]
    testloader = all_client_testloaders[int(cid)]
    return FlowerClient(cid, model, trainloader, testloader)

fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=len(all_client_trainloaders),
config=fl.server.ServerConfig(num_rounds=1000),
strategy=FedAvg(),
client_resources=client_resources)
```