# Quickstart Pytorch Tutorial

In this tutorial, you will see how to use flox to run FL experiments on PyTorch using first a local executor and then using real physical endpoints. We will train our model to classify instances from the CIFAR10 dataset.

In [2]:
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch import Tensor

from flox.clients.PyTorchClient import PyTorchClient
from flox.controllers.PyTorchController import PyTorchController
from flox.model_trainers.PyTorchTrainer import PyTorchTrainer

logger = logging.getLogger(__name__)

### Getting Data

Firstly, let's get some test data so we can evaluate our model later on. The function below takes in a dictionary with varibles that specify the dataset, batch_size, etc., and returns a train and test ``torch.utils.data.DataLoader`` instance that we will use for testing our models.

In [4]:
def get_test_data(config):
    import torch
    import torchvision
    import torchvision.transforms as transforms

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    batch_size = config.get("batch_size", 32)
    dataset_name = config["dataset_name"]
    num_workers = config.get("num_workers", 8)
    root = config.get("data_root", "./data")

    # create train DataLoader
    trainset = dataset_name(root=root, train=True, download=True, transform=transform)

    train_split_len = (
        len(trainset) if "num_samples" not in config.keys() else config["num_samples"]
    )

    train_subpart = torch.utils.data.random_split(
        trainset, [train_split_len, len(trainset) - train_split_len]
    )[0]
    trainloader = torch.utils.data.DataLoader(
        train_subpart, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )

    # create test DataLoader
    testset = dataset_name(root=root, train=False, download=True, transform=transform)
    test_split_len = (
        len(trainset) if "num_samples" not in config.keys() else config["num_samples"]
    )

    test_subpart = torch.utils.data.random_split(
        testset, [test_split_len, len(testset) - test_split_len]
    )[0]
    testloader = torch.utils.data.DataLoader(
        test_subpart, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return trainloader, testloader

data_config = {
    "num_samples": 1000,
    "batch_size": 32,
    "dataset_name": torchvision.datasets.CIFAR10,
    "num_workers": 4,
}

_, testloader = get_test_data(data_config)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:47<00:00, 3586962.11it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


### Defining the model

Now, let's define our PyTorch model architecture.

In [5]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        import torch

        x = self.pool(torch.nn.functional.relu(self.conv1(x)))
        x = self.pool(torch.nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

### Instantiating Model Trainer and Client instances

Next, we will initialize an instance of a PyTorch Model Trainer and Client. You can check out their implementation under ``flox/model_trainers`` and ``flox/clients``, respectively. You can also extend or modify these classes to fit your needs.

In [6]:
torch_trainer = PyTorchTrainer(net)
torch_client = PyTorchClient()

### Instantiating the Controller (Local Execution)

Let's also define our endpoints and initialize the PyTorch *Controller* that will do the heavy lifting of deploying tasks to the endpoints. We will run three rounds of FL, with 100 samples and 1 training epoch on each device. Note that we are specifying ``executor_type`` to "local", which will use ``concurrent.futures.ThreadPoolExecutor`` to execute the tasks locally. We are also providing the dataset name and the test data. Finally, we'll launch the experiment.

In [7]:
# since we are first executing the experiment locally, it does not matter what we name the endpoints:
eps = ["simulated_endpoint_1", "simulated_endpoint_2", "simulated_endpoint_3"]
logger.info(f"Endpoints: {eps}")

flox_controller = PyTorchController(
    endpoint_ids=eps,
    num_samples=100,
    epochs=1,
    rounds=3,
    client_logic=torch_client,
    model_trainer=torch_trainer,
    executor_type="local",  # choose "funcx" for FuncXExecutor, "local" for ThreadPoolExecutor
    testloader=testloader,
    dataset_name=torchvision.datasets.CIFAR10,
)

# Finally, let's launch the experiment
logger.info("STARTING FL LOCAL TORCH FLOW...")
flox_controller.run_federated_learning()

1676522661.787290 2023-02-16 12:44:21 INFO MainProcess-18532 MainThread-11664 __main__:3 <module> Endpoints: ['simulated_endpoint_1', 'simulated_endpoint_2', 'simulated_endpoint_3']
1676522661.788293 2023-02-16 12:44:21 INFO MainProcess-18532 MainThread-11664 __main__:18 <module> STARTING FL TORCH FLOW...
1676522661.790294 2023-02-16 12:44:21 DEBUG MainProcess-18532 MainThread-11664 flox.controllers.MainController:166 on_model_init No executor was provided, trying to retrieve the provided executor type local from the list of available executors: {'local': <class 'concurrent.futures.thread.ThreadPoolExecutor'>, 'funcx': <class 'funcx.sdk.executor.FuncXExecutor'>}
1676522661.792294 2023-02-16 12:44:21 DEBUG MainProcess-18532 MainThread-11664 flox.controllers.MainController:170 on_model_init The selected executor is <class 'concurrent.futures.thread.ThreadPoolExecutor'>
1676522663.218177 2023-02-16 12:44:23 DEBUG MainProcess-18532 MainThread-11664 flox.controllers.MainController:209 on_mo

Files already downloaded and verified
Files already downloaded and verified


1676522707.808488 2023-02-16 12:45:07 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:290 on_model_receive Starting to retrieve results from endpoints
1676522707.809497 2023-02-16 12:45:07 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:305 on_model_receive Finished retrieving all results from the endpoints
  a = np.asanyarray(a)
1676522707.813534 2023-02-16 12:45:07 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:335 on_model_aggregate Finished aggregating weights
1676522707.815490 2023-02-16 12:45:07 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:398 run_federated_learning Round 0 evaluation results: 
1676522717.986862 2023-02-16 12:45:17 DEBUG MainProcess-18532 MainThread-11664 flox.controllers.MainController:209 on_model_broadcast Launching the <class 'concurrent.futures.thread.ThreadPoolExecutor'> executor
1676522717.987861 2023-02-16 12:45:17 INFO MainProcess-18532 MainThread-11664 fl

{'loss': 73.71587228775024, 'metrics': {'accuracy': 0.094}}


1676522734.670792 2023-02-16 12:45:34 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:247 on_model_broadcast Deployed the task to endpoint simulated_endpoint_1


Files already downloaded and verified
Files already downloaded and verified


1676522764.506683 2023-02-16 12:46:04 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:290 on_model_receive Starting to retrieve results from endpoints
1676522764.507685 2023-02-16 12:46:04 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:305 on_model_receive Finished retrieving all results from the endpoints
1676522764.508685 2023-02-16 12:46:04 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:335 on_model_aggregate Finished aggregating weights
1676522764.510688 2023-02-16 12:46:04 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:398 run_federated_learning Round 1 evaluation results: 
1676522774.168290 2023-02-16 12:46:14 DEBUG MainProcess-18532 MainThread-11664 flox.controllers.MainController:209 on_model_broadcast Launching the <class 'concurrent.futures.thread.ThreadPoolExecutor'> executor
1676522774.169294 2023-02-16 12:46:14 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainCont

{'loss': 73.71451902389526, 'metrics': {'accuracy': 0.094}}


1676522791.292581 2023-02-16 12:46:31 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:247 on_model_broadcast Deployed the task to endpoint simulated_endpoint_1


Files already downloaded and verified
Files already downloaded and verified


1676522816.476501 2023-02-16 12:46:56 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:290 on_model_receive Starting to retrieve results from endpoints
1676522816.477504 2023-02-16 12:46:56 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:305 on_model_receive Finished retrieving all results from the endpoints
1676522816.479505 2023-02-16 12:46:56 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:335 on_model_aggregate Finished aggregating weights
1676522816.480503 2023-02-16 12:46:56 INFO MainProcess-18532 MainThread-11664 flox.controllers.MainController:398 run_federated_learning Round 2 evaluation results: 


{'loss': 73.71229720115662, 'metrics': {'accuracy': 0.094}}


### Real Endpoint (FuncX) Execution 

Now, let's switch "endpoint_type" to "funcx" and provide actual endpoints. However, make sure to follow instructions in this directory's README to set up your clients. 

In [8]:
eps = ["fb93a1c2-a8d7-49f3-ad59-375f4e298784", "c7487b2b-b129-47e2-989b-5a9ac361befc"]
logger.info(f"Endpoints: {eps}")

flox_controller = PyTorchController(
    endpoint_ids=eps,
    num_samples=100,
    epochs=1,
    rounds=3,
    client_logic=torch_client,
    model_trainer=torch_trainer,
    executor_type="funcx",  # choose "funcx" for FuncXExecutor, "local" for ThreadPoolExecutor
    testloader=testloader,
    dataset_name=torchvision.datasets.CIFAR10,
)

# Finally, let's launch the experiment
logger.info("STARTING FL FUNCX TORCH FLOW...")
flox_controller.run_federated_learning()

1676523188.972713 2023-02-16 12:53:08 INFO MainProcess-18532 MainThread-11664 __main__:2 <module> Endpoints: ['fb93a1c2-a8d7-49f3-ad59-375f4e298784', 'c7487b2b-b129-47e2-989b-5a9ac361befc']
1676523188.973714 2023-02-16 12:53:08 INFO MainProcess-18532 MainThread-11664 __main__:17 <module> STARTING FL FUNCX TORCH FLOW...
1676523188.975717 2023-02-16 12:53:08 DEBUG MainProcess-18532 MainThread-11664 flox.controllers.MainController:166 on_model_init No executor was provided, trying to retrieve the provided executor type funcx from the list of available executors: {'local': <class 'concurrent.futures.thread.ThreadPoolExecutor'>, 'funcx': <class 'funcx.sdk.executor.FuncXExecutor'>}
1676523188.976714 2023-02-16 12:53:08 DEBUG MainProcess-18532 MainThread-11664 flox.controllers.MainController:170 on_model_init The selected executor is <class 'funcx.sdk.executor.FuncXExecutor'>
1676523190.428348 2023-02-16 12:53:10 DEBUG MainProcess-18532 MainThread-11664 flox.controllers.MainController:209 on_

ValueError: The tasks queue is empty, no tasks were submitted for training!