In [1]:
import logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf

for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

from tensorflow.keras import datasets

from tensorflow import keras

Num GPUs Available:  1


In [2]:
import numpy as np

In [3]:
from fdavg.models import count_weights

In [4]:
(X_train, y_train), (X_test, y_test) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
#X_train, X_test = X_train / 255.0, X_test / 255.0

In [5]:
X_train.shape

(50000, 32, 32, 3)

In [6]:
y_train.shape

(50000, 1)

In [7]:
y_train = np.squeeze(y_train)
y_test = np.squeeze(y_test)

In [8]:
y_train.shape

(50000,)

# DenseNet

In [12]:
from tensorflow.keras import layers, models

""" 
Implementation from https://github.com/keras-team/keras-applications/blob/master/keras_applications/densenet.py

DenseNet, not pre-trained, specifically for the CIFAR-10 datasets. 

Note:
    - Preprocessing on input is assumed using `tensorflow.keras.applications.densenet.preprocess_input`.

Deviations from original keras implementation:
    1) We add dropout layers with rate=0.2 as suggested by Huang et. al, 2016 for training on CIFAR-10
    2) We adopt `he normal` weight-initialization He et al., 2015 as suggested by Huang et. al., 2016
"""

def dense_block(x, blocks, name):
    for i in range(blocks):
        x = conv_block(x, 32, name=name + '_block' + str(i + 1))
    return x


def transition_block(x, reduction, name):
    bn_axis = 1  # For NCHW format : (batch_size, channels, height, width)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')(x)
    x = layers.Activation('relu', name=name + '_relu')(x)
    x = layers.Conv2D(int(x.shape[bn_axis] * reduction), 1, kernel_initializer='he_normal', use_bias=False, name=name + '_conv', data_format='channels_first')(x)
    x = layers.AveragePooling2D(2, strides=2, name=name + '_pool', data_format='channels_first')(x)
    return x


def conv_block(x, growth_rate, name):
    bn_axis = 1  # For NCHW format : (batch_size, channels, height, width)
    x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(x)
    x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
    x1 = layers.Conv2D(4 * growth_rate, 1, use_bias=False, kernel_initializer='he_normal', name=name + '_1_conv', data_format='channels_first')(x1)
    x1 = layers.Dropout(0.2, name=name + '_1_dropout')(x1)  # Add dropout 0.2 after convolution as Huang et. al suggest for Cifar-10
    x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(x1)
    x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
    x1 = layers.Conv2D(growth_rate, 3, padding='same', use_bias=False, kernel_initializer='he_normal', name=name + '_2_conv', data_format='channels_first')(x1)
    x1 = layers.Dropout(0.2, name=name + '_2_dropout')(x1)  # Add dropout 0.2 after convolution as Huang et. al suggest for Cifar-10
    x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
    return x


def dense_net_fn(blocks, input_shape, classes):

    # Determine proper input shape
    img_input = layers.Input(shape=input_shape)

    bn_axis = 1  # For NCHW format : (batch_size, channels, height, width)
    
    x_nchw = tf.transpose(img_input, [0, 3, 1, 2])  # Tranform to NCHW format
    #x_nchw = img_input

    x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), data_format='channels_first')(x_nchw)
    x = layers.Conv2D(64, 7, strides=2, use_bias=False, kernel_initializer='he_normal', name='conv1/conv', data_format='channels_first')(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)
    x = layers.Activation('relu', name='conv1/relu')(x)
    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), data_format='channels_first')(x)
    x = layers.MaxPooling2D(3, strides=2, name='pool1', data_format='channels_first')(x)

    x = dense_block(x, blocks[0], name='conv2')
    x = transition_block(x, 0.5, name='pool2')
    x = dense_block(x, blocks[1], name='conv3')
    x = transition_block(x, 0.5, name='pool3')
    x = dense_block(x, blocks[2], name='conv4')
    x = transition_block(x, 0.5, name='pool4')
    x = dense_block(x, blocks[3], name='conv5')

    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
    x = layers.Activation('relu', name='relu')(x)

    x = layers.GlobalAveragePooling2D(name='avg_pool', data_format='channels_first')(x)
    x = layers.Dense(classes, kernel_initializer='he_normal', activation='softmax', name='fc10')(x)

    inputs = img_input

    # Create model.
    if blocks == [6, 12, 24, 16]:
        model = models.Model(inputs, x, name='densenet121')
    elif blocks == [6, 12, 32, 32]:
        model = models.Model(inputs, x, name='densenet169')
    elif blocks == [6, 12, 48, 32]:
        model = models.Model(inputs, x, name='densenet201')
    else:
        model = models.Model(inputs, x, name='densenet')

    return model


