# Federated Next Word Prediction with Director example
## Using low-level Python API

In [2]:
# Install dependencies if not already installed
# !pip uninstall tensorflow==2.5.1

# Connect to the Federation

In [2]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
cliend_id = 'frontend'

# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = 'cert/root_ca.crt'
# API_certificate = 'cert/frontend.crt'
# API_private_key = 'cert/frontend.key'

# federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051', disable_tls=False,
#                        cert_chain=cert_chain, api_cert=API_certificate, api_private_key=API_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051', tls=False)

In [3]:
shard_registry = federation.get_shard_registry()
shard_registry

{'env_one': {'shard_info': node_info {
    name: "env_one"
  }
  shard_description: "Market dataset, shard number 1 out of 2"
  n_samples: 6468
  sample_shape: "64"
  sample_shape: "128"
  sample_shape: "3"
  target_shape: "1501",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2021-08-16 21:04:28',
  'current_time': '2021-08-16 21:04:45',
  'valid_duration': seconds: 60},
 'env_two': {'shard_info': node_info {
    name: "env_two"
  }
  shard_description: "Market dataset, shard number 2 out of 2"
  n_samples: 6468
  sample_shape: "64"
  sample_shape: "128"
  sample_shape: "3"
  target_shape: "1501",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2021-08-16 21:04:31',
  'current_time': '2021-08-16 21:04:45',
  'valid_duration': seconds: 60}}

In [4]:
federation.target_shape

['1501']

In [5]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
sample, target = dummy_shard_desc[0]

## Creating a FL experiment using Interactive API

In [6]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register dataset

In [7]:
import numpy as np
from tensorflow.keras.utils import Sequence

class DataGenerator(Sequence):

    def __init__(self, shard_descriptor, indices, batch_size):
        self.shard_descriptor = shard_descriptor
        self.batch_size = batch_size
        self.indices = indices
        self.on_epoch_end()

    def __len__(self):
        return len(self.indices) // self.batch_size

    def __getitem__(self, index):
        index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
        batch = [self.indices[k] for k in index]

        X, y = self.shard_descriptor(batch)
        return X, y

    def on_epoch_end(self):
        self.index = np.arange(len(self.indices))


# Now you can implement you data loaders using dummy_shard_desc
class NextWordSD(DataInterface):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

        train = round(len(self) * 0.8)
        self.train_indeces = list(range(1, train))
        self.valid_indeces = list(range(train, len(self)))

    def __getitem__(self, index):
        return self.shard_descriptor[index]

    def __len__(self):
        return len(self.shard_descriptor)

    def get_train_loader(self):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        if self.kwargs['train_bs']:
            batch_size = self.kwargs['train_bs']
        else:
            batch_size = 64

        return DataGenerator(self.shard_descriptor, self.train_indeces, batch_size=batch_size)

    def get_valid_loader(self):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        if self.kwargs['valid_bs']:
            batch_size = self.kwargs['valid_bs']
        else:
            batch_size = 512

        return DataGenerator(self.shard_descriptor, self.valid_indeces, batch_size=batch_size)

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_indeces)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_indeces)


In [8]:
fed_dataset = NextWordSD(train_bs=64, valid_bs=512)

### Describe a model and optimizer

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.optimizers import Adam


class Model(tf.keras.Model):
    """ Model definition."""

    def __init__(self, **kwargs):
        super().__init__()

        vocab_size = 48904
        self.emb = Embedding(vocab_size, 10, input_length=3)
        self.lstm1 = LSTM(1000, return_sequences=True)
        self.lstm2 = LSTM(1000)
        self.d1 = Dense(1000, activation='tanh')
        self.d2 = Dense(vocab_size, activation='softmax')

    def call(self, x):
        x = self.emb(x)
        x = self.lstm1(x)
        x = self.lstm2(x)
        x = self.d1(x)
        x = self.d2(x)
        return x

# Construct an instance of Model
model = Model()
optimizer = Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=[tf.keras.metrics.Accuracy()])

#### Register model

In [12]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)
# Save the initial model state
initial_model = deepcopy(model)

### Define and register FL tasks

In [13]:
import tqdm

TI = TaskInterface()

@TI.register_fl_task(model='model', data_loader='train_loader')
def train(model, train_loader):  # https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit
    if tf.test.is_gpu_available():
        device = tf.device('/gpu:0')
    else:
        device = tf.device('/cpu:0')

    with device:
        train_loader = tqdm.tqdm(train_loader, desc='train')

        # Unpack the data
        x, y = train_loader()
        with tf.GradientTape() as tape:
            y_pred = model(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = model.compiled_loss(y, y_pred, regularization_losses=model.losses)

        # Compute gradients
        trainable_vars = model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        model.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics (includes the metric that tracks the loss)
        model.compiled_metrics.update_state(y, y_pred)
    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in model.metrics}


@TI.register_fl_task(model='model', data_loader='val_loader')
def validate(model, val_loader):
    if tf.test.is_gpu_available():
        device = tf.device('/gpu:0')
    else:
        device = tf.device('/cpu:0')

    with device:
        # Unpack the data
        x, y = val_loader
        # Compute predictions
        y_pred = model(x, training=False)
        # Updates the metrics tracking the loss
        model.compiled_loss(y, y_pred, regularization_losses=model.losses)
        # Update the metrics.
        model.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
    # Note that it will include the loss (tracked in self.metrics).
    return {m.name: m.result() for m in model.metrics}

## Time to start a federated learning experiment

In [14]:
# create an experimnet in federation
experiment_name = 'word_prediction_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [16]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=60,
                    opt_treatment='RESET')

TypeError: build() missing 2 required positional arguments: 'template' and 'settings'

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going 
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()

In [None]:
# todo: add testing on metamorphosis