# Model-Centric Federated Learning for Mobile - MNIST Example

This notebook will walk you through creating a simple model and a training plan, and hosting both as a federated learning process
for further training using OpenMined mobile FL workers.

This notebook is similar to "[MCFL - Create Plan](mcfl_create_plan.ipynb)"
however due to mobile limitations, the training plan is different.

In [None]:
# stdlib
import base64
import json

# third party
import torch as th

# syft absolute
import syft as sy
from syft.core.plan.plan_builder import ROOT_CLIENT
from syft.core.plan.plan_builder import PLAN_BUILDER_VM
from syft.core.plan.plan_builder import make_plan
from syft.core.plan.translation.torchscript.plan_translate import (
    translate as translate_to_ts,
)
from syft.federated.model_centric_fl_client import ModelCentricFLClient
from syft.lib.python.int import Int
from syft.lib.python.list import List

In [None]:
th.random.manual_seed(42)

## Step 1: Define the model

This model will train on MNIST data, it's very simple yet can demonstrate learning process.
There're 2 linear layers: 

* Linear 784x100
* ReLU
* Linear 100x10 

Note that the model contains additional methods for convenience of torch reference usage:

* `backward` - calculates backward pass gradients because autograd doesn't work on mobile (yet).
* `softmax_cross_entropy_with_logits` - loss function
* `accuracy` - calculates accuracy of prediction

In [None]:
class MLP(sy.Module):
    """
    Simple model with method for loss and hand-written backprop.
    """

    def __init__(self, torch_ref) -> None:
        super(MLP, self).__init__(torch_ref=torch_ref)
        self.fc1 = torch_ref.nn.Linear(784, 100)
        self.relu = torch_ref.nn.ReLU()
        self.fc2 = torch_ref.nn.Linear(100, 10)


    def forward(self, x):
        self.z1 = self.fc1(x)
        self.a1 = self.relu(self.z1)
        return self.fc2(self.a1)

    def backward(self, X, error):
        z1_grad = (error @ self.fc2.state_dict()["weight"]) * (self.a1 > 0).float()
        fc1_weight_grad = z1_grad.t() @ X
        fc1_bias_grad = z1_grad.sum(0)
        fc2_weight_grad = error.t() @ self.a1
        fc2_bias_grad = error.sum(0)
        return fc1_weight_grad, fc1_bias_grad, fc2_weight_grad, fc2_bias_grad

    def softmax_cross_entropy_with_logits(self, logits, target, batch_size):
        probs = self.torch_ref.softmax(logits, dim=1)
        loss = -(target * self.torch_ref.log(probs)).sum(dim=1).mean()
        loss_grad = (probs - target) / batch_size
        return loss, loss_grad

    def accuracy(self, logits, targets, batch_size):
        pred = self.torch_ref.argmax(logits, dim=1)
        targets_idx = self.torch_ref.argmax(targets, dim=1)
        acc = pred.eq(targets_idx).sum().float() / batch_size
        return acc

## Step 2: Define Training Plan

In [None]:
def set_remote_model_params(module_ptrs, params_list_ptr):
    """Sets the model parameters into traced model"""
    param_idx = 0
    for module_name, module_ptr in module_ptrs.items():
        for param_name, _ in PLAN_BUILDER_VM.store[
            module_ptr.id_at_location
        ].data.named_parameters():
            module_ptr.register_parameter(param_name, params_list_ptr[param_idx])
            param_idx += 1

# Create the model
local_model = MLP(th)

# Dummy inputs
bs = 3
classes_num = 10
model_params_zeros = sy.lib.python.List(
    [th.nn.Parameter(th.zeros_like(param)) for param in local_model.parameters()]
)

@make_plan
def training_plan(
    xs=th.randn(bs, 28 * 28),
    ys=th.nn.functional.one_hot(th.randint(0, classes_num, [bs]), classes_num),
    batch_size=th.tensor([bs]),
    lr=th.tensor([0.1]),
    params=model_params_zeros,
):
    # send the model to plan builder (but not its default params)
    # this is required to build the model inside the Plan
    model = local_model.send(ROOT_CLIENT, send_parameters=False)

    # set model params from input
    set_remote_model_params(model.modules, params)

    # forward
    logits = model(xs)

    # loss
    loss, loss_grad = model.softmax_cross_entropy_with_logits(
        logits, ys, batch_size
    )

    # backward
    grads = model.backward(xs, loss_grad)

    # SGD step
    updated_params = tuple(
        param - lr * grad for param, grad in zip(model.parameters(), grads)
    )

    # accuracy
    acc = model.accuracy(logits, ys, batch_size)

    # return things
    return (loss, acc, *updated_params)

