# Federated Market with Director example
## Using low-level Python API

# Long-Living entities update

* We now may have director running on another machine.
* We use Federation API to communicate with Director.
* Federation object should hold a Director's client (for user service)
* Keeping in mind that several API instances may be connacted to one Director.


* We do not think for now how we start a Director.
* But it knows the data shape and target shape for the DataScience problem in the Federation.
* Director holds the list of connected envoys, we do not need to specify it anymore.
* Director and Envoys are responsible for encrypting connections, we do not need to worry about certs.


* Yet we MUST have a cert to communicate to the Director.
* We MUST know the FQDN of a Director.
* Director communicates data and target shape to the Federation interface object.


* Experiment API may use this info to construct a dummy dataset and a `shard descriptor` stub.

In [None]:
# Install dependencies if not already installed
!pip install torch==1.9.0
!pip install torchvision==0.10.0

# Connect to the Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
cliend_id = 'frontend'

# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = 'cert/root_ca.crt'
# API_certificate = 'cert/frontend.crt'
# API_private_key = 'cert/frontend.key'

# federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051', disable_tls=False,
#                        cert_chain=cert_chain, api_cert=API_certificate, api_private_key=API_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051', disable_tls=True)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
federation.target_shape

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
sample, target = dummy_shard_desc[0]

## Creating a FL experiment using Interactive API

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register dataset

We extract User dataset class implementation.
Is it convinient?
What if the dataset is not a class?

In [1]:
import numpy as np
from torch.utils.data import Dataset, DataLoader

from dataset import Market1501
from tools import ImageDataset, RandomIdentitySampler
import transforms as T

# Now you can implement you data loaders using dummy_shard_desc
class MarketSD(DataInterface, Dataset):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        if self.kwargs['data_path'].isdigit():    # split aggregator data by data path (index 1 or 2 data[index-1::2])
            split_data = True
        else:    # absolute path
            split_data = False

        self.dataset = Market1501(root=self.kwargs['data_path'], split_data=split_data)

        # Prepare transforms
        self.transform_train = T.Compose([
            T.RandomCroping(256, 128, p=0.5),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            T.RandomErasing(probability=0.5)
        ])
        self.transform_test = T.Compose([
            T.Resize((265, 128)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __getitem__(self, index):    # todo
        raise NotImplementedError

    def __len__(self):    # todo
        raise NotImplementedError

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        if self.kwargs['train_bs']:
            batch_size = self.kwargs['train_bs']
        else:
            batch_size = 64

        return DataLoader(
            ImageDataset(self.dataset.train, transform=self.transform_train),
            sampler=RandomIdentitySampler(self.dataset.train, num_instances=4),
            batch_size=batch_size, num_workers=4, pin_memory=True, drop_last=True
        )

    def get_query_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        if self.kwargs['valid_bs']:
            batch_size = self.kwargs['valid_bs']
        else:
            batch_size = 512

        return DataLoader(
            ImageDataset(self.dataset.train, transform=self.transform_test),
            batch_size=batch_size, num_workers=4, pin_memory=True,
            drop_last=False, shuffle=False
        )

    def get_gallery_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        if self.kwargs['valid_bs']:
            batch_size = self.kwargs['valid_bs']
        else:
            batch_size = 512

        return DataLoader(
            ImageDataset(self.dataset.gallery, transform=self.transform_test),
            batch_size=batch_size, num_workers=4, pin_memory=True,
            drop_last=False, shuffle=False
        )

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return self.dataset.num_train_pids

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return self.dataset.num_gallery_pids

NameError: name 'DataInterface' is not defined

In [None]:
fed_dataset = MarketSD(train_bs=4, valid_bs=8)
fed_dataset.shard_descriptor = dummy_shard_desc
for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):
    print(sample.shape)

### Describe a model and optimizer

In [None]:
import torch.nn as nn
import torch.optim as optim
import torchvision

In [None]:
"""
ResNet and Classifier definition
"""

class ResNet50(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        resnet50 = torchvision.models.resnet50(pretrained=True)
        resnet50.layer4[0].conv2.stride = (1, 1)
        resnet50.layer4[0].downsample[0].stride = (1, 1)
        self.base = nn.Sequential(*list(resnet50.children())[:-2])

        self.bn = nn.BatchNorm1d(2048)
        nn.init.normal_(self.bn.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bn.bias.data, 0.0)

    def forward(self, x):
        x = self.base(x)
        x = nn.functional.avg_pool2d(x, x.size()[2:])
        x = x.view(x.size(0), -1)
        f = self.bn(x)

        return f


class NormalizedClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(1501, 2048))
        self.weight.data.uniform_(-1, 1).renorm_(2,0,1e-5).mul_(1e5)

    def forward(self, x):
        w = self.weight

        x = nn.functional.normalize(x, p=2, dim=1)
        w = nn.functional.normalize(w, p=2, dim=1)

        return nn.functional.linear(x, w)


resnet = ResNet50()
classifier = NormalizedClassifier()

In [None]:
from losses import ArcFaceLoss, TripletLoss

parameters = list(resnet.parameters()) + list(classifier.parameters())
optimizer_adam = optim.Adam(parameters, lr=1e-4)

criterion_cla = ArcFaceLoss(scale=16., margin=0.1)
criterion_pair = TripletLoss(margin=0.3, distance='cosine')

scheduler = optim.lr_scheduler.MultiStepLR(optimizer_adam, milestones=[20, 40], gamma=0.1)

#### Register model

In [None]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=resnet, optimizer=optimizer_adam, framework_plugin=framework_adapter)    # todo

# Save the initial model state
initial_model = deepcopy(resnet)    # todo

### Define and register FL tasks

In [None]:
TI = TaskInterface()
import torch

import tqdm

from tools import AverageMeter, evaluate, extract_feature


# Task interface currently supports only standalone functions.
@TI.add_kwargs(**{
    'classifier': classifier,
    'criterion_cla': criterion_cla,
    'criterion_pair': criterion_pair,
})
@TI.register_fl_task(model='resnet', data_loader='train_loader',
                     device='device', optimizer='optimizer')     
def train(model, classifier, train_loader, optimizer, device, criterion_cla, criterion_pair):
    # if not torch.cuda.is_available():
    #     device = 'cpu'
    # else:
    #     device = 'cuda'
    
    # function_defined_in_notebook(some_parameter)

    batch_cla_loss = AverageMeter()
    batch_pair_loss = AverageMeter()
    corrects = AverageMeter()

    trainloader = tqdm.tqdm(train_loader, desc='train')
    
    model.train()
    model.to(device)
    classifier.train()
    classifier.to(device)

    for batch_idx, (imgs, pids, _) in enumerate(trainloader):
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward
        features = model(imgs)
        outputs = classifier(features)
        _, preds = torch.max(outputs.data, 1)
        # Compute loss
        cla_loss = criterion_cla(outputs, pids)
        pair_loss = criterion_pair(features, pids)
        loss = cla_loss + pair_loss
        # Backward + Optimize
        loss.backward()
        optimizer.step()
        # statistics
        corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0))
        batch_cla_loss.update(cla_loss.item(), pids.size(0))
        batch_pair_loss.update(pair_loss.item(), pids.size(0))

    return {'ClaLoss': batch_cla_loss.avg,
            'PairLoss': batch_pair_loss.avg,
            'Accuracy': corrects.avg}


