In [None]:
import math
import random 
import collections
import numpy as np

import nest_asyncio
import tensorflow as tf
import tensorflow_federated as tff
import seaborn as sns

from tensorflow.keras import activations
from statistics import mean
from collections import Counter

from tqdm.notebook import tqdm as tqdm_notebook

sns.set_theme(style= 'whitegrid')
nest_asyncio.apply()
tff.backends.reference.set_reference_context()

### Data augmentation

In [None]:
import matplotlib.colors
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import sklearn.datasets as dt

from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV

color_map_discrete = matplotlib.colors.LinearSegmentedColormap.from_list("", ["red","cyan","magenta","blue"])

x,y = dt.make_classification(n_samples=50000,
                                 n_features=2,
                                 n_repeated=0,
                                 class_sep=2,
                                 n_redundant=0,
                                 n_classes=3,
                                 n_clusters_per_class=1,
                                 random_state=2000)

fig = plt.figure(figsize=(10,8))
my_scatter_plot = plt.scatter(x[:,0],
                                  x[:,1],
                                  c=y,
                                  vmin=min(y),
                                  vmax=max(y),
                                  s=5,
                                  cmap=color_map_discrete)

plt.show()

list0 = []
list1 = []
list2 = []

for a,b in zip(x,y):
    c = a.tolist()
    c.append(b)
    if b == 0:
        list0.append(c)
    elif b == 1:
        list1.append(c)
    else:
        list2.append(c)

client1 = []
client2 = []
client3 = []

l01 = math.ceil(len(list0)*0.7)
l02 = math.ceil(len(list0)*0.8)

l11 = math.ceil(len(list0)*0.6)
l12 = math.ceil(len(list0)*0.75)

l21 = math.ceil(len(list0)*0.65)
l22 = math.ceil(len(list0)*0.85)


CLIENT1 = list0[:l01] + list1[l11:l12] + list2[l22:]
CLIENT2 = list0[l02:] + list1[:l11] + list2[l21:l22]
CLIENT3 = list0[l01:l02] + list1[l12:] + list2[:l21]

random.shuffle(client1)
random.shuffle(client2)
random.shuffle(client3)

### Clients Data

In [None]:
ARR = []
ARR_MEAN = []

