# Model-Centric Federated Learning for Mobile - MNIST Example

This notebook will walk you through creating a simple model and a training plan, and hosting both as a federated learning process
for further training using OpenMined mobile FL workers.

This notebook is similar to "[MCFL - Create Plan](mcfl_create_plan.ipynb)"
however due to mobile limitations, the training plan is different.

In [1]:
# stdlib
import base64
import json

# third party
import torch as th

# syft absolute
import syft as sy
from syft.core.plan.plan_builder import ROOT_CLIENT
from syft.core.plan.plan_builder import PLAN_BUILDER_VM
from syft.core.plan.plan_builder import make_plan
from syft.core.plan.translation.torchscript.plan_translate import (
    translate as translate_to_ts,
)
from syft.federated.model_centric_fl_client import ModelCentricFLClient
from syft.lib.python.int import Int
from syft.lib.python.list import List
from itertools import zip_longest

In [2]:
th.random.manual_seed(42)

<torch._C.Generator at 0x7f2d200c6310>

## Step 1: Define the model

This model will train on MNIST data, it's very simple yet can demonstrate learning process.
There're 2 linear layers: 

* Linear 784x100
* ReLU
* Linear 100x10 

Note that the model contains additional methods for convenience of torch reference usage:

* `backward` - calculates backward pass gradients because autograd doesn't work on mobile (yet).
* `softmax_cross_entropy_with_logits` - loss function
* `accuracy` - calculates accuracy of prediction

Our knowledge
1. grads for FCN ==> weight/grad/weight/grad ... From earliest to latest. Weight expand from later to earlier


In [16]:
class EmbeddingNet(sy.Module): #This version follows hinnes' implmentation (first assume only one user and using numeric attributes of the song)
    """
    Simple model with method for loss and hand-written backprop.
    """

    def __init__(self, torch_ref,n_users) -> None:
        super(EmbeddingNet, self).__init__(torch_ref=torch_ref)
        #This is purely fully connected.
        # def gen_layers(n_in):
        #     """
        #     A generator that yields a sequence of hidden layers and 
        #     their activations/dropouts.
            
        #     Note that the function captures `hidden` and `dropouts` 
        #     values from the outer scope.
        #     """
        #     nonlocal hidden, dropouts
        #     assert len(dropouts) <= len(hidden)
            
        #     for n_out, rate in zip_longest(hidden, dropouts):
        #         yield torch_ref.nn.Linear(n_in, n_out)
        #         yield torch_ref.nn.ReLU()
        #         if rate is not None and rate > 0.:
        #             yield torch_ref.nn.Dropout(rate)
        #         n_in = n_out
        self.torch_ref = torch_ref
        self.embedlayer = torch_ref.nn.Embedding(n_users,50)
        self.embeddrop = torch_ref.nn.Dropout(0.02)
        self.fc1 = torch_ref.nn.Linear(60,50)
        self.a1 = torch_ref.nn.ReLU()
        self.d1 = torch_ref.nn.Dropout(0.25)
        self.fc2 = torch_ref.nn.Linear(50,100)
        self.a2 = torch_ref.nn.ReLU()
        self.d2 = torch_ref.nn.Dropout(0.5)
        self.fc3 = torch_ref.nn.Linear(100,75)
        self.a3 = torch_ref.nn.ReLU()
        self.fc4 = torch_ref.nn.Linear(75,1)
        #layersizes contain the size of each layer, including output but exclude input


    def forward(self, users,features):
        self.embedout = self.embedlayer(users.long()) 
        x = th.cat((self.embedout.get(),features.get()),dim=1)
        self.catout = self.embeddrop(x)
        x = self.fc1(self.catout)
        x = self.a1(x)
        self.l1out = self.d1(x)
        y = self.fc2(self.l1out)
        y = self.a2(y)
        self.l2out = self.d2(y)
        z = self.fc3(self.l2out)
        self.l3out = self.a3(z)
        k = self.fc4(self.l3out)
        self.l4out = self.torch_ref.sigmoid(k)
        return self.l4out
    
    def one_hot(self,user,num):
        y = []
        user = user.get()
        for use in user:
            x = [0] * num
            x[use] = 1
            y.append(x)
        return th.tensor(y,dtype=th.int32)

    def backward(self,user,error): #BATCH_SIZE=1 is assumed.
        pgs,flayers,aouts = [],[],[]
        latest_grad = error
        aouts.extend([self.catout,self.l1out,self.l2out,self.l3out])
        flayers.extend([self.fc2,self.fc3,self.fc4])
        for i in range(len(aouts)):
            j = len(aouts) - i - 1
            pgs.append(latest_grad.sum(0)) # bias grad
            pgs.append(latest_grad.t() @ aouts[j]) # weight grad
            # print(latest_grad.shape, self.flayers[j-1].state_dict()['weight'].shape, self.aouts[j].shape)
            if j-1 >= 0:
                latest_grad = (latest_grad @ flayers[j-1].state_dict()['weight']) * (aouts[j] > 0).float() 
            else: #Embedding layer
                latest_grad = (latest_grad @ self.fc1.state_dict()['weight']) #no ReLU in embedding
                latest_grad = latest_grad[:,:50]
        user = self.one_hot(user,9)
        embedgrad = latest_grad.t() @ user.float()
        pgs.append(embedgrad.t())
        pgs.reverse()
        return tuple(pgs)
    
    def get_model_parameters(self):
        params = []
        for i,layer in enumerate(self.hiddens):
            if i % 3 == 0 :
                params.append(layer.state_dict()['weight'])
                params.append(layer.state_dict()['bias'])
        params.append(self.fc.state_dict()['weight'])
        params.append(self.fc.state_dict()['bias'])
        return params

    def mse_loss(self, fpass, target, batch_size):# MSE first
        squared_error = self.torch_ref.pow(fpass - target,2)
        loss = squared_error.mean()
        loss_grad = (target-fpass) * fpass * (1-fpass) / batch_size
        return loss,loss_grad


