In [None]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

Mounted at /content/drive


In [None]:
#@test {"skip": true}
!pip install --quiet --upgrade tensorflow_federated
!pip install nest_asyncio
import nest_asyncio
nest_asyncio.apply()



In [None]:
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
else:
    print("No GPU found, model running on CPU")
import tensorflow_federated as tff
#from chexpert_parser import load_dataset, feature_description
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import IPython.display as display
import collections
from skimage import io
import matplotlib.pyplot as plt

tf.random.set_seed(123456789)
np.random.seed(123456789)


In [None]:
# Classi per le custom metrics

class LabelAUC_alt(tf.keras.metrics.AUC):
    def __init__(self, label_id, name='label_auc_alt', **kwargs):
        super(LabelAUC_alt, self).__init__(name=name, **kwargs)
        self.label_id = label_id
        self.auc = tf.constant(0)

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.auc = super(LabelAUC_alt, self).update_state(y_true[:, self.label_id], y_pred[:, self.label_id])

    def result(self):
        return print(self.auc)

    def reset_states(self):
        self.auc.assign(0)

class LabelAUC(tf.keras.metrics.AUC):
    def __init__(self, label_id, name="label_auc", **kwargs):
        super(LabelAUC, self).__init__(name=name, **kwargs)
        self.label_id = label_id
 
    def update_state(self, y_true, y_pred, **kwargs):
        return super(LabelAUC, self).update_state(y_true[:, self.label_id], y_pred[:, self.label_id], **kwargs)
 
    def result(self):
        return super(LabelAUC, self).result()
    
 
class MeanAUC_alt(tf.keras.metrics.AUC): # mean
    def __init__(self, name="label_mean_auc", **kwargs):
        super(MeanAUC, self).__init__(name=name, **kwargs)
        self.aucs = [LabelAUC(label_id=2), LabelAUC(label_id=5), LabelAUC(label_id=6), LabelAUC(label_id=8), LabelAUC(label_id=10)]

    def update_state(self, y_true, y_pred, **kwargs):
        self.mean=tf.constant(0)
        for auc in self.aucs:
            auc.update_state(y_true, y_pred)
        self.mean=(tf.constant(tf.reduce_mean([self.aucs[0].result(), self.aucs[1].result(), self.aucs[2].result(), self.aucs[3].result(), self.aucs[4].result()])))

    def result(self):
        return self.mean

    def reset_states(self):
        self.mean=tf.constant(0)

class MeanAUC(LabelAUC): 
    def __init__(self, label_id, name="label_mean_auc", **kwargs):
        super(MeanAUC, self).__init__(label_id=label_id, name=name, **kwargs)
        self.aucs = [LabelAUC(label_id=label_id[0]), LabelAUC(label_id=label_id[1]), LabelAUC(label_id=label_id[2]), LabelAUC(label_id=label_id[3]), LabelAUC(label_id=label_id[4])]

    def update_state(self, y_true, y_pred, **kwargs):
        for auc in self.aucs:
            auc.update_state(y_true, y_pred)
    
    def result(self):
        return tf.reduce_mean([auc.result().numpy() for auc in self.aucs])

    def reset_states(self):
        return super(LabelAUC, self).reset_states()

In [None]:
# Funzioni per il parsing dei tfrecords

def record_parser(example):
	example_fmt = {
		'label': tf.io.FixedLenFeature([14], tf.float32),
		'image': tf.io.FixedLenFeature([],tf.string, default_value='')}
	parsed = tf.io.parse_single_example(example, example_fmt)
	image = tf.image.resize_with_crop_or_pad(tf.io.decode_png(parsed["image"],channels=3), 224, 224)
	image = tf.image.convert_image_dtype(image, tf.float32)
	return image, parsed['label']

def normalize_image(img,labels):
	imagenet_mean = np.array([0.485, 0.456, 0.406])
	imagenet_std = np.array([0.229, 0.224, 0.225])
	img = (img - imagenet_mean) / imagenet_std
	return img,labels

def make_dataset(filename):
	dataset = tf.data.TFRecordDataset(filename)
	parsed_dataset = dataset.map(record_parser,num_parallel_calls = tf.data.experimental.AUTOTUNE)
	parsed_dataset = parsed_dataset.map(normalize_image,num_parallel_calls = tf.data.experimental.AUTOTUNE)
	return parsed_dataset