class DenseNet:
    def __init__(self, name, input_shape=(32, 32, 3), classes=10):
        
        self.model = None
        
        if name == 'DenseNet121':
            self.model = dense_net_fn([6, 12, 24, 16], input_shape, classes)
        if name == 'DenseNet169':
            self.model = dense_net_fn([6, 12, 32, 32], input_shape, classes)
        if name == 'DenseNet201':
            self.model = dense_net_fn([6, 12, 48, 32], input_shape, classes)
            
    def __getattr__(self, name):
        # Automatically delegate method calls to the underlying Keras model. 
        # This ensures that the custom class supports all methods of the 
        # Keras model without having to define each one explicitly.
        return getattr(self.model, name)

    def step(self, batch):
        x_batch, y_batch = batch
        return self.train_on_batch(x=x_batch, y=y_batch)
        
    def train(self, dataset):
        for batch in dataset:
            self.step(batch)
            
    def set_trainable_variables(self, trainable_vars):
        for model_var, var in zip(self.trainable_variables, trainable_vars):
            model_var.assign(var)
        
    def set_non_trainable_variables(self, non_trainable_vars):
        for model_var, var in zip(self.non_trainable_variables, non_trainable_vars):
            model_var.assign(var)

    @tf.function
    def trainable_vars_as_vector(self):
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)

    def per_layer_trainable_vars_as_vector(self):
        layer_vectors = [
            tf.concat([tf.reshape(var, [-1]) for var in layer.trainable_weights], axis=0)
            for layer in self.layers
            if layer.trainable_weights
        ]

        return layer_vectorsdd

    def set_layer_weights(self, layer_i, weights):
        for model_var, var in zip(self.layers[layer_i].trainable_weights, weights):
            model_var.assign(var)

    def get_trainable_layers_indices(self):
        trainable_layers_idx = [
            i for i, layer in enumerate(self.layers)
            if layer.trainable_weights
        ]

        return trainable_layers_idx

In [13]:
def get_compiled_and_built_densenet(name, cnn_batch_input, learning_rate_schedule):
    densenet = DenseNet(name)

    densenet.compile(
        optimizer=tf.keras.optimizers.SGD(
            learning_rate=learning_rate_schedule,
            momentum=0.9,
            weight_decay=1e-4,
            nesterov=True
        ),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),  # we have softmax
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
    )

    densenet.build(cnn_batch_input)

    return densenet


def create_learning_rate_schedule(total_epochs, steps_per_epoch):
    """
    DenseNet paper, where the learning rate changes at specific epochs (50% and 75% of total training epochs).
    Starts at 0.1, goes to 0.01 at 50% of epochs, and finally after 75% goes to 0.001

    Ref: https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/PiecewiseConstantDecay
    """

    total_steps = total_epochs * steps_per_epoch

    steps_at_50_percent = 0.5 * total_steps
    steps_at_75_percent = 0.75 * total_steps

    boundaries = [steps_at_50_percent, steps_at_75_percent]
    values = [0.1, 0.01, 0.001]

    learning_rate_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)

    return learning_rate_schedule

In [14]:
num_clients = 2
batch_size = 256