In [None]:
def get_batches(HOLDER):
    federated_data = []
    for i in HOLDER:
        client = []
        for index in range(len(i)//100):
            X= []
            Y= []
            for elements in i[index*100:index*100+100]:
                X.append(np.array([np.float32(elements[0]),np.float32(elements[1])]))
                Y.append(np.array(np.int32(elements[2])))
            client.append({
                'x':np.array(X),
                'y':np.array(Y)
            })
        federated_data.append(client)
    return federated_data

In [None]:
def poissonDistributiondata(clientData):
    data = random.sample(clientData,2000)
    return [data[:1000],data[1000:1500],data[1500:]]

In [None]:
def reselectClientData():
    c1_tra, c1_val, c1_test = poissonDistributiondata(CLIENT1)
    c2_tra, c2_val, c2_test = poissonDistributiondata(CLIENT2)
    c3_tra, c3_val, c3_test = poissonDistributiondata(CLIENT3)

    train = [c1_tra,   c2_tra,  c3_tra]
    val   = [c1_val,   c2_val,  c3_val]
    test  = [c1_test, c2_test, c3_test]

    federated_train_data      =  get_batches(train)
    federated_validation_data =  get_batches(val)
    federated_test_data       =  get_batches(test)
    
    #count
    ARR.append([Counter(list(map(lambda x: x[-1], c1_tra))), 
                Counter(list(map(lambda x: x[-1], c2_tra))), 
                Counter(list(map(lambda x: x[-1], c3_tra)))])
    #mean
    
    return (federated_train_data,federated_validation_data,federated_test_data)

In [None]:
def get_mean(arr_history):
    client1_means = [i[0] for i in arr_history]
    client2_means = [i[1] for i in arr_history]
    client3_means = [i[2] for i in arr_history]
    
    client1_means_vals_c1 = [i.get(0) for i in client1_means]
    client1_means_vals_c2 = [i.get(1) for i in client1_means]    
    client1_means_vals_c3 = [i.get(2) for i in client1_means]    
    
    client2_means_vals_c1 = [i.get(0) for i in client2_means]
    client2_means_vals_c2 = [i.get(1) for i in client2_means]    
    client2_means_vals_c3 = [i.get(2) for i in client2_means] 
    
    client3_means_vals_c1 = [i.get(0) for i in client3_means]
    client3_means_vals_c2 = [i.get(1) for i in client3_means]    
    client3_means_vals_c3 = [i.get(2) for i in client3_means]  
    
    return {
        'client1':{
            0: mean(client1_means_vals_c1),
            1: mean(client1_means_vals_c2),
            2: mean(client1_means_vals_c3)
        },
        'client2':{
            0: mean(client2_means_vals_c1),
            1: mean(client2_means_vals_c2),
            2: mean(client2_means_vals_c3)
        },
        'client3':{
            0: mean(client3_means_vals_c1),
            1: mean(client3_means_vals_c2),
            2: mean(client3_means_vals_c3)
        }
    }

####  Follow shows the get mean output
```get_mean(ARR)```

```{'client1': {0: 706.5, 1: 154, 2: 139.5},
 'client2': {0: 188, 1: 611, 2: 201},
 'client3': {0: 92, 1: 245, 2: 663}}```

#### Custom mean aggregation function goes here

In [None]:
def AGG_FUNCTION1(mean):
    return [
        np.float32(max(mean.get('client1').values())),
        np.float32(max(mean.get('client2').values())),
        np.float32(max(mean.get('client3').values()))
    ]

### Federated Data Types [Client and Server]

In [None]:
BATCH_SPEC = collections.OrderedDict(x=tf.TensorSpec(shape=[None, 2], dtype=tf.float32),
                                     y=tf.TensorSpec(shape=[None], dtype=tf.int32))
BATCH_TYPE = tff.to_type(BATCH_SPEC)


MODEL_SPEC = collections.OrderedDict(weights=tf.TensorSpec(shape=[2, 3], dtype=tf.float32),
                                     bias=tf.TensorSpec(shape=[3], dtype=tf.float32))
MODEL_TYPE = tff.to_type(MODEL_SPEC)


WEIGHT_SPEC = tff.TensorType(dtype=tf.float32, shape=None)
WEIGHT_TYPE = tff.to_type(WEIGHT_SPEC)
WEIGHT_DATA_TYPE = tff.FederatedType(WEIGHT_TYPE, tff.CLIENTS)

### Federated Client

In [None]:
@tf.function
def forward_pass(model, batch):
    predicted_y = tf.nn.softmax(tf.matmul(batch['x'], model['weights']) + model['bias'])
    return -tf.reduce_mean(tf.reduce_sum(tf.one_hot(batch['y'], 3) * tf.math.log(predicted_y), axis=[1]))

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
    return forward_pass(model, batch)

In [None]:
@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def batch_train(initial_model, batch, learning_rate):
    model_vars = collections.OrderedDict([
        (name, tf.Variable(name=name, initial_value=value))
        for name, value in initial_model.items()
    ])
    optimizer = tf.keras.optimizers.SGD(learning_rate)
    
    @tf.function
    def _train_on_batch(model_vars, batch):
        with tf.GradientTape() as tape:
            loss = forward_pass(model_vars, batch)
        grads = tape.gradient(loss, model_vars)
        optimizer.apply_gradients(zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
        return model_vars
  
    return _train_on_batch(model_vars, batch)

In [None]:
LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):
    @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
    def batch_fn(model, batch):
        return batch_train(model, batch, learning_rate)
    return tff.sequence_reduce(all_batches, initial_model, batch_fn)

In [None]:
@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
    return tff.sequence_sum(tff.sequence_map(tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),all_batches))

### Federated Server

In [None]:
SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)
CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

In [None]:
@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
    return tff.federated_mean(tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))

SERVER_FLOAT_TYPE

In [None]:
SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)

#### With custom weighting factors

In [None]:
@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,CLIENT_DATA_TYPE, WEIGHT_DATA_TYPE)
def federated_train(model, learning_rate, data, weights):
    return tff.federated_mean(
          tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data]), weights)