In [None]:
# Load Datasets
# In questo caso ogni client è diviso in 3 parti in modo da simulare la granularità fine

TAKE_ONLY = None
dataset_paths = {
    'client_0_part0': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client0_part-0.tfrecords',
    'client_1_part0': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client1_part-0.tfrecords',
    'client_2_part0': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client2_part-0.tfrecords',
    'client_3_part0': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client3_part-0.tfrecords',
    'client_4_part0': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client4_part-0.tfrecords',

    'client_0_part1': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client0_part-1.tfrecords',
    'client_1_part1': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client1_part-1.tfrecords',
    'client_2_part1': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client2_part-1.tfrecords',
    'client_3_part1': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client3_part-1.tfrecords',
    'client_4_part1': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client4_part-1.tfrecords',

    'client_0_part2': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client0_part-2.tfrecords',
    'client_1_part2': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client1_part-2.tfrecords',
    'client_2_part2': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client2_part-2.tfrecords',
    'client_3_part2': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client3_part-2.tfrecords',
    'client_4_part2': '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/Shards for TFF/client4_part-2.tfrecords',
}
client_list = dataset_paths.keys()
client_datasets = {client: make_dataset(dataset_paths[client]).batch(32, drop_remainder=False).prefetch(1) for client in client_list}

val_path = '/content/drive/MyDrive/tfrecords/nolat/Unbalanced/valid_norm.tfrecords'
val_dataset = make_dataset(val_path).batch(32, drop_remainder=False).prefetch(1)

### Create Federated Data

In [None]:
# Create the ClientData abstraction for the federated dataset
chex_train = tff.simulation.ClientData.from_clients_and_fn(client_list, lambda client: client_datasets[client])

# Fetch the dataset for each client
# federated_train_data = [chex_train.create_tf_dataset_for_client(client) for client in client_list]

In [None]:
# Return a different cluster of clients at each round

def make_federated_data(client_data, client_ids):
    return [ client_data.create_tf_dataset_for_client(x) for x in client_ids]

In [None]:
# Funzioni per train e validation step

@tf.function
def train_step(model, x, y):
    with tf.GradientTape(persistent=True) as tape:
        output = model(x, training=True)
        loss_value = loss_fn(y, output)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    del tape
    compute_metrics(y, output, run='training')
    return loss_value

@tf.function
def validation_step(model, x, y):
    output = model(x, training=False)
    loss = loss_fn(y, output)
    compute_metrics(y, output, run='validation')
    return loss, output

def compute_metrics(y_true, y_pred, run):
    if run == 'training':
        auc_train.update_state(y_true, y_pred)
        mean_auc_train.update_state(y_true, y_pred)
        auc_train_card.update_state(y_true, y_pred)
        auc_train_edema.update_state(y_true, y_pred)
        auc_train_cons.update_state(y_true, y_pred)
        auc_train_atel.update_state(y_true, y_pred)
        auc_train_peff.update_state(y_true, y_pred)
    if run == 'validation':
        auc_valid.update_state(y_true, y_pred)
        mean_auc_valid.update_state(y_true, y_pred)
        auc_valid_card.update_state(y_true, y_pred)
        auc_valid_edema.update_state(y_true, y_pred)
        auc_valid_cons.update_state(y_true, y_pred)
        auc_valid_atel.update_state(y_true, y_pred)
        auc_valid_peff.update_state(y_true, y_pred)

def callback_earlyStopping(MetricList, min_delta=0.1, patience=20, mode='min'):
    #No early stopping for the first patience epochs 
    if len(MetricList) <= patience:
        return False
    
    min_delta = abs(min_delta)
    if mode == 'min':
      min_delta *= -1
    else:
      min_delta *= 1
    
    #last patience epochs 
    last_patience_epochs = [x + min_delta for x in MetricList[::-1][1:patience + 1]]
    current_metric = MetricList[::-1][0]
    
    if mode == 'min':
        if current_metric >= max(last_patience_epochs):
            print(f'Metric did not decrease for the last {patience} epochs.')
            return True
        else:
            return False
    else:
        if current_metric <= min(last_patience_epochs):
            print(f'Metric did not increase for the last {patience} epochs.')
            return True
        else:
            return False

### Create Federated Model

In [None]:
from tensorflow.keras.applications.densenet import DenseNet201,DenseNet121,DenseNet169
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.keras.applications.xception import Xception
from tensorflow.keras.applications.nasnet import NASNetMobile
from tensorflow.keras.applications.mobilenet import MobileNet
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Activation, Lambda
from tensorflow.keras.models import Model
import functools


