# Quickstart Pytorch Tutorial

In this tutorial, you will see how to use flox to run FL experiments on PyTorch using first a local executor and then using real physical endpoints. We will train our model to classify instances from the CIFAR10 dataset.

In [1]:
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch import Tensor

from flox.clients.PyTorchClient import PyTorchClient
from flox.controllers.PyTorchController import PyTorchController
from flox.model_trainers.PyTorchTrainer import PyTorchTrainer

logger = logging.getLogger(__name__)

### Getting Data

Firstly, let's get some test data so we can evaluate our model later on. The function below takes in a dictionary with varibles that specify the dataset, batch_size, etc., and returns a train and test ``torch.utils.data.DataLoader`` instance that we will use for testing our models.

In [2]:
def get_test_data(config):
    import torch
    import torchvision
    import torchvision.transforms as transforms

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    batch_size = config.get("batch_size", 32)
    dataset_name = config["dataset_name"]
    num_workers = config.get("num_workers", 8)
    root = config.get("data_root", "./data")

    # create train DataLoader
    trainset = dataset_name(root=root, train=True, download=True, transform=transform)

    train_split_len = (
        len(trainset) if "num_samples" not in config.keys() else config["num_samples"]
    )

    train_subpart = torch.utils.data.random_split(
        trainset, [train_split_len, len(trainset) - train_split_len]
    )[0]
    trainloader = torch.utils.data.DataLoader(
        train_subpart, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )

    # create test DataLoader
    testset = dataset_name(root=root, train=False, download=True, transform=transform)
    test_split_len = (
        len(trainset) if "num_samples" not in config.keys() else config["num_samples"]
    )

    test_subpart = torch.utils.data.random_split(
        testset, [test_split_len, len(testset) - test_split_len]
    )[0]
    testloader = torch.utils.data.DataLoader(
        test_subpart, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return trainloader, testloader

data_config = {
    "num_samples": 1000,
    "batch_size": 32,
    "dataset_name": torchvision.datasets.CIFAR10,
    "num_workers": 4,
}

_, testloader = get_test_data(data_config)

Files already downloaded and verified
Files already downloaded and verified


### Defining the model

Now, let's define our PyTorch model architecture.

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        import torch

        x = self.pool(torch.nn.functional.relu(self.conv1(x)))
        x = self.pool(torch.nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

### Instantiating Model Trainer and Client instances

Next, we will initialize an instance of a PyTorch Model Trainer and Client. You can check out their implementation under ``flox/model_trainers`` and ``flox/clients``, respectively. You can also extend or modify these classes to fit your needs.

In [4]:
torch_trainer = PyTorchTrainer(net)
torch_client = PyTorchClient()

### Instantiating the Controller (Local Execution)

Let's also define our endpoints and initialize the PyTorch *Controller* that will do the heavy lifting of deploying tasks to the endpoints. We will run three rounds of FL, with 100 samples and 1 training epoch on each device. Note that we are specifying ``executor_type`` to "local", which will use ``concurrent.futures.ThreadPoolExecutor`` to execute the tasks locally. We are also providing the dataset name and the test data. Finally, we'll launch the experiment.

In [5]:
# since we are first executing the experiment locally, it does not matter what we name the endpoints:
eps = ["simulated_endpoint_1", "simulated_endpoint_2", "simulated_endpoint_3"]
logger.info(f"Endpoints: {eps}")

flox_controller = PyTorchController(
    endpoint_ids=eps,
    num_samples=100,
    epochs=1,
    rounds=3,
    client_logic=torch_client,
    model_trainer=torch_trainer,
    executor_type="local",  # choose "funcx" for FuncXExecutor, "local" for ThreadPoolExecutor
    testloader=testloader,
    dataset_name=torchvision.datasets.CIFAR10,
)

# Finally, let's launch the experiment
logger.info("STARTING FL LOCAL TORCH FLOW...")
flox_controller.run_federated_learning()

1676800283.857043 2023-02-19 17:51:23 INFO MainProcess-24960 MainThread-23204 __main__:3 <module> Endpoints: ['simulated_endpoint_1', 'simulated_endpoint_2', 'simulated_endpoint_3']
1676800283.858041 2023-02-19 17:51:23 INFO MainProcess-24960 MainThread-23204 __main__:18 <module> STARTING FL LOCAL TORCH FLOW...
1676800283.859042 2023-02-19 17:51:23 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:166 on_model_init No executor was provided, trying to retrieve the provided executor type local from the list of available executors: {'local': <class 'concurrent.futures.thread.ThreadPoolExecutor'>, 'funcx': <class 'funcx.sdk.executor.FuncXExecutor'>}
1676800283.861044 2023-02-19 17:51:23 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:170 on_model_init The selected executor is <class 'concurrent.futures.thread.ThreadPoolExecutor'>
1676800285.333203 2023-02-19 17:51:25 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:209

Files already downloaded and verified
Files already downloaded and verified


1676800321.902282 2023-02-19 17:52:01 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:290 on_model_receive Starting to retrieve results from endpoints
1676800321.904283 2023-02-19 17:52:01 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:305 on_model_receive Finished retrieving all results from the endpoints
  a = np.asanyarray(a)
1676800321.937283 2023-02-19 17:52:01 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:335 on_model_aggregate Finished aggregating weights
1676800321.939285 2023-02-19 17:52:01 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:398 run_federated_learning Round 0 evaluation results: 
1676800330.456416 2023-02-19 17:52:10 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:209 on_model_broadcast Launching the <class 'concurrent.futures.thread.ThreadPoolExecutor'> executor
1676800330.456416 2023-02-19 17:52:10 INFO MainProcess-24960 MainThread-23204 fl

{'loss': 73.78003287315369, 'metrics': {'accuracy': 0.082}}


1676800350.927112 2023-02-19 17:52:30 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:247 on_model_broadcast Deployed the task to endpoint simulated_endpoint_1


Files already downloaded and verified
Files already downloaded and verified


1676800371.149551 2023-02-19 17:52:51 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:290 on_model_receive Starting to retrieve results from endpoints
1676800371.151550 2023-02-19 17:52:51 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:305 on_model_receive Finished retrieving all results from the endpoints
1676800371.152550 2023-02-19 17:52:51 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:335 on_model_aggregate Finished aggregating weights
1676800371.154548 2023-02-19 17:52:51 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:398 run_federated_learning Round 1 evaluation results: 
1676800378.159143 2023-02-19 17:52:58 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:209 on_model_broadcast Launching the <class 'concurrent.futures.thread.ThreadPoolExecutor'> executor
1676800378.159143 2023-02-19 17:52:58 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainCont

{'loss': 73.77825355529785, 'metrics': {'accuracy': 0.082}}


1676800396.352078 2023-02-19 17:53:16 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:247 on_model_broadcast Deployed the task to endpoint simulated_endpoint_1


Files already downloaded and verified
Files already downloaded and verified


1676800417.115994 2023-02-19 17:53:37 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:290 on_model_receive Starting to retrieve results from endpoints
1676800417.117954 2023-02-19 17:53:37 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:305 on_model_receive Finished retrieving all results from the endpoints
1676800417.118958 2023-02-19 17:53:37 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:335 on_model_aggregate Finished aggregating weights
1676800417.121957 2023-02-19 17:53:37 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:398 run_federated_learning Round 2 evaluation results: 


{'loss': 73.7755799293518, 'metrics': {'accuracy': 0.082}}


### Real Endpoint (FuncX) Execution 

Now, let's switch "endpoint_type" to "funcx" and provide actual endpoints. However, make sure to follow instructions in this directory's README to set up your clients. 

In [6]:
eps = ["fb93a1c2-a8d7-49f3-ad59-375f4e298784"]
logger.info(f"Endpoints: {eps}")

flox_controller = PyTorchController(
    endpoint_ids=eps,
    num_samples=100,
    epochs=1,
    rounds=3,
    client_logic=torch_client,
    model_trainer=torch_trainer,
    executor_type="funcx",  # choose "funcx" for FuncXExecutor, "local" for ThreadPoolExecutor
    testloader=testloader,
    dataset_name=torchvision.datasets.CIFAR10,
)

# Finally, let's launch the experiment
logger.info("STARTING FL FUNCX TORCH FLOW...")
flox_controller.run_federated_learning()

1676800423.875944 2023-02-19 17:53:43 INFO MainProcess-24960 MainThread-23204 __main__:2 <module> Endpoints: ['fb93a1c2-a8d7-49f3-ad59-375f4e298784']
1676800423.876940 2023-02-19 17:53:43 INFO MainProcess-24960 MainThread-23204 __main__:17 <module> STARTING FL FUNCX TORCH FLOW...
1676800423.877940 2023-02-19 17:53:43 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:166 on_model_init No executor was provided, trying to retrieve the provided executor type funcx from the list of available executors: {'local': <class 'concurrent.futures.thread.ThreadPoolExecutor'>, 'funcx': <class 'funcx.sdk.executor.FuncXExecutor'>}
1676800423.878938 2023-02-19 17:53:43 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:170 on_model_init The selected executor is <class 'funcx.sdk.executor.FuncXExecutor'>
1676800425.342097 2023-02-19 17:53:45 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:209 on_model_broadcast Launching the <class 'fu

{'loss': 73.77296876907349, 'metrics': {'accuracy': 0.082}}


1676800462.290899 2023-02-19 17:54:22 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:215 on_model_broadcast Starting to broadcast a task to endpoint fb93a1c2-a8d7-49f3-ad59-375f4e298784
1676800464.405772 2023-02-19 17:54:24 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:247 on_model_broadcast Deployed the task to endpoint fb93a1c2-a8d7-49f3-ad59-375f4e298784
1676800475.895391 2023-02-19 17:54:35 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:290 on_model_receive Starting to retrieve results from endpoints
1676800475.897361 2023-02-19 17:54:35 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:305 on_model_receive Finished retrieving all results from the endpoints
1676800475.901360 2023-02-19 17:54:35 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:335 on_model_aggregate Finished aggregating weights
1676800475.905354 2023-02-19 17:54:35 INFO MainProcess-24960 MainThrea

{'loss': 73.76958775520325, 'metrics': {'accuracy': 0.082}}


1676800483.926109 2023-02-19 17:54:43 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:215 on_model_broadcast Starting to broadcast a task to endpoint fb93a1c2-a8d7-49f3-ad59-375f4e298784
1676800485.526042 2023-02-19 17:54:45 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:247 on_model_broadcast Deployed the task to endpoint fb93a1c2-a8d7-49f3-ad59-375f4e298784
1676800497.148367 2023-02-19 17:54:57 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:290 on_model_receive Starting to retrieve results from endpoints
1676800497.149407 2023-02-19 17:54:57 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:305 on_model_receive Finished retrieving all results from the endpoints
1676800497.150407 2023-02-19 17:54:57 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:335 on_model_aggregate Finished aggregating weights
1676800497.153367 2023-02-19 17:54:57 INFO MainProcess-24960 MainThrea

{'loss': 73.76791477203369, 'metrics': {'accuracy': 0.082}}


### Real Endpoint (FuncX) Execution with Running Average

When we have lots and lots of endpoints, aggregating all of their updated model weights at the same time might be computationally heavy and time consuming. Thus, we can utilize the time by aggregating the models as they come back from the endpoints. In this example, we change the ``running_average`` variable to ``True`` in flox_controller and run the same experiment again.

In [8]:
eps = ["fb93a1c2-a8d7-49f3-ad59-375f4e298784"]
logger.info(f"Endpoints: {eps}")

flox_controller = PyTorchController(
    endpoint_ids=eps,
    num_samples=100,
    epochs=1,
    rounds=3,
    client_logic=torch_client,
    model_trainer=torch_trainer,
    executor_type="funcx",  # choose "funcx" for FuncXExecutor, "local" for ThreadPoolExecutor
    testloader=testloader,
    dataset_name=torchvision.datasets.CIFAR10,
    running_average=True,
)

# Finally, let's launch the experiment
logger.info("STARTING FL FUNCX TORCH FLOW...")
flox_controller.run_federated_learning()

1676800614.600824 2023-02-19 17:56:54 INFO MainProcess-24960 MainThread-23204 __main__:2 <module> Endpoints: ['fb93a1c2-a8d7-49f3-ad59-375f4e298784']
1676800614.602821 2023-02-19 17:56:54 INFO MainProcess-24960 MainThread-23204 __main__:18 <module> STARTING FL FUNCX TORCH FLOW...
1676800614.603822 2023-02-19 17:56:54 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:166 on_model_init No executor was provided, trying to retrieve the provided executor type funcx from the list of available executors: {'local': <class 'concurrent.futures.thread.ThreadPoolExecutor'>, 'funcx': <class 'funcx.sdk.executor.FuncXExecutor'>}
1676800614.604822 2023-02-19 17:56:54 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:170 on_model_init The selected executor is <class 'funcx.sdk.executor.FuncXExecutor'>
1676800616.099678 2023-02-19 17:56:56 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:209 on_model_broadcast Launching the <class 'fu

{'loss': 73.70644783973694, 'metrics': {'accuracy': 0.074}}


1676800658.246667 2023-02-19 17:57:38 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:215 on_model_broadcast Starting to broadcast a task to endpoint fb93a1c2-a8d7-49f3-ad59-375f4e298784
1676800660.193863 2023-02-19 17:57:40 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:247 on_model_broadcast Deployed the task to endpoint fb93a1c2-a8d7-49f3-ad59-375f4e298784
1676800691.254356 2023-02-19 17:58:11 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:437 tasks_to_running_average Starting to retrieve results from endpoints
1676800691.257351 2023-02-19 17:58:11 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:449 tasks_to_running_average the running average is NONE, instantiating it for the first time
1676800691.260353 2023-02-19 17:58:11 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:475 tasks_to_running_average Finished retrieving all results from the endpoints
1676800691.

{'loss': 73.67541027069092, 'metrics': {'accuracy': 0.12}}


1676800699.527335 2023-02-19 17:58:19 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:215 on_model_broadcast Starting to broadcast a task to endpoint fb93a1c2-a8d7-49f3-ad59-375f4e298784
1676800701.460689 2023-02-19 17:58:21 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:247 on_model_broadcast Deployed the task to endpoint fb93a1c2-a8d7-49f3-ad59-375f4e298784
1676800731.164322 2023-02-19 17:58:51 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:437 tasks_to_running_average Starting to retrieve results from endpoints
1676800731.165288 2023-02-19 17:58:51 DEBUG MainProcess-24960 MainThread-23204 flox.controllers.MainController:449 tasks_to_running_average the running average is NONE, instantiating it for the first time
1676800731.166315 2023-02-19 17:58:51 INFO MainProcess-24960 MainThread-23204 flox.controllers.MainController:475 tasks_to_running_average Finished retrieving all results from the endpoints
1676800731.

{'loss': 73.62477946281433, 'metrics': {'accuracy': 0.103}}
