# Open Federated Learning Training

## imports

In [None]:
from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment
import os
import glob
import tqdm
import numpy as np
from copy import deepcopy
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms as T
from torch.utils.data import Dataset

In [None]:
np.random.seed(0)
torch.manual_seed(0)

## Connect to the Federation

In [None]:
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051', tls=False)

In [None]:
federation.target_shape

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
sample, target = dummy_shard_desc.get_dataset('train')[0]

## Dataset

In [None]:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
augmentation = T.RandomApply([T.RandomHorizontalFlip(),T.RandomRotation(10),T.RandomResizedCrop(64)], p=.8)
training_transform = T.ToTensor()
valid_transform = T.ToTensor()

In [None]:
class TransformedDataset(Dataset):
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, label = self.dataset[index]
        label = self.target_transform(label) if self.target_transform else label
        img = self.transform(img) if self.transform else img
        return img, label
    
class HistologyDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs

    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        self._shard_descriptor = shard_descriptor
        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=training_transform
        )
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=valid_transform
        )
        
    def get_train_loader(self, **kwargs):
        return DataLoader(self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True)

    def get_valid_loader(self, **kwargs):
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        return len(self.train_set)

    def get_valid_data_size(self):
        return len(self.valid_set)

In [None]:
fed_dataset = HistologyDataset(train_bs=4, valid_bs=4)

## Model

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        self.conv1 = nn.Conv2d(3, 16, **conv_kwargs)
        self.conv2 = nn.Conv2d(16, 32, **conv_kwargs)
        self.conv3 = nn.Conv2d(32, 64, **conv_kwargs)
        self.conv4 = nn.Conv2d(64, 128, **conv_kwargs)
        self.conv5 = nn.Conv2d(128 + 32, 256, **conv_kwargs)
        self.conv6 = nn.Conv2d(256, 512, **conv_kwargs)
        self.conv7 = nn.Conv2d(512 + 128 + 32, 256, **conv_kwargs)
        self.conv8 = nn.Conv2d(256, 512, **conv_kwargs)
        self.fc1 = nn.Linear(1184 * 9 * 9, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        torch.manual_seed(0)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        maxpool = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv3(maxpool))
        x = F.relu(self.conv4(x))
        concat = torch.cat([maxpool, x], dim=1)
        maxpool = F.max_pool2d(concat, 2, 2)

        x = F.relu(self.conv5(maxpool))
        x = F.relu(self.conv6(x))
        concat = torch.cat([maxpool, x], dim=1)
        maxpool = F.max_pool2d(concat, 2, 2)

        x = F.relu(self.conv7(maxpool))
        x = F.relu(self.conv8(x))
        concat = torch.cat([maxpool, x], dim=1)
        maxpool = F.max_pool2d(concat, 2, 2)

        x = maxpool.flatten(start_dim=1)
        x = F.dropout(self.fc1(x), p=0.5)
        x = self.fc2(x)
        return x

model_net = Net()
optimizer_adam = optim.Adam(model_net.parameters(), lr=1e-4)

In [None]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model_net, optimizer=optimizer_adam, framework_plugin=framework_adapter)
initial_model = deepcopy(model_net)

## FL Tasks

In [None]:
task_interface = TaskInterface()

def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

@task_interface.add_kwargs(**{'some_parameter': 42})
@task_interface.register_fl_task(model='net_model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
def train(net_model, train_loader, optimizer, device, loss_fn=F.cross_entropy, some_parameter=None):
    device = torch.device('cuda')
    if not torch.cuda.is_available():
        device = 'cpu'
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    net_model.train()
    net_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device) 
        optimizer.zero_grad()
        output = net_model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


@task_interface.register_fl_task(model='net_model', data_loader='val_loader', device='device')     
def validate(net_model, val_loader, device):
    device = torch.device('cuda')
    if not torch.cuda.is_available():
        device = 'cpu'
    net_model.eval()
    net_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device)
            output = net_model(data)
            pred = output.argmax(dim=1)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}

## Start The Federated Learning Experiment

In [None]:
experiment_name = f'test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=5,
    opt_treatment='CONTINUE_GLOBAL',
    device_assignment_policy='CUDA_PREFERRED'
)

In [None]:
fl_experiment.stream_metrics()