def modelR(weights, biases, data):
    count = 0
    avg = 0
    for j in range(len(federated_test_data[0])):
        l = [np.where(i==max(i))[0][0] for i in activations.sigmoid(np.matmul(federated_test_data[0][j].get('x'), model.get('weights'))).numpy()]
        Y = federated_test_data[0][j].get('y')
        
        for i in range(len(Y)):
            avg+=1
            if l[i] == Y[i]:
                count+= 1
                
    return count/avg

#### Without custom weighting factors

In [None]:
@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,CLIENT_DATA_TYPE)
def federated_train_2(model, learning_rate, data):
    return tff.federated_mean(
          tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data]))

In [None]:
initial_model = collections.OrderedDict(
    weights=np.zeros([2, 3], dtype=np.float32),
    bias=np.zeros([3], dtype=np.float32))

accuracy = 0
model = initial_model
learning_rate = 0.1
count = 0



accuracy_list_with_wights = []
accuracy_list_without_wights = []


for i in range(10):
    federated_train_data,federated_validation_data,federated_test_data = reselectClientData()
    
    modelW = federated_train(model, learning_rate, federated_train_data,AGG_FUNCTION1(get_mean(ARR)))
    
    learning_rateW = learning_rateW * 0.9
    lossW = federated_eval(modelW, federated_train_data)
    accuracyW = modelR(modelW.get('weights'), model.get('bias'), federated_test_data)
    accuracy_list_with_wights.append(accuracyW)
    
    
    model = federated_train_2(model, learning_rate, federated_train_data)
    
    learning_rate = learning_rate * 0.9
    loss = federated_eval(model, federated_train_data)
    accuracy = modelR(model.get('weights'), model.get('bias'), federated_test_data)
    accuracy_list_without_wights.append(accuracy)
    
    print('round {}, loss(with weights)={}, accuracy(with weights)={} | loss(without weights)={}, accuracy(without weights)={}'.format(count, lossW, accuracyW,loss, accuracy)) 
    
    count+=1

### Plotting

In [None]:
import pandas as pd
def createDF(arr):
    client1 = [i[0] for i in arr]
    client2 = [i[1] for i in arr]
    client3 = [i[2] for i in arr]
    return (pd.DataFrame(client1),pd.DataFrame(client2),pd.DataFrame(client3))

In [None]:
c1,c2,c3 = createDF(ARR)

In [None]:
c1

In [None]:
c2

In [None]:
c3

Client 1

In [None]:
import pandas as pd


f = plt.figure(figsize=(10,3))
ax1 =plt.subplot(1, 3, 1,)
ax1.set_title('class 0')
ax1.tick_params(labelrotation=90)

langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c1[0])
ax1.bar(langs,count)

ax2 =plt.subplot(1, 3, 2)
ax2.set_title('class 1')
ax2.tick_params(labelrotation=90)
langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c1[1])
ax2.bar(langs,count)

ax3 =plt.subplot(1, 3, 3)
ax3.set_title('class 2')
ax3.tick_params(labelrotation=90)
langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c1[2])
ax3.bar(langs,count)

plt.show()

Client 2

In [None]:
import pandas as pd


f = plt.figure(figsize=(10,3))
ax1 =plt.subplot(1, 3, 1,)
ax1.set_title('class 0')
ax1.tick_params(labelrotation=90)

langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c2[0])
ax1.bar(langs,count)

ax2 =plt.subplot(1, 3, 2)
ax2.set_title('class 1')
ax2.tick_params(labelrotation=90)
langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c2[1])
ax2.bar(langs,count)

ax3 =plt.subplot(1, 3, 3)
ax3.set_title('class 2')
ax3.tick_params(labelrotation=90)
langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c2[2])
ax3.bar(langs,count)

plt.show()

Client3

In [None]:
import pandas as pd


f = plt.figure(figsize=(10,3))
ax1 =plt.subplot(1, 3, 1,)
ax1.set_title('class 0')
ax1.tick_params(labelrotation=90)

langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c3[0])
ax1.bar(langs,count)

ax2 =plt.subplot(1, 3, 2)
ax2.set_title('class 1')
ax2.tick_params(labelrotation=90)
langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c3[1])
ax2.bar(langs,count)

ax3 =plt.subplot(1, 3, 3)
ax3.set_title('class 2')
ax3.tick_params(labelrotation=90)
langs = ['round {}'.format(i) for i in range(1,len(c1)+1)]
count = list(c3[2])
ax3.bar(langs,count)

plt.show()