In [15]:
def create_unbiased_federated_data(X_train, y_train, num_clients):
    X_train_unbiased_lst = np.array_split(X_train, num_clients)
    y_train_unbiased_lst = np.array_split(y_train, num_clients)

    unbiased_federated_dataset = [
        tf.data.Dataset.from_tensor_slices((X_train, y_train))
        for X_train, y_train in zip(X_train_unbiased_lst, y_train_unbiased_lst)
    ]

    return unbiased_federated_dataset

In [16]:
def prepare_federated_data(federated_dataset, batch_size, num_steps_until_rtc_check, seed=None):
    
    def process_client_dataset(_client_dataset, _batch_size, _num_steps_until_rtc_check, _seed):
        shuffle_size = _client_dataset.cardinality()  # Uniform shuffling
        return _client_dataset.shuffle(shuffle_size, seed=_seed).repeat().batch(_batch_size) \
            .take(_num_steps_until_rtc_check)

    federated_dataset_prepared = [
        process_client_dataset(client_dataset, batch_size, num_steps_until_rtc_check, seed)
        for client_dataset in federated_dataset
    ]
    return federated_dataset_prepared

# New paradigm

In [17]:
def count_weights2(model):
    total_params = 0
    for layer in model.layers:
        total_params += np.sum([np.prod(weight.shape) for weight in layer.trainable_weights])
        total_params += np.sum([np.prod(weight.shape) for weight in layer.non_trainable_weights])
    return int(total_params)

In [18]:
def average_client_weights2(client_models):
    # Retrieve the trainable variables from each client model
    client_trainable_weights = [model.trainable_variables for model in client_models]
    client_non_trainable_weights = [model.non_trainable_variables for model in client_models]

    # Compute the average weights for each layer
    avg_trainable_weights = [
        tf.reduce_mean(layer_weight_tensors, axis=0)
        for layer_weight_tensors in zip(*client_trainable_weights)
    ]
    
    avg_non_trainable_weights = [
        tf.reduce_mean(layer_weight_tensors, axis=0)
        for layer_weight_tensors in zip(*client_non_trainable_weights)
    ]
    
    return avg_trainable_weights, avg_non_trainable_weights

In [19]:
def synchronize_clients2(server_model, client_models):
    for client_model in client_models:
        client_model.set_trainable_variables(server_model.trainable_variables)
    
    for client_model in client_models:
        client_model.set_non_trainable_variables(server_model.non_trainable_variables)

# Training

In [20]:
def clients_train_synchronous(client_cnns, federated_dataset):
    for client_cnn, client_dataset in zip(client_cnns, federated_dataset):
        client_cnn.train(client_dataset)

# Data - models

In [21]:
from tensorflow.keras.applications.densenet import preprocess_input

In [22]:
# Preprocess the input data
X_train_dense = preprocess_input(X_train)
X_test_dense = preprocess_input(X_test)

In [23]:
test_ds = tf.data.Dataset.from_tensor_slices((X_test_dense, y_test)).batch(1024)

In [24]:
fed_ds = prepare_federated_data(
    create_unbiased_federated_data(X_train_dense, y_train, num_clients), batch_size, 1
)

In [25]:
server_cnn = DenseNet('DenseNet121')

server_cnn.compile(
    optimizer=tf.keras.optimizers.Adam(weight_decay=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),  # we have softmax
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
)

server_cnn.build((None, 32, 32, 3))

In [26]:
server_cnn = DenseNet('DenseNet121')

server_cnn.compile(
    optimizer=tf.keras.optimizers.Adam(weight_decay=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),  # we have softmax
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
)

server_cnn.build((None, 32, 32, 3))

In [27]:
count_weights(server_cnn)

6964106

In [193]:
# Define the model
client_cnns = [DenseNet('DenseNet121') for _ in range(num_clients)]
server_cnn = DenseNet('DenseNet121')

for cnn in client_cnns:
    cnn.compile(
        optimizer=tf.keras.optimizers.Adam(weight_decay=1e-4),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),  # we have softmax
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
    )
    
    cnn.build((None, 32, 32, 3))

