# Simulation of FL experiment

## Libraries & auxiliary functions

In [1]:
import torch
import traceback

import tqdm as tqdm
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from torchvision import models
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import precision_recall_fscore_support
from copy import deepcopy
from collections import OrderedDict


In [2]:
#Function for storing the runlog data in a pandas dataframe
def tflog2pandas(path):
    runlog_data = pd.DataFrame({"metric": [], "value": [], "step": []})
    try:
        event_acc = EventAccumulator(path)
        event_acc.Reload()
        tags = event_acc.Tags()["scalars"]
        for tag in tags:
            event_list = event_acc.Scalars(tag)
            values = list(map(lambda x: x.value, event_list))
            step = list(map(lambda x: x.step, event_list))
            r = {"metric": [tag] * len(step), "value": values, "step": step}
            r = pd.DataFrame(r)
            runlog_data = pd.concat([runlog_data, r])

    except Exception:
        print("Event file possibly corrupt: {}".format(path))
        traceback.print_exc()
    return runlog_data

## Setting the experiment

* Setting the director and fqdn of the server

In [2]:
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'ec2-XX-XX-XX-XXX.compute-1.amazonaws.com'
director_port = 4444

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)

In [3]:
shard_registry = federation.get_shard_registry()
shard_registry

({}, ['224', '224', '3'])

In [4]:
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=2)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
print(sample.shape)
print(target.shape)

(224, 224, 3)
(224, 224, 3)


## Data Preparation

In [5]:
class TransformedDataset(Dataset):
    """Image Person ReID Dataset."""

    def __init__(self, dataset, transform=None, target_transform=None):
        """Initialize Dataset."""
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        """Length of dataset."""
        return len(self.dataset)

    def __getitem__(self, index):
        img, label = self.dataset[index]
        label = self.target_transform(label) if self.target_transform else label
        img = self.transform(img) if self.transform else img
        return img, label

In [6]:
class ChestXrayDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=None
        )
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=None
        )
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        generator=torch.Generator()
        generator.manual_seed(0)
        return DataLoader(
            self.train_set, batch_size=self.kwargs['train_bs'], shuffle=True, generator=generator
            )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)

In [7]:
fed_dataset = ChestXrayDataset(train_bs=8, valid_bs=8)

## Describe the model and optimizer

