In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_virtual_device_configuration(
        gpus[0],[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=12120)])
  except RuntimeError as e:
    print(e)

In [2]:
import flwr as fl
import tensorflow as tf
from tensorflow import keras
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# AUxillary methods
def getDist(y):
    ax = sns.countplot(y)
    ax.set(title="Count of data classes")
    plt.show()

def getData(dist, x, y):
    dx = []
    dy = []
    counts = [0 for i in range(10)]
    for i in range(len(x)):
        if counts[y[i]]<dist[y[i]]:
            dx.append(x[i])
            dy.append(y[i])
            counts[y[i]] += 1
        
    return np.array(dx), np.array(dy)

# Load and compile Keras model
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28,28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

# Load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0
dist = [0, 10, 10, 10, 4000, 3000, 4000, 5000, 10, 4500]
x_train, y_train = getData(dist, x_train, y_train)
# getDist(y_train)

# Define Flower client
class FlowerClient(fl.client.NumPyClient):
    def get_parameters(self):
        return model.get_weights()

    def fit(self, parameters, config):
        model.set_weights(parameters)
        r = model.fit(x_train, y_train, epochs=1, validation_data=(x_test, y_test), verbose=0)
        hist = r.history
        print("Fit history : " ,hist)
        return model.get_weights(), len(x_train), {}

    def evaluate(self, parameters, config):
        model.set_weights(parameters)
        loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
        print("Eval accuracy : ", accuracy)
        return loss, len(x_test), {"accuracy": accuracy}

# Start Flower client
fl.client.start_numpy_client(
        server_address="localhost:8090", 
        client=FlowerClient(), 
        grpc_max_message_length = 1024*1024*1024
)

INFO flower 2022-07-11 11:33:36,493 | connection.py:102 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flower 2022-07-11 11:33:36,496 | connection.py:39 | ChannelConnectivity.IDLE
DEBUG flower 2022-07-11 11:33:36,500 | connection.py:39 | ChannelConnectivity.CONNECTING
DEBUG flower 2022-07-11 11:33:36,501 | connection.py:39 | ChannelConnectivity.READY


Fit history :  {'loss': [0.2160894125699997], 'accuracy': [0.934420645236969], 'val_loss': [3.4698874950408936], 'val_accuracy': [0.4690000116825104]}
Eval accuracy :  0.836899995803833
Fit history :  {'loss': [0.11109969764947891], 'accuracy': [0.9663583040237427], 'val_loss': [2.2349488735198975], 'val_accuracy': [0.5335000157356262]}
Eval accuracy :  0.9136000275611877
Fit history :  {'loss': [0.07357343286275864], 'accuracy': [0.9770204424858093], 'val_loss': [1.8115103244781494], 'val_accuracy': [0.6697999835014343]}
Eval accuracy :  0.9308000206947327
Fit history :  {'loss': [0.052647534757852554], 'accuracy': [0.9833495616912842], 'val_loss': [1.9127275943756104], 'val_accuracy': [0.6215999722480774]}
Eval accuracy :  0.9351000189781189
Fit history :  {'loss': [0.046924371272325516], 'accuracy': [0.9846153855323792], 'val_loss': [1.4337106943130493], 'val_accuracy': [0.7037000060081482]}
Eval accuracy :  0.9466000199317932
Fit history :  {'loss': [0.04099899157881737], 'accuracy

DEBUG flower 2022-07-11 11:34:25,074 | connection.py:121 | gRPC channel closed
INFO flower 2022-07-11 11:34:25,076 | app.py:101 | Disconnect and shut down


Eval accuracy :  0.9577999711036682