server_cnn.compile(
    optimizer=tf.keras.optimizers.Adam(weight_decay=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),  # we have softmax
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
)

server_cnn.build((None, 32, 32, 3))

In [194]:
steps_per_epoch = 50_000 / batch_size
num_epochs = 4

In [195]:
lr_schedule = create_learning_rate_schedule(num_epochs, steps_per_epoch)

In [196]:
server_cnn = get_compiled_and_built_densenet('DenseNet121', (None, 32, 32, 3), lr_schedule) 

In [197]:
server_cnn.fit(X_train_dense, y_train, batch_size=batch_size, validation_data=(X_test_dense, y_test), epochs=1)

I0000 00:00:1702379568.063747  312111 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.




<keras.src.callbacks.History at 0x7f25c1fd0c90>

In [None]:
server_cnn.optimizer.learning_rate.numpy()

# Train ds

In [None]:
from fdavg.models.miscellaneous import average_client_weights, synchronize_clients

In [None]:
for cnn in client_cnns:
    cnn.metrics[1].reset_state()

In [None]:
synchronize_clients2(server_cnn, client_cnns)

In [None]:
for step in range(500):
    
    clients_train_synchronous(client_cnns, fed_ds)
    
    avg_trainable_weights = average_client_weights(client_cnns)
    server_cnn.set_trainable_variables(avg_trainable_weights)
    synchronize_clients(server_cnn, client_cnns)

In [None]:
for var in server_cnn.non_trainable_variables:
    if not 'bn' in var.name:
        print(var.name)

In [None]:
client_cnns[1].evaluate(test_ds)

In [None]:
client_cnns[0].evaluate(test_ds)

In [None]:
avg_trainable_weights, avg_non_trainable_weights = average_client_weights2(client_cnns)

In [None]:
server_cnn.set_non_trainable_variables(avg_non_trainable_weights)

In [None]:
server_cnn.evaluate(test_ds)

In [None]:
tf.reduce_mean([cnn.metrics[1].result() for cnn in client_cnns]).numpy()

In [None]:
tf.reduce_mean([cnn.metrics[0].result() for cnn in client_cnns]).numpy()

In [None]:
server_cnn.evaluate(X_train_dense, y_train, batch_size=1024)

In [None]:
server_cnn.evaluate(test_ds)

In [None]:
client_cnns[0].evaluate(test_ds)

In [None]:
client_cnns[1].evaluate(test_ds)

In [None]:
tmp = DenseNet('DenseNet121')

tmp.compile(
    optimizer=tf.keras.optimizers.Adam(weight_decay=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),  # we have softmax
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
)

tmp.build((None, 32, 32, 3))

avg_trainable_weights, avg_non_trainable_weights = average_client_weights2(client_cnns)
tmp.set_trainable_variables(avg_trainable_weights)
tmp.set_non_trainable_variables(avg_non_trainable_weights)

In [None]:
tmp.evaluate(test_ds)

In [None]:
def current_accuracy2(client_models, test_dataset, tmp_model):
    avg_trainable_weights, avg_non_trainable_weights = average_client_weights2(client_models)
    tmp_model.set_trainable_variables(avg_trainable_weights)
    tmp_model.set_non_trainable_variables(avg_non_trainable_weights)
    # Evaluate the temporary model on the test dataset
    _, acc = tmp_model.evaluate(test_dataset, verbose=0)

    return acc

In [None]:
np.prod([])

In [None]:
server_cnn