In [6]:
model = EmbeddingNet(th,n_users=9)


# model.get_paremeters()
xs = th.randn(32,10)
xs2 = th.randint(1,8,(32,))
fpass = model(xs2,xs)
lr = 0.001
bs = 32
batch_size = th.tensor([bs])
ys=th.rand(32,1)
# # # print(ys)
loss, loss_grad = model.mse_loss(fpass,ys,batch_size)
# # # print(loss,loss_grad)
# # print(model.flayers)
# # # # backward
print("Results here:")
grads = model.backward(xs2, loss_grad)
# print(grads.shape)
for x in grads:
    print(x.shape)
print("Results sdf here:")
for y in model.parameters():
    print(y.shape)
# updated_params = tuple(
#         param - lr * grad for param, grad in zip(model.get_model_parameters(), grads)
#     )
# print(model.fc2.state_dict()["weight"].shape)


AttributeError: 'Tensor' object has no attribute 'get'

## Step 2: Define Training Plan

In [17]:
def set_remote_model_params(module_ptrs, params_list_ptr):
    """Sets the model parameters into traced model"""
    param_idx = 0
    for module_name, module_ptr in module_ptrs.items():
        for param_name, _ in PLAN_BUILDER_VM.store[
            module_ptr.id_at_location
        ].data.named_parameters():
            module_ptr.register_parameter(param_name, params_list_ptr[param_idx])
            param_idx += 1

# Create the model
local_model = EmbeddingNet(th,n_users=9)

# Dummy inputs
bs = 32
model_params_zeros = sy.lib.python.List(
    [th.nn.Parameter(th.zeros_like(param)) for param in local_model.parameters()]
)

@make_plan
def training_plan(
    xs=th.randint(1,8,(bs,)),
    xs2=th.randn(bs, 10),
    ys=th.rand(bs,1),
    batch_size=th.tensor([bs]),
    lr=th.tensor([0.01]),
    params=model_params_zeros,
):
    # send the model to plan builder (but not its default params)
    # this is required to build the model inside the Plan
    model = local_model.send(ROOT_CLIENT, send_parameters=False)

    # set model params from input
    set_remote_model_params(model.modules, params)
    
    print(xs,xs2)
    # forward
    fpass = model(xs,xs2)

    # loss
    loss, loss_grad = model.mse_loss(fpass,ys,batch_size)

    # backward
    grads = model.backward(xs, loss_grad)

    # SGD step
    updated_params = tuple(
        param - lr * grad for param, grad in zip(model.parameters(), grads)
    )

    # return things
    return (loss, *updated_params)

<syft.proxy.torch.TensorPointer object at 0x7f2c85bd44c0> <syft.proxy.torch.TensorPointer object at 0x7f2c85bd4400>


  grad = getattr(obj, "grad", None)


Translate the training plan to torchscript so it can be used with mobile workers.

In [18]:
# Translate to torchscript
ts_plan = translate_to_ts(training_plan)

# Let's examine its contents
print(ts_plan.torchscript.code)

