# Fed-BioMed secure aggregation tutorial


<font size=+2>
    Warning: secure aggregation is a work in progress. In current version it is not fully implement and does not provide any effective security/functionality. This notebook exists only for demonstration purposes.
</font>


## Example experimentation setup

This part contains setup of a basic example for Fed-BioMed. At this point, nothing is specific to secure aggregation.

### Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

### Setting nodes up
It is necessary to previously configure ** at least two nodes**:
1. `./scripts/fedbiomed_run node config config_node1.ini add` (respectively for the second node: `./scripts/fedbiomed_run node config config_node2.ini add`)
  * Select option 2 (default) to add MNIST to the node
  * Confirm default tags by hitting "y" and ENTER
  * Pick the folder where MNIST is downloaded (this is due to a pytorch issue https://github.com/pytorch/vision/issues/3549)
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run config config_node1.ini node list`
3. Run the node using `./scripts/fedbiomed_run config_node1.ini node run`. Wait until you get `Starting task manager`. it means you are online.

### Define an experiment model and parameters"

Declare a torch training plan MyTrainingPlan class to send for training on the node

In [None]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms


# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MyTrainingPlan(TorchTrainingPlan):
    
    # Defines and return model 
    def init_model(self, model_args):
        return self.Net(model_args = model_args)
    
    # Defines and return optimizer
    def init_optimizer(self, optimizer_args):
        return torch.optim.Adam(self.model().parameters(), lr = optimizer_args["lr"])
    
    # Declares and return dependencies
    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms"]
        return deps
    
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)


            output = F.log_softmax(x, dim=1)
            return output

    def training_data(self, batch_size = 48):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        return DataManager(dataset=dataset1, **train_kwargs)
    
    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss


This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side.
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. ðŸ¤“

In [None]:
model_args = {}

training_args = {
    'batch_size': 48, 
    'optimizer_args': {
        "lr" : 1e-3
    },
    'epochs': 1, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

### Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `round_limit` rounds, applying the `node_selection_strategy` between the rounds

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 2

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=MyTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

## Secure aggregation setup

Check experiment's secure aggregation status: no context is configured yet, experiment doesn't use secure integration.

In [None]:
print("Using secagg: ", exp.use_secagg())
exp_servkey, exp_biprime = exp.secagg_context()
if exp_servkey:
    print(f"Secagg servkey:\n- status {exp_servkey.status()}\n- secagg_id {exp_servkey.secagg_id ()}" \
        f"\n- context {exp_servkey.context()}")
else:
    print("No secagg servkey")
if exp_biprime:
    print(f"Secagg biprime:\n- status {exp_biprime.status()}\n- secagg_id {exp_biprime.secagg_id ()}" \
        f"\n- context {exp_biprime.context()}")
else:
    print("No secagg biprime")

Negotiate a secure aggregation cryptographic context among experiment parties (researcher and nodes), providing a sufficient `timeout` for the negotiation.
If context negotation is successful, require experiment to use secure aggregation from now on (`use_secagg`)

In [None]:
exp.set_use_secagg(use_secagg=True, timeout=10)

Check experiment's secure aggregation status: a context exists and secure aggregation is activated, experiment uses secure integration.

In [None]:
print("Using secagg: ", exp.use_secagg())
exp_servkey, exp_biprime = exp.secagg_context()
if exp_servkey:
    print(f"Secagg servkey:\n- status {exp_servkey.status()}\n- secagg_id {exp_servkey.secagg_id ()}" \
        f"\n- context {exp_servkey.context()}")
else:
    print("No secagg servkey")
if exp_biprime:
    print(f"Secagg biprime:\n- status {exp_biprime.status()}\n- secagg_id {exp_biprime.secagg_id ()}" \
        f"\n- context {exp_biprime.context()}")
else:
    print("No secagg biprime")

In [None]:
exp.info()

Run the experiment, using secure aggregation.

In [None]:
exp.run_once(increase=True)

## Misc secure aggregation commands

You can toggle off/on whether secure aggregation is used. Note that *same* secure aggregation context is used, it does not need to negotiate a new context.

In [None]:
exp.set_use_secagg(False)
exp.set_use_secagg(True)

In [None]:
print("Using secagg: ", exp.use_secagg())
exp_servkey, exp_biprime = exp.secagg_context()
if exp_servkey:
    print(f"Secagg servkey:\n- status {exp_servkey.status()}\n- secagg_id {exp_servkey.secagg_id ()}" \
        f"\n- context {exp_servkey.context()}")
else:
    print("No secagg servkey")
if exp_biprime:
    print(f"Secagg biprime:\n- status {exp_biprime.status()}\n- secagg_id {exp_biprime.secagg_id ()}" \
        f"\n- context {exp_biprime.context()}")
else:
    print("No secagg biprime")

You can use breakpoints along with secure aggregation:

In [None]:
exp.set_save_breakpoints(True)

In [None]:
exp.run_once(increase=True)

In [None]:
del exp

In [None]:
loaded_exp = Experiment.load_breakpoint()
loaded_exp.info()

In [None]:
print("Using secagg: ", loaded_exp.use_secagg())
exp_servkey, exp_biprime = loaded_exp.secagg_context()
if exp_servkey:
    print(f"Secagg servkey:\n- status {exp_servkey.status()}\n- secagg_id {exp_servkey.secagg_id ()}" \
        f"\n- context {exp_servkey.context()}")
else:
    print("No secagg servkey")
if exp_biprime:
    print(f"Secagg biprime:\n- status {exp_biprime.status()}\n- secagg_id {exp_biprime.secagg_id ()}" \
        f"\n- context {exp_biprime.context()}")
else:
    print("No secagg biprime")

In [None]:
loaded_exp.run_once(increase=True)

Alternate usage: you can setup secure aggregation from the experiment constructor, instead of using `set_use_secagg`:

In [None]:
exp2 = Experiment(tags=tags,
                  model_args=model_args,
                  training_plan_class=MyTrainingPlan,
                  training_args=training_args,
                  round_limit=rounds,
                  aggregator=FedAverage(),
                  node_selection_strategy=None,
                  use_secagg=True,
                  secagg_timeout=10)

Check experiment's secure aggregation status: a context exists and secure aggregation is activated, experiment uses secure integration.

In [None]:
exp2.info()

## Annex: direct use of secure aggregation contexts

Secure aggregation contexts can be used directly, without using an experiment. This is currently mainly for education and debug usage.

#### Discover online nodes

In [None]:
from fedbiomed.researcher.requests import Requests

requests = Requests()
nodes = requests.ping_nodes()
print(f'Online nodes:\n {nodes}')

#### Build parties list for secagg

In [None]:
from fedbiomed.researcher.environ import environ

parties = [environ['RESEARCHER_ID']] + nodes
if not len(parties) >= 3:
    print("Need at least 3 parties for secure aggregation")
print(f'Secure aggregation parties:\n {parties}')

#### Setup secagg with all online nodes

Example of a successful secagg context negotiation.

In [None]:
from fedbiomed.researcher.secagg import SecaggServkeyContext

secagg_servkey = SecaggServkeyContext(parties, 'DUMMY_JOB')

In [None]:
secagg_servkey.setup()
print("Status: ", secagg_servkey.status())
print("Context: ", secagg_servkey.context())

In [None]:
from fedbiomed.researcher.secagg import SecaggBiprimeContext

secagg_biprime = SecaggBiprimeContext(parties)

In [None]:
secagg_biprime.setup(timeout=10)
print("Status: ", secagg_biprime.status())
print("Context: ", secagg_biprime.context())

### Handle timeouts

If `timeout` is unsufficient for completing context negotiation, it fails with a timeout error:

In [None]:
from fedbiomed.researcher.secagg import SecaggServkeyContext
secagg_servkey2 = SecaggServkeyContext(parties, 'DUMMY_JOB')
secagg_servkey2.setup(2)

In [None]:
print("Status: ", secagg_servkey2.status())
print("Context: ", secagg_servkey2.context())

Retry negotiation with a sufficient timeout: secagg context is now successfully established (manages fails/retries).

In [None]:
secagg_servkey2.setup()
print("Status: ", secagg_servkey2.status())
print("Context: ", secagg_servkey2.context())

Note: it may also succeed with a short timeout (like `secagg_servkey2.setup(1.5)` if you wait long enough before retrying so that the secagg computation completes in the meantime

### Handle `job_id`

If another job tries to use the servkey secagg context element, then it fails.

In [None]:
secagg_servkey.set_job_id('ANOTHER_JOB')
secagg_servkey.setup()

Note: this does not apply to biprime secagg context element, as they don't use `job_id`


### Clean out secagg

In [None]:
secagg_servkey2.delete()

In [None]:
print("Status: ", secagg_servkey2.status())
print("Context: ", secagg_servkey2.context())

In [None]:
secagg_biprime.delete()

Note: `secagg_servkey.delete()` would fail as it lost track of the successfully established context after changing `job_id` and trying to setup another context

In [None]:
secagg_servkey.delete()