In [None]:
class AdvancedCNN(tf.keras.Model):
    """
    Advanced Convolutional Neural Network (CNN) for image classification.

    Attributes:
    - Layers for the CNN architecture (convolutional, pooling, dense layers, dropout layers).

    Methods:
    - call: Forward pass for the model.
    - step: Compute and apply gradients for one training batch.
    - train: Train the model on a dataset.
    - set_trainable_variables: Set the trainable variables of the model.
    - trainable_vars_as_vector: Return the trainable variables as a 1D tensor.
    """
    
    def __init__(self, cnn_input_reshape, num_classes):
        """
        Initialize the advanced CNN model with given input shape and number of output classes.

        Args:
        - cnn_input_reshape (tuple): The shape to which the input should be reshaped. (e.g., (28, 28, 1))
        - num_classes (int): Number of output classes.
        """
        super(AdvancedCNN, self).__init__()
        
        self.reshape = tf.keras.layers.Reshape(cnn_input_reshape)
        
        self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')
        self.conv2 = tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')
        self.max_pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        
        self.conv3 = tf.keras.layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')
        self.conv4 = tf.keras.layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')
        self.max_pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        
        self.conv5 = tf.keras.layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')
        self.conv6 = tf.keras.layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')
        self.max_pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))

        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(512, activation='relu')
        self.dropout1 = tf.keras.layers.Dropout(0.5)
        self.dense2 = tf.keras.layers.Dense(512, activation='relu')
        self.dropout2 = tf.keras.layers.Dropout(0.5)
        self.dense3 = tf.keras.layers.Dense(num_classes, activation='softmax')

    def call(self, inputs, training=None):
        x = self.reshape(inputs)  # Add a channel dimension
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.max_pool1(x)

        x = self.conv3(x)
        x = self.conv4(x)
        x = self.max_pool2(x)

        x = self.conv5(x)
        x = self.conv6(x)
        x = self.max_pool3(x)

        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dropout1(x, training=training)
        x = self.dense2(x)
        x = self.dropout2(x, training=training)
        x = self.dense3(x)
        return x
        
    def step(self, batch):
        x_batch, y_batch = batch
        return self.train_on_batch(x=x_batch, y=y_batch)
       
    def train(self, dataset):
        for batch in dataset:
            self.step(batch)

    def set_trainable_variables(self, trainable_vars):
        for model_var, var in zip(self.trainable_variables, trainable_vars):
            model_var.assign(var)

    def set_non_trainable_variables(self, non_trainable_vars):
        for model_var, var in zip(self.non_trainable_variables, non_trainable_vars):
            model_var.assign(var)

    @tf.function
    def trainable_vars_as_vector(self):
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)

    def per_layer_trainable_vars_as_vector(self):

        layer_vectors = [
            tf.concat([tf.reshape(var, [-1]) for var in layer.trainable_weights], axis=0)
            for layer in self.layers
            if layer.trainable_weights
        ]

        return layer_vectors

    def set_layer_weights(self, layer_i, weights):

        for model_var, var in zip(self.layers[layer_i].trainable_weights, weights):
            model_var.assign(var)

    def get_trainable_layers_indices(self):

        trainable_layers_idx = [
            i for i, layer in enumerate(self.layers)
            if layer.trainable_weights
        ]

        return trainable_layers_idx
    

def get_compiled_and_built_advanced_cnn(cnn_batch_input, cnn_input_reshape, num_classes):
    advanced_cnn = AdvancedCNN(cnn_input_reshape, num_classes)
    
    advanced_cnn.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),  # we have softmax
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
    )
    
    advanced_cnn.build(cnn_batch_input)
    
    return advanced_cnn

In [None]:
adv = get_compiled_and_built_advanced_cnn((None, 28, 28), (28, 28, 1), 10)

In [None]:
adv.metrics

In [None]:
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

In [None]:
X_train, X_test = X_train / 255.0, X_test / 255.0

In [None]:
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(256)

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(256).repeat()

In [None]:
adv.evaluate(test_ds)

In [None]:
it = iter(train_ds)

In [None]:
#%%timeit

adv.step(next(it))

In [None]:
adv.metrics

In [None]:
#%%timeit
for _ in range(100):
    adv.step(next(it))

In [None]:
adv.metrics[1].reset_state()

In [None]:
adv.evaluate(test_ds)