# An Example of the Horizontal Federated Learning Task

This is an example of running horizontal federated learning Delta Task on multiple Delta Nodes.

The data ([MNIST Dataset](http://yann.lecun.com/exdb/mnist/)) is distributed on several nodes with each node only having partial dataset.
And the task is to train a Convolutional Neural Network model to identify hand-writing digits.

This example could be executed in Deltaboard directly. <span style="color:#FF8F8F;font-weight:bold">Before hitting the run button, the Delta Node API address should be modified according to the user's config, the instructions are explained in section 4 below.</span>


## 1. Import the Required Packages

The computation logic is written in Torch. So we must import ```numpy``` and ```torch```, and some other helper tools. Then we need to import Delta Task framework components from Python package ```delta-task``` including ```DeltaNode``` for Delta Node API connection and ```HorizontalTask``` that we'll run in this example:

In [None]:
from typing import Dict, Iterable, List, Tuple, Any, Union

import numpy as np
import torch

from delta import DeltaNode
from delta.task import HorizontalTask
from delta.algorithm.horizontal import FedAvg

## 2. Define the Neural Network Model

Now let's define the CNN model, which is exactly the same as what we will do before:

In [None]:
class LeNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, 5, padding=2)
        self.pool1 = torch.nn.AvgPool2d(2, stride=2)
        self.conv2 = torch.nn.Conv2d(16, 16, 5)
        self.pool2 = torch.nn.AvgPool2d(2, stride=2)
        self.dense1 = torch.nn.Linear(400, 100)
        self.dense2 = torch.nn.Linear(100, 10)

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool2(x)

        x = x.view(-1, 400)
        x = self.dense1(x)
        x = torch.relu(x)
        x = self.dense2(x)
        return x

## 3. Define the Horizontal Federated Learning Task

The next step is to define our horizontal federated learning task to train the above model on multiple nodes.

There're several parts in the PPC Task that need to be programmed by the developer:

* ***Model Training Method***: Including what loss function and optimizer are used, and how to perform training steps.
* ***Data Pre-process Method***: Before performing training step, the function ```preprocess``` could be used to transform the training data. For detailed explanation of the arguments, please refer to [this document](https://docs.deltampc.com/network-deployment/prepare-data).
* ***Model Validation Method***: How to calculate precision score on each iteration.
* ***Horizontal Federated Learning Config***: The minimum/maximum number of nodes required to start an iteration, number of max steps, etc.


In [None]:
class ExampleTask(HorizontalTask):
    def __init__(self):
        super().__init__(
            name="example", # The task name which is used for displaying purpose.
            dataset="mnist", # The file/folder name of the dataset used. The file/folder should be placed under the data folder of all the Delta Nodes.
            max_rounds=2,  # The number of total rounds of training. In every round, all the nodes calculate their own partial results, and summit them to the server.
            validate_interval=1,  # The number of rounds after which we calculate a validation score.
            validate_frac=0.1,  # The ratio of samples for validate set in the whole dataset，range in (0,1)
        )
        
        # Pass in the NN model we just defined
        self.model = LeNet()
        
        # Define the loss function
        self.loss_func = torch.nn.CrossEntropyLoss()
        
        # Define the optimizer
        self.optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=0.1,
            momentum=0.9,
            weight_decay=1e-3,
            nesterov=True,
        )

    def preprocess(self, x, y=None):
        """
        The data pre-processing method.
        After data loading, every sample is passed through this method to be transformed.
        For the detailed explanation of the input arguments, please refer to https://docs.deltampc.com/network-deployment/prepare-data
        x: a sample from the dataset, the type depends on the data provided.
        y: the lable of the sample, None if no label is attached to the sample.
        return: the data and label after processing, the type should be torch.Tensor or np.ndarray
        """
        x /= 255.0
        x *= 2
        x -= 1
        x = x.reshape((1, 28, 28))
        return torch.from_numpy(x), torch.tensor(int(y), dtype=torch.long)

    def train(self, dataloader: Iterable):
        """
        The training step defination.
        dataloader: the dataloader corresponding to the dataset.
        return: None
        """
        for batch in dataloader:
            x, y = batch
            y_pred = self.model(x)
            loss = self.loss_func(y_pred, y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def validate(self, dataloader: Iterable) -> Dict[str, float]:
        """
        Validation method.
        To calculate validation scores on each node after several training steps.
        The result will also go through the secure aggregation before sending back to server.
        dataloader: the dataloader corresponding to the dataset.
        return: Dict[str, float], A dictionary with each key (str) corresponds to a score's name and the value (float) to the score's value.
        """
        total_loss = 0
        count = 0
        ys = []
        y_s = []
        for batch in dataloader:
            x, y = batch
            y_pred = self.model(x)
            loss = self.loss_func(y_pred, y)
            total_loss += loss.item()
            count += 1

            y_ = torch.argmax(y_pred, dim=1)
            y_s.extend(y_.tolist())
            ys.extend(y.tolist())
        avg_loss = total_loss / count
        tp = len([1 for i in range(len(ys)) if ys[i] == y_s[i]])
        precision = tp / len(ys)

        return {"loss": avg_loss, "precision": precision}

    def get_params(self) -> List[torch.Tensor]:
        """
        The params that need to be trained.
        Only the params returned by this function will be updated and saved during aggregation.
        return: List[torch.Tensor]， The list of model params.
        """
        return list(self.model.parameters())

    def algorithm(self):
        """
        Algorithm used to perform result aggregation. All the candidates are included in the package delta.algorithm.horizontal
        """
        return FedAvg(
            merge_interval_epoch=0,  # The number of epochs to run before aggregation is performed.
            merge_interval_iter=20,  # The number of iterations to run before aggregation is performed. One of this and the above number must be 0.
            wait_timeout=10,  # Timeout when waiting for node participanting.
            connection_timeout=10,  # Connection timeout in each communation in the aggreation algorithm.
            min_clients=2,  # Minimum nodes required in each round.
            max_clients=2,  # Maximum nodes allowed in each round.
        )

    def dataloader_config(
        self,
    ) -> Union[Dict[str, Any], Tuple[Dict[str, Any], Dict[str, Any]]]:
        """
        the config for dataloaders of training and validating，
        each config is a dictionary corresponding to the dataloader config of PyTorch.
        The details are in https://pytorch.org/docs/stable/data.html
        return: One or two Dict[str, Any]. When returning one dict, it is used for both training and validating dataloader.
        """
        train_config = {"batch_size": 64, "shuffle": True, "drop_last": True}
        val_config = {"batch_size": 64, "shuffle": False, "drop_last": False}
        return train_config, val_config


## 4. Set the API Address of the Delta Node

After defining the task details, we're ready to run the task on the Delta Nodes.

Delta Task framework could send the task to Delta Node directly, as long as the Delta Node API address is specified.

Here we use the Delta Node API provided by Deltaboard. Deltaboard provides a separate API address for each of its users, the tasks submitted by the API could be listed inside Deltaboard. The developer could also use API from Delta Node directly.

Click "Profiles" on the sidebar of Deltaboard, copy the API Address in Deltaboard API section, and paste it here:

In [None]:
DELTA_NODE_API = "http://127.0.0.1:6704"

## 5. Run the PPC Task

Finally we can start the task:

In [None]:
if __name__ == "__main__":
    task = ExampleTask()

    delta_node = DeltaNode(DELTA_NODE_API)
    delta_node.create_task(task)

## 6. Check the Running Status

After clicking the run button, some logs will be print out showing the task is submitted to the Delta Node successfully.

To see the task execution details, go to "My Tasks" on the sidebar of Deltaboard, the task should be listed.
Click the item to view the execution logs.