@TI.register_fl_task(model='resnet', data_loader='queryloader', device='device')
def validate(model, queryloader, galleryloader, device):
    # if not torch.cuda.is_available():
    #     device = 'cpu'
    # else:
    #     device = 'cuda'
    #
    model.eval()
    model.to(device)

    # Extract features for query set
    qf, q_pids, q_camids = extract_feature(model, queryloader)
    print(f'Extracted features for query set, obtained {qf.shape} matrix')
    # Extract features for gallery set
    gf, g_pids, g_camids = extract_feature(model, galleryloader)
    print(f'Extracted features for gallery set, obtained {gf.shape} matrix')
    # Compute distance matrix between query and gallery
    m, n = qf.size(0), gf.size(0)
    distmat = torch.zeros((m,n))
    # Cosine similarity
    qf = nn.functional.normalize(qf, p=2, dim=1)
    gf = nn.functional.normalize(gf, p=2, dim=1)
    for i in range(m):
        distmat[i] = - torch.mm(qf[i:i+1], gf.t())
    distmat = distmat.numpy()

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    return {'top1': cmc[0], 'top5': cmc[4], 'top10': cmc[9], 'mAP': mAP}


## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'market_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=2,
                    opt_treatment='CONTINUE_GLOBAL')


In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going 
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()

## Now we validate the best model!

In [None]:
best_model = fl_experiment.get_best_model()

In [None]:
# We remove exremove_experiment_datamove_experiment_datamove_experiment_datariment data from director
fl_experiment.remove_experiment_data()

In [None]:
best_model.inc.conv[0].weight
# model_unet.inc.conv[0].weight

In [None]:
# Validating initial model
validate(initial_model, fed_dataset.get_valid_loader(), 'cpu')

In [None]:
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')

## We can tune model further!

In [None]:
MI = ModelInterface(model=best_model, optimizer=optimizer_adam, framework_plugin=framework_adapter)
fl_experiment.start_experiment(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=4, \
                              opt_treatment='CONTINUE_GLOBAL')

In [None]:
best_model = fl_experiment.get_best_model()
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')

In [None]:
a = (np.zeros((2,4)), np.ones(2,), 2*np.ones(2,))
a = [elem for elem in a if elem.shape==2 else elem.newaxis()]
np.concatenate(a)

In [None]:
class A:
    def __init__(self, **kwargs):
        print("Class A", kwargs)
        
class B:
    def __init__(self, **kwargs):
        super().__init__()
        print("Class B", kwargs)
        
class C(B, A):
    def __init__(self, **kwargs):
        super().__init__()
        print("Class C", kwargs)
        
# class A:
#     def __init__(self, **kwargs):
#         print("Class A", kwargs)

In [None]:
c = C(x=1, z=5)