[2021-12-22T03:20:09.543137+0000][CRITICAL][logger]][303705] <class 'syft.core.store.store_memory.MemoryStore'> __getitem__ error <UID: 54316dcb944d483fa5ede01753452ac8> <UID: 54316dcb944d483fa5ede01753452ac8>
[2021-12-22T03:20:09.544916+0000][CRITICAL][logger]][303705] <UID: 54316dcb944d483fa5ede01753452ac8>


KeyError: <UID: 54316dcb944d483fa5ede01753452ac8>

## Step 3: Define Averaging Plan

Averaging Plan is executed by PyGrid at the end of the cycle,
to average _diffs_ submitted by workers and update the model
and create new checkpoint for the next cycle.

_Diff_ is the difference between client-trained
model params and original model params,
so it has same number of tensors and tensor's shapes
as the model parameters.

We define Plan that processes one diff at a time.
Such Plans require `iterative_plan` flag set to `True`
in `server_config` when hosting FL model to PyGrid.

Plan below will calculate simple mean of each parameter.

In [30]:
@make_plan
def avg_plan(
    avg=List(local_model.get_model_parameters()), item=List(local_model.get_model_parameters()), num=Int(0)
):
    new_avg = []
    for i, param in enumerate(avg):
        new_avg.append((avg[i] * num + item[i]) / (num + 1))
    return new_avg

## Step 4: Define Federated Learning Process Configuration

Before hosting the model and training plan to PyGrid,
we need to define some configuration parameters, such as
FL process name, version, workers configuration,
authentication method, etc.

In [31]:
name = "mnist"
version = "1.0"

client_config = {
    "name": name,
    "version": version,
    "batch_size": 64,
    "lr": 0.01,
    "max_updates": 100,  # number of updates to execute on workers
}

server_config = {
    "num_cycles": 30,  # total number of cycles (how many times global model is updated)
    "cycle_length": 60*60*24,  # max duration of the training cycle in seconds
    "max_diffs": 1,  # number of diffs to collect before updating global model
    "minimum_upload_speed": 0,
    "minimum_download_speed": 0,
    "iterative_plan": True,  # tells PyGrid that avg plan is executed per diff
}

This FL process will require workers to authenticate with signed JWT token.
Providing the `pub_key` in FL configuration allows PyGrid to verify JWT tokens.

In [32]:
def read_file(fname):
    with open(fname, "r") as f:
        return f.read()

In [33]:
public_key = read_file("example_rsa.pub").strip()

server_config["authentication"] = {
    "type": "jwt",
    "pub_key": public_key,
}

## Step 5: Host in PyGrid

Let's now host everything in PyGrid so that it can be accessed by worker libraries.

Note: assuming the PyGrid Domain is running locally on port 7000.

### Step 5.1: Start a PyGrid Domain
- Clone PyGrid Github repository from https://github.com/OpenMined/PyGrid

- Install poetry using pip:
```
$ pip install poetry
```

- Go to apps/domain and install requirements:
```
$ poetry install
```

- run a Grid domain using the command:
```
$ ./run.sh --name bob --port 7000 --start_local_db
```

In [34]:
grid_address = "localhost:7000"
grid = ModelCentricFLClient(address=grid_address, secure=False)
grid.connect()

Following code sends FL model, training plans, and configuration to the PyGrid:

In [35]:
response = grid.host_federated_training(
    model=local_model,
    client_plans={
        # Grid can store both types of plans (regular for python worker, torchscript for mobile):
        "training_plan": training_plan,
        "training_plan:ts": ts_plan,
    },
    client_protocols={},
    server_averaging_plan=avg_plan,
    client_config=client_config,
    server_config=server_config,
)

In [36]:
response

{'type': 'model-centric/host-training', 'data': {'status': 'success'}}

If you see successful response, you've just hosted your first FL process into PyGrid!

If you see error that FL process already exists,
this means FL process with such name and version is already hosted.
You might want to update name/version in configuration above, or cleanup PyGrid database.

To cleanup database, set path below correctly and run:

In [37]:
# !rm PyGrid/apps/domain/src/nodedatabase.db


To train hosted model, use one of the existing mobile FL workers:
 * [SwiftSyft](https://github.com/OpenMined/SwiftSyft) (see included worker example)
 * [KotlinSyft](https://github.com/OpenMined/KotlinSyft) (see included worker example)

Support for javascript worker is coming soon:
 * [syft.js](https://github.com/OpenMined/syft.js)
