In [1]:
import json
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

## Downloading the FEMNIST dataset

I have uploaded the FEMNIST dataset to Github as a zip file. Simply download and extract it.

Alternatively, if you want to download the data yourself from the source, clone the leaf repo (https://github.com/TalwalkarLab/leaf/tree/master) and then run the following command within the femnist folder:
```bash
./preprocess.sh -s niid --iu 1.0 --sf 1.0 -k 0 -t sample --smplseed 42 --spltseed 42
```

There is also a bug in the file ``data/femnist/preprocess/data_to_json.py``. Line 64: ``gray.thumbnail(size, Image.ANTIALIAS)`` must be changed to: ``gray.thumbnail(size, Image.LANCZOS)``.

## Using the FEMNIST dataset

Run the following code to extract the data into dataloaders, for use in Flower. Remember to provide the correct file path, to where you extracted the data.

In [4]:
path_to_data_folder = "femnist_data" # EDIT THIS WITH THE PATH TO THE DATA FOLDER

all_client_trainloaders = []
all_client_valloaders = []

for i in tqdm(range(0, 36)): # for each json file
    with open(f"{path_to_data_folder}/all_data_{i}.json") as file:

        # load the 100 clients in each json file
        data = json.load(file)
        all_clients = data["users"]
        
        for client in all_clients:
            # load the dataset from one client
            X_data = data["user_data"][client]["x"]
            num_samples = len(X_data)
            X_data = np.array(X_data).reshape(num_samples, 1, 28, 28) # reshape into BxCxHxW
            y_data = np.array(data["user_data"][client]["y"], dtype=np.int64)
        
            # split into test and train data
            X_train, X_test = random_split(X_data, (0.9, 0.1), torch.Generator().manual_seed(42))
            y_train, y_test = random_split(y_data, (0.9, 0.1), torch.Generator().manual_seed(42))

            # put the dataset into dataloaders
            torch.manual_seed(47)
            train_loader = DataLoader(dataset=list(zip(X_train, y_train)),
                                      batch_size=32,
                                      shuffle=True,
                                      pin_memory=True)
            torch.manual_seed(47)
            test_loader = DataLoader(dataset=list(zip(X_test, y_test)),
                                     batch_size=32,
                                     shuffle=True,
                                     pin_memory=True)
    
            # add the dataloader to the overall list
            all_client_trainloaders.append(train_loader)
            all_client_valloaders.append(test_loader)

100%|██████████████████████████████████████████████████████████████████████████████████| 36/36 [04:33<00:00,  7.60s/it]


These dataloaders can then be imported into each client using code along the lines of:

```python
def client_fn(cid):  
    model = resnet18().to(DEVICE)
    trainloader = all_client_trainloaders[int(cid)]
    valloader = all_client_valloaders[int(cid)]
    return FlowerClient(cid, model, trainloader, valloader)

fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=len(all_client_trainloaders),
config=fl.server.ServerConfig(num_rounds=1000),
strategy=FedAvg(),
client_resources=client_resources)
```