# Fedbiomed Researcher base example

Use for developing (autoreloads changes made across packages)

In [1]:
%load_ext autoreload
%autoreload 2

## Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

## Setting the nodes up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 2 (default)
  * Write MNIST to add MNIST to the node through `torchvision.datasets.MNIST`
  * Select the desired ratio of the MNIST dataset to be added to the current node
  * Confirm default tags by hitting "y" and ENTER
  * Pick the folder where MNIST is downloaded (this is due torch issue https://github.com/pytorch/vision/issues/3549)
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node run`. Wait until you get `Starting task manager`. it means you are online.

3. Following the same procedure, create another node with MNIST.

## Define an experiment model and parameters"

Declare a torch.nn MyTrainingPlan class to send for training on the node

In [2]:
from fedbiomed.researcher.environ import environ
import tempfile
tmp_dir_model = tempfile.TemporaryDirectory(dir=environ['TMP_DIR']+'/')
model_file = tmp_dir_model.name + '/class_export_mnist.py'

2022-01-21 09:42:56,469 fedbiomed INFO - Component environment:
2022-01-21 09:42:56,470 fedbiomed INFO - - type = ComponentType.RESEARCHER


Note : write **only** the code to export in the following cell

In [3]:
%%writefile "$model_file"

import torch
import torch.nn as nn
from fedbiomed.common.torchnnDP import TorchTrainingDPPlan
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MyTrainingPlan(TorchTrainingDPPlan):
    def __init__(self):
        super(MyTrainingPlan, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        # In this case, we need the torch DataLoader classes
        # Since we will train on MNIST, we need datasets and transform from torchvision
        deps = ["from torchvision import datasets, transforms",
               "from torch.utils.data import DataLoader",
               "import torch.nn.functional as F"]
        self.add_dependency(deps)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        
        
        output = F.log_softmax(x, dim=1)
        return output

    def training_data(self, batch_size = 48):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        data_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
        return data_loader
    
    def training_step(self, data, target):
        output = self.forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss


Writing /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/tmp/tmp95w166dq/class_export_mnist.py


This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side.
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [4]:
model_args = {}

training_args = {
    'batch_size': 48, 
    'lr': 1e-3, 
    'epochs': 3, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

## Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `rounds` rounds, applying the `node_selection_strategy` between the rounds

In [5]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 3

exp = Experiment(tags=tags,
                 #nodes=None,
                 model_path=model_file,
                 model_args=model_args,
                 model_class='MyTrainingPlan',
                 training_args=training_args,
                 rounds=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

2022-01-21 09:43:05,034 fedbiomed INFO - Messaging researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9 successfully connected to the message broker, object = <fedbiomed.common.messaging.Messaging object at 0x132cf0820>
2022-01-21 09:43:05,071 fedbiomed INFO - Searching dataset with data tags: ['#MNIST', '#dataset'] for all nodes
2022-01-21 09:43:05,119 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Message received: {'researcher_id': 'researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9', 'tags': ['#MNIST', '#dataset'], 'command': 'search'}
2022-01-21 09:43:05,129 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Message received: {'researcher_id': 'researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9', 'tags': ['#MNIST', '#dataset'], 'command': 'search'}
2022-01-21 09:43:15,093 fedbiomed INFO - Node selected for training -> node_edb44109-8e5f-4741-adfb-5e68136b3bab
2022-01-21 09:43:15,094 fedbiomed INFO - Node selected for training -> nod

Let's start the experiment.

By default, this function doesn't stop until all the `rounds` are done for all the nodes

In [6]:
exp.run()

2022-01-21 09:43:33,374 fedbiomed INFO - Sampled nodes in round 0 ['node_edb44109-8e5f-4741-adfb-5e68136b3bab', 'node_8dce5575-7403-48f0-91b6-e07e58ae5c47']
01/21/2022 09:43:33:INFO:Sampled nodes in round 0 ['node_edb44109-8e5f-4741-adfb-5e68136b3bab', 'node_8dce5575-7403-48f0-91b6-e07e58ae5c47']
2022-01-21 09:43:33,375 fedbiomed INFO - Send message to node node_edb44109-8e5f-4741-adfb-5e68136b3bab - {'researcher_id': 'researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9', 'job_id': 'c9be5f7d-533e-4224-b0df-7a1169c0e542', 'training_args': {'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 100}, 'model_args': {}, 'command': 'train', 'model_url': 'http://localhost:8844/media/uploads/2022/01/21/my_model_80387a84-c8eb-4f87-aaae-d6accbef52ef.py', 'params_url': 'http://localhost:8844/media/uploads/2022/01/21/aggregated_params_init_664613e5-2329-4fab-a1bc-720113dd9d24.pt', 'model_class': 'MyTrainingPlan', 'training_data': {'node_edb44109-8e5f-4741-adfb-5e68136b3bab': ['

01/21/2022 09:43:33:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - [TASKS QUEUE] Item:{'researcher_id': 'researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9', 'job_id': 'c9be5f7d-533e-4224-b0df-7a1169c0e542', 'params_url': 'http://localhost:8844/media/uploads/2022/01/21/aggregated_params_init_664613e5-2329-4fab-a1bc-720113dd9d24.pt', 'training_args': {'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 100}, 'training_data': {'node_8dce5575-7403-48f0-91b6-e07e58ae5c47': ['dataset_7adabeb5-d317-4167-a9e5-71b3cd07ab42']}, 'model_args': {}, 'model_url': 'http://localhost:8844/media/uploads/2022/01/21/my_model_80387a84-c8eb-4f87-aaae-d6accbef52ef.py', 'model_class': 'MyTrainingPlan', 'command': 'train'}
2022-01-21 09:43:34,587 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / INFO - {'monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x1315eefa0>, 'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'b

2022-01-21 09:44:07,080 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:44:07:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:44:38,466 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:44:38:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:44:38,863 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:44:38:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data


2022-01-21 09:45:02,886 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:45:02:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:45:03,556 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:45:03:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:45:05,284 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / INFO - results uploaded successfully 
01/21/2022 09:45:05:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / INFO - results uploaded successfully 
2022-01-21 09:45:05,536 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / INFO - results uploaded successfully 
01/21

2022-01-21 09:45:15,592 fedbiomed DEBUG - researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9
01/21/2022 09:45:15:DEBUG:researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9
2022-01-21 09:45:15,599 fedbiomed INFO - Send message to node node_8dce5575-7403-48f0-91b6-e07e58ae5c47 - {'researcher_id': 'researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9', 'job_id': 'c9be5f7d-533e-4224-b0df-7a1169c0e542', 'training_args': {'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 100}, 'model_args': {}, 'command': 'train', 'model_url': 'http://localhost:8844/media/uploads/2022/01/21/my_model_80387a84-c8eb-4f87-aaae-d6accbef52ef.py', 'params_url': 'http://localhost:8844/media/uploads/2022/01/21/aggregated_params_f5259ecc-10f4-4358-a2d9-a3253052357b.pt', 'model_class': 'MyTrainingPlan', 'training_data': {'node_8dce5575-7403-48f0-91b6-e07e58ae5c47': ['dataset_7adabeb5-d317-4167-a9e5-71b3cd07ab42']}}
01/21/2022 09:45:15:INFO:Send message to node node_8dce5575-7403-48f0-91b6-e07e58ae5c47 - 

2022-01-21 09:45:16,969 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Dataset_path/Users/balelli/data
01/21/2022 09:45:16:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Dataset_path/Users/balelli/data
2022-01-21 09:45:17,068 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / INFO - {'monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x1315ee280>, 'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 100}
01/21/2022 09:45:17:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / INFO - {'monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x1315ee280>, 'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 100}
2022-01-21 09:45:17,124 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Dataset_path/Users/balelli/data
01/21/2022 09:45:17:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Data

2022-01-21 09:46:23,417 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:46:23:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:46:25,400 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:46:25:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data


2022-01-21 09:46:48,071 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:46:48:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:46:49,717 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / INFO - results uploaded successfully 
01/21/2022 09:46:49:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / INFO - results uploaded successfully 
2022-01-21 09:46:50,563 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:46:50:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:46:52,580 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / INFO - results uploaded successfully 
01/21

2022-01-21 09:47:02,402 fedbiomed DEBUG - researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9
01/21/2022 09:47:02:DEBUG:researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9
2022-01-21 09:47:02,486 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Message received: {'researcher_id': 'researcher_bc5777a3-5027-465f-8742-3d3aa04a29c9', 'job_id': 'c9be5f7d-533e-4224-b0df-7a1169c0e542', 'training_args': {'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 100}, 'model_args': {}, 'command': 'train', 'model_url': 'http://localhost:8844/media/uploads/2022/01/21/my_model_80387a84-c8eb-4f87-aaae-d6accbef52ef.py', 'params_url': 'http://localhost:8844/media/uploads/2022/01/21/aggregated_params_e52a3826-ace9-4217-9ee5-f02dba48c9a6.pt', 'model_class': 'MyTrainingPlan', 'training_data': {'node_edb44109-8e5f-4741-adfb-5e68136b3bab': ['dataset_7f491de7-9375-4f96-a07d-f58bca25deb5']}}
01/21/2022 09:47:02:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3ba

2022-01-21 09:47:26,462 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:47:26:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:47:27,103 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:47:27:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data


2022-01-21 09:47:54,197 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:47:54:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:47:57,394 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:47:57:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data


2022-01-21 09:48:18,682 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:48:18:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:48:20,469 fedbiomed INFO - log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / INFO - results uploaded successfully 
01/21/2022 09:48:20:INFO:log from: node_edb44109-8e5f-4741-adfb-5e68136b3bab / INFO - results uploaded successfully 
2022-01-21 09:48:21,156 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
01/21/2022 09:48:21:INFO:log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / DEBUG - Reached 100 batches for this epoch, ignore remaining data
2022-01-21 09:48:23,012 fedbiomed INFO - log from: node_8dce5575-7403-48f0-91b6-e07e58ae5c47 / INFO - results uploaded successfully 
01/21

Local training results for each round and each node are available in `exp.training_replies` (index 0 to (`rounds` - 1) ).

For example you can view the training results for the last round below.

Different timings (in seconds) are reported for each dataset of a node participating in a round :
- `rtime_training` real time (clock time) spent in the training function on the node
- `ptime_training` process time (user and system CPU) spent in the training function on the node
- `rtime_total` real time (clock time) spent in the researcher between sending the request and handling the response, at the `Job()` layer

In [7]:
print("\nList the training rounds : ", exp.training_replies.keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies[rounds - 1].data
for c in range(len(round_data)):
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = round_data[c]['node_id'],
        rtraining = round_data[c]['timing']['rtime_training'],
        ptraining = round_data[c]['timing']['ptime_training'],
        rtotal = round_data[c]['timing']['rtime_total']))
print('\n')
    
exp.training_replies[rounds - 1].dataframe


List the training rounds :  dict_keys([0, 1, 2])

List the nodes for the last training round and their timings : 
	- node_edb44109-8e5f-4741-adfb-5e68136b3bab :    
		rtime_training=75.38 seconds    
		ptime_training=57.99 seconds    
		rtime_total=90.12 seconds
	- node_8dce5575-7403-48f0-91b6-e07e58ae5c47 :    
		rtime_training=77.89 seconds    
		ptime_training=59.39 seconds    
		rtime_total=90.45 seconds




Unnamed: 0,success,msg,dataset_id,node_id,params_path,params,timing
0,True,,dataset_7f491de7-9375-4f96-a07d-f58bca25deb5,node_edb44109-8e5f-4741-adfb-5e68136b3bab,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'conv1.weight': [[tensor([[ 0.2246, -0.2867, ...","{'rtime_training': 75.37591360500028, 'ptime_t..."
1,True,,dataset_7adabeb5-d317-4167-a9e5-71b3cd07ab42,node_8dce5575-7403-48f0-91b6-e07e58ae5c47,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'conv1.weight': [[tensor([[ 0.2099, -0.2984, ...","{'rtime_training': 77.89482873199995, 'ptime_t..."


Federated parameters for each round are available in `exp.aggregated_params` (index 0 to (`rounds` - 1) ).

For example you can view the federated parameters for the last round of the experiment :

In [8]:
print("\nList the training rounds : ", exp.aggregated_params.keys())

print("\nAccess the federated params for the last training round :")
print("\t- params_path: ", exp.aggregated_params[rounds - 1]['params_path'])
print("\t- parameter data: ", exp.aggregated_params[rounds - 1]['params'].keys())



List the training rounds :  dict_keys([0, 1, 2])

Access the federated params for the last training round :
	- params_path:  /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0003/aggregated_params_4e15b612-c480-4d83-89ea-9d42b2f549bf.pt
	- parameter data:  odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])


Feel free to run other sample notebooks or try your own models :D

# Testing

We define a little testing routine to extract the accuracy metrics on the testing dataset

In [9]:
import torch
import torch.nn.functional as F


def testing_Accuracy(model, data_loader):
    model.eval()
    test_loss = 0
    correct = 0
    device = 'cpu'

    correct = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

        pred = output.argmax(dim=1, keepdim=True)

    test_loss /= len(data_loader.dataset)
    accuracy = 100* correct/len(data_loader.dataset)

    return(test_loss, accuracy)

In [10]:
from torchvision import datasets, transforms
import os

local_mnist = os.path.join(environ['TMP_DIR'], 'local_mnist')

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

test_set = datasets.MNIST(root = local_mnist, download = True, train = False, transform = transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

fed_model = exp.model_instance
fed_model.load_state_dict(exp.aggregated_params[rounds - 1]['params'])

acc_federated = testing_Accuracy(fed_model, test_loader)

print('\nAccuracy federated training:  {:.4f}'.format(acc_federated[1]))

print('\nError federated training:  {:.4f}'.format(acc_federated[0]))


Accuracy federated training:  98.4100

Error federated training:  0.0485
