In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
import numpy as np
from scipy.spatial.distance import jensenshannon
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation

# Phase 1: Clustering

# Load Fashion-MNIST data
(x_train, _), (_, _) = fashion_mnist.load_data()
x_train = x_train.astype('float32') / 255.

# Define VAE architecture
latent_dim = 10

def vae_model():
    input_img = Input(shape=(784,))
    encoded = Dense(256, activation='relu')(input_img)
    encoded = Dense(128, activation='relu')(encoded)
    encoded = Dense(latent_dim, activation='relu')(encoded)

    decoded = Dense(128, activation='relu')(encoded)
    decoded = Dense(256, activation='relu')(decoded)
    decoded = Dense(784, activation='sigmoid')(decoded)

    # VAE model
    vae = Model(input_img, decoded)

    # Encoder model
    encoder = Model(input_img, encoded)

    return vae, encoder

vae, encoder = vae_model()
vae.compile(optimizer='adam', loss='binary_crossentropy')
vae.fit(x_train, x_train, epochs=20, batch_size=128)

# Get latent representations of all clients
latent_representations = encoder.predict(x_train)

# Calculate Jensen-Shannon Divergence between latent embeddings
def js_divergence(p, q):
    """Calculate Jensen-Shannon Divergence"""
    m = 0.5 * (p + q)
    return 0.5 * (jensenshannon(p, m) + jensenshannon(q, m))

# Perform clustering based on JSD
threshold = 0.1  # Adjust threshold as needed

# Compute Jensen-Shannon Divergence matrix
def js_divergence_matrix(latent_representations):
    num_clients = len(latent_representations)
    divergence_matrix = np.zeros((num_clients, num_clients))
    for i in range(num_clients):
        for j in range(i+1, num_clients):
            divergence = js_divergence(latent_representations[i], latent_representations[j])
            divergence_matrix[i, j] = divergence
            divergence_matrix[j, i] = divergence
    return divergence_matrix

# Define a function to assign clients to clusters based on divergence matrix and threshold
def assign_clusters(divergence_matrix, threshold):
    num_clients = len(divergence_matrix)
    clusters = [[] for _ in range(num_clients)]
    visited = set()
    for i in range(num_clients):
        if i not in visited:
            visited.add(i)
            clusters[i].append(i)
            for j in range(i+1, num_clients):
                if j not in visited and divergence_matrix[i, j] < threshold:
                    visited.add(j)
                    clusters[i].append(j)
                    clusters[j].append(i)
    return clusters

# Compute divergence matrix
divergence_matrix = js_divergence_matrix(latent_representations)
# Assign clusters based on divergence matrix and threshold
clusters = assign_clusters(divergence_matrix, threshold)

# Determine cluster heads
cluster_heads = {}
for i, cluster in enumerate(clusters):
    cluster_points = latent_representations[cluster]
    cluster_head_index = np.argmax(np.linalg.norm(cluster_points, axis=1))
    cluster_heads[i] = cluster_points[cluster_head_index]

# Phase 2: Global Training

# Define SimpleMLP class for local training
class SimpleMLP:
    @staticmethod
    def build(shape, classes):
        model = Sequential()
        model.add(Dense(200, input_shape=(shape,)))
        model.add(Activation("relu"))
        model.add(Dense(200))
        model.add(Activation("relu"))
        model.add(Dense(200))
        model.add(Activation("relu"))
        model.add(Dense(classes))
        model.add(Activation("softmax"))
        return model

# Perform local training using SimpleMLP
def local_training(x_train, y_train, num_classes):
    shape = x_train.shape[1]
    model = SimpleMLP.build(shape, num_classes)
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)  # Adjust epochs and batch_size as needed
    return model

# Define aggregation weightage calculation function
def calculate_aggregation_weight(alpha, beta, gamma, latent_values, num_classes, num_samples, total_samples, total_classes, total_latent):
    weight = (1 - alpha - beta - gamma) * (np.linalg.norm(latent_values) / np.linalg.norm(total_latent))
    weight += (1 - alpha - beta) * (num_classes / total_classes)
    weight += (1 - alpha) * (num_samples / total_samples)
    weight += alpha
    return weight

# Perform cluster aggregation with aggregation weightage
def cluster_aggregation(cluster, trained_models, alpha, beta, gamma):
    num_samples = sum([x_train.shape[0] for x_train, _ in cluster])
    num_classes = trained_models[0].output_shape[1]
    latent_values = sum([np.mean(encoder.predict(x_train), axis=0) for x_train, _ in cluster]) / len(cluster)
    total_samples = sum(num_samples for _, _ in cluster)
    total_classes = sum(num_classes for _, _ in cluster)
    total_latent = sum(latent_values)
    
    cluster_models_weighted = []
    for x_train, y_train in cluster:
        model = local_training(x_train, y_train, num_classes)
        weight = calculate_aggregation_weight(alpha, beta, gamma, latent_values, num_classes, x_train.shape[0], total_samples, total_classes, total_latent)
        cluster_models_weighted.append((model, weight))
        
    # Perform weighted aggregation
    aggregated_model = SimpleMLP.build(trained_models[0].input_shape[1], trained_models[0].output_shape[1])
    for model, weight in cluster_models_weighted:
        # Aggregation method using weight
        # aggregated_model.add(model * weight)  # Example aggregation (not actual aggregation method)
        pass
    # Compile the aggregated model
    aggregated_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    return aggregated_model

# Perform global aggregation with aggregation weightage
def global_aggregation(cluster_models, alpha, beta, gamma):
    total_samples = sum(sum(x_train.shape[0] for x_train, _ in cluster) for cluster in cluster_models)
    total_classes = sum(cluster[0].output_shape[1] for cluster in cluster_models)
    total_latent = sum(np.mean(encoder.predict(x_train), axis=0) for cluster in cluster_models for x_train, _ in cluster)
    
    global_models_weighted = []
    for cluster in cluster_models:
        weight = calculate_aggregation_weight(alpha, beta, gamma, total_latent, total_classes, total_samples, total_samples, total_classes, total_latent)
        global_models_weighted.append(cluster * weight)
        
    # Perform weighted aggregation
    global_model = SimpleMLP.build(cluster_models[0].input_shape[1], cluster_models[0].output_shape[1])
    for model, weight in global_models_weighted:
        # Aggregation method using weight
        # global_model.add(model * weight)  # Example aggregation (not actual aggregation method)
        pass
    # Compile the global model
    global_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    return global_model


# Cluster aggregation
alpha = 0.333
beta = 0.333
gamma = 0.333
aggregated_cluster_model = cluster_aggregation(clusters, [], alpha, beta, gamma)

# Global aggregation
aggregated_global_model = global_aggregation([aggregated_cluster_model], alpha, beta, gamma)

# Evaluate the aggregated global model
loss, accuracy = aggregated_global_model.evaluate(x_train, y_train)
print(f"Loss: {loss}, Accuracy: {accuracy}")