Translate the training plan to torchscript so it can be used with mobile workers.

In [None]:
# Translate to torchscript
ts_plan = translate_to_ts(training_plan)

# Let's examine its contents
print(ts_plan.torchscript.code)

## Step 3: Define Averaging Plan

Averaging Plan is executed by PyGrid at the end of the cycle,
to average _diffs_ submitted by workers and update the model
and create new checkpoint for the next cycle.

_Diff_ is the difference between client-trained
model params and original model params,
so it has same number of tensors and tensor's shapes
as the model parameters.

We define Plan that processes one diff at a time.
Such Plans require `iterative_plan` flag set to `True`
in `server_config` when hosting FL model to PyGrid.

Plan below will calculate simple mean of each parameter.

In [None]:
@make_plan
def avg_plan(
    avg=List(local_model.parameters()), item=List(local_model.parameters()), num=Int(0)
):
    new_avg = []
    for i, param in enumerate(avg):
        new_avg.append((avg[i] * num + item[i]) / (num + 1))
    return new_avg

## Step 4: Define Federated Learning Process Configuration

Before hosting the model and training plan to PyGrid,
we need to define some configuration parameters, such as
FL process name, version, workers configuration,
authentication method, etc.

In [None]:
name = "mnist"
version = "1.0"

client_config = {
    "name": name,
    "version": version,
    "batch_size": 64,
    "lr": 0.01,
    "max_updates": 100,  # number of updates to execute on workers
}

server_config = {
    "num_cycles": 30,  # total number of cycles (how many times global model is updated)
    "cycle_length": 60*60*24,  # max duration of the training cycle in seconds
    "max_diffs": 1,  # number of diffs to collect before updating global model
    "minimum_upload_speed": 0,
    "minimum_download_speed": 0,
    "iterative_plan": True,  # tells PyGrid that avg plan is executed per diff
}

This FL process will require workers to authenticate with signed JWT token.
Providing the `pub_key` in FL configuration allows PyGrid to verify JWT tokens.

In [None]:
def read_file(fname):
    with open(fname, "r") as f:
        return f.read()

In [None]:
public_key = read_file("example_rsa.pub").strip()

server_config["authentication"] = {
    "type": "jwt",
    "pub_key": public_key,
}

## Step 5: Host in PyGrid

Let's now host everything in PyGrid so that it can be accessed by worker libraries.

Note: assuming the PyGrid Domain is running locally on port 7000.

### Step 5.1: Start a PyGrid Domain
- Clone PyGrid Github repository from https://github.com/OpenMined/PyGrid

- Install poetry using pip:
```
$ pip install poetry
```

- Go to apps/domain and install requirements:
```
$ poetry install
```

- run a Grid domain using the command:
```
$ ./run.sh --name bob --port 7000 --start_local_db
```

In [None]:
from syft.grid.client.client import connect
from syft.grid.client.grid_connection import (GridHTTPConnection,) 

domain = connect(
    url="http://localhost:7000", 
    conn_type=GridHTTPConnection,
)

domain.setup(
    email="owner@openmined.org",
    password="owerpwd",
    domain_name="OpenMined Node",
    token="9G9MJ06OQH",
)

In [None]:
grid_address = "localhost:7000"
grid = ModelCentricFLClient(address=grid_address, secure=False)
grid.connect()

Following code sends FL model, training plans, and configuration to the PyGrid:

In [None]:
response = grid.host_federated_training(
    model=local_model,
    client_plans={
        # Grid can store both types of plans (regular for python worker, torchscript for mobile):
        "training_plan": training_plan,
        "training_plan:ts": ts_plan,
    },
    client_protocols={},
    server_averaging_plan=avg_plan,
    client_config=client_config,
    server_config=server_config,
)

In [None]:
response

If you see successful response, you've just hosted your first FL process into PyGrid!

If you see error that FL process already exists,
this means FL process with such name and version is already hosted.
You might want to update name/version in configuration above, or cleanup PyGrid database.

To cleanup database, set path below correctly and run:

In [1]:
!rm ~/Projects/PyGrid/apps/domain/src/nodedatabase.db


To train hosted model, use one of the existing mobile FL workers:
 * [SwiftSyft](https://github.com/OpenMined/SwiftSyft) (see included worker example)
 * [KotlinSyft](https://github.com/OpenMined/KotlinSyft) (see included worker example)

Support for javascript worker is coming soon:
 * [syft.js](https://github.com/OpenMined/syft.js)
