# FLSim Tutorial: Adding a custom communication channel

### Introduction

In this tutorial, you will learn to implement your custom channel to allow the server and the clients to communicate with each other. A custom channel might be needed to simulate various real-life scenarios, for instance when compressing the model updates computed on the client side before sending them to the server, where the updates will be decompressed and then aggregated.

### Prerequisites

To get the most of this tutorial, you should be comfortable training machine learning models with FLSim. Moreover, if you are not familiar with standard compression techniques, please take a look at the following resources:
- [PyTorch intro to Scalar Quantization](https://pytorch.org/docs/stable/quantization.html) (in particular `int8`).
- Some papers related to the quantization setup in the context of Federated learning (FL) or Distributed Server-Side Training: [Quantized-SGD](https://arxiv.org/abs/1610.02132), [Atomo](https://proceedings.neurips.cc/paper/2018/file/33b3214d792caf311e1f00fd22b392c5-Paper.pdf), [Power-SGD
](https://proceedings.neurips.cc/paper/2019/file/d9fbed9da256e344c1fa46bb46c34c5f-Paper.pdf)

Now that you're familiar with FLSim and compression, let's move on!

### Objectives

In this tutorial, you will learn:
- The structure of a [`Message`](https://github.com/facebookresearch/FLSim/blob/main/channels/message.py#L15) used to communicate any relevant information between the server and the clients. 
- The API of a channel with the example of the [`IdentityChannel`](https://github.com/facebookresearch/FLSim/blob/main/channels/base_channel.py#L59). It implements a pass-through without any modification to the underlying message.
- How to implement a custom channel with the example of the [`ScalarQuantizationChannel`](https://github.com/facebookresearch/FLSim/blob/main/channels/scalar_quantization_channel.py#L23).
- How to perform a training in FLSim with your custom channel.
- How to measure the number of bytes sent from the client to the server and from the server to the client.

# 1 - The `Message` dataclass

First, before diving in to the channel component, what's a [`Message`](https://github.com/facebookresearch/FLSim/blob/main/channels/message.py#L15)?

In [211]:
from flsim.channels.message import Message
from flsim.tests import utils


net = utils.SampleNet(utils.TwoFC())
message = Message(net)

Simply put, a message contains:
- The model (`nn.Module`) that is being trained.
- Any meta information such as the weight used when aggregating updates from multiple clients.

Sometimes (see below for scalar quantization), it's easier for the channel to work on the `state_dict()` of the model. Hence, we allow ourselves to populate the message with the state dict of the model used to instantiate it. Vice-versa, we allow to update the model attribute of the message after manipulating and changing its state dict. It is the responsibility of the user to make sure that, then the state dict is populated, it coincides with the state dict of the model.

In [212]:
print("Before populating", message.model_state_dict)
message.populate_state_dict()
print("After populating", message.model_state_dict)

Before populating OrderedDict()
After populating OrderedDict([('fc1.weight', tensor([[ 0.5033,  0.6609],
        [ 0.3683, -0.1048],
        [ 0.4713,  0.4913],
        [-0.0736,  0.3803],
        [-0.1871,  0.0246]])), ('fc1.bias', tensor([-0.5928, -0.0246, -0.0602,  0.5279, -0.0213])), ('fc2.weight', tensor([[ 0.2424, -0.0060, -0.3095,  0.4453,  0.2060]])), ('fc2.bias', tensor([0.1402]))])


In [213]:
message.model_state_dict["fc1.weight"].fill_(0)
print("Before updating model with state dict", message.model.sample_nn.fc1.weight)
message.update_model_()
print("After updating model", message.model.sample_nn.fc1.weight)

Before updating model with state dict Parameter containing:
tensor([[ 0.5033,  0.6609],
        [ 0.3683, -0.1048],
        [ 0.4713,  0.4913],
        [-0.0736,  0.3803],
        [-0.1871,  0.0246]], requires_grad=True)
After updating model Parameter containing:
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], requires_grad=True)


# 2 - The Identity Channel

Open the file [`channels/base_channel.py`](https://github.com/facebookresearch/FLSim/blob/main/channels/base_channel.py). The public API of `IdentityChannel` is made of two main functions:
- [`client_to_server`](https://github.com/facebookresearch/FLSim/blob/main/channels/base_channel.py#L162) that performs three successive steps to send a message from a client to the server.
- [`server_to_client`](https://github.com/facebookresearch/FLSim/blob/main/channels/base_channel.py#L173) that performs three successive steps to send a message from the server to a client.


In [214]:
from flsim.channels.base_channel import FLChannelConfig
from hydra.utils import instantiate


config = FLChannelConfig()
identity_channel = instantiate(config)

Let's verify that the channel implements a pass-through, *i.e.* that the channel does not modify the message it transmits.

In [215]:
net = utils.SampleNet(utils.TwoFC())
message_before = Message(net)

message_after = identity_channel.client_to_server(message_before)
assert message_after == message_before 

message_after = identity_channel.server_to_client(message_before)
assert message_after == message_before 

# 3 - Implement your own channel

Now, let's dive deeper into the channel API. You may have noticed that the following things happen under the hood when calling the [`client_to_server`](https://github.com/facebookresearch/FLSim/blob/main/channels/base_channel.py#L162) method in the `IdentityChannel`: 

```python
message = self._on_client_before_transmission(message)
message = self._during_transmission_client_to_server(message)
message = self._on_server_after_reception(message)
```

Since any channel inherits `IdentityChannel`, *we only need to override the parts that change*. For instance, since we wish to implement the identity in the server->client direction, we do not override the corresponding three functions inside the `server_to_client` method.


Let's break these three steps down using the concrete example of the [`ScalarQuantization`](https://github.com/facebookresearch/FLSim/blob/main/channels/scalar_quantization_channel.py#L23) channel. The goal of this channel is to compress the model state dict *only* in the client->server direction and to implement the identity in the server->client direction.

In [216]:
from flsim.channels.scalar_quantization_channel import ScalarQuantizationChannelConfig
from hydra.utils import instantiate


config = ScalarQuantizationChannelConfig(n_bits=8, report_communication_metrics=True)
sq_channel = instantiate(config)

First, let's print the originak network weights

In [220]:
net = utils.SampleNet(utils.TwoFC())
message = Message(net)

print(message.model.fl_get_module().state_dict()["fc1.weight"])

tensor([[-0.0155,  0.6545],
        [-0.4280, -0.3730],
        [ 0.0649, -0.0991],
        [ 0.1245, -0.5659],
        [ 0.3703,  0.4525]])


**on_client_before_transmission**: here we need to
1. Populate the state dict of the message (see part 1 of this tutorial)
2. Quantize all parameters (except for the biases due to their small overhead)

In [221]:
# we quantize over 8 bits here, and return a PyTorch quantized tensor
# note that the weights have lost some resolution, which is expected
message = sq_channel._on_client_before_transmission(message)
print(message.model_state_dict["fc1.weight"])

tensor([[-0.0144,  0.6557],
        [-0.4260, -0.3733],
        [ 0.0670, -0.1005],
        [ 0.1244, -0.5648],
        [ 0.3685,  0.4547]], size=(5, 2), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.004786025732755661,
       zero_point=118)


**during_transmission_client_to_server**: here we can optionally measure the (compressed) message size, more on this in part 4 of this tutorial.

In [201]:
from flsim.channels.communication_stats import ChannelDirection


message = sq_channel._during_transmission_client_to_server(message)


direction = ChannelDirection.CLIENT_TO_SERVER
stats = sq_channel.stats_collector.get_channel_stats()[direction]
print(f"Number of bytes sent from client to server: {stats.mean()}")

Number of bytes sent from client to server: 63.0


**_on_server_after_reception**: here we need to: 
1. Update `message.model` to match the updated state dict (see part 1 of this tutorial).
2. Dequantize all the parameters that were quantized. 

In [195]:
# we are back to a PyTorch fp32 tensor
message = sq_channel._on_server_after_reception(message)
print(message.model_state_dict["fc1.weight"])

tensor([[-0.6959, -0.6632],
        [-0.5491, -0.5926],
        [ 0.6904, -0.1196],
        [-0.2827,  0.2338],
        [ 0.5926,  0.4675]])


# 4 - Measure message size

The message size is gathered through a [`StatsCollector`](https://github.com/facebookresearch/FLSim/blob/main/channels/communication_stats.py#L14) for both directions (client to server and server to client). Let's quickly check that the server->client message is larger than the client_server message (since this one is compressed).

Note that here the compresion raio is not 4x since we also need to transmit the per-tensor scales and zero points (and the weight matrices we are considering are rather small: for larger networks, we would be clost to 4x).

In [200]:
from flsim.channels.communication_stats import ChannelDirection


# we need to forward the message at least once to measure its size
message = sq_channel.server_to_client(message)


direction = ChannelDirection.SERVER_TO_CLIENT
stats = sq_channel.stats_collector.get_channel_stats()[direction]
print(f"Number of bytes sent from server to client: {stats.mean()}")

Number of bytes sent from server to client: 84.0


If you wish to go deeper, you need to override the [`_calc_message_size_client_to_server`](https://github.com/facebookresearch/FLSim/blob/main/channels/scalar_quantization_channel.py#L68) function to tailor the measurement of the message size for your custom channel. Please check [`ScalarQuantizationChannel`]([`_calc_message_size_client_to_server`](https://github.com/facebookresearch/FLSim/blob/main/channels/scalar_quantization_channel.py#L68) for an example!

# 5 - Training in FLSim with a custom channel

Simply specify the config of your channel in the training file, see example below. Then, refer to the training tutorial in FLSim.

In [196]:
json_config = {
    "trainer": {
        "_base_": "base_sync_trainer",
        # there are different types of aggegator
        # fed avg doesn't require lr, while others such as fed_sgd fed_adam do
        "aggregator": {"_base_": "base_fed_avg_sync_aggregator"},
        "client": {
            # number of client's local epochs
            "epochs": 1,
            "optimizer": {
                "_base_": "base_optimizer_sgd",
                # client's local learning rate
                "lr": 0.01,
                # client's local momentum
                "momentum": 0.9,
            },
        },
        # insert here your favourite channel along with its config!
        "channel": {
            "_base_": "base_scalar_quantization_channel",
            "n_bits": 8,
            "quantize_per_tensor": True,
        },
        # type of user selection sampling
        "active_user_selector": {"_base_": "base_sequential_active_user_selector"},
        # number of users per round for aggregation
        "users_per_round": 5,
        # total number of global epochs
        # total #rounds = ceil(total_users / users_per_round) * epochs
        "epochs": 1,
        # frequency of reporting train metrics
        "train_metrics_reported_per_epoch": 10,
        # frequency of evaluation per epoch
        "eval_epoch_frequency": 1,
        "do_eval": True,
        # should we report train metrics after global aggregation
        "report_train_metrics_after_aggregation": True,
    }
}