def create_mobilenet():
    base_model = MobileNet(input_shape=(224, 224, 3), weights='imagenet', include_top=False)
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    predictions = Dense(14, activation='sigmoid')(x)
    model = Model(inputs=base_model.inputs, outputs=predictions)
    return model

def model_fn():
    input_spec_test = ( 
        tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.float32), 
        tf.TensorSpec(shape=[None, 14], dtype=tf.float32) 
    )
    # We _must_ create a new model here, and _not_ capture it from an external
    # scope. TFF will call this within different graph contexts.
    model = create_mobilenet()
    keras_model = tf.keras.models.clone_model(model)
    return tff.learning.from_keras_model(
        keras_model,
        input_spec = input_spec_test,
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=[tf.keras.metrics.BinaryAccuracy()])

In [None]:
# Creo il federated_process

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.4),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.9))



In [None]:
# Verifico la signature
str(iterative_process.initialize.type_signature)

'( -> <model=<trainable=<float32[3,3,3,32],float32[32],float32[32],float32[3,3,32,1],float32[32],float32[32],float32[1,1,32,64],float32[64],float32[64],float32[3,3,64,1],float32[64],float32[64],float32[1,1,64,128],float32[128],float32[128],float32[3,3,128,1],float32[128],float32[128],float32[1,1,128,128],float32[128],float32[128],float32[3,3,128,1],float32[128],float32[128],float32[1,1,128,256],float32[256],float32[256],float32[3,3,256,1],float32[256],float32[256],float32[1,1,256,256],float32[256],float32[256],float32[3,3,256,1],float32[256],float32[256],float32[1,1,256,512],float32[512],float32[512],float32[3,3,512,1],float32[512],float32[512],float32[1,1,512,512],float32[512],float32[512],float32[3,3,512,1],float32[512],float32[512],float32[1,1,512,512],float32[512],float32[512],float32[3,3,512,1],float32[512],float32[512],float32[1,1,512,512],float32[512],float32[512],float32[3,3,512,1],float32[512],float32[512],float32[1,1,512,512],float32[512],float32[512],float32[3,3,512,1],float

In [None]:
# Inizializzo lo stato del federated_prcess.
'''
N.B. Questo deve restituire Instructions for updating:
    Use `tf.compat.v1.graph_util.extract_sub_graph`

    WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:59: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
    Instructions for updating:
    Use `tf.compat.v1.graph_util.extract_sub_graph`

Se questo non viene restituito lo stato non è inizializzato correttamente, bisogna quindi ricreare l'iterative_process
'''

state = iterative_process.initialize()

Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


### Federated Training

In [None]:
auc_valid = tf.keras.metrics.AUC(name='auc_valid')
mean_auc_valid = MeanAUC(label_id=[2,5,6,8,10], name='mean_auc_valid')
auc_valid_card = LabelAUC(2, name='auc_valid_card')
auc_valid_edema = LabelAUC(5, name='auc_valid_edema')
auc_valid_cons = LabelAUC(6, name='auc_valid_cons')
auc_valid_atel = LabelAUC(8, name='auc_valid_atel')
auc_valid_peff = LabelAUC(10, name='auc_valid_peff')

loss_fn = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.SGD(1e-3)

outputFolder = "/content/drive/MyDrive/Modelli_Tesi/NOLAT/Federated/Unbalanced_Fine_A"
if not os.path.exists(outputFolder):
    os.makedirs(outputFolder)

In [None]:
#Inizializzo pesi poichè sto utilizzando una pretrained mobilenet. 
#Questo deve essere fatto solo se si utilizzano modelli pretrained

keras_model = create_mobilenet()
keras_model.compile(loss=tf.keras.losses.BinaryCrossentropy(), metrics=[LabelAUC(2, name='label_auc_2')])

state = tff.learning.state_with_new_model_weights(
    state,
    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],
    non_trainable_weights=[
        v.numpy() for v in keras_model.non_trainable_weights
    ])

