# Model-Centric Federated Learning for Mobile - Spotify Recommendation

This notebook is adapted from the PySyft model-centric federated learning [MNIST example](https://github.com/OpenMined/PySyft/blob/syft_0.5.0/packages/syft/examples/federated-learning/model-centric/mcfl_create_plan_mobile.ipynb), modified for training a Spotify track recommendation model.

This is a simple mode. It will take in a user index and some features of the tracks such as tempo, and predict whether the user likes the track.

In [None]:
# stdlib
import base64
import json

# third party
import torch as th

# syft absolute
import syft as sy
from syft.core.plan.plan_builder import ROOT_CLIENT
from syft.core.plan.plan_builder import PLAN_BUILDER_VM
from syft.core.plan.plan_builder import make_plan
from syft.core.plan.translation.torchscript.plan_translate import (
    translate as translate_to_ts,
)
from syft.federated.model_centric_fl_client import ModelCentricFLClient
from syft.lib.python.int import Int
from syft.lib.python.list import List
from itertools import zip_longest

In [None]:
th.random.manual_seed(42)

## Model and Training Configurations

* n_users: number of users joining the training

In [None]:
# Parameters
n_users = 10
song_features = 10
bs = 64  # batch_size
lr = 5e-4
embedding_size = 50
# Number of layer is hard-coded, any changes on no. of layers would require change of the EmbeddingNet class too.
layer_sizes = [(embedding_size + song_features, 150), (150, 300), (300, 200), (200, 1)]


## Step 1: Define the model

Let's create a embedding net, which contains an embedding layer embedding user identify, and a fully connected net that has concat of embedding output and the song features as input.

In [None]:
class EmbeddingNet(
    sy.Module
):  # This version follows hinnes' implmentation (first assume only one user and using numeric attributes of the song)
    """
    Simple model with method for loss and hand-written backprop.
    """

    def __init__(self, torch_ref) -> None:
        super(EmbeddingNet, self).__init__(torch_ref=torch_ref)
        self.torch_ref = torch_ref
        self.embedlayer = torch_ref.nn.Linear(n_users, embedding_size)
        torch_ref.nn.init.constant_(self.embedlayer.bias.data, 0)
        self.embeddrop = torch_ref.nn.Dropout(0.02)
        self.fc1 = torch_ref.nn.Linear(layer_sizes[0][0], layer_sizes[0][1])
        self.a1 = torch_ref.nn.ReLU()
        self.d1 = torch_ref.nn.Dropout(0.3)
        self.fc2 = torch_ref.nn.Linear(layer_sizes[1][0], layer_sizes[1][1])
        self.a2 = torch_ref.nn.ReLU()
        self.d2 = torch_ref.nn.Dropout(0.3)
        self.fc3 = torch_ref.nn.Linear(layer_sizes[2][0], layer_sizes[2][1])
        self.a3 = torch_ref.nn.ReLU()
        self.fc4 = torch_ref.nn.Linear(layer_sizes[3][0], layer_sizes[3][1])

    def forward(self, users, features, x):
        """
        users: a one-hot tensor of size (n_users,) representing the user.
        features: 10d vector using spotify-provided feature values
        x: a 60d dummy vector required from user
        """
        self.embedout = self.embedlayer(users.float())
        self.embedout = self.embeddrop(self.embedout)
        x[:, :embedding_size] = self.embedout
        x[:, embedding_size : embedding_size + song_features] = features
        self.catout = x
        x = self.fc1(self.catout)
        x = self.a1(x)
        self.l1out = self.d1(x)
        y = self.fc2(self.l1out)
        y = self.a2(y)
        self.l2out = self.d2(y)
        z = self.fc3(self.l2out)
        self.l3out = self.a3(z)
        k = self.fc4(self.l3out)
        self.l4out = self.torch_ref.sigmoid(k)
        return self.l4out

    def backward(self, user, error):  # BATCH_SIZE=1 is assumed.
        pgs, flayers, aouts = [], [], []
        latest_grad = error
        aouts.extend([self.catout, self.l1out, self.l2out, self.l3out])
        flayers.extend([self.fc2, self.fc3, self.fc4])
        for i in range(len(aouts)):
            j = len(aouts) - i - 1
            pgs.append(latest_grad.sum(0))  # bias grad
            pgs.append(latest_grad.t() @ aouts[j])  # weight grad
            if j - 1 >= 0:
                latest_grad = (latest_grad @ flayers[j - 1].state_dict()["weight"]) * (
                    aouts[j] > 0
                ).float()
            else:  # Embedding layer
                latest_grad = (
                    latest_grad @ self.fc1.state_dict()["weight"]
                )  # no ReLU in embedding
                latest_grad = latest_grad[:, :embedding_size]
        embedgrad = latest_grad.t() @ user.float()
        # For embedding layer, we mimic using a linear layer.
        # Therefore, bias exists but we do not need it.
        # Hence providing a zero tensor so that the bias weight remains 0.
        pgs.append(th.zeros((embedding_size,)))
        pgs.append(embedgrad)
        pgs.reverse()
        return tuple(pgs)

    def mse_loss(self, fpass, target):  # MSE first
        squared_error = self.torch_ref.pow(fpass - target, 2)
        loss = squared_error.mean()
        loss_grad = (target - fpass) * fpass * (1 - fpass) / bs
        return loss, loss_grad


In [None]:
model = EmbeddingNet(th)

features = th.randn(bs, song_features)
users = th.randint(1, n_users, (bs,))
users = th.nn.functional.one_hot(users)
x = th.randn(bs, song_features + embedding_size)
fpass = model(users, features, x)
batch_size = th.tensor([bs])
ys = th.rand(bs, 1)
loss, loss_grad = model.mse_loss(fpass, ys)
print("Results here:")
grads = model.backward(users, loss_grad)
# print(grads.shape)
for x in grads:
    print(x.shape)
print("Results sdf here:")
for y in model.parameters():
    print(y.shape)
# updated_params = tuple(
#         param - lr * grad for param, grad in zip(model.get_model_parameters(), grads)
#     )
# print(model.fc2.state_dict()["weight"].shape)


## Step 2: Define Training Plan

In [None]:
def set_remote_model_params(module_ptrs, params_list_ptr):
    """Sets the model parameters into traced model"""
    param_idx = 0
    for module_name, module_ptr in module_ptrs.items():
        for param_name, _ in PLAN_BUILDER_VM.store[
            module_ptr.id_at_location
        ].data.named_parameters():
            module_ptr.register_parameter(param_name, params_list_ptr[param_idx])
            param_idx += 1

# Create the model
local_model = EmbeddingNet(th)


model_params_zeros = sy.lib.python.List(
    [th.nn.Parameter(th.zeros_like(param)) for param in local_model.parameters()]
)

@make_plan
def training_plan(
    features = th.randn(bs,10),
    users = th.nn.functional.one_hot(th.randint(1,n_users,(bs,))),
    xs = th.randn(bs,embedding_size+song_features),
    ys=th.rand(bs,1),
    lr=th.tensor([lr]),
    params=model_params_zeros,
):
    # send the model to plan builder (but not its default params)
    # this is required to build the model inside the Plan
    model = local_model.send(ROOT_CLIENT, send_parameters=False)

    # set model params from input
    set_remote_model_params(model.modules, params)
    

    # forward
    fpass = model(users,features,xs)

    # loss
    loss, loss_grad = model.mse_loss(fpass,ys)

    # backward
    grads = model.backward(users, loss_grad)

    # SGD step
    updated_params = tuple(
        param - lr * grad for param, grad in zip(model.parameters(), grads)
    )

    # return things
    return (loss, *updated_params)

Translate the training plan to torchscript so it can be used with mobile workers.

In [None]:
# Translate to torchscript
ts_plan = translate_to_ts(training_plan)

# Let's examine its contents
print(ts_plan.torchscript.code)

## Step 3: Define Averaging Plan

Averaging Plan is executed by PyGrid at the end of the cycle,
to average _diffs_ submitted by workers and update the model
and create new checkpoint for the next cycle.

_Diff_ is the difference between client-trained
model params and original model params,
so it has same number of tensors and tensor's shapes
as the model parameters.

We define Plan that processes one diff at a time.
Such Plans require `iterative_plan` flag set to `True`
in `server_config` when hosting FL model to PyGrid.

Plan below will calculate simple mean of each parameter.

In [None]:
@make_plan
def avg_plan(
    avg=List(local_model.parameters()), item=List(local_model.parameters()), num=Int(0)
):
    new_avg = []
    for i, param in enumerate(avg):
        new_avg.append((avg[i] * num + item[i]) / (num + 1))
    return new_avg

## Step 4: Define Federated Learning Process Configuration

Before hosting the model and training plan to PyGrid,
we need to define some configuration parameters, such as
FL process name, version, workers configuration,
authentication method, etc.

In [None]:
name = "spotify_recommendation"
version = "1.0"

client_config = {
    "name": name,
    "version": version,
    "batch_size": bs,
    "lr": lr,
    "max_updates": 100,  # custom syft.js option that limits number of training loops per worker
}

server_config = {
    # "min_workers": 2,
    # "max_workers": 2,
    # "pool_selection": "random",
    # "do_not_reuse_workers_until_cycle": 6,
    "cycle_length": 28800,  # max cycle length in seconds
    "num_cycles": 30,  # max number of cycles
    "max_diffs": 1,  # number of diffs to collect before avg
    "minimum_upload_speed": 0,
    "minimum_download_speed": 0,
    "iterative_plan": True,  # tells PyGrid that avg plan is executed per diff
}

This FL process will require workers to authenticate with signed JWT token.
Providing the `pub_key` in FL configuration allows PyGrid to verify JWT tokens.

In [None]:
def read_file(fname):
    with open(fname, "r") as f:
        return f.read()

In [None]:
public_key = read_file("example_rsa.pub").strip()

server_config["authentication"] = {
    "type": "jwt",
    "pub_key": public_key,
}

## Step 5: Host in PyGrid

Let's now host everything in PyGrid so that it can be accessed by worker libraries.

Note: assuming the PyGrid Domain is running locally on port 7001.

In [None]:
grid_address = "localhost:7001"

### Setup the domain

Run this once and only once after the PyGrid is cleared.

In [None]:
from syft.grid.client.client import connect
from syft.grid.client.grid_connection import GridHTTPConnection

domain = connect(
    url=f"http://{grid_address}", 
    conn_type=GridHTTPConnection,
)
domain.setup(
    email="owner@openmined.org",
    password="owerpwd",
    domain_name="OpenMined Node",
    token="9G9MJ06OQH",
)

In [None]:
grid = ModelCentricFLClient(address=grid_address, secure=False)
domain = grid.connect()

Following code sends FL model, training plans, and configuration to the PyGrid:

In [None]:
response = grid.host_federated_training(
    model=local_model,
    client_plans={
        # Grid can store both types of plans (regular for python worker, torchscript for mobile):
        "training_plan": training_plan,
        "training_plan:ts": ts_plan,
    },
    client_protocols={},
    server_averaging_plan=avg_plan,
    client_config=client_config,
    server_config=server_config,
)

In [None]:
response

If you see successful response, you've just hosted your first FL process into PyGrid!

If you see error that FL process already exists,
this means FL process with such name and version is already hosted.
You might want to update name/version in configuration above, or cleanup PyGrid database.

To cleanup database, set path below correctly and run:

In [None]:
# !rm PyGrid/apps/domain/src/nodedatabase.db


To train hosted model, use one of the existing mobile FL workers:
 * [SwiftSyft](https://github.com/OpenMined/SwiftSyft) (see included worker example)
 * [KotlinSyft](https://github.com/OpenMined/KotlinSyft) (see included worker example)

Support for javascript worker is coming soon:
 * [syft.js](https://github.com/OpenMined/syft.js)
