In [1]:
import os
import glob

from PIL import Image

import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment
from copy import deepcopy
import torchvision
from torchvision import transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import tqdm

import albumentations as A
from model import Model

torch.manual_seed(0)
np.random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.cuda.is_available()

False

In [4]:
torch.cuda.device_count()

0

In [24]:
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'

federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051', tls=False)

In [25]:
federation.target_shape

['1']

In [26]:
shard_registry = federation.get_shard_registry()
shard_registry

{'envoy_0': {'shard_info': node_info {
    name: "envoy_0"
  }
  shard_description: "Local MRI Shard Descriptor is working."
  sample_shape: "4"
  sample_shape: "256"
  sample_shape: "256"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-05-24 17:01:14',
  'current_time': '2022-05-24 17:01:39',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'envoy_2': {'shard_info': node_info {
    name: "envoy_2"
  }
  shard_description: "Local MRI Shard Descriptor is working."
  sample_shape: "4"
  sample_shape: "256"
  sample_shape: "256"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-05-24 17:01:14',
  'current_time': '2022-05-24 17:01:39',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'envoy_1': {'shard_info': node_info {
    name: "envoy_1"
  }
  shard_description: "Local MRI Shard Descrip

In [27]:
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
print(sample.shape)
print(target.shape)


(4, 256, 256, 3)
(1,)


In [28]:
train_transform = A.Compose([
                                A.HorizontalFlip(p=0.5),
                                A.ShiftScaleRotate(
                                    shift_limit=0.0625,
                                    scale_limit=0.1,
                                    rotate_limit=10,
                                    p=0.5
                                ),
                                A.RandomBrightnessContrast(p=0.5),
                            ])

In [29]:
class MRIImageDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        self._shard_descriptor = shard_descriptor
        
        self._shard_descriptor.split_dataset(test_size=0.2)
        self._shard_descriptor.set_transform_params(train_transform)
        
        self.train_set = self._shard_descriptor.get_dataset('train')
        self.valid_set = self._shard_descriptor.get_dataset('val')
        
    def get_train_loader(self, **kwargs):
        generator=torch.Generator()
        generator.manual_seed(0)
        return DataLoader(
            self.train_set, batch_size=self.kwargs['train_bs'], shuffle=True, generator=generator, num_workers=0
            )

    def get_valid_loader(self, **kwargs):
        return DataLoader(self.valid_set, batch_size=self.kwargs['valid_bs'], shuffle=True, num_workers=0)

    def get_train_data_size(self):
        return len(self.train_set)

    def get_valid_data_size(self):
        return len(self.valid_set)

In [30]:
fed_dataset = MRIImageDataset(train_bs=4, valid_bs=8)

In [31]:
model = Model()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = F.binary_cross_entropy_with_logits

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)

In [32]:
task_interface = TaskInterface()

# The Interactive API supports registering functions definied in main module or imported.
    
class Loss:
    def __init__(self):
        self.avg = 0
        self.n = 0

    def update(self, val):
        self.n += 1
        self.avg = val / self.n + (self.n - 1) / self.n * self.avg

class Acc:
    def __init__(self):
        self.avg = 0
        self.n = 0

    def update(self, y_true, y_pred):
        y_true = y_true.cpu().numpy().astype(int)
        y_pred = y_pred.cpu().numpy() >= 0
        last_n = self.n
        self.n += len(y_true)
        true_count = np.sum(y_true == y_pred)
        self.avg = true_count / self.n + last_n / self.n * self.avg

# Task interface currently supports only standalone functions.
@task_interface.register_fl_task(model='model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
def train(model, train_loader, optimizer, device, loss_fn=criterion):
    model.train()
    device='cpu'
    t = time.time()
    train_loss = Loss()
    train_score = Acc()
    
    train_loader = tqdm.tqdm(train_loader)

    for step, batch in enumerate(train_loader):
        X = batch["X"].to(device)
        targets = batch["y"].to(device)
        optimizer.zero_grad()
        outputs = model(X).squeeze(1)

        loss = loss_fn(outputs, targets)
        loss.backward()

        train_loss.update(loss.detach().item())
        train_score.update(targets, outputs.detach())

        optimizer.step()

        _loss, _score = train_loss.avg, train_score.avg
        # message = f'Train Step {step}/{len(train_loader)}, train_loss: {_loss}, train_score: {_score}'

    return {'train_loss': train_loss.avg, 'train_acc': train_score.avg}


@task_interface.register_fl_task(model='model', data_loader='val_loader', device='device')     
def validate(model, val_loader, device, loss_fn=criterion):
    model.eval()
    t = time.time()
    valid_loss = Loss()
    valid_score = Acc()
    
    val_loader = tqdm.tqdm(val_loader)

    for step, batch in enumerate(val_loader):
        with torch.no_grad():
            X = batch["X"].to(device)
            targets = batch["y"].to(device)

            outputs = model(X).squeeze(1)
            loss = loss_fn(outputs, targets)

            valid_loss.update(loss.detach().item())
            valid_score.update(targets, outputs)

        _loss, _score = valid_loss.avg, valid_score.avg
        # message = f'Valid Step {step}/{len(val_loader)}, valid_loss: {_loss}, valid_score: {_score}'
            
    return {'val_loss': valid_loss.avg, 'val_acc': valid_score.avg,}

In [33]:
experiment_name = 'MRI_classifier_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [34]:
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=5,
    opt_treatment='CONTINUE_GLOBAL'
)

In [None]:
fl_experiment.stream_metrics(tensorboard_logs=True)