# Fed-BioMed Researcher base example

This example uses MNIST dataset. Please check `README.md` file in `notebooks` directory for the instructions to load MNIST dataset and configure nodes.

## Define an experiment model and parameters"

Declare a torch training plan MyTrainingPlan class to send for training on the node

In [None]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms
from fedbiomed.common.optimizers.optimizer import Optimizer
from fedbiomed.common.optimizers.declearn import ScaffoldClientModule, AdamModule, FedProxRegularizer


# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MyTrainingPlan(TorchTrainingPlan):
    
    # Defines and return model 
    def init_model(self, model_args):
        return self.Net(model_args = model_args)
    
    # Defines and return a declearn optimizer
    def init_optimizer(self, optimizer_args):
        return Optimizer(lr=optimizer_args["lr"] ,modules=[AdamModule()], regularizers=[FedProxRegularizer()])
        # return Optimizer(lr=optimizer_args['lr'], modules=[ScaffoldClientModule()])
    
    # Declares and return dependencies
    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms",
                "from fedbiomed.common.optimizers.optimizer import Optimizer",
                "from fedbiomed.common.optimizers.declearn import ScaffoldClientModule, AdamModule, FedProxRegularizer"]
        return deps
    
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)


            output = F.log_softmax(x, dim=1)
            return output

    def training_data(self):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=dataset1, **train_kwargs)
    
    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss


In [None]:
model_args = {}

training_args = {
    'loader_args': { 'batch_size': 48, }, 
    'optimizer_args': {
        "lr" : 1e-3
    },
    'epochs': 2, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

## Declare and run the experiment

In [None]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
from fedbiomed.common.optimizers.declearn import YogiModule as FedYogi
from fedbiomed.common.optimizers.declearn import ScaffoldServerModule

# Please load MNIST dataset in the nodes before creating an experiment
tags =  ['#MNIST', '#dataset']


rounds = 2
exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=MyTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None,
                 secagg=True,
                 tensorboard=True)

exp.set_test_ratio(0.1)
exp.set_test_on_local_updates(True)
exp.set_test_on_global_updates(True)
# here we are adding an Optimizer on Researcher side (FedYogi)
fed_opt = Optimizer(lr=.8, modules=[ScaffoldServerModule()])
#exp.set_agg_optimizer(fed_opt)

Let's start the experiment.

By default, this function doesn't stop until all the `round_limit` rounds are done for all the nodes

In [None]:
exp.run()

In [None]:
exp.training_replies()[0]['NODE_e241b8dd-6484-4448-a55c-6edc31710a1e']['optim_aux_var']['enc_specs']

In [None]:
exp.aggregated_params()[3]['params']['conv1.weight'][1]
# tensor([[[ 0.1946, -0.2137, -0.1822],
#          [ 0.3152, -0.2212,  0.2557],
#          [-0.3067,  0.0849, -0.0262]]])

In [None]:
opt = exp.agg_optimizer()

In [None]:

opt.get_aux()['scaffold'].to_dict()["state"].coefs['conv1.weight'][1]

# tensor([[[-1.1387e-03,  9.5662e-05,  4.2048e-04],
#          [-1.1957e-03, -1.2749e-04,  1.2802e-04],
#          [-9.8669e-04, -5.4240e-04, -3.7431e-04]]])

Save trained model to file

In [None]:
exp.training_plan().export_model('./trained_model')

In [None]:
from fedbiomed.researcher.environ import environ
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir "$tensorboard_dir"

In [None]:
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1]
for r in round_data.values():
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = r['node_id'],
        rtraining = r['timing']['rtime_training'],
        ptraining = r['timing']['ptime_training'],
        rtotal = r['timing']['rtime_total']))
print('\n')

In [None]:
print("\nList the training rounds : ", exp.aggregated_params().keys())

print("\nAccess the federated params for the last training round :")
print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())
