In [1]:
import os
import glob

from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment
from copy import deepcopy
import torchvision
from torchvision import transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import tqdm

import albumentations as A
from model import Model

torch.manual_seed(0)
np.random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'

federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051', tls=False)

In [3]:
federation.target_shape

['1']

In [4]:
shard_registry = federation.get_shard_registry()
shard_registry

{'env_zero': {'shard_info': node_info {
    name: "env_zero"
  }
  shard_description: "Local MRI Shard Descriptor is working."
  sample_shape: "4"
  sample_shape: "256"
  sample_shape: "256"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-05-23 16:39:53',
  'current_time': '2022-05-23 16:40:06',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'env_two': {'shard_info': node_info {
    name: "env_two"
  }
  shard_description: "Local MRI Shard Descriptor is working."
  sample_shape: "4"
  sample_shape: "256"
  sample_shape: "256"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-05-23 16:39:53',
  'current_time': '2022-05-23 16:40:06',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'}}

In [5]:
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
print(sample.shape)
print(target.shape)


(4, 256, 256, 3)
(1,)


In [6]:
train_transform = A.Compose([
                                A.HorizontalFlip(p=0.5),
                                A.ShiftScaleRotate(
                                    shift_limit=0.0625,
                                    scale_limit=0.1,
                                    rotate_limit=10,
                                    p=0.5
                                ),
                                A.RandomBrightnessContrast(p=0.5),
                            ])
valid_transform = A.Compose([])

In [7]:
class TinyImageNetDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self._shard_descriptor.split_dataset(test_size=0.2)
        self._shard_descriptor.set_transform_params(train_transform)
        
        self.train_set = self._shard_descriptor.get_dataset('train')
        self.valid_set = self._shard_descriptor.get_dataset('val')
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        generator=torch.Generator()
        generator.manual_seed(0)
        return DataLoader(
            self.train_set, batch_size=self.kwargs['train_bs'], shuffle=True, generator=generator, num_workers=0
            )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, batch_size=self.kwargs['valid_bs'], shuffle=True, num_workers=0)

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)

In [8]:
fed_dataset = TinyImageNetDataset(train_bs=8, valid_bs=8)

In [9]:
model = Model()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = F.binary_cross_entropy_with_logits

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)

In [10]:
task_interface = TaskInterface()


# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')
    
class Loss:
    def __init__(self):
        self.avg = 0
        self.n = 0

    def update(self, val):
        self.n += 1
        self.avg = val / self.n + (self.n - 1) / self.n * self.avg

class Acc:
    def __init__(self):
        self.avg = 0
        self.n = 0

    def update(self, y_true, y_pred):
        y_true = y_true.cpu().numpy().astype(int)
        y_pred = y_pred.cpu().numpy() >= 0
        last_n = self.n
        self.n += len(y_true)
        true_count = np.sum(y_true == y_pred)
        self.avg = true_count / self.n + last_n / self.n * self.avg

# Task interface currently supports only standalone functions.
@task_interface.add_kwargs(**{'some_parameter': 42})
@task_interface.register_fl_task(model='model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
def train(model, train_loader, optimizer, device, loss_fn=criterion, some_parameter=None):
    model.train()
    device='cpu'
    t = time.time()
    train_loss = Loss()
    train_score = Acc()

    for step, batch in enumerate(train_loader, 1):
        X = batch["X"].to(device)
        targets = batch["y"].to(device)
        optimizer.zero_grad()
        outputs = model(X).squeeze(1)

        loss = loss_fn(outputs, targets)
        loss.backward()

        train_loss.update(loss.detach().item())
        train_score.update(targets, outputs.detach())

        optimizer.step()

        _loss, _score = train_loss.avg, train_score.avg
        message = 'Train Step {}/{}, train_loss: {:.5f}, train_score: {:.5f}'
        # self.info_message(message, step, len(train_loader), _loss, _score, end="\r")
        print(step)

    return {'train_loss': train_loss.avg,}train_loss.avg


@task_interface.register_fl_task(model='model', data_loader='val_loader', device='device')     
def validate(model, val_loader, device):
    model.eval()
    t = time.time()
    valid_loss = Loss()
    valid_score = Acc()

    for step, batch in enumerate(valid_loader, 1):
        with torch.no_grad():
            X = batch["X"].to(self.device)
            targets = batch["y"].to(device)

            outputs = self.model(X).squeeze(1)
            loss = self.criterion(outputs, targets)

            valid_loss.update(loss.detach().item())
            valid_score.update(targets, outputs)

        _loss, _score = valid_loss.avg, valid_score.avg
        message = 'Valid Step {}/{}, valid_loss: {:.5f}, valid_score: {:.5f}'
        # self.info_message(message, step, len(valid_loader), _loss, _score, end="\r")
        print(step)
            
    return {'acc': valid_score.avg,}

In [11]:
experiment_name = 'tinyimagenet_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [12]:
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=5,
    opt_treatment='CONTINUE_GLOBAL'
)



In [None]:
fl_experiment.stream_metrics(tensorboard_logs=False)