In [None]:
# Sampling clients partcipating to the training
# NB: da 0 a 4 partecipano i primi "batch" dei 5 clients. Da 5 a 9 i secondi batch. Da 10 a 14 i terzi batch.
# Da fare ad ogni inizio round. Ovviamente il numero di round viene moltiplicato per 3 nel caso abbiamo 3 batch per ogni client.
NUM_ROUNDS = 15
for round_num in range(1, NUM_ROUNDS+1):
    if round_num in (3,6,9,12,15):
        sample_clients = chex_train.client_ids[10:15]
        federated_train_data = make_federated_data(chex_train, sample_clients)
    elif round_num in (2,5,8,11,14):
        sample_clients = chex_train.client_ids[5:10]
        federated_train_data = make_federated_data(chex_train, sample_clients)
    elif round_num in (1,4,7,10,13):
        sample_clients = chex_train.client_ids[0:5]
        federated_train_data = make_federated_data(chex_train, sample_clients)
    print('Following clients are partecipating to training:')
    for client in sample_clients:
        print(client)

    #TRAINING
    state, metrics = iterative_process.next(state, federated_train_data)
    print('round {:2d}, metrics={}'.format(round_num, metrics))

    # Dopo ogni round di training estraggo i pesi del modello globale che assegno poi ad un nuovo modello per il calcolo delle metrics.
    # Unico modo attualmente per validare il modello
    keras_model = create_mobilenet()
    keras_model.compile(loss=tf.keras.losses.BinaryCrossentropy(), metrics=[LabelAUC(2, name='label_auc_2')])
    state.model.assign_weights_to(keras_model)

    #VALIDATION
    for step, row in enumerate(val_dataset):
        val_loss = validation_step(keras_model, row[0], row[1])
        if step % 100 == 0:
            template = 'VALIDATION: Round {}, Step {}, AUC MEAN: {}, AUC_cardiomegaly: {}, AUC_edema: {}, AUC_consolidation: {}, AUC_atelectasis: {}, AUC_pleural_effusion: {}, AUC_keras: {}'
            print(template.format(round_num, step, mean_auc_valid.result().numpy(), auc_valid_card.result().numpy(), auc_valid_edema.result().numpy(), auc_valid_cons.result().numpy(), auc_valid_atel.result().numpy(), auc_valid_peff.result().numpy(), auc_valid.result().numpy()))

    keras_model.save(outputFolder+'/model_'+str(round_num)+'.h5')
    keras_model.save_weights(outputFolder+'/weights_'+str(round_num)+'.h5')

    e = {'Round': [round_num], 'AUC': [auc_valid.result().numpy()], 'AUC Mean': [mean_auc_valid.result().numpy()], 'AUC_cardiomegaly': [auc_valid_card.result().numpy()], 'AUC_edema': [auc_valid_edema.result().numpy()], 'AUC_consolidation': [auc_valid_cons.result().numpy()], 'AUC_atelectasis': [auc_valid_atel.result().numpy()], 'AUC_pleural_effusion': [auc_valid_peff.result().numpy()]}
    log_e = pd.DataFrame(data=e)
    if round_num == 1:
        log_tot = log_e
    else:
        log_tot = log_tot.append(log_e)

    print(log_tot)

    auc_valid.reset_states()
    mean_auc_valid.reset_states()
    auc_valid_card.reset_states()
    auc_valid_edema.reset_states()
    auc_valid_cons.reset_states()
    auc_valid_atel.reset_states()
    auc_valid_peff.reset_states()

log_tot.to_csv(outputFolder+'/log.csv',index=False)

Following clients are partecipating to training:
client_0_part0
client_1_part0
client_2_part0
client_3_part0
client_4_part0
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('binary_accuracy', 0.83963686), ('loss', 0.351564)])), ('stat', OrderedDict([('num_examples', 59380)]))])
VALIDATION: Round 1, Step 0, AUC MEAN: 0.5999886393547058, AUC_cardiomegaly: 0.5, AUC_edema: 0.625, AUC_consolidation: 0.5, AUC_atelectasis: 0.5507246255874634, AUC_pleural_effusion: 0.82421875, AUC_keras: 0.6743844151496887
VALIDATION: Round 1, Step 100, AUC MEAN: 0.5736873149871826, AUC_cardiomegaly: 0.5, AUC_edema: 0.5511252284049988, AUC_consolidation: 0.5, AUC_atelectasis: 0.6254855990409851, AUC_pleural_effusion: 0.6918256878852844, AUC_keras: 0.6395471692085266
VALIDATION: Round 1, Step 200, AUC MEAN: 0.5773312449455261, AUC_cardiomegaly: 0.5, AUC_edema: 0.5544095635414124, AUC_consolidation: 0.5, AUC_atelect