In [None]:
class ConvNet(torch.nn.Module):
    """ConvNetBN."""

    def __init__(self, width=32, num_classes=2, num_channels=3):
        """Init with width and num classes."""
        super().__init__()
        self.model = torch.nn.Sequential(OrderedDict([
            ('conv0', torch.nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
            ('bn0', torch.nn.BatchNorm2d(1 * width)),
            ('relu0', torch.nn.ReLU()),

            ('conv1', torch.nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
            ('bn1', torch.nn.BatchNorm2d(2 * width)),
            ('relu1', torch.nn.ReLU()),

            ('conv2', torch.nn.Conv2d(2 * width, 2 * width, kernel_size=3, padding=1)),
            ('bn2', torch.nn.BatchNorm2d(2 * width)),
            ('relu2', torch.nn.ReLU()),

            ('pool1', torch.nn.MaxPool2d(3)),
            ('flatten', torch.nn.Flatten()),
           #('var_bottleneck', VariationalBottleneck((10952 * width,))),
            ('linear', torch.nn.Linear(10952  * width, num_classes))
        ]))

    def forward(self, input):
        return self.model(input)

model_net = ConvNet()

In [10]:
params_to_update = []
for param in model_net.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
'''
FEDPROX
'''        
#from openfl.utilities.optimizers.torch import FedProxAdam        
#optimizer = FedProxAdam(params_to_update, lr=1e-4, mu=0.01)

'''
ORIGINALE
'''
optimizer = optim.Adadelta(params_to_update, lr = 0.05)
#optimizer = optim.Adam(params_to_update, lr = 0.01)

#optimizer = optim.AdamW(params_to_update, lr=0.001, weight_decay=0.02)
#optimizer = optim.SGD(params_to_update, lr=0.01, momentum=0.9, weight_decay=0.0005)

#scheduler
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
def cross_entropy(output, target):
    """Cross-entropy metric
    """
    #return F.cross_entropy(input=output,target=target)
    #return F.binary_cross_entropy_with_logits(input=output,target=target)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(output, target)
    return loss

In [11]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model_net, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_net)

In [13]:
task_interface = TaskInterface()

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

# Task interface currently supports only standalone functions.
@task_interface.add_kwargs(**{'some_parameter': 42})
@task_interface.register_fl_task(model='net_model', data_loader='train_loader',
                     device='device', optimizer='optimizer') 
#@task_interface.set_aggregation_function(FedCurvWeightedAverage())
def train(net_model, train_loader, optimizer, device, loss_fn=cross_entropy, some_parameter=None):
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor = 0.75, patience = 2)
    device='cpu'
    net_model.train()
    net_model.to(device)    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    losses = []
    epochs = 2

    for epoch in tqdm.tqdm(range(epochs)):
        train_loss = 0.0
        train_loss_min = 10
        if scheduler != None:
            scheduler.step(train_loss)
        
        for data, target in train_loader:
            optimizer.zero_grad()
            output = net_model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()   
            train_loss += loss.item() * data.size(0)
    
        mean_train_loss = train_loss/len(train_loader)
        losses.append(train_loss)
        if train_loss <= train_loss_min:
            print('Train loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(train_loss_min,train_loss))
            torch.save(net_model.state_dict(), 'saved_state.pth')
            
    return {'train_loss': np.mean(losses),}

@task_interface.register_fl_task(model='net_model', data_loader='val_loader', device='device')     
def validate(net_model, val_loader, device):
    torch.manual_seed(0)
    device = torch.device('cpu')
    val_loader = tqdm.tqdm(val_loader, desc = "validate")
    net_model.eval()
    net_model.to(device)
    metrics = {'Accuracy':[], 'Precision':[], 'Recall':[], 'F1-Score':[]}
    number_correct, number_data = 0, 0
    true_labels, predicted_labels = [], []
    valid_loss = 0.0

    with torch.no_grad():

        for data, target in val_loader:
            output = net_model(data)
            loss = cross_entropy(output, target)
            valid_loss += loss.item() * data.size(0)
            _, pred = torch.max(output, 1) 
            correct_tensor = pred.eq(target.data.view_as(pred))
            correct = np.squeeze(correct_tensor.numpy()) 
            number_correct += sum(correct)
            number_data += correct.shape[0]
            true_labels.extend(target.cpu().numpy())
            predicted_labels.extend(pred.cpu().numpy())

        mean_valid_loss = valid_loss / len(val_loader)
        accuracy = (100 * number_correct / number_data)
        precision, recall, f1_score, support = precision_recall_fscore_support(
        true_labels, predicted_labels, average='weighted')
    
        metrics['Accuracy'].append(accuracy)
        metrics['Precision'].append(precision*100)
        metrics['Recall'].append(recall*100)
        metrics['F1-Score'].append(f1_score*100)  
              
    return {'acc': np.mean(metrics['Accuracy']),}

## Define and start the experiment

In [14]:
experiment_name = 'ChestXray_EPOCHS2_ROUND5_CNN'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [15]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=5,
    opt_treatment='CONTINUE_GLOBAL'
)



In [16]:
%tensorboard --logdir logs/fit
fl_experiment.stream_metrics(tensorboard_logs=True)

_MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "Socket closed"
	debug_error_string = "UNKNOWN:Error received from peer ipv4:172.31.0.197:4444 {grpc_message:"Socket closed", grpc_status:14, created_time:"2023-05-10T19:24:22.461920377+00:00"}"
>