# Federated Next Word Prediction with Director example
## Using low-level Python API

# Connect to the Federation

In [2]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
cliend_id = 'frontend'

# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = 'cert/root_ca.crt'
# API_certificate = 'cert/frontend.crt'
# API_private_key = 'cert/frontend.key'

# federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051', disable_tls=False,
#                        cert_chain=cert_chain, api_cert=API_certificate, api_private_key=API_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051', tls=False)

In [3]:
shard_registry = federation.get_shard_registry()
shard_registry

{'env_one': {'shard_info': node_info {
    name: "env_one"
  }
  shard_description: "Market dataset, shard number 1 out of 2"
  n_samples: 6468
  sample_shape: "64"
  sample_shape: "128"
  sample_shape: "3"
  target_shape: "1501",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2021-08-16 21:04:28',
  'current_time': '2021-08-16 21:04:45',
  'valid_duration': seconds: 60},
 'env_two': {'shard_info': node_info {
    name: "env_two"
  }
  shard_description: "Market dataset, shard number 2 out of 2"
  n_samples: 6468
  sample_shape: "64"
  sample_shape: "128"
  sample_shape: "3"
  target_shape: "1501",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2021-08-16 21:04:31',
  'current_time': '2021-08-16 21:04:45',
  'valid_duration': seconds: 60}}

In [4]:
federation.target_shape

['1501']

In [5]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
sample, target = dummy_shard_desc[0]

## Creating a FL experiment using Interactive API

In [6]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register dataset

In [7]:
# Now you can implement you data loaders using dummy_shard_desc
class NextWordSD(DataInterface, Dataset):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

    def __getitem__(self, index): # todo
        img, pid = self.shard_descriptor[index]
        img = self.img_trans(img).numpy()
        return img, pid

    def __len__(self):
        return len(self.shard_descriptor)

    def get_train_loader(self, **kwargs): # todo
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        if self.kwargs['train_bs']:
            batch_size = self.kwargs['train_bs']
        else:
            batch_size = 64

        return DataLoader(
            ImageDataset(self.shard_descriptor.train, transform=self.transform_train),
            sampler=RandomIdentitySampler(self.shard_descriptor.train, num_instances=4),
            batch_size=batch_size, num_workers=4, pin_memory=True, drop_last=True
        )

    def get_valid_loader(self, **kwargs): # todo
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        if self.kwargs['valid_bs']:
            batch_size = self.kwargs['valid_bs']
        else:
            batch_size = 512

        return (
            DataLoader(ImageDataset(self.shard_descriptor.query, transform=self.transform_test),
                       batch_size=batch_size, num_workers=4, pin_memory=True,
                       drop_last=False, shuffle=False),
            DataLoader(ImageDataset(self.shard_descriptor.gallery, transform=self.transform_test),
                       batch_size=batch_size, num_workers=4, pin_memory=True,
                       drop_last=False, shuffle=False)
        )

    def get_train_data_size(self): # todo
        """
        Information for aggregation
        """
        return self.shard_descriptor.num_train_pids

    def get_valid_data_size(self): # todo
        """
        Information for aggregation
        """
        return self.shard_descriptor.num_gallery_pids

In [8]:
fed_dataset = NextWordSD(train_bs=64, valid_bs=512)

### Describe a model and optimizer

In [None]:
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

vocab_size = 2573

model = Sequential()
model.add(Embedding(vocab_size, 10, input_length=1))
model.add(LSTM(1000, return_sequences=True))
model.add(LSTM(1000))
model.add(Dense(1000, activation='relu'))
model.add(Dense(vocab_size, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.001))

In [None]:
from tensorflow.keras import Model as tf_model
from tensorflow.keras.layers import Embedding, LSTM, Dense


class Model(tf_model):
    """ Model definition """

    def __init__(self, **kwargs):
        super().__init__()

        vocab_size = 2573
        self.emb = Embedding(vocab_size, 10, input_length=1)
        self.lstm1 = LSTM(1000, return_sequences=True)
        self.lstm2 = LSTM(1000)
        self.d1 = Dense(1000, activation='relu')
        self.d2 = Dense(vocab_size, activation='softmax')

    def call(self, x):
        x = self.emb(x)
        x = self.lstm1(x)
        x = self.lstm2(x)
        x = self.d1(x)
        x = self.d2(x)
        return x

    def train_step(self, data):    # https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

model = Model()

#### Register model

In [12]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)
# Save the initial model state
initial_model = deepcopy(model)

### Define and register FL tasks

In [13]:
TI = TaskInterface()

from logging import getLogger

logger = getLogger(__name__)

# Task interface currently supports only standalone functions.
@TI.register_fl_task(model='model', data_loader='train_loader',
                     device='device', optimizer='optimizer')
def train(model, train_loader, optimizer, device=''):
    device = torch.device('cuda')
    
    criterion_cla = ArcFaceLoss(scale=16., margin=0.1)
    criterion_pair = TripletLoss(margin=0.3, distance='cosine')

    batch_cla_loss = AverageMeter()
    batch_pair_loss = AverageMeter()
    corrects = AverageMeter()
    
    model.train()
    model.to(device)
    model.classifier.train()
    model.classifier.to(device)
    
    logger.info('==> Start training')
    train_loader = tqdm.tqdm(train_loader, desc='train')

    for batch_idx, (imgs, pids, _) in enumerate(train_loader):
        imgs, pids = torch.tensor(imgs).to(device), torch.tensor(pids).to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward
        features = model(imgs)
        outputs = model.classifier(features)
        _, preds = torch.max(outputs.data, 1)
        # Compute loss
        cla_loss = criterion_cla(outputs, pids)
        pair_loss = criterion_pair(features, pids)
        loss = cla_loss + pair_loss
        # Backward + Optimize
        loss.backward()
        optimizer.step()
        # statistics
        corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0))
        batch_cla_loss.update(cla_loss.item(), pids.size(0))
        batch_pair_loss.update(pair_loss.item(), pids.size(0))

    return {'ArcFaceLoss': batch_cla_loss.avg,
            'TripletLoss': batch_pair_loss.avg,
            'Accuracy': corrects.avg.cpu()}


@TI.register_fl_task(model='model', data_loader='val_loader', device='device')
def validate(model, val_loader, device):
    queryloader, galleryloader = val_loader
    device = torch.device('cuda')
    
    logger.info('==> Start validating')
    model.eval()
    model.to(device)
    
    # Extract features for query set
    qf, q_pids, q_camids = extract_feature(model, queryloader)
    logger.info(f'Extracted features for query set, obtained {qf.shape} matrix')
    # Extract features for gallery set
    gf, g_pids, g_camids = extract_feature(model, galleryloader)
    logger.info(f'Extracted features for gallery set, obtained {gf.shape} matrix')
    # Compute distance matrix between query and gallery
    m, n = qf.size(0), gf.size(0)
    distmat = torch.zeros((m,n))
    # Cosine similarity
    qf = nn.functional.normalize(qf, p=2, dim=1)
    gf = nn.functional.normalize(gf, p=2, dim=1)
    for i in range(m):
        distmat[i] = - torch.mm(qf[i:i+1], gf.t())
    distmat = distmat.numpy()

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    return {'top1': cmc[0], 'top5': cmc[4], 'top10': cmc[9], 'mAP': mAP}

## Time to start a federated learning experiment

In [14]:
# create an experimnet in federation
experiment_name = 'market_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [16]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=3,
                    opt_treatment='RESET')

TypeError: build() missing 2 required positional arguments: 'template' and 'settings'

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going 
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()