In [43]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from tensorboardX import SummaryWriter
import copy
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, silhouette_score
from sklearn.decomposition import PCA
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
from collections import defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from scipy.cluster.hierarchy import linkage, fcluster

In [44]:
class CNNMnist(nn.Module):
    def __init__(self):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

In [45]:
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        #self.idxs = list(idxs)
        self.idxs = np.array(idxs).astype(int)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        #image, label = self.dataset[self.idxs[item]]
        item = int(item)
        idx = int(self.idxs[item])
        image,label = self.dataset[idx]
        return image, label

In [46]:
# Cell 4: Dataset Loading
data_dir = 'D:/mnist_data'
os.makedirs(data_dir, exist_ok=True)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=True,
    transform=transform,
    download=True
)

test_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=False,
    transform=transform,
    download=True
)

In [47]:
class LocalUpdate(object):
    def __init__(self, dataset, idxs, local_bs=32, local_ep=3, lr=0.01, momentum=0.5):
        self.trainloader = DataLoader(
            DatasetSplit(dataset, idxs),
            batch_size=local_bs,
            shuffle=True
        )
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.criterion = nn.CrossEntropyLoss()
        self.local_ep = local_ep
        self.lr = lr
        self.momentum = momentum

    def update_weights(self, model):
        model.train()
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=self.lr,
            momentum=self.momentum
        )

        epoch_loss = []
        for _ in range(self.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return model.state_dict(), sum(epoch_loss)/len(epoch_loss)

In [48]:
def mnist_iid(dataset, num_users):
    """
    Sample IID client data from MNIST dataset
    :param dataset: MNIST dataset
    :param num_users: Number of users
    :return: Dictionary of user groups
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


def mnist_noniid(dataset, num_users=10):
    """
    Sample non-I.I.D client data by assigning specific digits to each client
    """
    labels = dataset.targets.numpy()
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    
    # Digit distribution for 5 clients (can be adjusted)
    client_digits = {
        0: [0, 1, 2],      # Client 0 gets digits 0 and 1
        1: [2, 3],      # Client 1 gets digits 2 and 3
        2: [4, 5],      # Client 2 gets digits 4 and 5
        3: [5, 6, 7],      # Client 3 gets digits 6 and 7
        4: [8, 9, 1],
        5: [1, 2, 3],
        6 : [4, 5, 6],
        7 : [7,  9],
        8 : [0, 1],
        9 : [3, 5]
                    # Client 4 gets digits 8 and 9
    }
    
    # Assign data to each client based on digits
    for client_id, digits in client_digits.items():
        for digit in digits:
            digit_indices = np.where(labels == digit)[0]
            dict_users[client_id] = np.concatenate((dict_users[client_id], digit_indices))
            
    return dict_users


In [49]:
def average_weights(w):
    """Average the weights of models"""
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

In [50]:
def test_model(model, test_dataset, device):
    """
    Test the model and return accuracy, precision, recall, and F1 score
    Returns macro-averaged metrics
    """
    model.eval()
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    
    correct = 0
    total = 0
    
    # Initialize confusion matrix (10 classes for MNIST)
    confusion_matrix = torch.zeros(10, 10)
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            # Get predictions
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # Update confusion matrix
            for t, p in zip(target.view(-1), predicted.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
    
    # Calculate metrics for each class and average
    precisions = []
    recalls = []
    f1_scores = []
    
    for i in range(10):
        # True Positives: diagonal elements
        tp = confusion_matrix[i, i]
        # False Positives: sum of column i minus diagonal element
        fp = confusion_matrix[:, i].sum() - tp
        # False Negatives: sum of row i minus diagonal element
        fn = confusion_matrix[i, :].sum() - tp
        
        # Calculate precision
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        precisions.append(precision)
        
        # Calculate recall
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        recalls.append(recall)
        
        # Calculate F1 score
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        f1_scores.append(f1)
    
    # Calculate macro averages
    macro_precision = sum(precisions) / len(precisions)
    macro_recall = sum(recalls) / len(recalls)
    macro_f1 = sum(f1_scores) / len(f1_scores)
    
    # Calculate accuracy
    accuracy = 100 * correct / total
    
    # Print metrics
    print("\nTest Metrics:")
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Precision: {macro_precision:.4f}")
    print(f"Recall: {macro_recall:.4f}")
    print(f"F1 Score: {macro_f1:.4f}")
    
    return accuracy, macro_precision, macro_recall, macro_f1

In [51]:
def calculate_communication_costs(self, cluster_participation_history, num_rounds=10):
    """
    Calculate communication costs with:
    - 10 clients in traditional FL
    - 3 clusters in hierarchical FL with selective participation
    - No compression
    
    Args:
        cluster_participation_history: List of lists containing participating cluster IDs for each round
        num_rounds: Total number of training rounds
    """
    # Model parameters for each layer
    model_params = {
        'conv1': (5*5*1*10) + 10,          # 260 parameters
        'conv2': (5*5*10*20) + 20,         # 5,020 parameters
        'fc1': (320*50) + 50,              # 16,050 parameters
        'fc2': (50*10) + 10                # 510 parameters
    }
    
    total_params = sum(model_params.values())  # 21,840 parameters
    bytes_per_param = 4  # Float32
    num_clients = 10  # Total number of clients
    num_clusters = 3  # Number of clusters
    
    # Calculate base model sizes
    model_size_bytes = total_params * bytes_per_param
    model_size_mb = model_size_bytes / (1024 * 1024)
    
    # 1. Traditional FL Communication (10 clients)
    traditional_costs = {
        'initial_distribution': num_clients * total_params,  # Initial model to 10 clients
        'per_round': {
            'client_to_server': num_clients * total_params,  # Updates from 10 clients
            'server_to_client': num_clients * total_params   # New model to 10 clients
        }
    }
    
    traditional_params_per_round = (
        traditional_costs['per_round']['client_to_server'] + 
        traditional_costs['per_round']['server_to_client']
    )
    
    traditional_total_params = (
        traditional_costs['initial_distribution'] +  # Initial distribution
        (traditional_params_per_round * num_rounds)  # Training rounds both ways
    )
    
    traditional_total_mb = (traditional_total_params * bytes_per_param) / (1024 * 1024)
    
    # 2. Hierarchical FL Communication with Selective Participation
    
    # A. Initial Setup Phase (Clustering)
    clustering_metadata = {
        'data_distribution': 10,    # Distribution over 10 classes
        'data_size': 1             # Number of samples
    }
    total_clustering_metadata = sum(clustering_metadata.values())
    
    initial_setup_costs = {
        'client_metadata': num_clients * total_clustering_metadata,  # All clients send metadata
        'cluster_assignments': num_clients,                         # Server assigns clusters
        'head_announcements': num_clusters,                         # Server announces heads
        'initial_model': total_params * num_clusters               # Model to cluster heads
    }
    
    total_initial_setup_params = sum(initial_setup_costs.values())
    
    # B. Calculate actual communication per round based on participation history
    total_round_params = 0
    participation_stats = {
        'rounds': [],
        'avg_participation': 0
    }
    
    for round_idx, participating_clusters in enumerate(cluster_participation_history):
        num_participating = len(participating_clusters)
        round_params = {
            'head_to_server': total_params * num_participating,  # Updates from participating heads
            'server_to_head': total_params * num_participating   # New model to participating heads
        }
        round_total = sum(round_params.values())
        total_round_params += round_total
        
        participation_stats['rounds'].append({
            'round': round_idx + 1,
            'num_clusters': num_participating,
            'params_transferred': round_total
        })
    
    participation_stats['avg_participation'] = (
        sum(len(r) for r in cluster_participation_history) / len(cluster_participation_history)
    )
    
    # Total Hierarchical Communication
    hierarchical_total_params = (
        total_initial_setup_params +  # Setup phase
        total_round_params           # Actual training rounds
    )
    
    hierarchical_total_mb = (hierarchical_total_params * bytes_per_param) / (1024 * 1024)
    
    # Calculate reductions
    param_reduction = ((traditional_total_params - hierarchical_total_params) / traditional_total_params) * 100
    size_reduction = ((traditional_total_mb - hierarchical_total_mb) / traditional_total_mb) * 100
    
    # Prepare detailed results
    results = {
        'model_architecture': {
            'total_parameters': total_params,
            'size_mb': model_size_mb,
            'parameter_breakdown': model_params
        },
        'traditional_fl': {
            'num_clients': num_clients,
            'total_params': traditional_total_params,
            'total_mb': traditional_total_mb
        },
        'hierarchical_fl': {
            'num_clusters': num_clusters,
            'initial_setup_mb': (total_initial_setup_params * bytes_per_param) / (1024 * 1024),
            'total_params': hierarchical_total_params,
            'total_mb': hierarchical_total_mb,
            'participation_stats': participation_stats
        },
        'reduction': {
            'parameters_percent': param_reduction,
            'size_percent': size_reduction
        }
    }
    
    # Print detailed analysis
    print("\nDetailed Communication Cost Analysis:")
    print(f"\nModel Base Size: {model_size_mb:.2f} MB")
    
    print("\nTraditional FL (10 clients):")
    print(f"Total Communication: {traditional_total_mb:.2f} MB")
    
    print("\nHierarchical FL with Selective Participation:")
    print(f"Initial Setup: {results['hierarchical_fl']['initial_setup_mb']:.4f} MB")
    print(f"Average Participating Clusters per Round: {participation_stats['avg_participation']:.2f}")
    print(f"Total Communication: {hierarchical_total_mb:.2f} MB")
    
    print("\nReduction Achieved:")
    print(f"Parameter reduction: {param_reduction:.2f}%")
    print(f"Size reduction: {size_reduction:.2f}%")
    
    return results

In [52]:
class Client:
    def __init__(self, client_id, dataset, idxs, args):
        self.client_id = client_id
        self.dataset = dataset
        self.idxs = idxs
        self.args = args
        self.cluster_id = None
        self.is_cluster_head = False
        self.performance_history = []  # Track training performance
        self.member_performance = {}  # Track member performance
        self.last_contribution_round = 0  # Track last round of contribution
        self.cluster_performance = []  # Track cluster's overall performance
        self.participation_score = 1.0  # Dynamic participation score
    
    def should_participate(self, current_round, threshold=0.5):
        """Decide whether cluster should participate in current round"""
        # Calculate participation score based on multiple factors
        
        # 1. Performance trend (40%)
        performance_weight = 0.6
        if len(self.cluster_performance) >= 2:
            performance_trend = (self.cluster_performance[-1] - self.cluster_performance[-2])
            # More stringent performance scoring
            if performance_trend > 2:  # Significant improvement
                performance_score = 1.0
            elif performance_trend > 0:  # Modest improvement
                performance_score = 0.7
            else:  # No improvement or decline
                performance_score = 0.3
        else:
            performance_score = 0.5  # More conservative initial score
            
        # 2. Rounds since last contribution (30%)
        staleness_weight = 0.4
        rounds_since_contribution = current_round - self.last_contribution_round
        # Force participation after 4 rounds of non-participation
        if rounds_since_contribution >= 4:
            return True
        staleness_score = min(0.8, rounds_since_contribution / 4.0)
    
        
        # Calculate final participation score
        self.participation_score = (
            performance_weight * performance_score +
            staleness_weight * staleness_score 
            
        )
        
        # Debug print using cluster_id instead of client_id
        print(f"\nParticipation Score Components for Cluster {self.cluster_id}:")  # Changed from client_id to cluster_id
        print(f"Performance Score: {performance_score:.3f}")
        print(f"Staleness Score: {staleness_score:.3f}")
        print(f"Final Score: {self.participation_score:.3f}")
        print(f"Threshold: {threshold}")
        
        # Return True if score exceeds threshold
        return self.participation_score > threshold
    
  

    def update_cluster_performance(self, performance_metric):
        """Track cluster's performance"""
        self.cluster_performance.append(performance_metric)
        if len(self.cluster_performance) > 5:  # Keep last 5 rounds
            self.cluster_performance.pop(0)
        
    def compute_data_distribution(self):
        """Compute distribution of data labels"""
        labels = [self.dataset.targets[i].item() for i in self.idxs]
        distribution = np.zeros(10)  # 10 classes for MNIST
        for label in labels:
            distribution[label] += 1
        return distribution / len(labels)  # Normalize

    def train_local(self, global_model):
        """Local training"""
        local_model = LocalUpdate(
            dataset=self.dataset,
            idxs=self.idxs,
            local_bs=self.args.local_bs,
            local_ep=self.args.local_ep,
            lr=self.args.lr,
            momentum=self.args.momentum
        )
        weights, loss = local_model.update_weights(global_model)
        self.performance_history.append(loss)
        return weights, loss

class ClusterHead(Client):
    def __init__(self, client_id, dataset, idxs, args):
        super().__init__(client_id, dataset, idxs, args)
        self.is_cluster_head = True
        self.cluster_members = []
        self.member_performance = {}  # Track member performance
        
    def aggregate_cluster_updates(self, updates, losses):
        """Aggregate updates from cluster members with performance tracking"""
        # Update member performance
        for client_id, loss in losses.items():
            if client_id not in self.member_performance:
                self.member_performance[client_id] = []
            self.member_performance[client_id].append(loss)
            
        return average_weights(updates)

class HierarchicalServer:
    def __init__(self, args, n_clusters=3):
        self.args = args
        self.n_clusters = n_clusters
        self.clients = {}
        self.clusters = {}
        self.cluster_heads = {}
        self.max_client_data = 0  # For normalization
        self.client_history = {}  # Track client performance
        
    def initialize_clients(self, train_dataset, user_groups):
        """Initialize all clients and find max data size"""
        for client_id in user_groups.keys():
            self.clients[client_id] = Client(
                client_id,
                train_dataset,
                user_groups[client_id],
                self.args
            )
            self.max_client_data = max(
                self.max_client_data, 
                len(user_groups[client_id])
            )
    
    def calculate_distribution_score(self, client):
        """Calculate how well-distributed a client's data is"""
        distribution = client.compute_data_distribution()
        # Entropy as a measure of distribution
        entropy = -np.sum(distribution * np.log(distribution + 1e-10))
        # Normalize entropy
        max_entropy = -np.log(1/10)  # Maximum entropy for 10 classes
        return entropy / max_entropy
    
    def get_client_resources(self, client_id):
        """Placeholder for resource scoring"""
        # Could be extended to include actual resource metrics
        return 1.0
        
    def select_cluster_head(self, cluster_members):
        """Select cluster head based on multiple criteria"""
        head_scores = {}
        
        for client_id in cluster_members:
            client = self.clients[client_id]
            
            # Data quantity score (normalized)
            data_score = len(client.idxs) / self.max_client_data
            
            # Distribution score
            dist_score = self.calculate_distribution_score(client)
            
            
            # Combined weighted score
            head_scores[client_id] = (
                0.6 * data_score +
                0.4 * dist_score 
            )
            
        return max(head_scores.items(), key=lambda x: x[1])[0]

    def perform_clustering(self):
        """Perform hierarchical clustering with improved head selection"""
        distributions = []
        client_ids = []
        
        # Collect client distributions
        for client_id, client in self.clients.items():
            distributions.append(client.compute_data_distribution())
            client_ids.append(client_id)
                
        distributions = np.array(distributions)
        
        # Perform hierarchical clustering
        linkage_matrix = linkage(distributions, method='ward')
        cluster_labels = fcluster(linkage_matrix, 
                                t=self.n_clusters, 
                                criterion='maxclust')
        
        # Group clients by cluster
        cluster_members = {}
        for i, label in enumerate(cluster_labels):
            cluster_id = int(label - 1)  # Convert to 0-based indexing
            if cluster_id not in cluster_members:
                cluster_members[cluster_id] = []
            cluster_members[cluster_id].append(client_ids[i])
        
        # Select heads and assign clusters
        for cluster_id, members in cluster_members.items():
            # Select best head for this cluster
            head_id = self.select_cluster_head(members)
            
            self.clusters[cluster_id] = members
            self.cluster_heads[cluster_id] = ClusterHead(
                head_id,
                self.clients[head_id].dataset,
                self.clients[head_id].idxs,
                self.args
            )
            self.cluster_heads[cluster_id].cluster_id = cluster_id  # Set the cluster_id
            
            print(f"\nCluster {cluster_id}:")
            print(f"  Head: Client {head_id}")
            print(f"  Head Score Components:")
            client = self.clients[head_id]
            print(f"    - Data Score: {len(client.idxs)/self.max_client_data:.3f}")
            print(f"    - Distribution Score: {self.calculate_distribution_score(client):.3f}")
            print(f"  Members: {members}")
            
            # Assign cluster IDs to clients
            for member_id in members:
                self.clients[member_id].cluster_id = cluster_id

    def calculate_communication_costs(self, cluster_participation_history, num_rounds=10):
        """
        Calculate communication costs with:
        - 10 clients in traditional FL
        - 3 clusters in hierarchical FL with selective participation
        - No compression
        """
        # Model parameters for each layer
        model_params = {
            'conv1': (5*5*1*10) + 10,          # 260 parameters
            'conv2': (5*5*10*20) + 20,         # 5,020 parameters
            'fc1': (320*50) + 50,              # 16,050 parameters
            'fc2': (50*10) + 10                # 510 parameters
        }
        
        total_params = sum(model_params.values())  # 21,840 parameters
        bytes_per_param = 4  # Float32
        num_clients = 10  # Total number of clients
        num_clusters = 3  # Number of clusters
        
        # Calculate base model sizes
        model_size_bytes = total_params * bytes_per_param
        model_size_mb = model_size_bytes / (1024 * 1024)
        
        # 1. Traditional FL Communication (10 clients)
        traditional_costs = {
            'per_round': {
                'client_to_server': num_clients * total_params,  # Updates from 10 clients
                'server_to_client': num_clients * total_params   # New model to 10 clients
            }
        }
        
        traditional_params_per_round = (
            traditional_costs['per_round']['client_to_server'] + 
            traditional_costs['per_round']['server_to_client']
        )
        
        traditional_total_params = traditional_params_per_round * num_rounds
        traditional_total_mb = (traditional_total_params * bytes_per_param) / (1024 * 1024)
        
        # 2. Hierarchical FL Communication with Selective Participation
        
        # A. Initial Setup Phase (Clustering)
        clustering_metadata = {
            'data_distribution': 10,    # Distribution over 10 classes
            'data_size': 1             # Number of samples
        }
        total_clustering_metadata = sum(clustering_metadata.values())
        
        initial_setup_costs = {
            'client_metadata': num_clients * total_clustering_metadata,  # All clients send metadata
            'cluster_assignments': num_clients,                         # Server assigns clusters
            'head_announcements': num_clusters                         # Server announces heads
        }
        
        total_initial_setup_params = sum(initial_setup_costs.values())
        
        # B. Calculate actual communication based on participation history
        total_round_params = 0
        total_participating_clusters = 0
        
        for participating_clusters in cluster_participation_history:
            num_participating = len(participating_clusters)
            total_participating_clusters += num_participating
            round_total = total_params * num_participating * 2  # *2 for bidirectional communication
            total_round_params += round_total
        
        avg_participation = total_participating_clusters / len(cluster_participation_history)
        
        # Total Hierarchical Communication
        hierarchical_total_params = total_initial_setup_params + total_round_params
        hierarchical_total_mb = (hierarchical_total_params * bytes_per_param) / (1024 * 1024)
        
        # Calculate reductions
        param_reduction = ((traditional_total_params - hierarchical_total_params) / traditional_total_params) * 100
        size_reduction = ((traditional_total_mb - hierarchical_total_mb) / traditional_total_mb) * 100
        
        # Print detailed analysis
        print("\nDetailed Communication Cost Analysis:")
        print("=" * 50)
        
        print("\nModel Architecture:")
        print(f"Base Model Size: {model_size_mb:.2f} MB ({total_params:,} parameters)")
        print("\nParameter Breakdown:")
        for layer, params in model_params.items():
            print(f"- {layer}: {params:,} parameters")
        
        print("\nTraditional FL (10 clients):")
        print(f"Per Round: {traditional_params_per_round:,} parameters ({traditional_params_per_round * bytes_per_param / (1024 * 1024):.2f} MB)")
        print(f"Total Communication: {traditional_total_params:,} parameters ({traditional_total_mb:.2f} MB)")
        
        print("\nHierarchical FL with Selective Participation:")
        print(f"Initial Clustering Setup: {total_initial_setup_params:,} parameters ({total_initial_setup_params * bytes_per_param / (1024 * 1024):.2f} MB)")
        print(f"Average Participating Clusters per Round: {avg_participation:.2f}")
        print(f"Total Round Communication: {total_round_params:,} parameters ({total_round_params * bytes_per_param / (1024 * 1024):.2f} MB)")
        print(f"Total Communication: {hierarchical_total_params:,} parameters ({hierarchical_total_mb:.2f} MB)")
        
        print("\nReduction Achieved:")
        print(f"Parameter reduction: {param_reduction:.2f}%")
        print(f"Size reduction: {size_reduction:.2f}%")
        
        return {
            'model_architecture': {
                'total_parameters': total_params,
                'size_mb': model_size_mb,
                'parameter_breakdown': model_params
            },
            'traditional_fl': {
                'num_clients': num_clients,
                'total_params': traditional_total_params,
                'total_mb': traditional_total_mb,
                'per_round_params': traditional_params_per_round
            },
            'hierarchical_fl': {
                'num_clusters': num_clusters,
                'initial_setup_params': total_initial_setup_params,
                'total_params': hierarchical_total_params,
                'total_mb': hierarchical_total_mb,
                'avg_participation': avg_participation
            },
            'reduction': {
                'parameters_percent': param_reduction,
                'size_percent': size_reduction
            }
        }

    def train_federated(self, train_dataset, test_dataset, user_groups):
        """Modified training loop with selective cluster participation"""
        try:
            # Initialize tracking for cluster participation
            cluster_participation_history = []
            global_model = CNNMnist().to(self.args.device)
            global_model.train()
            
            self.initialize_clients(train_dataset, user_groups)
            self.perform_clustering()
            
            # Track performance history for all clusters
            cluster_performance_history = {cluster_id: [] for cluster_id in self.cluster_heads.keys()}
            
            # Use tqdm for epoch progress
            for epoch in tqdm(range(self.args.epochs), desc="Training Progress"):
                print(f'\nEpoch {epoch+1}/{self.args.epochs}')
                
                cluster_updates = {}
                participating_clusters = []
                
                # Determine participating clusters
                for cluster_id, head in self.cluster_heads.items():
                    should_participate = head.should_participate(epoch)
                    print(f"\nCluster {cluster_id} evaluation:")
                    print(f"Score: {head.participation_score:.2f}")
                    print(f"Threshold: {0.5}")
                    print(f"Decision: {'participating' if should_participate else 'skipping'}")
                    
                    if should_participate:
                        participating_clusters.append(cluster_id)
                
                # Store participation history for this round
                cluster_participation_history.append(participating_clusters.copy())
                
                # Modified minimum participation requirement
                if len(participating_clusters) < max(1, len(self.cluster_heads) // 3):
                    if epoch == 0:
                        # For first epoch, select based on data distribution and quantity
                        cluster_scores = {}
                        for cluster_id, head in self.cluster_heads.items():
                            dist_score = self.calculate_distribution_score(head)
                            data_score = len(head.idxs) / max(len(c.idxs) for c in self.cluster_heads.values())
                            cluster_scores[cluster_id] = 0.6 * dist_score + 0.4 * data_score
                        
                        best_cluster = max(cluster_scores.items(), key=lambda x: x[1])[0]
                    else:
                        # Use performance history for subsequent epochs
                        avg_performances = {
                            cid: np.mean(history) if history else 0 
                            for cid, history in cluster_performance_history.items()
                        }
                        best_cluster = max(avg_performances.items(), key=lambda x: x[1])[0]
                    
                    if best_cluster not in participating_clusters:
                        participating_clusters.append(best_cluster)
                        print(f"\nForcing cluster {best_cluster} to participate")
                
                # Train participating clusters
                for cluster_id in participating_clusters:
                    head = self.cluster_heads[cluster_id]
                    cluster_members = self.clusters[cluster_id]
                    member_updates = []
                    member_losses = {}
                    
                    print(f"\nTraining Cluster {cluster_id}")
                    
                    # Train cluster members
                    for client_id in tqdm(cluster_members, desc=f"Training Cluster {cluster_id} Members"):
                        if client_id != head.client_id:
                            try:
                                weights, loss = self.clients[client_id].train_local(
                                    copy.deepcopy(global_model)
                                )
                                member_updates.append(weights)
                                member_losses[client_id] = loss
                            except Exception as e:
                                print(f"Error training client {client_id}: {str(e)}")
                                continue
                    
                    # Train cluster head
                    try:
                        head_weights, head_loss = head.train_local(copy.deepcopy(global_model))
                        member_updates.append(head_weights)
                        member_losses[head.client_id] = head_loss
                    except Exception as e:
                        print(f"Error training cluster head {head.client_id}: {str(e)}")
                        continue
                    
                    if member_updates:
                        cluster_updates[cluster_id] = head.aggregate_cluster_updates(
                            member_updates,
                            member_losses
                        )
                        head.last_contribution_round = epoch
                
                # Global aggregation
                if cluster_updates:
                    try:
                        global_weights = average_weights(list(cluster_updates.values()))
                        global_model.load_state_dict(global_weights)
                        
                        # Silent evaluation for updating cluster performance
                        accuracy, _, _, _ = test_model(
                            model=global_model, 
                            test_dataset=test_dataset, 
                            device=self.args.device
                        )
                        
                        # Update performance metrics for participating clusters
                        for cluster_id in participating_clusters:
                            cluster_performance_history[cluster_id].append(accuracy)
                            self.cluster_heads[cluster_id].update_cluster_performance(accuracy)
                        
                    except Exception as e:
                        print(f"\nError in evaluation: {str(e)}")
                        continue
                
                print('-' * 50)
            
            # Calculate final communication costs
            communication_results = self.calculate_communication_costs(
                cluster_participation_history,
                num_rounds=self.args.epochs
            )
            
            # Final evaluation
            final_accuracy, final_precision, final_recall, final_f1 = test_model(
                model=global_model,
                test_dataset=test_dataset,
                device=self.args.device
            )
            
            # Prepare final results
            final_results = {
                'model': global_model,
                'final_accuracy': final_accuracy,
                'communication_costs': communication_results,
                'cluster_participation_history': cluster_participation_history
            }
            
            return final_results
            
        except KeyboardInterrupt:
            print("\nTraining interrupted by user")
            return None
        except Exception as e:
            print(f"\nUnexpected error: {str(e)}")
            raise



In [53]:
# Usage
if __name__ == '__main__':
    class Args:
        def __init__(self):
            # Training parameters
            self.epochs = 10
            self.num_users = 10
            self.frac = 1
            self.local_ep = 3
            self.local_bs = 32
            self.lr = 0.01
            self.momentum = 0.5
            
            # System parameters
            self.iid = 0  # 1 for IID, 0 for non-IID
            # Set device after initialization
            self.device = None
            self.gpu = None
            
        def initialize_device(self):
            """Initialize device after torch is fully imported"""
            self.gpu = torch.cuda.is_available()
            self.device = 'cuda' if self.gpu else 'cpu'
    # Initialize arguments
    args = Args()
    args.initialize_device()
    torch.manual_seed(42)
    np.random.seed(42)

    # Generate user groups
    user_groups = mnist_noniid(train_dataset, args.num_users)
    
    # Initialize and train
    server = HierarchicalServer(args, n_clusters=3)
    results = server.train_federated(train_dataset, test_dataset, user_groups)
    
    if results is not None:
        print("\nTraining Complete!")
        print(f"Final Test Accuracy: {results['final_accuracy']:.2f}%")
        print("\nCommunication Cost Summary:")
        print(f"Traditional FL Total: {results['communication_costs']['traditional_fl']['total_mb']:.2f} MB")
        print(f"Hierarchical FL Total: {results['communication_costs']['hierarchical_fl']['total_mb']:.2f} MB")
        print(f"Reduction: {results['communication_costs']['reduction']['size_percent']:.2f}%")

    



Cluster 1:
  Head: Client 0
  Head Score Components:
    - Data Score: 0.989
    - Distribution Score: 0.476
  Members: [0, 4, 8]

Cluster 2:
  Head: Client 5
  Head Score Components:
    - Data Score: 1.000
    - Distribution Score: 0.477
  Members: [1, 5, 9]

Cluster 0:
  Head: Client 3
  Head Score Components:
    - Data Score: 0.935
    - Distribution Score: 0.476
  Members: [2, 3, 6, 7]


Training Progress:   0%|          | 0/10 [00:00<?, ?it/s]


Epoch 1/10

Participation Score Components for Cluster 1:
Performance Score: 0.500
Staleness Score: 0.000
Final Score: 0.300
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.30
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 2:
Performance Score: 0.500
Staleness Score: 0.000
Final Score: 0.300
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.30
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 0:
Performance Score: 0.500
Staleness Score: 0.000
Final Score: 0.300
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.30
Threshold: 0.5
Decision: skipping

Forcing cluster 2 to participate

Training Cluster 2


Training Cluster 2 Members: 100%|██████████| 3/3 [00:25<00:00,  8.53s/it]
Training Progress:  10%|█         | 1/10 [00:48<07:14, 48.24s/it]


Test Metrics:
Accuracy: 12.84%
Precision: 0.0441
Recall: 0.1266
F1 Score: 0.0493
--------------------------------------------------

Epoch 2/10

Participation Score Components for Cluster 1:
Performance Score: 0.500
Staleness Score: 0.250
Final Score: 0.400
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.40
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 2:
Performance Score: 0.500
Staleness Score: 0.250
Final Score: 0.400
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.40
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 0:
Performance Score: 0.500
Staleness Score: 0.250
Final Score: 0.400
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.40
Threshold: 0.5
Decision: skipping

Forcing cluster 2 to participate

Training Cluster 2


Training Cluster 2 Members: 100%|██████████| 3/3 [00:29<00:00,  9.72s/it]
Training Progress:  20%|██        | 2/10 [01:42<06:53, 51.65s/it]


Test Metrics:
Accuracy: 20.29%
Precision: 0.1499
Recall: 0.1980
F1 Score: 0.0898
--------------------------------------------------

Epoch 3/10

Participation Score Components for Cluster 1:
Performance Score: 0.500
Staleness Score: 0.500
Final Score: 0.500
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.50
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 2:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 0:
Performance Score: 0.500
Staleness Score: 0.500
Final Score: 0.500
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.50
Threshold: 0.5
Decision: skipping

Training Cluster 2


Training Cluster 2 Members: 100%|██████████| 3/3 [00:36<00:00, 12.08s/it]
Training Progress:  30%|███       | 3/10 [02:44<06:34, 56.41s/it]


Test Metrics:
Accuracy: 29.88%
Precision: 0.2335
Recall: 0.2844
F1 Score: 0.1796
--------------------------------------------------

Epoch 4/10

Participation Score Components for Cluster 1:
Performance Score: 0.500
Staleness Score: 0.750
Final Score: 0.600
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.60
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 2:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 0:
Performance Score: 0.500
Staleness Score: 0.750
Final Score: 0.600
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.60
Threshold: 0.5
Decision: participating

Training Cluster 1


Training Cluster 1 Members: 100%|██████████| 3/3 [00:33<00:00, 11.10s/it]



Training Cluster 2


Training Cluster 2 Members: 100%|██████████| 3/3 [00:28<00:00,  9.52s/it]



Training Cluster 0


Training Cluster 0 Members: 100%|██████████| 4/4 [00:49<00:00, 12.40s/it]
Training Progress:  40%|████      | 4/10 [05:48<10:42, 107.03s/it]


Test Metrics:
Accuracy: 40.06%
Precision: 0.1847
Recall: 0.3935
F1 Score: 0.2431
--------------------------------------------------

Epoch 5/10

Participation Score Components for Cluster 1:
Performance Score: 0.500
Staleness Score: 0.250
Final Score: 0.400
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.40
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 2:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 0:
Performance Score: 0.500
Staleness Score: 0.250
Final Score: 0.400
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.40
Threshold: 0.5
Decision: skipping

Training Cluster 2


Training Cluster 2 Members: 100%|██████████| 3/3 [00:45<00:00, 15.25s/it]
Training Progress:  50%|█████     | 5/10 [07:14<08:16, 99.31s/it] 


Test Metrics:
Accuracy: 34.71%
Precision: 0.2145
Recall: 0.3345
F1 Score: 0.2270
--------------------------------------------------

Epoch 6/10

Participation Score Components for Cluster 1:
Performance Score: 0.500
Staleness Score: 0.500
Final Score: 0.500
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.50
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 2:
Performance Score: 0.300
Staleness Score: 0.250
Final Score: 0.280
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.28
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 0:
Performance Score: 0.500
Staleness Score: 0.500
Final Score: 0.500
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.50
Threshold: 0.5
Decision: skipping

Forcing cluster 1 to participate

Training Cluster 1


Training Cluster 1 Members: 100%|██████████| 3/3 [00:53<00:00, 17.91s/it]
Training Progress:  60%|██████    | 6/10 [08:40<06:19, 94.85s/it]


Test Metrics:
Accuracy: 41.91%
Precision: 0.4958
Recall: 0.4050
F1 Score: 0.3198
--------------------------------------------------

Epoch 7/10

Participation Score Components for Cluster 1:
Performance Score: 0.700
Staleness Score: 0.250
Final Score: 0.520
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.52
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 2:
Performance Score: 0.300
Staleness Score: 0.500
Final Score: 0.380
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.38
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 0:
Performance Score: 0.500
Staleness Score: 0.750
Final Score: 0.600
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.60
Threshold: 0.5
Decision: participating

Training Cluster 1


Training Cluster 1 Members: 100%|██████████| 3/3 [00:39<00:00, 13.27s/it]



Training Cluster 0


Training Cluster 0 Members: 100%|██████████| 4/4 [01:07<00:00, 16.86s/it]
Training Progress:  70%|███████   | 7/10 [11:24<05:51, 117.27s/it]


Test Metrics:
Accuracy: 64.48%
Precision: 0.7060
Recall: 0.6402
F1 Score: 0.5720
--------------------------------------------------

Epoch 8/10

Participation Score Components for Cluster 1:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 2:
Performance Score: 0.300
Staleness Score: 0.750
Final Score: 0.480
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.48
Threshold: 0.5
Decision: skipping

Participation Score Components for Cluster 0:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Training Cluster 1


Training Cluster 1 Members: 100%|██████████| 3/3 [00:41<00:00, 13.99s/it]



Training Cluster 0


Training Cluster 0 Members: 100%|██████████| 4/4 [01:07<00:00, 16.95s/it]
Training Progress:  80%|████████  | 8/10 [14:20<04:31, 135.91s/it]


Test Metrics:
Accuracy: 77.69%
Precision: 0.8494
Recall: 0.7753
F1 Score: 0.7674
--------------------------------------------------

Epoch 9/10

Participation Score Components for Cluster 1:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Cluster 2 evaluation:
Score: 0.48
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 0:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Training Cluster 1


Training Cluster 1 Members: 100%|██████████| 3/3 [00:48<00:00, 16.15s/it]



Training Cluster 2


Training Cluster 2 Members: 100%|██████████| 3/3 [00:44<00:00, 14.96s/it]



Training Cluster 0


Training Cluster 0 Members: 100%|██████████| 4/4 [01:01<00:00, 15.40s/it]
Training Progress:  90%|█████████ | 9/10 [18:37<02:53, 173.93s/it]


Test Metrics:
Accuracy: 83.51%
Precision: 0.8693
Recall: 0.8333
F1 Score: 0.8253
--------------------------------------------------

Epoch 10/10

Participation Score Components for Cluster 1:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 1 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 2:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 2 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Participation Score Components for Cluster 0:
Performance Score: 1.000
Staleness Score: 0.250
Final Score: 0.700
Threshold: 0.5

Cluster 0 evaluation:
Score: 0.70
Threshold: 0.5
Decision: participating

Training Cluster 1


Training Cluster 1 Members: 100%|██████████| 3/3 [00:56<00:00, 18.91s/it]



Training Cluster 2


Training Cluster 2 Members: 100%|██████████| 3/3 [00:46<00:00, 15.41s/it]



Training Cluster 0


Training Cluster 0 Members: 100%|██████████| 4/4 [01:17<00:00, 19.28s/it]
Training Progress: 100%|██████████| 10/10 [23:10<00:00, 139.09s/it]


Test Metrics:
Accuracy: 85.01%
Precision: 0.8783
Recall: 0.8483
F1 Score: 0.8383
--------------------------------------------------

Detailed Communication Cost Analysis:

Model Architecture:
Base Model Size: 0.08 MB (21,840 parameters)

Parameter Breakdown:
- conv1: 260 parameters
- conv2: 5,020 parameters
- fc1: 16,050 parameters
- fc2: 510 parameters

Traditional FL (10 clients):
Per Round: 436,800 parameters (1.67 MB)
Total Communication: 4,368,000 parameters (16.66 MB)

Hierarchical FL with Selective Participation:
Initial Clustering Setup: 123 parameters (0.00 MB)
Average Participating Clusters per Round: 1.50
Total Round Communication: 655,200 parameters (2.50 MB)
Total Communication: 655,323 parameters (2.50 MB)

Reduction Achieved:
Parameter reduction: 85.00%
Size reduction: 85.00%






Test Metrics:
Accuracy: 85.01%
Precision: 0.8783
Recall: 0.8483
F1 Score: 0.8383

Training Complete!
Final Test Accuracy: 85.01%

Communication Cost Summary:
Traditional FL Total: 16.66 MB
Hierarchical FL Total: 2.50 MB
Reduction: 85.00%
