# Explanations

### **Summary of the Deployment**

This deployment incorporates a robust and efficient Federated Learning (FL) system that leverages advanced techniques to address challenges in large-scale, decentralized IoT networks. Here’s an overview of the entire system:

---

### **1. Device Clustering Using GNN-KMeans**
- **Objective**: Cluster devices into logical groups based on their data distribution, device characteristics, and network constraints.
- **Process**:
  - A **Graph Neural Network (GNN)** is used to model the relationship between devices, capturing both their features and interconnectivity.
  - A **KMeans clustering algorithm** operates on the embeddings generated by the GNN, grouping devices with similar characteristics.
  - Clusters represent groups of devices that share common traits, enabling efficient resource allocation and training at the edge level.
- **Outcome**: Devices are assigned to clusters, ensuring devices with similar data and characteristics are grouped, which reduces variance in model aggregation and improves performance.

---

### **2. Hybrid Data Redistribution**
- **Objective**: Balance the data distribution across devices within a cluster to enhance the representativeness of local models.
- **Process**:
  - Devices within a cluster exchange a portion of their local datasets based on a **hybrid redistribution policy**.
  - A **threshold-based strategy** ensures that redistribution prioritizes devices with less diverse data while maintaining privacy constraints.
  - Label diversity and sample quantity are considered during redistribution to improve local model training.
- **Outcome**: Each device within a cluster holds a more balanced dataset, reducing model bias and improving global model accuracy.

---

### **3. Device Assignment Using DRL PPO**
- **Objective**: Optimize the assignment of devices to edge servers and the scheduling of active devices for FL tasks.
- **Key Steps**:
  - **Cluster Assignment**:
    - A **Deep Reinforcement Learning (DRL)** agent, using the **Proximal Policy Optimization (PPO)** algorithm, determines the best edge server for each cluster based on available bandwidth and server capacity.
    - The goal is to minimize load imbalance and maximize resource utilization across edge servers.
  - **Device Scheduling**:
    - A second DRL PPO agent selects the optimal subset of devices within each cluster to participate in training.
    - Scheduling decisions consider factors such as energy usage, bandwidth availability, and data diversity.
- **Outcome**:
  - Efficient utilization of edge server resources and reduced communication overhead.
  - Active devices are strategically chosen to improve model training efficiency and accuracy.

---

### **4. Semi-Synchronous Federated Learning**
- **Objective**: Combine the benefits of synchronous and asynchronous training to balance efficiency and performance.
- **Key Features**:
  - **Edge-Level Independence**:
    - Each edge server trains its devices sequentially but independently of other edge servers.
    - Training on edge devices incorporates **device-specific characteristics** like CPU power and memory.
  - **Synchronous Aggregation**:
    - After training, all devices within an edge server aggregate their models synchronously.
    - The aggregated models from edge servers are then sent to the cloud for global aggregation.
  - **Dynamic Adaptation**:
    - Device characteristics (e.g., energy, bandwidth) are dynamically updated after each global iteration.
    - Device assignments and schedules are recalculated using the trained DRL agents.
- **Outcome**:
  - Reduced global synchronization bottlenecks due to edge-level independence.
  - Enhanced scalability and adaptability for large-scale IoT networks.

---

### **5. Key Innovations and Metrics**
#### **CPU and Memory Effects in Training**
- Local training accounts for device-specific CPU power and memory capacity, influencing:
  - Batch size (larger for devices with higher memory).
  - Training time (faster for devices with higher CPU power).
  - Energy consumption (lower for devices with efficient CPUs).

#### **Concurrency Control**
- Limits the number of concurrently active edge servers during training, ensuring resource constraints are respected.

#### **Comprehensive Metrics**
- Tracks and logs **energy consumption**, **time delays**, and **bandwidth usage** across devices, edge servers, and the cloud.
- Metrics are saved as JSON files for analysis and debugging.

---

### **System Workflow**
1. **Data Distribution**:
   - Data is distributed across devices, followed by clustering using GNN-KMeans.
2. **Data Balancing**:
   - Hybrid redistribution balances data within clusters.
3. **Device Assignment**:
   - DRL PPO agents assign clusters to edge servers and schedule active devices.
4. **Training**:
   - Each edge server independently trains its devices.
   - Models are aggregated at the edge and then globally in a semi-synchronous manner.
5. **Dynamic Adaptation**:
   - Device assignments and schedules are updated after each global iteration.
6. **Metrics Logging**:
   - Energy, time, and bandwidth metrics are tracked and saved for analysis.

---

### **Benefits of the System**
- **Efficiency**:
  - Optimized use of computational resources at devices and edge servers.
  - Reduced communication costs and delays.
- **Scalability**:
  - Supports large-scale IoT networks with diverse devices and dynamic conditions.
- **Accuracy**:
  - Balanced data distribution and optimal scheduling enhance global model performance.
- **Adaptability**:
  - Dynamic adjustments in device assignments and training configurations ensure sustained performance in changing environments.

---

This deployment demonstrates the integration of state-of-the-art techniques to create a scalable and adaptive Federated Learning system tailored for decentralized IoT networks.

# Gap & Weaknesses 

---

### **1. Gaps in Device Clustering**
- **Dynamic Clustering**: 
  - The current clustering approach is static, performed at the initial stage. This assumes that device characteristics remain constant throughout the process. However, real-world IoT devices often experience dynamic changes in energy levels, connectivity, or data characteristics.
  - **Improvement**: Implement periodic re-clustering or incremental updates to clusters as device characteristics change over time.

- **Limited Clustering Criteria**:
  - GNN-KMeans primarily focuses on predefined features like data similarity, energy, and bandwidth. It does not account for latent patterns or new relationships that may emerge during training.
  - **Improvement**: Introduce more sophisticated clustering methods, such as self-organizing maps or adaptive GNNs, to capture evolving patterns.

---

### **2. Gaps in Hybrid Data Redistribution**
- **Privacy Concerns**:
  - Although data redistribution enhances diversity, it may lead to privacy concerns if sensitive data is shared, even partially.
  - **Improvement**: Introduce differential privacy mechanisms to ensure that data redistribution does not compromise user privacy.

- **Data Heterogeneity**:
  - Redistribution assumes that sharing data samples improves diversity, but it may not adequately handle devices with fundamentally skewed data distributions (e.g., devices with single-label datasets).
  - **Improvement**: Include a more robust mechanism to detect and address extreme data heterogeneity using synthetic data generation or federated augmentation.

---

### **3. Weaknesses in Device Assignment Using DRL PPO**
- **Training Overhead**:
  - Training DRL agents for cluster assignment and device scheduling introduces computational overhead, particularly if the system requires frequent retraining due to dynamic conditions.
  - **Improvement**: Explore lightweight or offline RL models that are faster to train and deploy in dynamic environments.

- **Single Objective Optimization**:
  - The current DRL approach optimizes resource utilization and scheduling but does not account for multi-objective trade-offs, such as fairness in device participation or minimizing latency for specific applications.
  - **Improvement**: Extend the DRL framework to support multi-objective optimization with trade-offs between fairness, latency, and resource efficiency.

- **Scalability Issues**:
  - As the number of devices and clusters increases, the action space for DRL agents grows exponentially, making training and decision-making more complex.
  - **Improvement**: Use hierarchical DRL or distributed RL frameworks to handle large-scale deployments more effectively.

---

### **4. Weaknesses in Semi-Synchronous Federated Learning**
- **Limited Parallelism at Edge**:
  - The concurrency control limits the number of edge servers that can train simultaneously. While this ensures resource constraints are respected, it can reduce overall training efficiency in underutilized environments.
  - **Improvement**: Introduce dynamic concurrency based on resource availability, allowing more edge servers to operate in low-load conditions.

- **Device-Specific Constraints**:
  - The system now considers CPU power and memory for batch size and training time, but it does not fully address device-specific constraints like hardware failures, intermittent connectivity, or overheating.
  - **Improvement**: Incorporate failure detection and mitigation strategies, such as proactive device handover or redundant training on backup devices.

- **Synchronization Delays**:
  - Semi-synchronous aggregation requires edge servers to wait for slow devices within their cluster, potentially leading to bottlenecks.
  - **Improvement**: Introduce straggler mitigation techniques, such as partial aggregation or stale model updates, to reduce delays.

---

### **5. Gaps in Metrics and Monitoring**
- **Lack of Real-Time Feedback**:
  - Metrics are logged and saved but do not provide real-time feedback or visualization for dynamic decision-making.
  - **Improvement**: Implement a real-time monitoring dashboard to visualize energy, time, and bandwidth usage, helping operators make proactive adjustments.

- **Limited Metrics Granularity**:
  - Current metrics focus on high-level energy, time, and bandwidth usage but lack detailed insights into individual device contributions or network bottlenecks.
  - **Improvement**: Add more granular metrics, such as per-device latency, packet loss, or model divergence, for detailed analysis.

---

### **6. Potential Weaknesses in Energy Efficiency**
- **High Communication Overhead**:
  - While communication energy is calculated, the frequent exchange of model updates during aggregation can still be expensive in large-scale deployments.
  - **Improvement**: Use model compression techniques like quantization, sparsification, or knowledge distillation to reduce communication costs.

- **Unbalanced Energy Usage**:
  - Devices with higher energy reserves are not prioritized for participation, potentially leading to early dropouts of low-energy devices.
  - **Improvement**: Implement energy-aware scheduling to balance the energy consumption across devices and prolong the overall system lifetime.

---

### **7. Weaknesses in Security and Privacy**
- **Vulnerability to Model Poisoning**:
  - The system does not explicitly address the risk of malicious devices introducing poisoned updates during model aggregation.
  - **Improvement**: Introduce defense mechanisms like robust aggregation techniques (e.g., median aggregation) or anomaly detection for updates.

- **Data Privacy Compliance**:
  - While the system ensures data remains on devices, the hybrid redistribution may conflict with privacy regulations like GDPR if not properly managed.
  - **Improvement**: Ensure compliance with privacy laws by integrating secure multi-party computation or federated differential privacy.

---

### **8. Gaps in Adaptability and Scalability**
- **Edge Server Scalability**:
  - The system assumes a fixed number of edge servers, which may not scale well as the number of devices grows.
  - **Improvement**: Introduce dynamic edge server provisioning, where additional edge servers can be deployed as needed.

- **Cloud Aggregation Bottleneck**:
  - The cloud server becomes a single point of aggregation, potentially leading to delays in large-scale systems.
  - **Improvement**: Use a hierarchical aggregation structure, where intermediate aggregators reduce the load on the central cloud server.

---

### **Summary of Gaps and Weaknesses**
While the current system demonstrates innovative features and significant performance gains, addressing the gaps mentioned above will further enhance its robustness, scalability, and adaptability. Future improvements should focus on dynamic clustering, advanced data redistribution, multi-objective DRL optimization, enhanced privacy and security measures, and real-time monitoring to make the system more resilient and efficient in real-world scenarios.

# Metrics

| **Step**                           | **Metrics**                                                                                         | **Description**                                                                                               |
|-------------------------------------|-----------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------|
| **1. Device Clustering**            | - **Intra-Cluster Similarity**                                                                      | Measures similarity of devices within a cluster based on characteristics (e.g., hardware, datasets).         |
|                                     | - **Inter-Cluster Separation**                                                                     | Measures distinction between clusters to ensure clear boundaries.                                             |
|                                     | - **Time to Cluster**                                                                              | Time taken to cluster devices dynamically.                                                                    |
|                                     | - **Energy Consumption**                                                                           | Energy used by devices during clustering (e.g., for communication or feature extraction).                     |
|                                     | - **Cluster Size Balance**                                                                         | Checks if clusters are balanced in size or data diversity.                                                    |
|                                     | - **Communication Overhead**                                                                       | Measures data exchanged during clustering (e.g., device features).                                            |
| **2. Hybrid Data Redistribution**   | - **Dataset Label Distribution**                                                                   | Tracks the uniformity of specific labels across clusters.                                                     |
|                                     | - **Data Transfer Time**                                                                           | Time required to transfer datasets between devices.                                                           |
|                                     | - **Energy Consumption**                                                                           | Energy used by devices for transferring data.                                                                 |
|                                     | - **Communication Bandwidth Utilization**                                                          | Measures bandwidth usage during data redistribution.                                                          |
|                                     | - **Redundancy Reduction**                                                                         | Assesses whether redundant data copies have been minimized.                                                   |
|                                     | - **Data Quality**                                                                                 | Evaluates if data is corrupted or degraded during redistribution.                                             |
| **3. Resource Allocation**          | - **Edge Server Load**                                                                             | Tracks the current load (CPU, memory, bandwidth) on each edge server.                                         |
|                                     | - **Device-Edge Latency**                                                                          | Measures the delay between devices and their assigned edge servers.                                           |
|                                     | - **Energy Consumption**                                                                           | Energy used for communication and computation during allocation.                                              |
|                                     | - **Bandwidth Utilization**                                                                        | Tracks how well edge server bandwidth is being used.                                                          |
|                                     | - **Fairness Index**                                                                               | Ensures equal resource distribution across devices and servers.                                               |
|                                     | - **Resource Utilization Efficiency**                                                              | Percentage of available resources (e.g., CPU, memory) effectively utilized.                                   |
|                                     | - **Mapping Accuracy**                                                                             | Evaluates how well the resource allocation matches device needs.                                              |
| **4. Federated Learning**           | - **Model Accuracy**                                                                               | Tracks the performance of the global model.                                                                   |
|                                     | - **Convergence Speed**                                                                            | Measures how quickly the global model converges to an optimal state.                                          |
|                                     | - **Communication Rounds**                                                                         | Number of communication rounds required to reach a certain accuracy.                                          |
|                                     | - **Device Contribution Diversity**                                                                | Evaluates the diversity of data contributions from devices.                                                   |
|                                     | - **Energy Consumption**                                                                           | Energy used during model training and communication.                                                          |
|                                     | - **Straggler Count**                                                                              | Number of slow devices causing delays in training.                                                            |
|                                     | - **Latency (Device-Edge, Edge-Cloud)**                                                            | Measures delays in transferring model updates between layers.                                                 |
|                                     | - **Resource Utilization (Edge and Cloud)**                                                        | Monitors CPU, memory, and storage usage at edge and cloud levels.                                             |
|                                     | - **Training Completion Time**                                                                     | Overall time taken to complete a training iteration.                                                          |
|                                     | - **Global Model Stability**                                                                       | Tracks stability in model updates to avoid oscillations.                                                      |

---

### **How Metrics Interrelate**
- **Clustering Metrics** directly influence **Redistribution Metrics**, as well-formed clusters lead to efficient data balancing.
- **Resource Allocation Metrics** depend on the output of clustering and redistribution to allocate resources optimally.
- **Federated Learning Metrics** are influenced by all previous steps, as better clustering, redistribution, and allocation enhance training efficiency.

Using these metrics allows you to identify bottlenecks, monitor real-time performance, and ensure efficient system operation. Moreover, you can feed these metrics as part of the **state** in your **DRL framework** for adaptive optimization.

## **1. Device Clustering**

### **Metrics:**

1. **Clustering Quality Metrics:**
   - **Silhouette Score:**
     - Measures how similar a device is to its own cluster compared to other clusters.
     - **Range:** -1 (incorrect clustering) to +1 (ideal clustering).
   - **Davies-Bouldin Index (DBI):**
     - Evaluates intra-cluster similarity and inter-cluster differences.
     - **Lower values** indicate better clustering.
   - **Calinski-Harabasz Index (CHI):**
     - Ratio of between-cluster dispersion to within-cluster dispersion.
     - **Higher values** indicate better-defined clusters.

2. **Cluster Size Distribution:**
   - **Number of Devices per Cluster:**
     - Shows how devices are distributed among clusters.
   - **Cluster Imbalance Ratio:**
     - Ratio of sizes between the largest and smallest clusters.

3. **Cluster Feature Statistics:**
   - **Within-Cluster Variance:**
     - Variance of device features (e.g., CPU power, memory) within each cluster.
   - **Between-Cluster Variance:**
     - Variance of cluster centroids to assess distinctiveness.

4. **Visualization Metrics:**
   - **t-SNE or PCA Projections:**
     - Visual representation of high-dimensional data in 2D or 3D space.
     - Devices colored by cluster assignments.

### **Plots:**

- **Bar Chart:** Number of devices per cluster.
- **Box Plot:** Feature distributions (e.g., energy usage) within clusters.
- **Silhouette Plot:** Silhouette scores for each device.
- **Scatter Plot (t-SNE/PCA):** Visualizing clusters in reduced dimensions.

### **Examples:**

1. **Cluster Size Distribution Plot:**

   <img src="https://example.com/cluster_size_distribution.png" alt="Cluster Size Distribution" width="400"/>

2. **Silhouette Scores Plot:**

   <img src="https://example.com/silhouette_scores.png" alt="Silhouette Scores" width="400"/>

3. **t-SNE Visualization:**

   <img src="https://example.com/tsne_clusters.png" alt="t-SNE Clusters" width="400"/>

---

## **2. Data Redistribution**

### **Metrics:**

1. **Data Distribution Metrics:**
   - **Label Distribution per Device (Before and After):**
     - Proportion of samples for each label on each device.
   - **Average Label Entropy:**
     - Measures diversity of labels on devices.
     - **Higher entropy** indicates more balanced label distribution.

2. **Redistribution Effectiveness:**
   - **Kullback-Leibler (KL) Divergence:**
     - Quantifies the difference between device label distributions and the global distribution.
   - **Total Data Transferred:**
     - Amount of data moved during redistribution.

3. **Resource Impact:**
   - **Change in Device Data Volume:**
     - Number of samples before and after redistribution.
   - **Energy and Bandwidth Consumption:**
     - Resources used during data redistribution.

### **Plots:**

- **Histogram:** Label distribution per device before and after redistribution.
- **Box Plot:** Entropy of label distributions across devices.
- **Line Plot:** KL divergence over iterations (if redistribution is iterative).
- **Heatmap:** Device-to-device data transfer matrix.

### **Examples:**

1. **Label Distribution Histogram (Before and After):**

   <img src="https://example.com/label_distribution.png" alt="Label Distribution" width="400"/>

2. **Entropy Box Plot:**

   <img src="https://example.com/entropy_boxplot.png" alt="Label Entropy" width="400"/>

3. **KL Divergence Over Time:**

   <img src="https://example.com/kl_divergence.png" alt="KL Divergence" width="400"/>

---

## **3. Device Assignment**

### **Metrics:**

1. **Edge Server Load Metrics:**
   - **Number of Devices per Edge Server:**
     - Shows load distribution across edge servers.
   - **Edge Server Capacity Utilization:**
     - Ratio of assigned load to edge server capacity.
   - **Load Variance:**
     - Variance in the number of devices or total bandwidth per edge server.

2. **Scheduling Metrics:**
   - **Scheduled vs. Unscheduled Devices:**
     - Counts and percentages of devices participating in training.
   - **Average Device Resources:**
     - Mean CPU power, memory, energy, and bandwidth of scheduled devices.

3. **Assignment Efficiency Metrics:**
   - **Resource Utilization Efficiency:**
     - Comparison of actual resource usage against maximum available.
   - **Assignment Optimality Score:**
     - Evaluated using custom criteria or baseline comparisons.

### **Plots:**

- **Bar Chart:** Number of devices assigned to each edge server.
- **Pie Chart:** Proportion of scheduled vs. unscheduled devices.
- **Stacked Bar Chart:** Resource usage per edge server.
- **Heatmap:** Edge server capacity utilization.

### **Examples:**

1. **Edge Server Load Bar Chart:**

   <img src="https://example.com/edge_server_load.png" alt="Edge Server Load" width="400"/>

2. **Device Scheduling Pie Chart:**

   <img src="https://example.com/device_scheduling.png" alt="Device Scheduling" width="400"/>

3. **Resource Utilization Heatmap:**

   <img src="https://example.com/resource_utilization.png" alt="Resource Utilization" width="400"/>

---

## **4. Federated Learning**

### **Metrics:**

1. **Performance Metrics:**
   - **Global Model Accuracy Over Iterations:**
     - Test accuracy after each global aggregation.
   - **Edge Model Accuracies:**
     - Validation accuracy at each edge server before global aggregation.
   - **Local Model Accuracies:**
     - Training accuracy on devices after local epochs.

2. **Convergence Metrics:**
   - **Loss Over Time:**
     - Training loss on devices, edge servers, and global model.
   - **Accuracy vs. Communication Rounds:**
     - Relationship between model performance and number of communication rounds.

3. **Resource Consumption Metrics:**
   - **Energy Consumption:**
     - Total energy used by devices, edge servers, and cloud server.
     - Energy consumption per iteration.
   - **Time Delays:**
     - Computation and communication time at each level.
     - Total time per global iteration.
   - **Bandwidth Usage:**
     - Amount of data transmitted during model updates.
     - Bandwidth used per iteration.

4. **Participation Metrics:**
   - **Active Devices per Iteration:**
     - Number of devices participating in each global iteration.
   - **Device Dropout Rate:**
     - Percentage of devices that fail or drop out over time.

5. **System Efficiency Metrics:**
   - **Training Speed:**
     - Time taken to reach specific accuracy thresholds.
   - **Communication Efficiency:**
     - Model accuracy achieved per unit of data transmitted.

### **Plots:**

- **Line Plot:** Global accuracy over iterations.
- **Loss Curve:** Global and edge model loss over iterations.
- **Bar Chart:** Energy consumption per iteration.
- **Stacked Area Chart:** Bandwidth usage over time.
- **Scatter Plot:** Energy consumption vs. accuracy.
- **Box Plot:** Time delays across devices or edge servers.

### **Examples:**

1. **Global Accuracy Over Iterations:**

   <img src="https://example.com/global_accuracy.png" alt="Global Accuracy" width="400"/>

2. **Energy Consumption Bar Chart:**

   <img src="https://example.com/energy_consumption.png" alt="Energy Consumption" width="400"/>

3. **Bandwidth Usage Stacked Area Chart:**

   <img src="https://example.com/bandwidth_usage.png" alt="Bandwidth Usage" width="400"/>

4. **Participation Rate Line Plot:**

   <img src="https://example.com/participation_rate.png" alt="Participation Rate" width="400"/>

---

## **5. Additional Metrics and Comparisons**

### **Scalability Metrics:**

- **Performance vs. Number of Devices:**
  - Analyze how increasing the number of devices affects accuracy, training time, and resource consumption.

- **Communication Overhead:**
  - Total data transmitted as the network scales.

### **Comparative Analysis:**

- **Baseline Comparisons:**
  - Compare your system's metrics with baseline methods (e.g., centralized training, traditional FL without optimizations).

- **Impact of Device Characteristics:**
  - Correlate device-specific metrics (e.g., CPU power, memory) with local training performance.

### **Plots:**

- **Line Plot:** Accuracy vs. number of devices.
- **Bar Chart:** Communication overhead for different scaling scenarios.
- **Correlation Matrix:** Relationships between device characteristics and performance.

---

## **Implementing and Visualizing Metrics**

To effectively collect and visualize these metrics:

1. **Logging and Storage:**
   - Implement logging mechanisms at each step to record metrics.
   - Use structured formats (e.g., JSON, CSV) for easy data manipulation.

2. **Data Analysis Tools:**
   - Utilize data analysis libraries like **Pandas** for data manipulation.
   - Employ **Matplotlib** or **Seaborn** for plotting.
   - For interactive visualizations, consider **Plotly** or **Bokeh**.

3. **Reproducibility:**
   - Ensure that experiments are conducted under controlled settings to make fair comparisons.
   - Document hyperparameters and configurations for each run.

4. **Reporting:**
   - Summarize findings in reports or dashboards.
   - Highlight key improvements or insights gained from the metrics.

---

## **Example Workflow for Generating Plots**

1. **Collect Metrics:**

   ```python
   # Example: Collecting global accuracy over iterations
   global_accuracies = federated_system.accuracies['global_iterations']
   iterations = list(range(1, len(global_accuracies) + 1))
   ```

2. **Plot Metrics:**

   ```python
   import matplotlib.pyplot as plt

   # Plotting Global Accuracy Over Iterations
   plt.figure(figsize=(10, 6))
   plt.plot(iterations, global_accuracies, marker='o')
   plt.title('Global Model Accuracy Over Iterations')
   plt.xlabel('Global Iteration')
   plt.ylabel('Accuracy (%)')
   plt.grid(True)
   plt.show()
   ```

3. **Save Plots:**

   ```python
   plt.savefig('global_accuracy_over_iterations.png')
   ```

4. **Analyze Results:**
   - Interpret the trends observed in the plots.
   - Identify any anomalies or unexpected behaviors.

---

## **Conclusion**

By systematically collecting and visualizing these metrics, you can gain deep insights into each component of your federated learning system. This allows you to:

- **Validate** the effectiveness of each step.
- **Identify** areas for improvement.
- **Demonstrate** the advantages of your approach over baseline methods.
- **Communicate** your findings to stakeholders through clear and informative visualizations.

Remember to tailor the metrics and plots to align with your specific objectives and the aspects of the system you wish to highlight.

### **Scenario-Based Explanation of Dynamic Device Assignment Using PPO in Federated Learning**

Imagine you are managing a large-scale **Federated Learning (FL)** system deployed across multiple **Edge Servers** and numerous **Devices** (like smartphones, IoT devices, etc.). Your goal is to efficiently distribute computational tasks to these edge servers and devices to optimize performance metrics such as **accuracy**, **energy consumption**, and **communication costs**. To achieve this, you employ **Deep Reinforcement Learning (DRL)** techniques, specifically **Proximal Policy Optimization (PPO)**, to dynamically assign clusters of devices to edge servers and further schedule individual devices within these clusters.

Let's walk through how each component of your system works together in this scenario.

---

### **1. Understanding the Environment: The Real-World Setup**

**Entities Involved:**
- **Devices:** Each device has unique characteristics like **bandwidth usage**, **energy consumption**, **memory**, and **CPU power**.
- **Clusters:** Devices are grouped into clusters based on certain criteria (e.g., geographical location, data similarity).
- **Edge Servers:** These servers handle computational tasks for the devices. Each edge server has a limited **capacity** in terms of **bandwidth** and **computational resources**.

**Objective:**
- **Cluster Assignment:** Assign each cluster of devices to an appropriate edge server such that the server's capacity is optimally utilized without overloading.
- **Device Scheduling:** Within each cluster, dynamically assign individual devices to edge servers to further optimize performance metrics.

---

### **2. Cluster Assignment with PPO: Assigning Clusters to Edge Servers**

**Environment Setup: `ClusterAssignmentEnv`**

- **Purpose:** This environment simulates the process of assigning clusters of devices to edge servers. The PPO agent learns to make optimal assignments based on the current state of the system.

**Key Components:**

1. **State Representation:**
   - **Normalized Bandwidth:** Represents the bandwidth requirements of each cluster, normalized to ensure numerical stability.
   - **Normalized Capacities:** Reflects the capacity of each edge server, also normalized.
   - **Unused Capacities:** Indicates how much capacity remains unused on each server after assignments.

2. **Action Space:**
   - **MultiDiscrete Actions:** Each action corresponds to assigning a cluster to one of the available edge servers. If there are `N` clusters and `M` edge servers, the action space is a vector of length `N`, where each element can take a value from `0` to `M-1` (representing the server IDs).

3. **Reward Function:**
   - **Overload Penalty:** Penalizes assignments that exceed an edge server's capacity.
   - **Load Variance Penalty:** Discourages imbalanced load distributions across servers.
   - **Resource Utilization Incentive:** Rewards assignments that maximize the usage of available resources without overloading.
   
   The reward is a weighted combination of these factors, encouraging the PPO agent to find assignments that are balanced, efficient, and within capacity constraints.

**Training Process:**

- **Initialization:** The environment initializes with the current bandwidth requirements of clusters and the capacities of edge servers.
- **Learning:** The PPO agent interacts with this environment, making assignments and receiving rewards based on the effectiveness of those assignments.
- **Optimization:** Over multiple iterations (`timesteps`), the agent learns policies that maximize cumulative rewards, effectively learning to assign clusters to servers optimally.

**Deployment:**

Once trained, the **Cluster Assignment Agent** can predict the best cluster-to-server assignments for new or dynamic configurations of devices and servers.

---

### **3. Device Scheduling with PPO: Dynamically Assigning Devices Within Clusters**

**Environment Setup: `DeviceSchedulingEnv`**

- **Purpose:** After clusters are assigned to specific edge servers, this environment handles the dynamic assignment of individual devices within those clusters to the designated servers. The PPO agent optimizes device scheduling to further enhance performance metrics.

**Key Components:**

1. **State Representation:**
   - **Normalized Bandwidth Usage:** Represents the bandwidth consumption of each device.
   - **Normalized Energy Usage:** Reflects the energy consumption of each device.
   - **Diversity Scores:** Calculated using entropy to measure the heterogeneity of data labels each device possesses, normalized between `[0, 1]`.

2. **Action Space:**
   - **MultiBinary Actions:** Each action represents whether a device is selected (`1`) or not (`0`) for scheduling. The action space is a binary vector where each element corresponds to a device.

3. **Reward Function:**
   - **Accuracy Improvement:** Rewards the selection of devices that contribute to improving the global model's accuracy.
   - **Communication Cost Penalty:** Penalizes the total bandwidth used for communication.
   - **Energy Consumption Penalty:** Penalizes the total energy consumed by the selected devices.

   The reward function balances the benefits of selecting devices that enhance model accuracy against the costs associated with their bandwidth and energy usage.

**Training Process:**

- **Initialization:** The environment is initialized with device characteristics and the static `cluster_to_server_map`.
- **Learning:** The PPO agent learns to select devices that offer the best trade-off between improving model accuracy and minimizing costs.
- **Optimization:** Through extensive training (`timesteps`), the agent fine-tunes its policies to make effective scheduling decisions.

**Deployment:**

The **Device Scheduling Agent** can dynamically decide which devices within each cluster should participate in training rounds, adapting to changes in device availability, energy levels, or network conditions.

---

### **4. Orchestrating Assignments: The Role of `MainAgent`**

**Class Overview: `MainAgent`**

- **Purpose:** Acts as the central coordinator that integrates both the Cluster Assignment and Device Scheduling agents. It manages the overall workflow, ensuring that clusters are assigned to edge servers and devices are scheduled appropriately before initiating the federated learning process.

**Key Functions:**

1. **Initialization:**
   - **Data Preparation:** Aggregates device bandwidth requirements and determines edge server capacities.
   - **Environment Creation:** Sets up both `ClusterAssignmentEnv` and `DeviceSchedulingEnv` with the necessary parameters.
   - **Agent Loading:** Loads the trained PPO agents for cluster assignment and device scheduling.

2. **Run Method:**
   - **Cluster Assignment:**
     - **State Reset:** Resets the cluster assignment environment to get the initial state.
     - **Action Prediction:** Uses the Cluster Assignment Agent to predict the best assignments of clusters to edge servers.
     - **DataFrame Update:** Updates the `devices_df` with the predicted `assigned_servers` based on cluster assignments.
   
   - **Device Scheduling:**
     - **State Reset:** Resets the device scheduling environment to get the initial state.
     - **Action Prediction:** Uses the Device Scheduling Agent to decide which devices to schedule for participation.
     - **DataFrame Update:** Updates the `devices_df` with the `is_scheduled` flag based on scheduling decisions.
   
   - **Federated Learning Integration:**
     - **Parameter Preparation:** Extracts necessary parameters for the federated learning system.
     - **System Initialization:** Initializes the `FederatedLearningSystem` with the updated device assignments and scheduling information.
     - **Learning Execution:** Initiates the federated learning process, which involves training global models based on the scheduled devices and assigned clusters.

**Deployment:**

When `MainAgent.run()` is called, it seamlessly integrates the cluster and device assignments into the federated learning workflow, ensuring that each training round is optimized based on current system states and learned policies.

---

### **5. Deploying Dynamic Device Assignment: Step-by-Step Workflow**

**Step 1: Configuration and Initialization**

- **Configurations:** Define different setups for experiments, specifying parameters like the number of devices, edge servers, dataset names, and iteration counts.
- **Metrics Directory:** Create a dedicated folder to store performance metrics for analysis.

**Step 2: Data Distribution and Clustering**

- **Data Distributor (`GNNClustering`):**
  - **Distribute Data:** Allocates data to devices based on the specified dataset.
  - **Clustering Devices:** Groups devices into clusters using Graph Neural Networks (GNN) or other clustering methods.

**Step 3: Hybrid Data Redistribution**

- **Data Redistributor (`HybridDataRedistributor`):**
  - **Redistribute Data:** Adjusts data distribution among devices to balance workloads and ensure diversity, based on a percentage threshold.

**Step 4: Training the Cluster Assignment Agent**

- **Bandwidth and Capacity Calculation:**
  - **Cluster Bandwidth:** Sum of bandwidth requirements for each cluster.
  - **Edge Server Capacities:** Determined based on total bandwidth needed and individual server capabilities.
  
- **Agent Training (`train_cluster_assignment_agent`):**
  - **Environment Setup:** Initializes `ClusterAssignmentEnv` with current cluster bandwidths and edge server capacities.
  - **PPO Training:** Trains the PPO agent to learn optimal cluster-to-server assignments.
  - **Model Saving:** Saves the trained Cluster Assignment Agent for future use.
  
- **Cluster-to-Server Map Generation:**
  - **Prediction:** Uses the trained agent to predict the best assignments.
  - **Mapping:** Creates a dictionary mapping each cluster to its assigned edge server.

**Step 5: Training the Device Scheduling Agent**

- **Agent Training (`train_device_scheduling_agent`):**
  - **Environment Setup:** Initializes `DeviceSchedulingEnv` with device data and the static `cluster_to_server_map`.
  - **PPO Training:** Trains the PPO agent to learn optimal device scheduling within clusters.
  - **Model Saving:** Saves the trained Device Scheduling Agent for future use.

**Step 6: Orchestrating with MainAgent**

- **Initialization:**
  - **Agent Loading:** Loads both the Cluster Assignment and Device Scheduling agents.
  - **Environment Creation:** Sets up the respective environments with the necessary mappings.
  
- **Execution:**
  - **Cluster Assignment:** Uses the Cluster Assignment Agent to assign clusters to servers.
  - **Device Scheduling:** Uses the Device Scheduling Agent to schedule devices within clusters.
  - **Federated Learning Integration:** Passes the assignments and schedules to the `FederatedLearningSystem` and initiates training.

**Step 7: Federated Learning Execution**

- **FederatedLearningSystem:**
  - **Local Training:** Devices train local models based on assignments and schedules.
  - **Edge Aggregation:** Edge servers aggregate models from their assigned devices.
  - **Global Aggregation:** The cloud server aggregates models from all edge servers to update the global model.
  - **Evaluation:** Periodically evaluates the global model's accuracy and logs performance metrics.

---

### **6. Deploying Dynamic Device Assignment: Detailed Process**

Let's visualize how dynamic device assignment is deployed using the provided code through a comprehensive scenario.

**Scenario:**  
You are deploying an FL system for the **MNIST** dataset with **20 devices** distributed across **5 edge servers**. Each device has varying bandwidth and energy profiles. Your goal is to optimize the assignment of clusters and devices to edge servers to achieve high model accuracy while minimizing communication costs and energy consumption.

**Workflow Steps:**

1. **Setup and Configuration:**
   - **Define Configuration:** Specify parameters such as the number of devices, edge servers, dataset, and iteration counts.
   - **Initialize Metrics Directory:** Create a folder (`metrics/`) to store training metrics for each configuration.

2. **Data Distribution and Clustering:**
   - **Initialize `GNNClustering`:**
     - Distribute the MNIST dataset among 20 devices.
     - Cluster the devices into 5 clusters based on data characteristics using GNN-based clustering.
   - **Redistribute Data:**
     - Use `HybridDataRedistributor` to balance data distribution among devices, ensuring diversity and load balancing.

3. **Cluster Assignment:**
   - **Calculate Bandwidth and Capacities:**
     - Sum the bandwidth requirements for each cluster.
     - Determine edge server capacities based on total bandwidth needs and server capabilities.
   - **Train Cluster Assignment Agent:**
     - Initialize `ClusterAssignmentEnv` with current cluster bandwidths and server capacities.
     - Train the PPO agent using `train_cluster_assignment_agent`, allowing it to learn optimal cluster-to-server assignments over 10,000 timesteps.
     - Save the trained Cluster Assignment Agent for future predictions.
   - **Generate Cluster-to-Server Map:**
     - Use the trained agent to predict assignments.
     - Create a mapping dictionary that assigns each cluster to a specific edge server.

4. **Device Scheduling:**
   - **Train Device Scheduling Agent:**
     - Initialize `DeviceSchedulingEnv` with device data and the static `cluster_to_server_map`.
     - Train the PPO agent using `train_device_scheduling_agent` over 10,000 timesteps to learn optimal device scheduling within clusters.
     - Save the trained Device Scheduling Agent.

5. **Orchestrate Assignments with `MainAgent`:**
   - **Initialize `MainAgent`:**
     - Load both the Cluster Assignment and Device Scheduling agents.
     - Set up environments with the current device data and cluster-to-server mappings.
   - **Run Assignments and Federated Learning:**
     - **Cluster Assignment:** Assign clusters to edge servers using the Cluster Assignment Agent.
     - **Device Scheduling:** Schedule devices within clusters using the Device Scheduling Agent.
     - **Initialize Federated Learning:** Pass the assignments and schedules to `FederatedLearningSystem`.
     - **Execute Federated Learning:** Start the training process, where devices train local models, edge servers aggregate them, and the global model is updated iteratively.

6. **Federated Learning Execution:**
   - **Local Model Training:**
     - Each scheduled device trains its local model based on assigned data and device-specific configurations.
     - Track and log energy consumption and communication costs.
   - **Edge Server Aggregation:**
     - Edge servers aggregate models from their assigned devices, considering communication costs and energy metrics.
     - Perform multiple edge iterations before global aggregation.
   - **Global Model Aggregation:**
     - The cloud server aggregates models from all edge servers to update the global model.
     - Evaluate the global model's accuracy using the test dataset.
     - Log performance metrics and save summaries for analysis.

7. **Dynamic Adaptation:**
   - **Device Characteristics Update:**
     - After each global iteration, simulate changes in device energy and bandwidth profiles to reflect real-world dynamics.
   - **Reassignment:**
     - Re-run the Cluster Assignment and Device Scheduling agents to adapt to the updated device states.
     - Ensure that assignments remain optimal as device conditions evolve.

---

### **7. Key Benefits of This Deployment Approach**

- **Optimized Resource Utilization:**
  - **Cluster Assignment:** Ensures that each edge server is assigned clusters in a way that maximizes resource utilization without overloading any server.
  - **Device Scheduling:** Dynamically selects devices that contribute most effectively to model training, balancing accuracy improvements against energy and communication costs.

- **Scalability and Flexibility:**
  - The system can easily scale to accommodate more devices or edge servers by adjusting configurations and retraining the PPO agents as needed.
  - Dynamic adaptation allows the system to respond to changes in device availability or network conditions in real-time.

- **Enhanced Performance Metrics:**
  - By leveraging DRL, the system continuously learns and improves its assignment strategies, leading to higher model accuracies and lower operational costs over time.

- **Modular Design:**
  - The separation of cluster assignment and device scheduling into distinct environments and agents promotes modularity, making the system easier to maintain and extend.

---

### **8. Practical Considerations for Deployment**

**A. Training Considerations:**

- **Training Duration:** PPO agents require substantial training (`timesteps=10,000` in your case) to learn effective policies. Ensure sufficient computational resources and time.
- **Environment Accuracy:** The state representations and reward functions must accurately reflect real-world conditions to guide the PPO agents effectively.
- **Reward Balancing:** Properly balance the reward components to ensure that agents prioritize critical metrics (e.g., accuracy over energy consumption).

**B. System Integration:**

- **Agent Synchronization:** Ensure that both Cluster Assignment and Device Scheduling agents are trained and loaded correctly before initiating federated learning.
- **Error Handling:** Implement robust error handling, especially in multi-threaded environments (e.g., during edge server training) to prevent failures from cascading.

**C. Real-Time Adaptation:**

- **Dynamic Updates:** The system should periodically reassess and update assignments based on changing device states to maintain optimal performance.
- **Feedback Loops:** Incorporate feedback mechanisms to continuously monitor performance metrics and trigger retraining or adjustments as necessary.

**D. Monitoring and Logging:**

- **Metrics Tracking:** Consistently log key metrics like model accuracy, energy consumption, communication costs, and assignment efficiencies for analysis and debugging.
- **Visualization:** Utilize dashboards or visualization tools to monitor the system's performance in real-time, facilitating quick responses to issues or inefficiencies.

---

### **9. Conclusion**

Deploying dynamic device assignment in a federated learning system using PPO-based DRL agents involves a well-coordinated interplay between cluster assignment and device scheduling. By leveraging the strengths of PPO in optimizing complex decision-making processes, your system can achieve efficient resource utilization, adapt to dynamic environments, and maintain high model accuracies while minimizing operational costs.

This scenario-based approach provides a comprehensive understanding of how each component functions and interacts within the system, ensuring that your federated learning deployment is both effective and resilient.

Feel free to reach out if you need further clarifications or assistance with specific aspects of your deployment!

Certainly! Let's explore **Proximal Policy Optimization (PPO)**, a prominent algorithm in the field of **Deep Reinforcement Learning (DRL)**, and understand how it's seamlessly integrated into your federated learning system through your provided code. This explanation will cover the foundational aspects of PPO and elucidate its specific implementation within your clustering and device scheduling environments.

---

## **1. Understanding Proximal Policy Optimization (PPO)**

### **1.1. What is Proximal Policy Optimization (PPO)?**

**Proximal Policy Optimization (PPO)** is a type of **policy gradient** algorithm in reinforcement learning, introduced by OpenAI as a more efficient and stable alternative to earlier methods like Trust Region Policy Optimization (TRPO). PPO is designed to optimize the policy—the strategy by which an agent selects actions—while ensuring that updates to the policy are neither too drastic nor too conservative.

### **1.2. Key Features of PPO**

1. **Clipping Mechanism:**
   - PPO introduces a **clipping** strategy that restricts the policy updates to stay within a small, predefined range. This prevents the new policy from deviating excessively from the old policy, enhancing training stability.

2. **Surrogate Objective Function:**
   - PPO optimizes a **surrogate objective** that balances between improving the policy and maintaining proximity to the previous policy. This dual focus helps in efficient learning without large policy shifts.

3. **Sample Efficiency:**
   - Unlike some algorithms that require fresh data for each update, PPO can reuse the same batch of data multiple times (through multiple epochs) to update the policy, making it more sample-efficient.

4. **Ease of Implementation:**
   - PPO is relatively straightforward to implement compared to more complex algorithms, while still achieving high performance across a variety of tasks.

### **1.3. Why Choose PPO?**

- **Stability:** The clipping mechanism ensures that policy updates remain within a "proximal" region, preventing unstable or divergent learning.
- **Performance:** PPO consistently delivers strong performance across diverse environments and tasks.
- **Flexibility:** It is versatile, working effectively with both discrete and continuous action spaces.

---

## **2. PPO in the Context of Your Federated Learning System**

Your federated learning system involves optimizing two critical components:

1. **Cluster Assignment:** Deciding which clusters of devices should be assigned to which edge servers.
2. **Device Scheduling:** Selecting which devices should participate in the training process at any given time.

Both of these components are formulated as reinforcement learning problems, where PPO serves as the underlying algorithm to learn optimal policies for decision-making.

### **2.1. Cluster Assignment with PPO**

- **Environment:** The `ClusterAssignmentEnv` simulates the scenario where clusters of devices need to be assigned to edge servers based on factors like bandwidth and server capacities.
  
- **Agent:** A PPO-based agent interacts with this environment to learn how to assign clusters to servers optimally. The agent observes the current state of server loads, bandwidth requirements, and capacities, then decides on assignments that maximize overall system efficiency.

- **Reward Function:** The reward is crafted to encourage:
  - **High Resource Utilization:** Maximizing the use of server capacities.
  - **Minimizing Overload:** Avoiding assignments that exceed server capacities.
  - **Balancing Load:** Ensuring that server loads are evenly distributed to prevent bottlenecks.

### **2.2. Device Scheduling with PPO**

- **Environment:** The `DeviceSchedulingEnv` models the process of selecting devices for participation in training. It considers device-specific metrics like bandwidth usage and energy consumption.

- **Agent:** Another PPO-based agent operates within this environment to determine which devices should be active in the training process, balancing the trade-off between improving model accuracy and conserving resources.

- **Reward Function:** The reward structure incentivizes:
  - **Accuracy Improvement:** Selecting more devices can lead to better model performance.
  - **Reducing Communication Costs:** Minimizing the bandwidth used for communication between devices and servers.
  - **Lowering Energy Consumption:** Choosing devices that consume less energy during training.

---

## **3. How PPO is Implemented in Your Code**

### **3.1. Initialization and Environment Setup**

- **Environment Instances:**
  - **ClusterAssignmentEnv:** Initialized with parameters like cluster bandwidths, edge server capacities, and device information. It defines the state and action spaces relevant to cluster assignments.
  
  - **DeviceSchedulingEnv:** Initialized with device data and the current mapping of clusters to servers. It defines the state and action spaces pertinent to device scheduling.

- **Action and Observation Spaces:**
  - **Cluster Assignment:** Uses a `MultiDiscrete` action space where each action corresponds to assigning a cluster to one of the available servers.
  
  - **Device Scheduling:** Utilizes a `MultiBinary` action space where each action determines whether a device is selected (`1`) or not (`0`).

### **3.2. Agent Configuration**

- **Policy Network:**
  - Both agents use a **Multi-Layer Perceptron (MLP)** policy, which is a type of neural network suitable for handling the state representations defined in their respective environments.

- **PPO Parameters:**
  - **Learning Rate:** Determines the step size during optimization. A typical starting point is `3e-4`.
  
  - **Entropy Coefficient (`ent_coef`):** Encourages exploration by penalizing certainty in action choices. A common value is `0.01`.
  
  - **Clipping Range (`clip_range`):** Controls the extent to which the policy can change in a single update. Values around `0.1` are standard.
  
  - **Batch Size and Steps (`batch_size`, `n_steps`):** Influence how data is sampled and processed during training. For instance, `batch_size=64` and `n_steps=2048` help in balancing computational efficiency and learning stability.
  
  - **Other Parameters:** Include `max_grad_norm` to prevent gradient explosion, and `gae_lambda` for Generalized Advantage Estimation, enhancing the quality of the advantage estimates used in updates.

### **3.3. Training Process**

- **Learning Loop:**
  - The agents undergo training over a specified number of timesteps (e.g., `10000`), during which they interact with their respective environments, collect experiences, and update their policies based on the PPO algorithm.

- **Callbacks and Scheduling:**
  - **Custom Callbacks:** Implemented to adjust learning rates dynamically and log training progress. These callbacks interact with the PPO agent at specific points in the training process, such as after a certain number of timesteps.
  
  - **Learning Rate Schedulers:** Employed within callbacks to modify the learning rate based on training progress, ensuring that the agent doesn't overshoot optimal policies as training progresses.

### **3.4. Evaluation and Deployment**

- **Model Evaluation:**
  - After training, the agents are evaluated by resetting their environments, predicting actions based on current states, and assessing the resulting performance using predefined metrics.
  
  - **Metrics Captured:** Include rewards, resource utilization, overload penalties, load variance, and accuracy improvements, providing a comprehensive view of agent performance.

- **Model Saving:**
  - Trained PPO models are saved for future deployment, allowing the federated learning system to utilize these agents for ongoing cluster assignments and device scheduling without retraining.

---

## **4. The Role of PPO in Enhancing Federated Learning**

### **4.1. Optimizing Resource Allocation**

- **Cluster Assignment:** PPO agents learn to allocate clusters to edge servers in a manner that maximizes resource utilization while preventing server overloads and ensuring balanced loads across servers.

- **Device Scheduling:** PPO agents determine the optimal set of devices to participate in training, balancing the trade-off between enhancing model accuracy and minimizing communication and energy costs.

### **4.2. Adaptive Decision-Making**

- PPO enables your system to adapt to dynamic conditions, such as fluctuating server capacities, varying device performances, and changing network conditions. The agents continuously learn and adjust their policies to maintain optimal performance.

### **4.3. Balancing Multiple Objectives**

- **Multi-Objective Rewards:** The reward structures for both cluster assignment and device scheduling encapsulate multiple objectives (e.g., utilization, overload penalties, variance penalties), allowing PPO to navigate the trade-offs and find balanced solutions.

---

## **5. Advantages of Using PPO in Your System**

1. **Stability and Reliability:**
   - The clipping mechanism in PPO ensures that policy updates are stable, preventing drastic changes that could destabilize the learning process.

2. **Sample Efficiency:**
   - PPO's ability to reuse data through multiple epochs enhances learning efficiency, making it well-suited for environments where data collection is computationally expensive.

3. **Scalability:**
   - PPO can handle large and complex state and action spaces, making it ideal for the intricate tasks of cluster assignment and device scheduling in a federated learning setup.

4. **Ease of Integration:**
   - Leveraging libraries like Stable Baselines3 allows for straightforward implementation and integration of PPO agents into your existing system.

---

## **6. Considerations and Best Practices**

### **6.1. Hyperparameter Tuning**

- **Importance:** The performance of PPO agents heavily depends on the chosen hyperparameters. Fine-tuning parameters like learning rate, clipping range, and entropy coefficient can lead to significant performance improvements.

- **Approach:** Use techniques like grid search, random search, or Bayesian optimization to explore different hyperparameter configurations systematically.

### **6.2. Enhanced State Representations**

- **Richer Features:** Incorporate additional contextual information into the state representations to provide agents with a more comprehensive understanding of the environment.

- **Temporal Information:** Including historical data or trends can help agents make more informed decisions based on past states.

### **6.3. Reward Shaping**

- **Balanced Rewards:** Ensure that the reward functions adequately balance the different objectives, preventing the agent from focusing excessively on one aspect at the expense of others.

- **Normalization:** Normalize rewards to maintain consistent learning dynamics and prevent issues like reward scaling that can hinder training.

### **6.4. Monitoring and Logging**

- **Progress Tracking:** Utilize tools like TensorBoard or custom logging to monitor training progress, rewards, and other key metrics in real-time.

- **Debugging:** Detailed logs help in diagnosing issues, understanding agent behaviors, and refining training processes.

### **6.5. Exploration vs. Exploitation**

- **Encouraging Exploration:** Adjust the entropy coefficient to strike a balance between exploration (trying new actions) and exploitation (leveraging known good actions).

- **Preventing Premature Convergence:** Ensure that the agent doesn't settle into suboptimal policies by promoting sufficient exploration during training.

---

## **7. Summary**

**Proximal Policy Optimization (PPO)** is a robust and efficient **Deep Reinforcement Learning (DRL)** algorithm well-suited for complex decision-making tasks like cluster assignment and device scheduling in a federated learning environment. By leveraging PPO's strengths—such as stability, sample efficiency, and scalability—your system can dynamically and optimally allocate resources, balance multiple objectives, and adapt to changing conditions.

In your implementation:

- **Two PPO Agents:** One for cluster assignment and another for device scheduling, each interacting with their respective environments (`ClusterAssignmentEnv` and `DeviceSchedulingEnv`).
  
- **Environment Interaction:** The agents observe the current state, take actions (assignments or scheduling decisions), receive rewards based on their actions' effectiveness, and iteratively refine their policies to maximize cumulative rewards.
  
- **Integration with Federated Learning:** The PPO agents' decisions directly influence the federated learning process, optimizing how clusters and devices are managed to enhance overall system performance and efficiency.

By adhering to best practices—such as hyperparameter tuning, enhancing state representations, careful reward shaping, and thorough monitoring—you can ensure that your PPO-based agents perform optimally, contributing significantly to the success of your federated learning system.

If you have further questions or need deeper insights into specific aspects of PPO or its implementation, feel free to ask!

Switching from **Proximal Policy Optimization (PPO)** to a **Dueling Double Deep Q-Network (D3QN)** in your federated learning system's reinforcement learning (RL) framework is a significant architectural change. To determine whether this is a **good idea**, it's essential to understand the core differences between these algorithms, assess their suitability for your specific tasks, and consider the potential benefits and challenges involved in such a transition.

---

## **1. Understanding the Algorithms**

### **1.1. Proximal Policy Optimization (PPO)**

- **Type:** **Policy-Based** Deep Reinforcement Learning (DRL) algorithm.
- **Mechanism:** PPO directly optimizes the policy (the agent's decision-making strategy) by adjusting its parameters to maximize expected rewards. It uses a **clipping mechanism** to ensure that policy updates remain within a certain range, enhancing training stability.
- **Action Space Compatibility:** Handles both **discrete** and **continuous** action spaces effectively.
- **Strengths:**
  - **Stability:** Clipping prevents large, destabilizing policy updates.
  - **Sample Efficiency:** Can reuse data across multiple epochs.
  - **Flexibility:** Suitable for complex environments with high-dimensional state spaces.

### **1.2. Dueling Double Deep Q-Network (D3QN)**

- **Type:** **Value-Based** Deep Reinforcement Learning (DRL) algorithm.
- **Mechanism:** Combines three key enhancements to the standard Deep Q-Network (DQN):
  - **Double DQN:** Mitigates overestimation bias in action-value estimates by decoupling action selection from evaluation.
  - **Dueling Network Architecture:** Separates the estimation of **state-value** and **advantage** functions, allowing the network to learn which states are (or are not) valuable without having to learn the effect of each action for each state.
  - **Prioritized Experience Replay (optional):** Samples more important transitions more frequently.
- **Action Space Compatibility:** Primarily designed for **discrete** action spaces.
- **Strengths:**
  - **Efficiency:** Often faster to train on discrete tasks with well-defined action spaces.
  - **Bias Reduction:** Double DQN reduces overestimation of Q-values.
  - **State Evaluation:** Dueling architecture enhances the ability to evaluate state values independently of actions.

---

## **2. Comparing PPO and D3QN for Your Use Case**

Your federated learning system utilizes two primary RL tasks:

1. **Cluster Assignment:** Assigning clusters of devices to edge servers.
2. **Device Scheduling:** Selecting which devices participate in training.

Both tasks involve **discrete action spaces**:
- **Cluster Assignment:** MultiDiscrete action space where each cluster is assigned to one of the available servers.
- **Device Scheduling:** MultiBinary action space where each device is either selected or not.

### **2.1. Suitability of PPO**

- **Pros:**
  - **Handles Complex Action Spaces:** PPO can manage MultiDiscrete and MultiBinary action spaces more naturally.
  - **Stability in Training:** The clipping mechanism ensures stable policy updates, which is beneficial for environments with multiple interacting actions.
  - **Policy Flexibility:** Directly models the policy, allowing for more nuanced decision-making in complex scenarios.

- **Cons:**
  - **Sample Inefficiency:** Generally requires more samples to converge compared to value-based methods like DQN variants.
  - **Computational Overhead:** Policy-based methods can be more computationally intensive due to the need to compute gradients for policy updates.

### **2.2. Suitability of D3QN**

- **Pros:**
  - **Efficiency in Discrete Actions:** Excels in environments with discrete, well-defined actions.
  - **Bias Reduction:** Double DQN reduces the overestimation of action values, leading to more accurate learning.
  - **Enhanced State Evaluation:** Dueling architecture allows the network to distinguish between the value of being in a state and the advantages of actions, potentially improving decision-making.

- **Cons:**
  - **Complexity with MultiDiscrete/MultiBinary Actions:** D3QN is primarily designed for single discrete actions. Extending it to handle MultiDiscrete or MultiBinary actions can be non-trivial and may require architectural modifications.
  - **Less Flexibility:** Value-based methods can struggle with environments that require nuanced policies or have high-dimensional state spaces.
  - **Exploration Challenges:** DQN variants rely heavily on exploration strategies (like ε-greedy), which might be less effective in complex action spaces compared to policy-based methods.

---

## **3. Practical Considerations for Switching to D3QN**

### **3.1. Action Space Handling**

- **MultiDiscrete Actions:**
  - **Challenge:** D3QN naturally handles single discrete actions. MultiDiscrete actions involve selecting multiple discrete actions simultaneously, which complicates the action-value estimation.
  - **Potential Solutions:**
    - **Independent D3QNs:** Train separate D3QN agents for each component of the MultiDiscrete action space. However, this can lead to increased computational resources and coordination challenges.
    - **Joint Action Representation:** Represent the entire MultiDiscrete action as a single composite action. This exponentially increases the action space size, making learning more difficult and resource-intensive.

- **MultiBinary Actions:**
  - **Challenge:** Selecting multiple binary actions simultaneously can be approached by treating each binary decision independently, but D3QN isn't inherently designed for this.
  - **Potential Solutions:**
    - **Separate Outputs:** Modify the D3QN architecture to output multiple Q-values corresponding to each binary decision. However, this deviates from the standard D3QN implementation and may require significant customization.

### **3.2. Network Architecture Modifications**

- **Dueling Architecture Integration:** Ensure that the dueling components (separating state-value and advantage functions) are appropriately integrated to handle the complexity of your tasks.
  
- **Scalability:** The network must efficiently scale with the number of clusters and devices, avoiding exponential growth in parameters.

### **3.3. Training Stability and Sample Efficiency**

- **PPO's Stability vs. D3QN's Efficiency:** While D3QN might train faster on discrete actions, PPO offers more stability, especially in complex environments with multiple interacting decisions.

- **Exploration Strategies:** Implement effective exploration mechanisms to ensure that the agent adequately explores the multi-action space, which is more challenging with D3QN.

### **3.4. Computational Resources**

- **Increased Complexity:** Handling MultiDiscrete or MultiBinary actions with D3QN can significantly increase the computational burden, requiring more memory and processing power.

- **Training Time:** Due to the larger action spaces and potential architectural modifications, training time may increase compared to PPO.

---

## **4. Recommendations**

### **4.1. Assess the Complexity of Action Spaces**

Given that your tasks involve MultiDiscrete and MultiBinary action spaces, **PPO remains a more suitable choice** due to its inherent ability to handle complex and multi-dimensional actions without requiring extensive architectural changes.

### **4.2. Consider Hybrid or Alternative Approaches**

If you're inclined to explore value-based methods like D3QN, consider the following:

- **Hybrid Models:** Combine policy-based and value-based methods to leverage the strengths of both. For instance, using PPO for policy updates while incorporating value estimates for certain decisions.

- **Actor-Critic Architectures:** These architectures, which PPO already utilizes, can benefit from incorporating aspects of D3QN's value estimation enhancements.

### **4.3. Incremental Testing**

Before fully transitioning to D3QN, implement a smaller-scale version to test its effectiveness in handling your specific action spaces. This approach allows you to identify and address challenges without overhauling the entire system.

### **4.4. Leverage Existing Libraries and Frameworks**

Utilize RL libraries like **Stable Baselines3**, which offer robust implementations of both PPO and DQN variants. These libraries can facilitate experimentation and ensure that best practices are followed.

### **4.5. Prioritize Stability and Performance**

In federated learning environments, **training stability** and **reliability** are paramount. PPO's design inherently supports these aspects, making it a strong candidate for your tasks.

---

## **5. Conclusion**

While **Dueling Double Deep Q-Network (D3QN)** is a powerful algorithm for **value-based** reinforcement learning, its applicability to your federated learning system's **multi-dimensional discrete action spaces** is limited and would require significant modifications. On the other hand, **Proximal Policy Optimization (PPO)**, being a **policy-based** method, naturally accommodates complex action spaces and offers stability and flexibility that align well with the demands of cluster assignment and device scheduling tasks.

**Therefore, continuing with PPO is likely the more effective and efficient choice** for your current setup. However, if you have specific reasons or requirements that you believe D3QN can better address, consider experimenting with hybrid approaches or exploring advanced architectures that can bridge the gap between value-based and policy-based methods.

Should you decide to proceed with D3QN or any other alternative, ensure thorough testing and validation to assess its performance and stability within your federated learning framework.

Feel free to ask if you need further clarification or assistance with specific aspects of implementing these algorithms!

Absolutely, your proposed strategy is both **possible** and **doable**. Leveraging **Proximal Policy Optimization (PPO)** within a **Deep Reinforcement Learning (DRL)** framework to assign clusters of devices—each focusing on specific data labels—to corresponding edge servers can significantly enhance the performance and accuracy of your federated learning system. By tailoring each edge server's model to specialize in particular labels and subsequently aggregating these specialized models, you can achieve more nuanced and accurate global models. Below, I provide a comprehensive overview of how this can be accomplished, the benefits it offers, potential challenges, and strategic implementation guidelines.

---

## **1. Conceptual Overview**

### **1.1. Specialized Edge Server Models**

- **Cluster Specialization:** Each cluster of devices is inherently focused on specific data labels, leading to models trained on diverse but targeted datasets.
  
- **Edge Server Assignment:** Assigning these clusters to corresponding edge servers allows each server to train a model specialized in certain labels, enhancing the model's performance on those specific tasks.

### **1.2. Model Aggregation for Enhanced Global Performance**

- **Aggregated Learning:** After specialized training, aggregating these models can combine their strengths, resulting in a comprehensive global model that benefits from specialized expertise across various labels.

- **Improved Accuracy:** This approach can lead to improved accuracy, as each specialized model contributes its refined knowledge to the global model, ensuring robust performance across all labels.

---

## **2. Feasibility and Doability**

### **2.1. Feasibility**

- **Technical Viability:** Modern federated learning frameworks support heterogeneous model training and aggregation, making the implementation of specialized models feasible.

- **DRL Applicability:** PPO, as a DRL algorithm, is well-suited for complex decision-making tasks like dynamic cluster-to-server assignments, especially when considering multiple objectives such as diversity, load balancing, and resource optimization.

### **2.2. Doability**

- **Existing Infrastructure:** If your current system already employs PPO for cluster assignments, extending it to incorporate label-specific assignments is a logical and manageable progression.

- **Scalability:** This approach scales well with an increasing number of clusters and edge servers, provided that the assignment strategy remains efficient and the aggregation process is optimized.

---

## **3. Benefits of Specialized Assignments**

### **3.1. Enhanced Model Performance**

- **Targeted Learning:** Specialized models can achieve higher accuracy on their respective labels due to focused training, leading to better performance on specific tasks.

- **Reduced Overfitting:** By concentrating on specific labels, models can generalize better within their domain, reducing the risk of overfitting to irrelevant data.

### **3.2. Efficient Resource Utilization**

- **Balanced Loads:** Assigning clusters based on their resource requirements ensures that edge servers are optimally utilized, preventing scenarios where some servers are overburdened while others are underutilized.

- **Scalability:** As the number of devices and clusters grows, this approach facilitates efficient scaling by maintaining balanced workloads across servers.

### **3.3. Improved Data Privacy and Security**

- **Data Segmentation:** By clustering devices based on specific labels, data can be better segmented, enhancing privacy controls and reducing the risk of sensitive information leakage.

---

## **4. Strategic Implementation Guidelines**

To effectively implement this strategy, consider the following steps and best practices:

### **4.1. Define Clear Objectives and Metrics**

- **Performance Metrics:** Establish metrics to evaluate both individual edge server model performance and the aggregated global model's accuracy.

- **Diversity Metrics:** Implement metrics that quantify data diversity within each server, ensuring that specialization does not lead to skewed or imbalanced learning.

### **4.2. Enhance the PPO-Based Cluster Assignment Agent**

- **State Representation:**
  - **Incorporate Label Distribution:** Extend the state to include information about the label distribution within each cluster. This allows the agent to consider data diversity during assignments.
  
  - **Resource Metrics:** Include additional resource-related features such as CPU/GPU availability, memory usage, and network bandwidth of edge servers.

- **Action Space Adjustment:**
  - **Cluster-to-Server Mapping:** Ensure that the action space accurately represents possible mappings of clusters to edge servers, considering both diversity and resource constraints.
  
  - **Hierarchical Actions:** If necessary, implement a hierarchical action space where higher-level actions determine cluster groupings before assigning them to servers.

- **Reward Function Refinement:**
  - **Diversity Incentives:** Modify the reward function to include positive rewards for maintaining high data diversity within edge servers.
  
  - **Penalize Overlaps:** Introduce penalties for assigning multiple clusters with similar label distributions to the same server, preventing redundancy and promoting diversity.

### **4.3. Implement Clustering Constraints**

- **Constraint Enforcement:**
  - **Action Masking:** Utilize action masking within the environment to prevent the agent from selecting assignments that violate clustering constraints (e.g., assigning multiple similar clusters to a single server).
  
  - **Post-Assignment Validation:** After an assignment action, validate the cluster-to-server mapping and adjust rewards or enforce corrections if constraints are violated.

- **Constraint-Based Reward Shaping:**
  - **Strong Penalties:** Assign substantial negative rewards for violating constraints to discourage the agent from making such assignments in the future.
  
  - **Balanced Rewards:** Ensure that rewards for meeting objectives (e.g., diversity) are sufficiently weighted to outweigh any penalties, promoting a balanced optimization process.

### **4.4. Optimize the Aggregation Process**

- **Model Aggregation Strategy:**
  - **Federated Averaging:** Utilize Federated Averaging (FedAvg) or more sophisticated aggregation methods that can effectively combine specialized models into a robust global model.
  
  - **Weighted Aggregation:** Assign weights to each edge server's model based on factors like data volume, diversity, and performance metrics to ensure that more informative models have a greater influence on the global model.

- **Consistency Across Models:**
  - **Uniform Architecture:** Ensure that all edge server models share the same architecture to facilitate seamless aggregation.
  
  - **Regular Synchronization:** Schedule periodic synchronization and aggregation steps to integrate learning from specialized models into the global model efficiently.

### **4.5. Monitor and Iterate**

- **Continuous Monitoring:** Implement monitoring tools to track assignment decisions, model performance, and aggregation outcomes in real-time.
  
- **Iterative Refinement:** Use insights from monitoring to iteratively refine the state representation, action space, and reward function, enhancing the agent's decision-making capabilities over time.

- **Feedback Loops:** Incorporate feedback mechanisms where the performance of the global model informs adjustments in cluster assignments, fostering a dynamic and responsive learning environment.

---

## **5. Potential Challenges and Mitigation Strategies**

### **5.1. Balancing Specialization and Generalization**

- **Challenge:** While specialization enhances performance on specific labels, it may lead to models that lack generalization across the entire data distribution.

- **Mitigation:**
  - **Controlled Specialization:** Limit the degree of specialization to ensure that each model retains sufficient generalization capabilities.
  
  - **Diverse Aggregation:** Ensure that the aggregation process effectively integrates specialized knowledge without sacrificing the ability to generalize.

### **5.2. Managing Increased Complexity**

- **Challenge:** Introducing clustering constraints and specialized assignments increases the complexity of the RL environment and the agent's decision-making process.

- **Mitigation:**
  - **Incremental Implementation:** Gradually introduce complexity, starting with basic clustering constraints and progressively adding more nuanced rules.
  
  - **Simplified Models:** Begin with simpler models and environments to ensure stability before scaling up to more complex scenarios.

### **5.3. Computational Overhead**

- **Challenge:** Specialized assignments and enhanced state representations may require additional computational resources.

- **Mitigation:**
  - **Efficient Computations:** Optimize algorithms for calculating diversity metrics and enforce constraints efficiently.
  
  - **Resource Allocation:** Ensure that edge servers have adequate computational resources to handle specialized training tasks without bottlenecks.

### **5.4. Data Privacy and Security**

- **Challenge:** Managing specialized clusters may introduce new vectors for data privacy and security vulnerabilities.

- **Mitigation:**
  - **Robust Security Protocols:** Implement strong encryption and access controls to protect data across all clusters and edge servers.
  
  - **Compliance Checks:** Ensure that data handling practices comply with relevant regulations and standards.

---

## **6. Validation and Evaluation**

### **6.1. Comprehensive Testing**

- **Unit Testing:** Validate individual components, such as diversity metric calculations and constraint enforcement mechanisms, to ensure they function as intended.

- **Integration Testing:** Test the entire pipeline—from cluster assignment to model aggregation—to identify and address any integration issues.

### **6.2. Performance Metrics**

- **Edge Server Performance:**
  - **Accuracy per Server:** Measure the accuracy of each specialized edge server model on its designated labels.
  
  - **Resource Utilization:** Track CPU, memory, and bandwidth usage to ensure balanced resource allocation.

- **Global Model Performance:**
  - **Aggregated Accuracy:** Evaluate the global model's accuracy across all labels to assess the effectiveness of the aggregation process.
  
  - **Generalization Capability:** Test the global model on diverse datasets to ensure it generalizes well beyond specialized training.

### **6.3. Iterative Refinement**

- **Analyze Results:** Use performance metrics to identify areas where specialization and aggregation can be improved.

- **Adjust Strategies:** Modify state representations, reward functions, or aggregation methods based on empirical results to enhance overall system performance.

---

## **7. Best Practices for Implementation**

### **7.1. Robust Environment Design**

- **Comprehensive State Information:** Ensure that the state representation captures all necessary aspects for informed decision-making, including diversity metrics and resource states.

- **Effective Constraint Enforcement:** Implement strict and efficient mechanisms to enforce clustering constraints, preventing the agent from making detrimental assignments.

### **7.2. Efficient Learning and Training**

- **Hyperparameter Tuning:** Systematically tune PPO hyperparameters (e.g., learning rate, batch size, clipping range) to optimize learning performance.

- **Regular Evaluation:** Conduct frequent evaluations during training to monitor progress and make timely adjustments.

### **7.3. Scalability and Flexibility**

- **Modular Design:** Design the system in a modular fashion, allowing for easy scaling and adjustments as the number of clusters and edge servers grows.

- **Adaptability:** Ensure that the system can adapt to changes in device distributions, server capacities, or data label distributions without requiring extensive reconfiguration.

### **7.4. Comprehensive Monitoring and Logging**

- **Detailed Logs:** Maintain detailed logs of assignment decisions, model performances, and training metrics to facilitate debugging and performance analysis.

- **Visualization Tools:** Utilize visualization tools like TensorBoard to gain real-time insights into training dynamics and agent behaviors.

---

## **8. Conclusion**

Your strategy to assign clusters—each focusing on specific data labels—to corresponding edge servers using PPO agents is both innovative and promising. By implementing this approach, you can achieve:

- **Targeted Model Training:** Specialized models that excel in particular label domains, enhancing overall system accuracy.

- **Balanced Resource Utilization:** Efficient distribution of workloads across edge servers, preventing bottlenecks and ensuring smooth operation.

- **Robust Global Models:** Aggregated models that benefit from the specialized knowledge of individual edge servers, leading to improved generalization and performance.

To successfully implement this strategy, focus on enhancing your PPO agents with comprehensive state representations, refined reward functions, and stringent constraint enforcement. Coupled with efficient model aggregation and continuous monitoring, this approach can significantly elevate the effectiveness and reliability of your federated learning system.

**Next Steps:**

1. **Prototype Implementation:** Start by implementing the proposed enhancements in a controlled environment to observe their effects.

2. **Iterative Testing:** Gradually introduce complexity, monitor outcomes, and refine strategies based on empirical data.

3. **Scale Up:** Once validated, scale the approach to handle larger clusters and more edge servers, ensuring that the system remains efficient and effective.

4. **Continuous Optimization:** Continuously optimize hyperparameters, reward structures, and aggregation methods to adapt to evolving data distributions and system requirements.

By following this comprehensive strategy, you position your federated learning system to harness the full potential of specialized model training and aggregation, ultimately achieving superior accuracy and performance across diverse data labels.

If you have further questions or need more detailed guidance on specific implementation aspects, feel free to ask!

# Libraries

In [None]:
# %pip install torch==2.2.0 pyomo scikit-learn torch_geometric torchvision gym gymnasium stable_baselines3 'shimmy>=2.0' 
# %pip install gym gymnasium stable_baselines3 'shimmy>=2.0'

In [None]:
import os
import gc
import ast
import sys
import json
import math
import time
import glob
import random
import psutil
import asyncio
import threading
import numpy as np
import pandas as pd
import nest_asyncio
import seaborn as sns
from copy import deepcopy
import pyomo.environ as pyo
from scipy.stats import entropy
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Subset, random_split, TensorDataset

from torch_geometric.data import Data
import torchvision.transforms as transforms

from sklearn.utils import shuffle
from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics import (precision_score, recall_score, f1_score, silhouette_score, calinski_harabasz_score, davies_bouldin_score, 
                            adjusted_rand_score, normalized_mutual_info_score)
from sklearn.neighbors import kneighbors_graph
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split

from cryptography.fernet import Fernet
import zlib

import gym
from gym import spaces
from stable_baselines3 import PPO

In [None]:
torch.cuda.is_available()

# GNN-Kmeans Clustering

In [None]:
class LocalDataset:
    def __init__(self):
        self.images = np.array([], dtype=np.float32)
        self.labels = np.array([], dtype=np.int64)

    def store_data(self, images, labels):
        """
        Store images and labels in the local dataset.
        Overwrites any existing data.
        """
        self.images = np.array(images)
        self.labels = np.array(labels)

    def get_data(self):
        """Return the images and labels."""
        return self.images, self.labels

    def count_labels(self):
        """Count the occurrences of each label in the dataset."""
        if len(self.labels) == 0:
            return {}
        label_counts = dict(Counter(self.labels))
        return label_counts  

    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.labels)

    def __repr__(self):
        """String representation of the dataset."""
        return f"LocalDataset(num_samples={len(self)}, images_shape={self.images.shape}, labels_shape={self.labels.shape})"

    def remove_samples(self, label, count):
        """Remove samples of a specific label from the dataset."""
        indices = np.where(self.labels == label)[0]
        if len(indices) > 0:
            np.random.shuffle(indices)
            to_remove_indices = indices[:count]
            self.images = np.delete(self.images, to_remove_indices, axis=0)
            self.labels = np.delete(self.labels, to_remove_indices)
            # Optionally, update disk usage or other metrics here.

    def add_samples(self, new_images, new_labels):
        """Add new samples to the dataset."""
        # Ensure dimensions match before concatenating
        if self.images.size > 0:  # Check if self.images already contains data
            if new_images.ndim != self.images.ndim:
                # Handle dimensional mismatch by reshaping new_images
                if new_images.ndim == 3 and self.images.ndim == 4:  # Flattened vs raw
                    new_images = new_images.reshape((new_images.shape[0], *self.images.shape[1:]))
                elif new_images.ndim == 4 and self.images.ndim == 3:  # Raw vs flattened
                    new_images = new_images.reshape(new_images.shape[0], -1)
                else:
                    raise ValueError(f"Cannot match dimensions: existing {self.images.shape}, new {new_images.shape}")

        self.images = np.concatenate((self.images, new_images), axis=0)
        self.labels = np.concatenate((self.labels, new_labels))

class GNNClustering:
    def __init__(self, num_devices=20, dataset_name="mnist", mfactor=1, num_edge_servers=5, metrics_file='test.json'):
        self.mfactor = mfactor
        self.num_devices = num_devices
        self.dataset_name = dataset_name.lower()
        self.num_edge_servers = num_edge_servers
        self.devices_df = self.generate_iot_devices()
        self.labels = None
        self.pseudo_labels = None
        self.cluster_nums = num_edge_servers  # K is set to the number of edge servers
        self.clustering_metrics = {}
        self.metrics_file = metrics_file.replace('.json', '_clustering.json')

    def load_data(self):
        """
        Load the specified dataset and return original and augmented data.
        """
        datasets = {
            "mnist": torchvision.datasets.MNIST,
            "fashion_mnist": torchvision.datasets.FashionMNIST,
            "cifar10": torchvision.datasets.CIFAR10,
        }
        dataset_class = datasets.get(self.dataset_name, torchvision.datasets.MNIST)
        
        # Define transformations
        if self.dataset_name == "cifar10":
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(15),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
            normalization = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
        else:
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
            normalization = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
        
        # Load original dataset
        trainset_original = dataset_class(root='./data', train=True, download=True, transform=normalization)
        testset_original = dataset_class(root='./data', train=False, download=True, transform=normalization)

        # Load augmented dataset
        trainset_augmented = dataset_class(root='./data', train=True, download=True, transform=augmentations)
        augmented_indices = np.random.choice(len(trainset_augmented), size=len(trainset_original) * self.mfactor, replace=True)
        augmented_subset = Subset(trainset_augmented, augmented_indices)
        augmented_loader = DataLoader(augmented_subset, batch_size=len(augmented_subset))

        # Convert datasets to numpy arrays
        train_images_original = np.array([np.array(image) for image, _ in trainset_original])
        train_labels_original = np.array([label for _, label in trainset_original])
        augmented_images, augmented_labels = next(iter(augmented_loader))
        train_images_augmented = augmented_images.numpy()
        train_labels_augmented = augmented_labels.numpy()
        test_images = np.array([np.array(image) for image, _ in testset_original])
        test_labels = np.array([label for _, label in testset_original])

        return train_images_original, train_labels_original, train_images_augmented, train_labels_augmented, test_images, test_labels

    def generate_iot_devices(self):
        """Generate IoT devices with randomized properties and store in a DataFrame."""
        devices_df = pd.DataFrame({
            'device_id': range(self.num_devices),
            'cpu_power': np.round(np.random.uniform(1.0, 2.0, self.num_devices), 2),
            'memory': np.random.choice([1, 2, 4], self.num_devices),
            'bandwidth': np.round(np.random.uniform(10, 125, self.num_devices), 2),
            'local_storage': np.round(np.random.uniform(1, 5, self.num_devices), 2),  # Storage in GB
            'disk_usage': 0.0,
            'local_data': [LocalDataset() for _ in range(self.num_devices)],  # Use LocalDataset class
            'labels': [{} for _ in range(self.num_devices)],  # New column for label counts
            'energy_usage': 0.0,  # Initialize energy consumption
            'bandwidth_usage': 0.0  # Initialize bandwidth usage
        })
        return devices_df

    def split_dataset_among_devices(self, images, labels):
        """Split dataset among IoT devices and calculate each device's storage usage in GB."""
        total_samples = len(images)
        device_indices = np.array_split(np.random.permutation(total_samples), self.num_devices)

        def assign_data(row):
            indices = device_indices[row.device_id]
            images_device = images[indices]
            labels_device = labels[indices]
            row.local_data.store_data(images_device, labels_device)  # Store data in LocalDataset

            # Calculate disk usage in GB
            data_size_bytes = images_device.nbytes + labels_device.nbytes
            row.disk_usage = data_size_bytes / (1024 ** 3)  # Convert bytes to GB

            row.labels = row.local_data.count_labels()  # Update labels with counts
            return row

        self.devices_df = self.devices_df.apply(assign_data, axis=1)
        self.labels = labels[np.concatenate(device_indices)]
        return self.devices_df

    def build_feature_matrix(self):
        """Create a feature matrix for the devices with increased weight on label diversity."""
        # Extract label counts
        label_counts = np.array([
            np.bincount(data.labels.astype(int), minlength=10) if len(data.labels) > 0 else np.zeros(10)
            for data in self.devices_df['local_data']
        ])
        
        # Normalize label counts
        label_counts_normalized = label_counts / label_counts.sum(axis=1, keepdims=True, where=label_counts.sum(axis=1, keepdims=True) != 0)
        label_counts_normalized = np.nan_to_num(label_counts_normalized)  # Handle divisions by zero

        # Assign higher weight to label diversity and normalized counts
        weight_label_counts = 0.5  # Increase weight for label counts
        weighted_labels = weight_label_counts * label_counts_normalized

        # Extract other characteristics
        cpu_values = self.devices_df['cpu_power'].values.reshape(-1, 1)
        bandwidth_values = self.devices_df['bandwidth'].values.reshape(-1, 1)
        local_disk_values = self.devices_df['local_storage'].values.reshape(-1, 1)

        # Normalize other characteristics
        scaler = MinMaxScaler()
        cpu_values_normalized = scaler.fit_transform(cpu_values)
        bandwidth_values_normalized = scaler.fit_transform(bandwidth_values)
        local_disk_values_normalized = scaler.fit_transform(local_disk_values)

        # Assign lower weights to other characteristics
        weight_cpu = 0.1
        weight_bandwidth = 0.2
        weight_local_disk = 0.2

        # Combine all features
        features = np.hstack([
            weighted_labels,
            weight_cpu * cpu_values_normalized,
            weight_bandwidth * bandwidth_values_normalized,
            weight_local_disk * local_disk_values_normalized
        ])

        # Standardize the final feature matrix
        features = StandardScaler().fit_transform(features)

        return features

    def build_device_graph(self):
        """Construct a device graph using K-Nearest Neighbors (KNN)."""

        k = self.num_edge_servers
        # Obtain the feature matrix of the devices
        features = self.build_feature_matrix()
        
        # Fit the KNN model to the features
        nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='auto').fit(features)
        distances, indices = nbrs.kneighbors(features)
        
        # Initialize the list to store edge indices
        edge_index = []
        
        # Build the edge indices
        for i in range(indices.shape[0]):
            for j in indices[i]:
                if i != j:  # Exclude self-loops
                    edge_index.append([i, j])
        
        # Convert to tensor and ensure correct shape
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        return edge_index

    def gnn_clustering(self):
        """Perform GNN-based clustering on devices using GAT with Residual Connections."""
        # Build feature matrix
        features = self.build_feature_matrix()
        features_tensor = torch.tensor(features, dtype=torch.float)

        k = self.cluster_nums
        
        # Perform GNN-KMeans clustering
        print("\nPerforming GNN-KMeans clustering using GNN embeddings...")

        start_time_gnn = time.time()
        embeddings = self.get_gnn_embeddings(features_tensor)
        kmeans_gnn = KMeans(n_clusters=k, random_state=42)
        labels_gnn = kmeans_gnn.fit_predict(embeddings)
        self.devices_df['cluster'] = labels_gnn
        end_time_gnn = time.time()
        exec_time_gnn = end_time_gnn - start_time_gnn

        # Compute evaluation metrics for GNN-KMeans
        silhouette_gnn = silhouette_score(embeddings, labels_gnn)
        calinski_gnn = calinski_harabasz_score(embeddings, labels_gnn)
        davies_gnn = davies_bouldin_score(embeddings, labels_gnn)
        energy_gnn = self.calculate_energy_consumption(
            num_nodes=features_tensor.shape[0],
            num_features=features_tensor.shape[1],
            num_edges=self.build_device_graph().size(1) // 2,
            num_layers=2,
            epochs=200
        )

        print(f"GNN-KMeans Clustering Metrics:")
        print(f"  Silhouette Score: {silhouette_gnn:.4f}")
        print(f"  Calinski-Harabasz Index: {calinski_gnn:.4f}")
        print(f"  Davies-Bouldin Index: {davies_gnn:.4f}")
        print(f"  Execution Time: {exec_time_gnn:.4f} seconds")
        print(f"  Energy Consumption: {energy_gnn:.4f} J")

        # Store metrics for comparison in a DataFrame
        clustering_metrics_df = pd.DataFrame({
            'Metric': ['Silhouette Score', 'Calinski-Harabasz Index', 'Davies-Bouldin Index', 'Execution Time (s)', 'Energy Consumption (J)'],
            'GNN-KMeans': [silhouette_gnn, calinski_gnn, davies_gnn, exec_time_gnn, energy_gnn]
        })

        # Save metrics to JSON file
        clustering_metrics_file = self.metrics_file
        clustering_metrics_df.to_json(clustering_metrics_file, orient='split', indent=4)

        print(f"\nClustering metrics saved to '{clustering_metrics_file}'")

    def get_gnn_embeddings(self, features):
        """Train the GNN model and return embeddings."""
        # Build device graph
        edge_index = self.build_device_graph()

        # Prepare data
        data = Data(x=features, edge_index=edge_index)

        # Define the GAT Autoencoder with residual connections
        class GATAutoencoder(torch.nn.Module):
            def __init__(self, in_channels, hidden_channels):
                super(GATAutoencoder, self).__init__()
                self.encoder = GATConv(in_channels, hidden_channels, heads=2, concat=False)
                self.decoder = GATConv(hidden_channels, in_channels, heads=2, concat=False)
                self.relu = torch.nn.ReLU()
                self.res_proj_enc = None
                if in_channels != hidden_channels:
                    self.res_proj_enc = torch.nn.Linear(in_channels, hidden_channels)
                self.res_proj_dec = None
                if hidden_channels != in_channels:
                    self.res_proj_dec = torch.nn.Linear(hidden_channels, in_channels)

            def forward(self, x, edge_index):
                x_res_enc = x
                if self.res_proj_enc is not None:
                    x_res_enc = self.res_proj_enc(x)
                x_enc = self.encoder(x, edge_index)
                x_enc = self.relu(x_enc)
                x_enc = x_enc + x_res_enc  # Residual connection

                x_res_dec = x
                if self.res_proj_dec is not None:
                    x_res_dec = self.res_proj_dec(x_enc)
                x_dec = self.decoder(x_enc, edge_index)
                x_dec = self.relu(x_dec)
                x_dec = x_dec + x_res_dec  # Residual connection

                return x_dec, x_enc  # Return reconstructed input and embeddings

        # Instantiate the model
        in_channels = features.shape[1]
        hidden_channels = 16  # You can adjust this value
        model = GATAutoencoder(in_channels=in_channels, hidden_channels=hidden_channels)

        # Training setup
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        model.train()

        # Training loop
        num_epochs = 200
        for epoch in range(num_epochs):
            optimizer.zero_grad()
            reconstructed_x, embeddings = model(data.x, data.edge_index)
            loss = F.mse_loss(reconstructed_x, data.x)  # Reconstruction loss
            loss.backward()
            optimizer.step()

            # Optionally print training progress
            # if (epoch + 1) % 10 == 0:
            #     print(f"GNN Training Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

        # Extract embeddings
        model.eval()
        _, embeddings = model(data.x, data.edge_index)
        embeddings = embeddings.detach().numpy()

        # Calculate energy consumption
        num_nodes = features.shape[0]
        num_features = features.shape[1]
        num_edges = edge_index.shape[1] // 2  # Undirected edges (divided by 2)
        num_layers = 2  # Number of GNN layers
        self.gnn_energy_consumption = self.calculate_energy_consumption(
            num_nodes=num_nodes,
            num_features=num_features,
            num_edges=num_edges,
            num_layers=num_layers,
            epochs=200,  # Number of epochs
            energy_per_flop=1e-9  # Joules per FLOP
        )
        print(f"Total energy consumption for GNN clustering: {self.gnn_energy_consumption:.6f} J")

        return embeddings

    def calculate_energy_consumption(self, num_nodes, num_features, num_edges, num_layers, epochs, energy_per_flop=1e-9):
        """
        Calculate the energy consumption of a GNN during training.

        Parameters:
            num_nodes (int): Number of nodes in the graph.
            num_features (int): Number of features per node.
            num_edges (int): Number of edges in the graph.
            num_layers (int): Number of GNN layers.
            epochs (int): Number of training epochs.
            energy_per_flop (float): Energy consumed per FLOP in joules (default is 1e-9 J).

        Returns:
            float: Total energy consumption in joules.
        """
        # Forward pass FLOPs per layer (approximation for GAT)
        forward_flops_per_layer = 2 * num_edges * num_features  # Edge-based operations (attention and aggregation)
        forward_flops_per_layer += num_nodes * num_features**2  # Node feature transformation (matmul)

        # Backward pass FLOPs per layer (usually ~2x forward FLOPs)
        backward_flops_per_layer = 2 * forward_flops_per_layer

        # Total FLOPs per epoch
        flops_per_epoch = num_layers * (forward_flops_per_layer + backward_flops_per_layer)

        # Total energy consumption
        total_energy = epochs * flops_per_epoch * energy_per_flop

        return total_energy

    def compare_clustering_methods(self):
        """Compare GNN-KMeans clustering with simple KMeans clustering."""
        # Build feature matrix
        features = self.build_feature_matrix()
        features_tensor = torch.tensor(features, dtype=torch.float)

        # Set the number of clusters to the number of edge servers
        k = self.cluster_nums  # K is fixed to the number of edge servers

        # Perform simple KMeans clustering
        print("\nPerforming simple KMeans clustering on original features...")
        start_time_simple = time.time()
        kmeans_simple = KMeans(n_clusters=k, random_state=42)
        labels_simple = kmeans_simple.fit_predict(features)
        end_time_simple = time.time()
        exec_time_simple = end_time_simple - start_time_simple

        # Compute evaluation metrics for simple KMeans
        silhouette_simple = silhouette_score(features, labels_simple)
        calinski_simple = calinski_harabasz_score(features, labels_simple)
        davies_simple = davies_bouldin_score(features, labels_simple)

        print(f"Simple KMeans Clustering Metrics:")
        print(f"  Silhouette Score: {silhouette_simple:.4f}")
        print(f"  Calinski-Harabasz Index: {calinski_simple:.4f}")
        print(f"  Davies-Bouldin Index: {davies_simple:.4f}")
        print(f"  Execution Time: {exec_time_simple:.4f} seconds")

        # Perform GNN-KMeans clustering
        print("\nPerforming GNN-KMeans clustering using GNN embeddings...")
        start_time_gnn = time.time()
        embeddings = self.get_gnn_embeddings(features_tensor)
        kmeans_gnn = KMeans(n_clusters=k, random_state=42)
        labels_gnn = kmeans_gnn.fit_predict(embeddings)
        end_time_gnn = time.time()
        exec_time_gnn = end_time_gnn - start_time_gnn

        # Compute evaluation metrics for GNN-KMeans
        silhouette_gnn = silhouette_score(embeddings, labels_gnn)
        calinski_gnn = calinski_harabasz_score(embeddings, labels_gnn)
        davies_gnn = davies_bouldin_score(embeddings, labels_gnn)
        energy_gnn = self.calculate_energy_consumption(
            num_nodes=features_tensor.shape[0],
            num_features=features_tensor.shape[1],
            num_edges=self.build_device_graph().size(1) // 2,
            num_layers=2,
            epochs=200
        )

        print(f"GNN-KMeans Clustering Metrics:")
        print(f"  Silhouette Score: {silhouette_gnn:.4f}")
        print(f"  Calinski-Harabasz Index: {calinski_gnn:.4f}")
        print(f"  Davies-Bouldin Index: {davies_gnn:.4f}")
        print(f"  Execution Time: {exec_time_gnn:.4f} seconds")
        print(f"  Energy Consumption: {energy_gnn:.4f} J")

        # Store metrics for comparison in a DataFrame
        clustering_metrics_df = pd.DataFrame({
            'Metric': ['Silhouette Score', 'Calinski-Harabasz Index', 'Davies-Bouldin Index', 'Execution Time (s)', 'Energy Consumption (J)'],
            'GNN-KMeans': [silhouette_gnn, calinski_gnn, davies_gnn, exec_time_gnn, energy_gnn]
        })

        # Save metrics to JSON file
        clustering_metrics_file = "clustering_metrics.json"
        clustering_metrics_df.to_json(clustering_metrics_file, orient='split', indent=4)

        print(f"\nClustering metrics saved to '{clustering_metrics_file}'")

    def distribute_data(self):
        """Main method to distribute data among devices and cluster them."""
        # Load data
        train_images_original, train_labels_original, train_images_augmented, train_labels_augmented, test_images, test_labels = self.load_data()

        # Combine original and augmented data
        train_images = np.concatenate((train_images_original, train_images_augmented), axis=0)
        train_labels = np.concatenate((train_labels_original, train_labels_augmented), axis=0)

        # Split data among devices
        self.devices_df = self.split_dataset_among_devices(train_images, train_labels)

        return test_images, test_labels

    def clustering_devices(self):
        """Main method to distribute data among devices and cluster them."""
        # Perform clustering
        clustering_start_time = time.time()
        self.gnn_clustering()
        clustering_stop_time = time.time()
        print(f'Clustering Part Took {abs(clustering_start_time - clustering_stop_time)} seconds to process.')

        return self.devices_df

# Hybrid Data Redistribution

In [None]:
class HybridDataRedistributor:
    def __init__(self, devices_df, dataset_name="mnist", metrics_file='test.json'):
        self.devices_df = devices_df.reset_index(drop=True)
        self.dataset_name = dataset_name.lower()
        self.metrics = {}  # Dictionary to store the results
        self.metrics_file = metrics_file.replace('.json', '_redistribution.json')

    def calculate_label_distribution(self):
        """Calculate the label distribution for each device."""
        label_distribution = {}
        total_samples = sum(len(row.local_data) for _, row in self.devices_df.iterrows())

        for _, row in self.devices_df.iterrows():
            device_label_counts = row.local_data.count_labels()
            total_device_samples = sum(device_label_counts.values())
            label_distribution[row.device_id] = {
                label: count / total_device_samples if total_device_samples > 0 else 0
                for label, count in device_label_counts.items()
            }

        return label_distribution

    def calculate_global_distribution(self):
        """Calculate the global label distribution across all devices."""
        total_label_counts = Counter()
        for _, row in self.devices_df.iterrows():
            total_label_counts.update(row.local_data.count_labels())

        total_samples = sum(total_label_counts.values())
        global_distribution = {
            label: count / total_samples if total_samples > 0 else 0
            for label, count in total_label_counts.items()
        }

        return global_distribution

    def calculate_kl_divergence(self, device_distributions, global_distribution):
        """Calculate KL divergence for each device."""
        kl_divergences = {}

        for device_id, device_dist in device_distributions.items():
            kl_divergence = 0
            for label, p in device_dist.items():
                q = global_distribution.get(label, 1e-10)  # Avoid log(0)
                if p > 0:
                    kl_divergence += p * np.log(p / q)

            kl_divergences[device_id] = kl_divergence

        return kl_divergences

    def summarize_data_volume(self):
        """Summarize the data volume per device."""
        data_volumes = {
            row.device_id: len(row.local_data)
            for _, row in self.devices_df.iterrows()
        }
        return data_volumes

    def summarize_label_presence_by_cluster(self):
        """Summarize and sort the presence of labels in each cluster."""
        label_presence = {}

        for cluster_id in self.devices_df['cluster'].unique():
            cluster_devices = self.devices_df[self.devices_df['cluster'] == cluster_id]
            combined_labels = Counter()

            for labels_dict in cluster_devices['labels']:
                if isinstance(labels_dict, dict):  # Ensure it's a dictionary
                    combined_labels.update(labels_dict)

            # Sort the labels in ascending order
            sorted_labels = dict(sorted(combined_labels.items()))
            label_presence[cluster_id] = sorted_labels

        return label_presence

    def assign_labels_by_density(self):
        """Assign labels to clusters based on density (starting from the highest density)."""
        # Calculate total occurrences of each label across all devices
        total_label_counts = self.calculate_total_label_counts()

        # Sort labels by their density in descending order
        sorted_labels = sorted(total_label_counts.items(), key=lambda x: x[1], reverse=True)
        
        # Initialize the cluster-to-label mapping
        num_clusters = len(self.devices_df['cluster'].unique())
        cluster_labels = {i: [] for i in range(num_clusters)}

        # Distribute labels to clusters
        for label, _ in sorted_labels:
            # Find the cluster with the fewest total assigned labels
            cluster_id = min(cluster_labels.keys(), key=lambda k: len(cluster_labels[k]))
            cluster_labels[cluster_id].append(label)

        print(f"Assigned Labels for Clusters (by density): {cluster_labels}")
        return cluster_labels

    def print_cluster_summary(self, label_presence):
        """Print the summary of label presence in clusters."""
        print("\nCluster Summary:\n" + "-" * 30)

        for cluster_id, labels in sorted(label_presence.items()):
            print(f"Cluster {cluster_id}:")
            for label, count in labels.items():
                print(f"  {label}: {count}")
            print("-" * 30)

    def hybrid_data_redistribution(self, percentage_threshold=0.3):
        """Perform data redistribution to balance labels across clusters."""
        cluster_labels = self.assign_labels_by_density()
        total_label_counts = self.calculate_total_label_counts()
        print(f"Total label counts across all devices: {total_label_counts}\n")

        device_data_info = {
            row.device_id: {
                'local_data': row.local_data,
                'cluster': row.cluster,
                'label_counts': row.local_data.count_labels(),
                'energy_usage': row.energy_usage,
                'bandwidth_usage': row.bandwidth_usage,
                'bandwidth': row.bandwidth,
                'cpu_power': row.cpu_power,
            } for _, row in self.devices_df.iterrows()
        }

        total_time_delay = 0
        total_energy_consumption = 0
        total_bandwidth_usage = 0
        image_size_in_bytes = self.get_image_size_in_bytes()

        for cluster_id, labels in cluster_labels.items():
            print(f"\nProcessing Cluster {cluster_id} with assigned labels: {labels}")

            cluster_devices = self.devices_df[self.devices_df['cluster'] == cluster_id]
            combined_label_counts = Counter()
            for device in cluster_devices['local_data']:
                combined_label_counts.update(device.count_labels())

            for label in labels:
                current_label_count = combined_label_counts.get(label, 0)
                total_label_count_in_dataset = total_label_counts.get(label, 1)
                label_percentage = (current_label_count / total_label_count_in_dataset) if total_label_count_in_dataset > 0 else 0

                if label_percentage < percentage_threshold:
                    shortage = int((percentage_threshold * total_label_count_in_dataset) - current_label_count)
                    print(f"Cluster {cluster_id} needs {shortage} additional samples of label '{label}'.")

                    donor_devices = [
                        (dev_id, info) for dev_id, info in device_data_info.items()
                        if info['label_counts'].get(label, 0) > 0 and info['cluster'] != cluster_id
                    ]
                    donor_devices.sort(key=lambda x: (x[1]['label_counts'].get(label, 0), x[1]['bandwidth']), reverse=True)

                    for donor_device_id, donor_info in donor_devices:
                        available_samples = donor_info['label_counts'].get(label, 0)
                        request_count = min(shortage, available_samples)

                        if request_count > 0:
                            transfer_time, energy_consumed, bandwidth_used = self.simulate_data_transfer(
                                donor_device_id, cluster_devices, request_count, image_size_in_bytes,
                                donor_info['bandwidth'], 1e-9, device_data_info
                            )

                            total_time_delay += transfer_time
                            total_energy_consumption += energy_consumed
                            total_bandwidth_usage += bandwidth_used

                            self.update_devices_after_transfer(
                                donor_device_id, donor_info['local_data'], cluster_devices, label, request_count,
                                image_size_in_bytes, device_data_info
                            )

                            shortage -= request_count
                            if shortage <= 0:
                                break

        print(f"\nTotal Time Delay: {total_time_delay:.2f} seconds")
        print(f"Total Energy Consumption: {total_energy_consumption:.6f} joules")
        print(f"Total Bandwidth Usage: {total_bandwidth_usage / (1024 ** 2):.2f} MB")

        for device_id, info in device_data_info.items():
            idx = self.devices_df.index[self.devices_df['device_id'] == device_id][0]
            self.devices_df.at[idx, 'energy_usage'] = info['energy_usage']
            self.devices_df.at[idx, 'bandwidth_usage'] = info['bandwidth_usage']
            self.devices_df.at[idx, 'local_data'] = info['local_data']
            self.devices_df.at[idx, 'labels'] = info['local_data'].count_labels()

        return total_bandwidth_usage, total_energy_consumption, total_time_delay

    def calculate_total_label_counts(self):
        """Calculate the total label counts across all devices in devices_df."""
        total_label_counts = Counter()
        for _, row in self.devices_df.iterrows():
            device_labels = row.local_data.count_labels()
            total_label_counts.update(device_labels)
        return total_label_counts

    def get_image_size_in_bytes(self):
        """Get the image size in bytes based on the dataset."""
        if self.dataset_name in ["mnist", "fashion_mnist"]:
            return 28 * 28
        else:
            return 32 * 32 * 3

    def simulate_data_transfer(self, donor_device_id, cluster_devices, request_count, image_size_in_bytes,
                               donor_bandwidth, energy_per_byte, device_data_info):
        """
        Simulate data transfer and calculate time delay, energy consumption, and bandwidth usage,
        incorporating CPU power and bandwidth limitations.
        """
        total_data_size = request_count * image_size_in_bytes  # Total data size in bytes

        # Donor device CPU power and bandwidth
        donor_info = device_data_info[donor_device_id]
        donor_cpu_power = donor_info['cpu_power']
        donor_bandwidth = donor_info['bandwidth']  # Bandwidth in Mbps

        # Recipient devices' CPU power and bandwidth
        recipient_bandwidths = cluster_devices['bandwidth'].values
        recipient_cpu_powers = cluster_devices['cpu_power'].values

        # Effective bandwidth: consider the lowest bandwidth between donor and recipients
        effective_bandwidth = min(donor_bandwidth, recipient_bandwidths.min())  # in Mbps

        # Compute transfer time based on effective bandwidth
        transfer_time = total_data_size / (effective_bandwidth * (1024 ** 2) / 8)  # Convert Mbps to bytes/sec

        # Factor in the CPU processing time for both donor and recipients
        # Simplified: Processing time = Data size / CPU power
        processing_time_donor = total_data_size / (donor_cpu_power * 1e9)  # CPU power in GHz
        processing_time_recipients = total_data_size / (recipient_cpu_powers.min() * 1e9)

        # Total transfer time includes network transfer and processing delays
        total_transfer_time = transfer_time + processing_time_donor + processing_time_recipients

        # Energy consumption for data transfer
        energy_consumed_transfer = total_data_size * energy_per_byte  # Energy per byte transferred
        energy_consumed_processing_donor = processing_time_donor * donor_cpu_power * 1e9 * energy_per_byte
        energy_consumed_processing_recipients = (
            processing_time_recipients * recipient_cpu_powers.mean() * 1e9 * energy_per_byte
        )

        total_energy_consumption = (
            energy_consumed_transfer + energy_consumed_processing_donor + energy_consumed_processing_recipients
        )

        # Bandwidth usage
        total_bandwidth_used = total_data_size  # Total bytes transferred

        # Update donor metrics
        donor_info['energy_usage'] += (energy_consumed_transfer / 2) + energy_consumed_processing_donor
        donor_info['bandwidth_usage'] += total_bandwidth_used / 2  # Half bandwidth usage for donor

        # Update recipient devices' metrics
        num_recipients = len(cluster_devices)
        energy_per_recipient = (energy_consumed_transfer / 2) / num_recipients
        bandwidth_per_recipient = (total_bandwidth_used / 2) / num_recipients

        for _, device_row in cluster_devices.iterrows():
            device_id = device_row.device_id
            recipient_info = device_data_info[device_id]
            recipient_info['energy_usage'] += energy_per_recipient + energy_consumed_processing_recipients / num_recipients
            recipient_info['bandwidth_usage'] += bandwidth_per_recipient

        return total_transfer_time, total_energy_consumption, total_bandwidth_used

    def update_devices_after_transfer(self, donor_device_id, donor_device, cluster_devices, label, request_count,
                                      image_size_in_bytes, device_data_info):
        """Update devices after data transfer."""
        # Get the images and labels to transfer
        indices = np.where(donor_device.labels == label)[0][:request_count]
        images_to_transfer = donor_device.images[indices]
        labels_to_transfer = donor_device.labels[indices]

        # Distribute the samples to the devices in the cluster
        num_devices = len(cluster_devices)
        samples_per_device = request_count // num_devices
        remainder = request_count % num_devices

        start_idx = 0
        for idx, (_, device_row) in enumerate(cluster_devices.iterrows()):
            device = device_row.local_data
            device_id = device_row.device_id
            device_samples = samples_per_device + (1 if idx < remainder else 0)

            if device_samples > 0:
                end_idx = start_idx + device_samples
                # Add samples to the recipient device
                device.add_samples(images_to_transfer[start_idx:end_idx],
                                   labels_to_transfer[start_idx:end_idx])

                # Update device data info
                device_data_info[device_id]['local_data'] = device
                device_data_info[device_id]['label_counts'][label] += device_samples

                start_idx = end_idx

        # Remove samples from the donor device
        donor_device.remove_samples(label, request_count)
        device_data_info[donor_device_id]['local_data'] = donor_device
        device_data_info[donor_device_id]['label_counts'][label] -= request_count

    def save_metrics_to_json(self):
        """Save the collected metrics to a JSON file."""
        def convert_keys_to_serializable(data):
            """Recursively convert keys in dictionaries to serializable types."""
            if isinstance(data, dict):
                return {str(key): convert_keys_to_serializable(value) for key, value in data.items()}
            elif isinstance(data, list):
                return [convert_keys_to_serializable(element) for element in data]
            else:
                return data

        # Convert the metrics dictionary to ensure all keys are serializable
        serializable_metrics = convert_keys_to_serializable(self.metrics)        
        with open(self.metrics_file, 'w') as f:
            json.dump(serializable_metrics, f, indent=4)

    def redistribute_data(self, percentage_threshold=0.3):
        """Main method to perform hybrid data redistribution."""
        # Show the local datasets before data redistribution
        label_presence = self.summarize_label_presence_by_cluster()
        self.print_cluster_summary(label_presence)

        print("\nLocal datasets before hybrid data redistribution:")
        print(self.devices_df[['device_id', 'cpu_power', 'bandwidth', 'local_storage', 'disk_usage', 'cluster',
                               'energy_usage', 'bandwidth_usage']])
        
        # Calculate metrics before redistribution
        label_distribution_before = self.calculate_label_distribution()
        global_distribution = self.calculate_global_distribution()
        kl_divergence_before = self.calculate_kl_divergence(label_distribution_before, global_distribution)
        data_volume_before = self.summarize_data_volume()

        # Store pre-redistribution metrics
        self.metrics['before'] = {
            'label_distribution': label_distribution_before,
            'kl_divergence': kl_divergence_before,
            'data_volume': data_volume_before,
        }        

        # Perform hybrid data redistribution
        total_bandwidth_usage, total_energy_consumption, total_time_delay = self.hybrid_data_redistribution(percentage_threshold=percentage_threshold)

        # Calculate metrics after redistribution
        label_distribution_after = self.calculate_label_distribution()
        kl_divergence_after = self.calculate_kl_divergence(label_distribution_after, global_distribution)
        data_volume_after = self.summarize_data_volume()

        # Store post-redistribution metrics
        self.metrics['after'] = {
            'label_distribution': label_distribution_after,
            'kl_divergence': kl_divergence_after,
            'data_volume': data_volume_after,
            'total_data_transferred': total_bandwidth_usage,
        }

        # Calculate resource impact
        self.metrics['impact'] = {
            'energy_consumption': self.devices_df['energy_usage'].sum(),
            'bandwidth_usage': self.devices_df['bandwidth_usage'].sum(),
        }

        # Save metrics to JSON
        self.save_metrics_to_json()

        # Show the local datasets after data redistribution
        label_presence = self.summarize_label_presence_by_cluster()
        self.print_cluster_summary(label_presence)

        print("\nLocal datasets after hybrid data redistribution:")
        print(self.devices_df[['device_id', 'cpu_power', 'bandwidth', 'local_storage', 'disk_usage', 'cluster',
                               'energy_usage', 'bandwidth_usage']])

        return self.devices_df, label_presence 

# DRL Assignment using Proximal Policy Optimization algorithm (PPO)

In [None]:
from stable_baselines3.common.callbacks import BaseCallback
from torch.optim.lr_scheduler import StepLR
import torch
import logging

# Define a callback for logging and learning rate scheduling

class CustomCallback(BaseCallback):
    """
    A custom callback for logging and learning rate scheduling.
    """
    def __init__(self, verbose=0):
        super(CustomCallback, self).__init__(verbose)
        self.scheduler = None

    def _on_training_start(self) -> None:
        """
        Called before the first rollout starts.
        Initialize the learning rate scheduler here since `self.model` is now available.
        """
        self.scheduler = StepLR(self.model.policy.optimizer, step_size=1000, gamma=0.95)
        if self.verbose > 0:
            logging.info("Learning rate scheduler initialized.")

    def _on_step(self) -> bool:
        """
        Called at each step after the action is taken.
        Step the scheduler at desired intervals.
        """
        if self.scheduler and self.num_timesteps % 1000 == 0:
            self.scheduler.step()
            if self.verbose > 0:
                current_lr = self.scheduler.get_last_lr()[0]
                logging.info(f"Step {self.num_timesteps}: Learning rate updated to {current_lr:.6f}")
        return True  # Returning False would stop training

    def _on_training_end(self) -> None:
        """
        Called after training ends.
        """
        if self.verbose > 0:
            logging.info("Training completed.")

class ClusterAssignmentEnv(gym.Env):
    def __init__(self, cluster_bandwidth, edge_server_capacities, devices_df):
        super(ClusterAssignmentEnv, self).__init__()
        self.devices_df = devices_df.reset_index(drop=True)
        self.cluster_bandwidth = cluster_bandwidth
        self.edge_server_capacities = edge_server_capacities
        self.num_clusters = len(cluster_bandwidth)
        self.num_servers = len(edge_server_capacities)

        # Define action and observation spaces
        self.action_space = spaces.MultiDiscrete([self.num_servers] * self.num_clusters)
        self.observation_space = spaces.Box(
            low=0,
            high=1,
            shape=(self.num_clusters + self.num_servers * 2,),
            dtype=np.float32
        )

        self.reset()

    def reset(self):
        """Reset the environment and return the initial state."""
        self.state = self._get_state()
        return self.state

    def _get_state(self):
        """Generate a normalized state representation."""
        normalized_bandwidth = self.cluster_bandwidth / (np.max(self.cluster_bandwidth) + 1e-6)
        normalized_capacities = self.edge_server_capacities / (np.max(self.edge_server_capacities) + 1e-6)

        # Calculate current server loads
        current_server_loads = np.zeros(self.num_servers)
        assigned_clusters = self.devices_df['assigned_servers']
        for cluster_idx in range(self.num_clusters):
            server_idx = assigned_clusters[cluster_idx]
            if server_idx >= 0:
                current_server_loads[server_idx] += self.cluster_bandwidth[cluster_idx]

        # Calculate unused capacities
        unused_capacities = self.edge_server_capacities - current_server_loads
        normalized_unused_capacities = unused_capacities / (np.max(self.edge_server_capacities) + 1e-6)

        state = np.concatenate([
            normalized_bandwidth,
            normalized_capacities,
            normalized_unused_capacities
        ]).astype(np.float32)

        return state

    def step(self, action):
        """Execute an action and assign clusters to servers."""
        assignments = dict(zip(range(self.num_clusters), action))
        server_loads = np.zeros(self.num_servers)

        for cluster_idx, server_idx in assignments.items():
            server_loads[server_idx] += self.cluster_bandwidth[cluster_idx]

        reward = self._calculate_reward(server_loads)
        done = True  # Single-step environment
        info = {
            'server_loads': server_loads,
            'cluster_assignments': assignments,
        }
        return self.state, reward, done, info

    def _calculate_reward(self, server_loads):
        """Calculate reward based on server loads."""
        total_capacity = np.sum(self.edge_server_capacities)
        total_load = np.sum(server_loads)

        # Overload penalty: Penalize loads exceeding capacity
        overload = np.maximum(0, server_loads - self.edge_server_capacities)
        overload_penalty = np.sum(overload) / (total_capacity + 1e-6)  # Total overload as a fraction

        # Load variance penalty: Penalize imbalances in server loads
        load_variance = np.var(server_loads)
        load_variance_penalty = load_variance / (total_load + 1e-6)

        # Resource utilization incentive: Encourage maximizing server usage
        resource_utilization = total_load / total_capacity

        # Further Optimized Reward calculation with better balance
        reward = (
            25 * resource_utilization
            - 3 * overload_penalty
            - 2 * load_variance_penalty
        )

        # Ensure non-negative reward
        reward = max(reward, 0.0)

        return reward

    def evaluate(self, assignments):
        """Evaluate a given set of assignments and return metrics."""
        server_loads = np.zeros(self.num_servers)
        for cluster_idx, server_idx in assignments.items():
            if server_idx >= 0:
                server_loads[server_idx] += self.cluster_bandwidth[cluster_idx]

        overload = np.maximum(0, server_loads - self.edge_server_capacities)
        overload_penalty = np.sum(overload) / (np.sum(self.edge_server_capacities) + 1e-6)

        load_variance = np.var(server_loads)
        load_variance_penalty = load_variance / (np.sum(server_loads) + 1e-6)

        resource_utilization = np.sum(server_loads) / np.sum(self.edge_server_capacities)

        # Optimized Reward calculation with adjusted weights
        reward = (
            10 * resource_utilization
            - 5 * overload_penalty
            - 3 * load_variance_penalty
        )

        # Ensure non-negative reward
        # reward = max(reward, 0.0)

        metrics = {
            'overload_penalty': overload_penalty,
            'load_variance_penalty': load_variance_penalty,
            'resource_utilization': resource_utilization,
            'server_loads': server_loads.tolist()
        }

        return metrics

def train_cluster_assignment_agent(cluster_bandwidth, edge_server_capacities, devices_df, timesteps=10000, metrics_file='metrics.json'):
    """Train a PPO-based agent for the cluster assignment problem."""
    if 'assigned_servers' not in devices_df.columns:
        devices_df['assigned_servers'] = -1

    cluster_env = ClusterAssignmentEnv(cluster_bandwidth, edge_server_capacities, devices_df)

    cluster_model = PPO(
        "MlpPolicy",
        cluster_env,
        verbose=1,
        learning_rate=3e-4,
        ent_coef=0.01,        # Encourage exploration
        n_steps=2048,
        batch_size=64,
        clip_range=0.1,       # Stability in updates
        max_grad_norm=0.5,    # Prevent exploding gradients
        gae_lambda=0.95,      # GAE parameter
        vf_coef=0.5,          # Value function coefficient
        tensorboard_log="./cluster_tensorboard/"
    )

    # Configure logging
    logging.basicConfig(filename='cluster_training.log', level=logging.INFO)

    # Initialize the custom callback
    custom_callback = CustomCallback(verbose=1)  # Set verbose=1 for detailed logging

    print(f"Starting cluster assignment training for {timesteps} timesteps.")
    cluster_model.learn(total_timesteps=timesteps, callback=custom_callback)

    # Save the model
    model_filename = metrics_file.replace('.json', '_cluster_assignment_agent')
    cluster_model.save(model_filename)
    print(f"Cluster Assignment Agent trained and saved to {model_filename}.")

    # Evaluate the trained model
    state = cluster_env.reset()
    action, _ = cluster_model.predict(state)
    evaluation_state, evaluation_reward, _, evaluation_info = cluster_env.step(action)
    print(f"Evaluation Reward: {evaluation_reward}")
    print(f"Evaluation Info: {evaluation_info}")

    # **Fix: Convert action array to assignments dict before evaluation**
    assignments = dict(zip(range(cluster_env.num_clusters), action))
    evaluation_metrics = cluster_env.evaluate(assignments)
    print(f"Evaluation Metrics: {evaluation_metrics}")

    # Optionally, log evaluation metrics
    with open(metrics_file.replace('.json', '_cluster_evaluation.json'), 'w') as f:
        json.dump(evaluation_metrics, f, indent=4)

    return cluster_model

class DeviceSchedulingEnv(gym.Env):
    def __init__(self, devices_df, cluster_to_server_map):
        super(DeviceSchedulingEnv, self).__init__()
        self.devices_df = devices_df.reset_index(drop=True)
        self.cluster_to_server_map = cluster_to_server_map

        # Number of devices
        self.num_devices = len(devices_df)

        # Define action and observation spaces
        self.action_space = spaces.MultiBinary(self.num_devices)
        self.observation_space = spaces.Box(
            low=0,
            high=1,
            shape=(self.num_devices * 3,),  # Bandwidth, energy, and diversity scores
            dtype=np.float32,
        )

        # Initialize potential-based reward shaping
        self.previous_potential = 0.0

    def reset(self):
        """Reset the environment and return the initial state."""
        self.state = self._get_state()
        self.previous_potential = self.calculate_potential(self.state)
        return self.state

    def _get_state(self):
        """Generate a normalized state representation."""
        bandwidth_usage = self.devices_df['bandwidth_usage'].values
        energy_usage = self.devices_df['energy_usage'].values

        # Normalize bandwidth and energy usage
        normalized_bandwidth_usage = bandwidth_usage / (np.max(bandwidth_usage) + 1e-6)
        normalized_energy_usage = energy_usage / (np.max(energy_usage) + 1e-6)

        # Generate diversity scores (mocked for now)
        diversity_scores = np.random.random(self.num_devices)

        # Combine into a single state
        state = np.concatenate([normalized_bandwidth_usage, normalized_energy_usage, diversity_scores]).astype(np.float32)
        return state

    def _calculate_intrinsic_reward(self, state, next_state):
        """Calculate intrinsic reward based on state novelty."""
        prediction_error = np.linalg.norm(next_state - self.predict_next_state(state))
        intrinsic_reward = min(prediction_error, 1.0)  # Cap intrinsic rewards
        return intrinsic_reward

    def predict_next_state(self, state):
        """Placeholder for a predictive model to estimate the next state."""
        # Implement a simple prediction or use a trained model
        return state  # For simplicity, assume no change

    def step(self, action):
        """Perform an action and return the next state, reward, and status."""
        selected_devices = np.where(action == 1)[0]

        # Simulate training and calculate performance metrics
        accuracy_improvement, communication_cost, energy_consumption = self._simulate_training(selected_devices)

        # Normalize components
        normalized_accuracy = accuracy_improvement
        normalized_communication = communication_cost / (self.devices_df['bandwidth_usage'].sum() + 1e-6)
        normalized_energy = energy_consumption / (self.devices_df['energy_usage'].sum() + 1e-6)

        # Calculate intrinsic reward
        current_potential = self.calculate_potential(self.state)
        intrinsic_reward = self._calculate_intrinsic_reward(self.state, self._get_state())
        self.previous_potential = current_potential

        # Optimized Reward calculation with intrinsic motivation
        reward = (
            25 * normalized_accuracy
            - 3 * normalized_communication
            - 2 * normalized_energy
            + intrinsic_reward  # Add intrinsic reward
        )
        reward = max(reward, 0.0)  # Ensure reward is non-negative

        self.state = self._get_state()  # Update the state
        done = True  # Single-step environment

        info = {
            'selected_devices': selected_devices.tolist(),
            'accuracy_improvement': accuracy_improvement,
            'communication_cost': communication_cost,
            'energy_consumption': energy_consumption,
            'intrinsic_reward': intrinsic_reward,
        }

        return self.state, reward, done, info

    def calculate_potential(self, state):
        """Calculate potential based on the current state."""
        return np.sum(state)  # Simple potential function

    def _simulate_training(self, selected_devices):
        """Simulate training metrics."""
        num_selected = len(selected_devices)
        accuracy_improvement = num_selected / (self.num_devices + 1e-6)  # Normalize by number of devices
        communication_cost = np.sum(self.devices_df.loc[selected_devices, 'bandwidth_usage'].values)
        energy_consumption = np.sum(self.devices_df.loc[selected_devices, 'energy_usage'].values)

        return accuracy_improvement, communication_cost, energy_consumption

    def evaluate(self, action):
        """Evaluate a given action."""
        selected_devices = np.where(action == 1)[0]
        accuracy_improvement, communication_cost, energy_consumption = self._simulate_training(selected_devices)

        normalized_accuracy = accuracy_improvement
        normalized_communication = communication_cost / (self.devices_df['bandwidth_usage'].sum() + 1e-6)
        normalized_energy = energy_consumption / (self.devices_df['energy_usage'].sum() + 1e-6)

        # Calculate intrinsic reward
        current_potential = self.calculate_potential(self._get_state())
        intrinsic_reward = self._calculate_intrinsic_reward(self.state, self._get_state())

        reward = (
            25 * normalized_accuracy
            - 3 * normalized_communication
            - 2 * normalized_energy
            + intrinsic_reward
        )
        reward = max(reward, 0.0)

        metrics = {
            'accuracy_improvement': accuracy_improvement,
            'communication_cost': communication_cost,
            'energy_consumption': energy_consumption,
            'intrinsic_reward': intrinsic_reward,
            'reward': reward,
        }

        return metrics

def train_device_scheduling_agent(devices_df, timesteps=10000, metrics_file='test.json', cluster_to_server_map=None):
    """
    Train the PPO-based agent for device scheduling.
    """
    if cluster_to_server_map is None:
        raise ValueError("cluster_to_server_map is required for DeviceSchedulingEnv initialization.")

    # Initialize the environment
    env = DeviceSchedulingEnv(devices_df, cluster_to_server_map)

    # Initialize PPO model
    scheduling_model = PPO(
        "MlpPolicy",
        env,
        verbose=1,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        vf_coef=0.5,
        ent_coef=0.02,
        max_grad_norm=0.5
    )

    print(f"Starting device scheduling training for {timesteps} timesteps.")
    scheduling_model.learn(total_timesteps=timesteps)

    # Save the model
    model_filename = metrics_file.replace('.json', '') + "_device_scheduling_agent"
    scheduling_model.save(model_filename)
    print(f"Device Scheduling Agent trained and saved to {model_filename}.")

    # Perform evaluation
    state = env.reset()
    action, _ = scheduling_model.predict(state)
    evaluation_metrics = env.evaluate(action)
    print(f"Final Evaluation Metrics: {evaluation_metrics}")

    # Save evaluation metrics
    with open(metrics_file.replace('.json', '_scheduling_evaluation.json'), 'w') as f:
        import json
        json.dump(evaluation_metrics, f, indent=4)

    return scheduling_model


# Semi-Synchronous Federated Learning

In [None]:
# MODELS
# CIFAR-10
# Models : https://github.com/kuangliu/pytorch-cifar

'''DLA in PyTorch CIFAR-10

Reference:
    Deep Layer Aggregation. https://arxiv.org/abs/1707.06484
'''
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride=1, padding=(kernel_size - 1) // 2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, xs):
        x = torch.cat(xs, 1)
        out = F.relu(self.bn(self.conv(x)))
        return out

class Tree(nn.Module):
    def __init__(self, block, in_channels, out_channels, level=1, stride=1):
        super(Tree, self).__init__()
        self.level = level
        if level == 1:
            self.root = Root(2*out_channels, out_channels)
            self.left_node = block(in_channels, out_channels, stride=stride)
            self.right_node = block(out_channels, out_channels, stride=1)
        else:
            self.root = Root((level+2)*out_channels, out_channels)
            for i in reversed(range(1, level)):
                subtree = Tree(block, in_channels, out_channels,
                               level=i, stride=stride)
                self.__setattr__('level_%d' % i, subtree)
            self.prev_root = block(in_channels, out_channels, stride=stride)
            self.left_node = block(out_channels, out_channels, stride=1)
            self.right_node = block(out_channels, out_channels, stride=1)

    def forward(self, x):
        xs = [self.prev_root(x)] if self.level > 1 else []
        for i in reversed(range(1, self.level)):
            level_i = self.__getattr__('level_%d' % i)
            x = level_i(x)
            xs.append(x)
        x = self.left_node(x)
        xs.append(x)
        x = self.right_node(x)
        xs.append(x)
        out = self.root(xs)
        return out

class DLA(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(DLA, self).__init__()
        self.base = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer1 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )

        self.layer3 = Tree(block,  32,  64, level=1, stride=1)
        self.layer4 = Tree(block,  64, 128, level=2, stride=2)
        self.layer5 = Tree(block, 128, 256, level=2, stride=2)
        self.layer6 = Tree(block, 256, 512, level=1, stride=2)
        self.linear = nn.Linear(512, num_classes)

    def forward(self, x):
        out = self.base(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

############################################

# MNIST & Fashion MNIST

# Define BasicBlock for DLA if not already defined
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
    
        if self.downsample is not None:
            identity = self.downsample(x)
    
        out += identity
        out = self.relu(out)
    
        return out

# Revised Bottleneck
class Bottleneck(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(Bottleneck, self).__init__()
        inter_channels = 4 * growth_rate
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, inter_channels, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(inter_channels)
        self.conv2 = nn.Conv2d(inter_channels, growth_rate, kernel_size=3, padding=1, bias=False)
    
    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        out = torch.cat((x, out), 1)
        return out

# Revised SingleLayer
class SingleLayer(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(SingleLayer, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)
    
    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = torch.cat((x, out), 1)
        return out

# Revised Transition
class Transition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Transition, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
    
    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = F.avg_pool2d(out, 2)
        return out

# Revised DenseNet
class DenseNet(nn.Module):
    def __init__(self, in_channels, growthRate, depth, reduction, nClasses, bottleneck=True):
        super(DenseNet, self).__init__()
    
        nDenseBlocks = (depth - 4) // 3
        if bottleneck:
            nDenseBlocks = nDenseBlocks // 2
    
        nChannels = 2 * growthRate
        self.conv1 = nn.Conv2d(in_channels, nChannels, kernel_size=3, padding=1, bias=False)
        self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
        nChannels += nDenseBlocks * growthRate
        nOutChannels = int(math.floor(nChannels * reduction))
        self.trans1 = Transition(nChannels, nOutChannels)
    
        nChannels = nOutChannels
        self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
        nChannels += nDenseBlocks * growthRate
        nOutChannels = int(math.floor(nChannels * reduction))
        self.trans2 = Transition(nChannels, nOutChannels)
    
        nChannels = nOutChannels
        self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
        nChannels += nDenseBlocks * growthRate
    
        self.bn1 = nn.BatchNorm2d(nChannels)
        self.fc = nn.Linear(nChannels, nClasses)
    
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    
    def _make_dense(self, in_channels, growth_rate, nDenseBlocks, bottleneck):
        layers = []
        for _ in range(int(nDenseBlocks)):
            if bottleneck:
                layers.append(Bottleneck(in_channels, growth_rate))
            else:
                layers.append(SingleLayer(in_channels, growth_rate))
            in_channels += growth_rate
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.trans1(self.dense1(out))
        out = self.trans2(self.dense2(out))
        out = self.dense3(out)
        out = F.relu(self.bn1(out))
        out = F.avg_pool2d(out, kernel_size=out.size()[2:])  # Adaptive pooling to 1x1
        out = torch.flatten(out, 1)  # Flatten all dimensions except batch
        out = self.fc(out)
        out = F.log_softmax(out, dim=1)
        return out


In [None]:
# class FederatedLearningSystem:
#     def __init__(self, devices_df, test_images, test_labels, dataset_name, metrics_file,
#                  cluster_agent, scheduling_agent, cluster_env, scheduling_env, **kwargs):
#         self.devices_df = devices_df.reset_index(drop=True)
#         self.test_images = test_images
#         self.test_labels = test_labels
#         self.dataset_name = dataset_name
#         self.metrics_file = metrics_file
#         self.edge_servers = {}
#         self.accuracies = {
#             'local_epochs': {},       # Store local epoch accuracies
#             'edge_iterations': {},    # Store edge iteration accuracies
#             'global_iterations': []   # Store global iteration accuracies
#         }

#         # Model selection based on dataset name
#         if dataset_name in ['mnist', 'fashion_mnist']:
#             self.model_class = DenseNet  # Assign the class, not an instance
#             self.model_args = {
#                 'in_channels': 1,
#                 'growthRate': 12,
#                 'depth': 100,
#                 'reduction': 0.5,
#                 'nClasses': 10,
#                 'bottleneck': True
#             }
#         elif dataset_name == 'cifar10':
#             self.model_class = DLA  # Replace with DLA if needed
#             self.model_args = {
#                 'block': BasicBlock,
#                 'num_classes': 10
#             }
#         else:
#             raise ValueError(f"Unsupported dataset: {dataset_name}")

#         # Additional parameters
#         self.global_iterations = kwargs.get('global_iterations', 5)
#         self.edge_iterations = kwargs.get('edge_iterations', 3)
#         self.local_epochs = kwargs.get('local_epochs', 1)
#         self.input_channels = kwargs.get('input_channels', 1)
#         self.num_classes = kwargs.get('num_classes', 10)
#         self.batch_size = kwargs.get('batch_size', 32)

#         # Parameters for Scenario 3
#         self.k_edge = kwargs.get('k_edge', 2)
#         self.m_global = kwargs.get('m_global', 1)
#         self.alpha = kwargs.get('alpha', 0.1)

#         # Initialize parameters for energy and time calculations
#         self.model_size = kwargs.get('model_size', 1.0)                # Size in MB
#         self.computation_energy_rate = kwargs.get('computation_energy_rate', 0.5)   # Energy per second
#         self.communication_energy_rate = kwargs.get('communication_energy_rate', 0.1)  # Energy per MB
#         self.device_latency = kwargs.get('device_latency', 0.1)        # Seconds
#         self.edge_latency = kwargs.get('edge_latency', 0.05)           # Seconds

#         # Dictionaries to store energy and time metrics
#         self.energy_consumption = {
#             'devices': {},
#             'edge_servers': {},
#             'cloud_server': 0.0
#         }
#         self.time_delays = {
#             'devices': {},
#             'edge_servers': {},
#             'cloud_server': 0.0
#         }
#         self.bandwidth_usage = {
#             'device_to_edge': 0.0,
#             'edge_to_cloud': 0.0
#         }

#         # Store agents and environments for reuse
#         self.cluster_agent = cluster_agent
#         self.scheduling_agent = scheduling_agent
#         self.cluster_env = cluster_env
#         self.scheduling_env = scheduling_env

#     def compute_computation_time(self, num_samples, cpu_power):
#         # Simple model: Time = (Number of samples) / (CPU Power)
#         return num_samples / cpu_power

#     def compute_computation_energy(self, computation_time):
#         # Energy = Computation Time * Energy Rate
#         return computation_time * self.computation_energy_rate

#     def compute_communication_time(self, data_size_mb, bandwidth_mbps, latency):
#         # Time = Data Size / Bandwidth + Latency
#         return (data_size_mb / bandwidth_mbps) + latency

#     def compute_communication_energy(self, data_size_mb):
#         # Energy = Data Size * Energy Rate
#         return data_size_mb * self.communication_energy_rate

#     def compute_batch_size(self, memory, cpu_power):
#         base_batch_size = 24
#         memory_factor = memory / 2  # Assume memory ranges from 1 to 4 GB
#         cpu_factor = cpu_power / 1.0  # Assume CPU power ranges from 1.0 to 2.0 GHz
#         batch_size = int(base_batch_size * memory_factor * cpu_factor)
#         return int(round(max(16, min(batch_size, 32))))  # Batch size capped between 16 and 32

#     def mix_models(self, local_state, global_state, alpha):
#         """
#         Gradually mix the global/edge model into the local model using the formula:
#         w_local = alpha * w_global/edge + (1 - alpha) * w_local
#         """
#         mixed_state = {}
#         for key in local_state.keys():
#             mixed_state[key] = alpha * global_state[key] + (1 - alpha) * local_state[key]
#         return mixed_state

#     def layer_wise_update(self, local_state, global_state, shared_layers):
#         """
#         Update only the shared layers in the local model with the global model.
#         """
#         updated_state = deepcopy(local_state)
#         for key in shared_layers:
#             if key in global_state:
#                 updated_state[key] = global_state[key]
#         return updated_state

#     def get_shared_layer_keys(self, model):
#         """
#         Get the keys of the shared layers (e.g., feature extraction layers).
#         """
#         shared_layers = []
#         for name, param in model.named_parameters():
#             if 'classifier' not in name and 'fc' not in name:
#                 shared_layers.append(name)
#         return shared_layers

#     def train_local_model(self, device_id, local_model_state, edge_model_state, global_model_state, train_loader, epochs, edge_iteration, global_iteration, lr=0.01):
#         """
#         Train the local model on a specific device while considering its CPU power and memory.
#         """
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         model = self.model_class(**self.model_args).to(device)

#         # Load the local model state
#         local_state = deepcopy(local_model_state)

#         # Apply gradual mixing with edge/global model periodically
#         if edge_iteration % self.k_edge == 0:
#             # Mix with edge model
#             alpha = min(self.alpha * (edge_iteration + 1), 1.0)  # Increase alpha over time
#             local_state = self.mix_models(local_state, edge_model_state, alpha)

#         if global_iteration % self.m_global == 0:
#             # Mix with global model
#             alpha = min(self.alpha * (global_iteration + 1), 1.0)
#             local_state = self.mix_models(local_state, global_model_state, alpha)

#         # Apply layer-wise updating to only update shared layers
#         shared_layers = self.get_shared_layer_keys(model)
#         local_state = self.layer_wise_update(local_state, local_model_state, shared_layers)

#         model.load_state_dict(local_state)
#         optimizer = optim.Adam(model.parameters(), lr=lr)
#         criterion = nn.CrossEntropyLoss()

#         # Fetch device-specific characteristics
#         device_info = self.devices_df.loc[self.devices_df['device_id'] == device_id].iloc[0]
#         cpu_power = device_info['cpu_power']  # GHz
#         memory = device_info['memory']       # GB

#         # Adjust batch size based on memory
#         adjusted_batch_size = self.compute_batch_size(memory, cpu_power)

#         # Simulate adjusted training time and energy
#         num_samples = len(train_loader.dataset)
#         computation_time_per_epoch = self.compute_computation_time(num_samples, cpu_power)
#         computation_energy_per_epoch = self.compute_computation_energy(computation_time_per_epoch)

#         # Update DataLoader with adjusted batch size
#         train_loader = DataLoader(train_loader.dataset, batch_size=adjusted_batch_size, shuffle=True)

#         # Store energy and time metrics
#         total_training_time = 0
#         total_training_energy = 0

#         # Training loop
#         model.train()
#         for epoch in range(epochs):
#             correct, total = 0, 0

#             # Simulate epoch training considering computation time and energy
#             for images, labels in train_loader:
#                 images, labels = images.to(device), labels.to(device)
#                 optimizer.zero_grad()
#                 outputs = model(images)
#                 loss = criterion(outputs, labels)
#                 loss.backward()
#                 optimizer.step()

#                 # Calculate training accuracy for the current batch
#                 _, predicted = torch.max(outputs, 1)
#                 correct += (predicted == labels).sum().item()
#                 total += labels.size(0)

#             # Simulate time and energy per epoch
#             total_training_time += computation_time_per_epoch
#             total_training_energy += computation_energy_per_epoch

#             # Log accuracy for the epoch
#             epoch_accuracy = 100 * correct / total
#             print(f"Device {device_id} - Epoch {epoch + 1}/{epochs} - Accuracy: {epoch_accuracy:.2f}%")

#             # Store epoch accuracy
#             if device_id not in self.accuracies['local_epochs']:
#                 self.accuracies['local_epochs'][device_id] = []
#             self.accuracies['local_epochs'][device_id].append(epoch_accuracy)

#         # Update total time and energy for the device
#         if device_id not in self.energy_consumption['devices']:
#             self.energy_consumption['devices'][device_id] = 0.0
#         if device_id not in self.time_delays['devices']:
#             self.time_delays['devices'][device_id] = 0.0

#         self.energy_consumption['devices'][device_id] += total_training_energy
#         self.time_delays['devices'][device_id] += total_training_time

#         print(f"Device {device_id} - Total Training Time: {total_training_time:.2f}s")
#         print(f"Device {device_id} - Total Training Energy: {total_training_energy:.2f}J")

#         return device_id, deepcopy(model.state_dict())

#     def aggregate_models(self, models):
#         avg_model = deepcopy(models[0])
#         for key in avg_model.keys():
#             for model in models[1:]:
#                 avg_model[key] += model[key]
#             avg_model[key] = avg_model[key] / len(models)
#         return avg_model

#     def evaluate_model(self, model_state, test_loader):
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         model = self.model_class(**self.model_args).to(device)
#         model.load_state_dict(model_state)
#         model.eval()
#         total, correct = 0, 0

#         with torch.no_grad():
#             for images, labels in test_loader:
#                 images, labels = images.to(device), labels.to(device)
#                 outputs = model(images)
#                 _, predicted = torch.max(outputs, 1)
#                 total += labels.size(0)
#                 correct += (predicted == labels).sum().item()

#         return 100 * correct / total

#     def save_summary_metrics(self, summary_file="metrics_summary.json"):
#         """
#         Consolidate and save key metrics for communication costs, energy consumption, 
#         and training performance.
#         """
#         summary_metrics = {
#             "global_iterations": self.global_iterations,
#             "edge_iterations": self.edge_iterations,
#             "local_epochs": self.local_epochs,
#             "overall_energy_consumption": self.energy_consumption,
#             "overall_time_delays": self.time_delays,
#             "overall_bandwidth_usage": self.bandwidth_usage,
#             "global_model_accuracy": self.accuracies.get("global_iterations", [])
#         }

#         # Save the summary to a JSON file
#         with open(summary_file, 'w') as f:
#             json.dump(summary_metrics, f, indent=4)

#         print(f"Summary metrics saved to {summary_file}")

#     def edge_server_training(self, edge_id, edge_devices, edge_model_state, edge_iteration, global_iteration):
#         """
#         Train devices assigned to the edge server independently and update the local model after each edge iteration.
#         """
#         edge_computation_time = 0.0
#         edge_computation_energy = 0.0

#         communication_costs = []
#         communication_latencies = []
#         communication_energies = []

#         print(f"Edge Server {edge_id}: Starting training")
#         # Edge server's current model state is passed as edge_model_state

#         local_results = []  # Collect local model updates from devices

#         for device_id, train_loader in edge_devices:
#             print(f"Edge Server {edge_id} sends model to Device {device_id} for training.")
#             # Train local model starting from edge_model_state
#             local_result = self.train_local_model(
#                 device_id,
#                 edge_model_state,          # local_model_state
#                 edge_model_state,          # edge_model_state
#                 self.global_model_state,   # global_model_state
#                 train_loader,
#                 self.local_epochs,
#                 edge_iteration,
#                 global_iteration
#             )
#             local_results.append(local_result)

#             # Communication metrics: time and energy from device to edge server
#             device_info = self.devices_df.loc[self.devices_df['device_id'] == device_id].iloc[0]
#             bandwidth = device_info['bandwidth']  # in Mbps
#             communication_time = self.compute_communication_time(
#                 self.model_size,
#                 bandwidth,
#                 self.device_latency
#             )
#             communication_energy = self.compute_communication_energy(self.model_size)

#             communication_costs.append(self.model_size)
#             communication_latencies.append(communication_time)
#             communication_energies.append(communication_energy)

#             # Update device energy and time
#             self.time_delays['devices'][device_id] += communication_time
#             self.energy_consumption['devices'][device_id] += communication_energy

#             # Update total bandwidth usage
#             self.bandwidth_usage['device_to_edge'] += self.model_size

#         # Aggregate models from all devices at the edge server level
#         local_states = [state for _, state in local_results]
#         updated_edge_model_state = self.aggregate_models(local_states)
#         print(f"Edge Server {edge_id}: Aggregated models from devices")

#         # Edge server computation time for aggregation
#         num_models = len(local_states)
#         aggregation_time = num_models * 0.1  # 0.1 seconds per model
#         edge_computation_time += aggregation_time
#         edge_computation_energy += aggregation_time * self.computation_energy_rate

#         # Store edge server energy and time
#         if edge_id not in self.energy_consumption['edge_servers']:
#             self.energy_consumption['edge_servers'][edge_id] = 0.0
#         if edge_id not in self.time_delays['edge_servers']:
#             self.time_delays['edge_servers'][edge_id] = 0.0
#         self.time_delays['edge_servers'][edge_id] += edge_computation_time
#         self.energy_consumption['edge_servers'][edge_id] += edge_computation_energy

#         # Update total bandwidth usage for communication with cloud
#         self.bandwidth_usage['edge_to_cloud'] += self.model_size

#         # Evaluate aggregated edge model
#         edge_accuracy = self.evaluate_model(updated_edge_model_state, self.get_test_loader())
#         print(f"Edge Server {edge_id} - Edge Iteration {edge_iteration + 1} - Edge Model Accuracy: {edge_accuracy:.2f}%")

#         # Record edge_iteration accuracy
#         if global_iteration not in self.accuracies['edge_iterations']:
#             self.accuracies['edge_iterations'][global_iteration] = {}
#         if edge_iteration not in self.accuracies['edge_iterations'][global_iteration]:
#             self.accuracies['edge_iterations'][global_iteration][edge_iteration] = {}
#         self.accuracies['edge_iterations'][global_iteration][edge_iteration][edge_id] = edge_accuracy

#         # Return the updated model state for this edge server
#         return updated_edge_model_state

#     def recalculate_assignments(self):
#         # Re-run cluster assignment agent
#         self.cluster_env.devices_df = self.devices_df  # Update devices_df in the environment
#         cluster_state = self.cluster_env.reset()
#         cluster_action, _ = self.cluster_agent.predict(cluster_state)
#         assignments = dict(zip(range(self.cluster_env.num_clusters), cluster_action))
#         self.devices_df['assigned_servers'] = self.devices_df['cluster'].map(assignments)

#         # Re-run device scheduling agent
#         self.scheduling_env.devices_df = self.devices_df.reset_index(drop=True)  # Update devices_df in the environment
#         scheduling_state = self.scheduling_env.reset()
#         scheduling_action, _ = self.scheduling_agent.predict(scheduling_state)
#         scheduled_devices = [i for i, a in enumerate(scheduling_action) if a == 1]
#         self.devices_df['is_scheduled'] = False
#         self.devices_df.loc[scheduled_devices, 'is_scheduled'] = True

#         # Reset edge_servers
#         self.edge_servers = {}

#         # Recreate edge_servers with updated assignments, initializing model_state from global model
#         for idx, row in self.devices_df.iterrows():
#             if not row['is_scheduled']:
#                 continue  # Skip unscheduled devices

#             local_dataset = row["local_data"]
#             if not isinstance(local_dataset, LocalDataset):
#                 continue

#             images, labels = local_dataset.get_data()
#             if len(images) == 0 or len(labels) == 0:
#                 continue

#             memory = row["memory"]
#             cpu_power = row["cpu_power"]
#             batch_size = self.compute_batch_size(memory, cpu_power)

#             tensor_dataset = TensorDataset(
#                 torch.tensor(images, dtype=torch.float32),
#                 torch.tensor(labels, dtype=torch.long)
#             )
#             loader = DataLoader(tensor_dataset, batch_size=batch_size, shuffle=True)

#             edge_server_id = row["assigned_servers"]
#             if edge_server_id not in self.edge_servers:
#                 self.edge_servers[edge_server_id] = {
#                     'devices': [],
#                     'model_state': deepcopy(self.global_model_state)  # Initialize from global model
#                 }
#             self.edge_servers[edge_server_id]['devices'].append((row["device_id"], loader))

#         # Print updated device distribution across edge servers
#         print("Updated device distribution across edge servers:")
#         for edge_id, edge_info in self.edge_servers.items():
#             num_devices = len(edge_info['devices'])
#             print(f"Edge Server {edge_id}: {num_devices} devices")

#     def federated_learning(self, global_model_state, max_parallel_edge_servers=2):
#         """
#         Perform semi-synchronous federated learning with independent device training
#         and multiple edge server iterations, ensuring local model updates after each iteration.
#         """
#         cloud_computation_time = 0.0
#         cloud_computation_energy = 0.0

#         for global_iteration in range(self.global_iterations):
#             print("-" * 60)
#             print(f"Global Iteration {global_iteration + 1}/{self.global_iterations}")
#             print("-" * 60)

#             # Recalculate assignments at the beginning of each global iteration
#             self.global_model_state = deepcopy(global_model_state)  # Update global model state
#             self.recalculate_assignments()

#             # Initialize edge server models from the global model
#             for edge_id in self.edge_servers.keys():
#                 # Apply gradual mixing with the global model every m_global iterations
#                 if global_iteration % self.m_global == 0:
#                     alpha = min(self.alpha * (global_iteration + 1), 1.0)
#                     self.edge_servers[edge_id]['model_state'] = self.mix_models(
#                         self.edge_servers[edge_id]['model_state'],
#                         self.global_model_state,
#                         alpha
#                     )

#             for edge_iteration in range(self.edge_iterations):
#                 print(f"  Edge Iteration {edge_iteration + 1}/{self.edge_iterations}")

#                 edge_results = {}

#                 # Use ThreadPoolExecutor to limit concurrent edge server training
#                 with ThreadPoolExecutor(max_workers=max_parallel_edge_servers) as executor:
#                     future_to_edge = {
#                         executor.submit(
#                             self.edge_server_training,
#                             edge_id,
#                             edge_info['devices'],
#                             edge_info['model_state'],
#                             edge_iteration,
#                             global_iteration
#                         ): edge_id
#                         for edge_id, edge_info in self.edge_servers.items()
#                     }

#                     # Collect results as each edge server finishes
#                     for future in as_completed(future_to_edge):
#                         edge_id = future_to_edge[future]
#                         try:
#                             updated_state = future.result()  # Result is updated edge model state
#                             edge_results[edge_id] = updated_state
#                             # Update the edge server's model state
#                             self.edge_servers[edge_id]['model_state'] = updated_state
#                             print(f"  Edge Server {edge_id}: Updated model after Edge Iteration {edge_iteration + 1}")
#                         except Exception as e:
#                             print(f"Edge Server {edge_id} encountered an error: {e}")

#                 # Periodically update edge models with global model
#                 if (edge_iteration + 1) % self.k_edge == 0:
#                     for edge_id in self.edge_servers.keys():
#                         alpha = min(self.alpha * (edge_iteration + 1), 1.0)
#                         self.edge_servers[edge_id]['model_state'] = self.mix_models(
#                             self.edge_servers[edge_id]['model_state'],
#                             self.global_model_state,
#                             alpha
#                         )

#             # Perform global aggregation after all edge iterations
#             global_states = [edge_info['model_state'] for edge_info in self.edge_servers.values()]
#             self.global_model_state = self.aggregate_models(global_states)
#             global_model_state = deepcopy(self.global_model_state)
#             print(f"Global Iteration {global_iteration + 1}: Aggregated all edge models")

#             # Cloud server computation time for aggregation
#             num_edge_models = len(self.edge_servers)
#             aggregation_time = num_edge_models * 0.2  # 0.2 seconds per edge model
#             cloud_computation_time += aggregation_time
#             cloud_computation_energy += aggregation_time * self.computation_energy_rate

#             # Evaluate global model and store accuracy
#             global_accuracy = self.evaluate_model(global_model_state, self.get_test_loader())
#             print(f"Global Iteration {global_iteration + 1}: Global Model Accuracy: {global_accuracy:.2f}%")

#             self.accuracies['global_iterations'].append(global_accuracy)

#             # Periodically write accuracies to file
#             with open(self.metrics_file, 'w') as f:
#                 json.dump(self.accuracies, f, indent=4)

#             # Note: Assignments will be recalculated at the beginning of the next global iteration

#         # Store cloud server energy and time
#         self.energy_consumption['cloud_server'] = cloud_computation_energy
#         self.time_delays['cloud_server'] = cloud_computation_time

#     def main(self, global_iterations, edge_iterations, local_epochs):
#         self.global_iterations = global_iterations
#         self.edge_iterations = edge_iterations
#         self.local_epochs = local_epochs

#         # Initialize global model state
#         global_model = self.model_class(**self.model_args)
#         self.global_model_state = deepcopy(global_model.state_dict())  # Initialize global model state

#         # Run federated learning
#         self.federated_learning(self.global_model_state, max_parallel_edge_servers=1)

#         final_accuracy = self.evaluate_model(self.global_model_state, self.get_test_loader())
#         print(f"Final Global Model Accuracy: {final_accuracy:.2f}%")

#         # Save the accuracies to JSON
#         with open(self.metrics_file.replace('.json', '_accuracies.json'), 'w') as json_file:
#             json.dump(self.accuracies, json_file, indent=4)

#         self.save_summary_metrics(self.metrics_file.replace('.json', '_full_metrics.json'))

#         # After training, consolidate metrics into a single dictionary
#         metrics_summary = {
#             "Energy Consumption": self.energy_consumption,
#             "Time Delays": self.time_delays,
#             "Bandwidth Usage": self.bandwidth_usage,
#         }

#         # Save metrics to a JSON file
#         with open(self.metrics_file.replace('.json', '_summary.json'), 'w') as f:
#             json.dump(metrics_summary, f, indent=4)

#     def get_test_loader(self):
#         test_dataset = TensorDataset(
#             torch.tensor(self.test_images, dtype=torch.float32),
#             torch.tensor(self.test_labels, dtype=torch.long)
#         )
#         test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
#         return test_loader

In [None]:
# class FederatedLearningSystem:
#     def __init__(self, devices_df, test_images, test_labels, dataset_name, metrics_file,
#                  cluster_agent, scheduling_agent, cluster_env, scheduling_env, **kwargs):
#         self.devices_df = devices_df.reset_index(drop=True)
#         self.test_images = test_images
#         self.test_labels = test_labels
#         self.dataset_name = dataset_name
#         self.metrics_file = metrics_file
#         self.edge_servers = {}
#         self.accuracies = {
#             'local_epochs': {},       # Store local epoch accuracies
#             'edge_iterations': {},    # Store edge iteration accuracies
#             'global_iterations': []   # Store global iteration accuracies
#         }

#         # Model selection based on dataset name
#         if dataset_name in ['mnist', 'fashion_mnist']:
#             self.model_class = DenseNet  # Assign the class, not an instance
#             self.model_args = {
#                 'in_channels': 1,
#                 'growthRate': 12,
#                 'depth': 100,
#                 'reduction': 0.5,
#                 'nClasses': 10,
#                 'bottleneck': True
#             }
#         elif dataset_name == 'cifar10':
#             self.model_class = DLA  # Replace with DLA if needed
#             self.model_args = {
#                 'block': BasicBlock,
#                 'num_classes': 10
#             }
#         else:
#             raise ValueError(f"Unsupported dataset: {dataset_name}")

#         # Additional parameters
#         self.global_iterations = kwargs.get('global_iterations', 5)
#         self.edge_iterations = kwargs.get('edge_iterations', 3)
#         self.local_epochs = kwargs.get('local_epochs', 1)
#         self.input_channels = kwargs.get('input_channels', 1)
#         self.num_classes = kwargs.get('num_classes', 10)
#         self.batch_size = kwargs.get('batch_size', 32)

#         # Parameters for Scenario 3
#         self.k_edge = kwargs.get('k_edge', 2)
#         self.m_global = kwargs.get('m_global', 1)
#         self.alpha = kwargs.get('alpha', 0.1)

#         # Parameters for FedProx
#         self.mu = kwargs.get('mu', 0.1)  # FedProx hyperparameter

#         # Initialize parameters for energy and time calculations
#         self.model_size = kwargs.get('model_size', 1.0)                # Size in MB
#         self.computation_energy_rate = kwargs.get('computation_energy_rate', 0.5)   # Energy per second
#         self.communication_energy_rate = kwargs.get('communication_energy_rate', 0.1)  # Energy per MB
#         self.device_latency = kwargs.get('device_latency', 0.1)        # Seconds
#         self.edge_latency = kwargs.get('edge_latency', 0.05)           # Seconds

#         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#         # Dictionaries to store energy and time metrics
#         self.energy_consumption = {
#             'devices': {},
#             'edge_servers': {},
#             'cloud_server': 0.0
#         }
#         self.time_delays = {
#             'devices': {},
#             'edge_servers': {},
#             'cloud_server': 0.0
#         }
#         self.bandwidth_usage = {
#             'device_to_edge': 0.0,
#             'edge_to_cloud': 0.0
#         }

#         # Store agents and environments for reuse
#         self.cluster_agent = cluster_agent
#         self.scheduling_agent = scheduling_agent
#         self.cluster_env = cluster_env
#         self.scheduling_env = scheduling_env

#     def compute_computation_time(self, num_samples, cpu_power):
#         # Simple model: Time = (Number of samples) / (CPU Power)
#         return num_samples / cpu_power

#     def compute_computation_energy(self, computation_time):
#         # Energy = Computation Time * Energy Rate
#         return computation_time * self.computation_energy_rate

#     def compute_communication_time(self, data_size_mb, bandwidth_mbps, latency):
#         # Time = Data Size / Bandwidth + Latency
#         return (data_size_mb / bandwidth_mbps) + latency

#     def compute_communication_energy(self, data_size_mb):
#         # Energy = Data Size * Energy Rate
#         return data_size_mb * self.communication_energy_rate

#     def compute_batch_size(self, memory, cpu_power):
#         base_batch_size = 34
#         memory_factor = memory / 2  # Assume memory ranges from 1 to 4 GB
#         cpu_factor = cpu_power / 1.0  # Assume CPU power ranges from 1.0 to 2.0 GHz
#         batch_size = int(base_batch_size * memory_factor * cpu_factor)
#         return int(round(max(24, min(batch_size, 64))))  # Batch size capped between 16 and 32

#     def mix_models(self, local_state, global_state, alpha):
#         """
#         Gradually mix the global/edge model into the local model using the formula:
#         w_local = alpha * w_global/edge + (1 - alpha) * w_local
#         """
#         mixed_state = {}
#         for key in local_state.keys():
#             mixed_state[key] = alpha * global_state[key].to(self.device) + (1 - alpha) * local_state[key].to(self.device)
#         return mixed_state

#     def layer_wise_update(self, local_state, global_state, shared_layers):
#         """
#         Update only the shared layers in the local model with the global model.
#         """
#         updated_state = deepcopy(local_state)
#         for key in shared_layers:
#             if key in global_state:
#                 updated_state[key] = global_state[key]
#         return updated_state

#     def get_shared_layer_keys(self, model):
#         """
#         Get the keys of the shared layers (e.g., feature extraction layers).
#         """
#         shared_layers = []
#         for name, param in model.named_parameters():
#             if 'classifier' not in name and 'fc' not in name:
#                 shared_layers.append(name)
#         return shared_layers

#     def train_local_model(self, device_id, local_model_state, edge_model_state, global_model_state, train_loader, epochs, edge_iteration, global_iteration, lr=0.0001):
#         """
#         Train the local model on a specific device while considering its CPU power and memory.
#         Incorporates FedProx by adding a proximal term to the loss function.
#         Includes ReduceLROnPlateau to dynamically adjust the learning rate.
#         """
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         model = self.model_class(**self.model_args).to(device)
    
#         # Load the local model state
#         local_state = deepcopy(local_model_state)
    
#         # Apply gradual mixing with edge/global model periodically
#         if edge_iteration % self.k_edge == 0:
#             alpha = min(self.alpha * (edge_iteration + 1), 1.0)  # Increase alpha over time
#             local_state = self.mix_models(local_state, edge_model_state, alpha)
    
#         if global_iteration % self.m_global == 0:
#             alpha = min(self.alpha * (global_iteration + 1), 1.0)
#             local_state = self.mix_models(local_state, global_model_state, alpha)
    
#         # Apply layer-wise updating to only update shared layers
#         shared_layers = self.get_shared_layer_keys(model)
#         local_state = self.layer_wise_update(local_state, local_model_state, shared_layers)
    
#         model.load_state_dict(local_state)
#         optimizer = optim.Adam(model.parameters(), lr=lr)
#         scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
#         criterion = nn.CrossEntropyLoss()
    
#         # Fetch device-specific characteristics
#         device_info = self.devices_df.loc[self.devices_df['device_id'] == device_id].iloc[0]
#         cpu_power = device_info['cpu_power']  # GHz
#         memory = device_info['memory']       # GB
    
#         # Adjust batch size based on memory
#         adjusted_batch_size = self.compute_batch_size(memory, cpu_power)
    
#         # Simulate adjusted training time and energy
#         num_samples = len(train_loader.dataset)
#         computation_time_per_epoch = self.compute_computation_time(num_samples, cpu_power)
#         computation_energy_per_epoch = self.compute_computation_energy(computation_time_per_epoch)
    
#         # Update DataLoader with adjusted batch size
#         train_loader = DataLoader(train_loader.dataset, batch_size=adjusted_batch_size, shuffle=True)
    
#         # Store energy and time metrics
#         total_training_time = 0
#         total_training_energy = 0
    
#         # Training loop
#         model.train()
#         for epoch in range(epochs):
#             correct, total = 0, 0
#             running_loss = 0.0
    
#             for images, labels in train_loader:
#                 images, labels = images.to(device), labels.to(device)
#                 optimizer.zero_grad()
    
#                 outputs = model(images)
#                 loss = criterion(outputs, labels)
    
#                 # Compute FedProx proximal term
#                 proximal_loss = 0.0
#                 for name, param in model.named_parameters():
#                     if name in shared_layers:
#                         gl_param = global_model_state[name].clone().detach().to(device) 
#                         proximal_loss += torch.norm(param - gl_param)**2
                        
#                 proximal_loss = (self.mu / 2) * proximal_loss
    
#                 # Total loss
#                 total_loss = loss + proximal_loss
#                 total_loss.backward()
#                 optimizer.step()
    
#                 running_loss += total_loss.item()
    
#                 # Calculate training accuracy for the current batch
#                 _, predicted = torch.max(outputs, 1)
#                 correct += (predicted == labels).sum().item()
#                 total += labels.size(0)
    
#             # Simulate time and energy per epoch
#             total_training_time += computation_time_per_epoch
#             total_training_energy += computation_energy_per_epoch
    
#             # Log accuracy for the epoch
#             epoch_accuracy = 100 * correct / total
#             epoch_loss = running_loss / len(train_loader)
#             print(f"Device {device_id} - Epoch {epoch + 1}/{epochs} - Accuracy: {epoch_accuracy:.2f}%, Loss: {epoch_loss:.4f}")
    
#             # Adjust learning rate based on loss
#             scheduler.step(epoch_loss)
    
#             # Store epoch accuracy
#             if device_id not in self.accuracies['local_epochs']:
#                 self.accuracies['local_epochs'][device_id] = []
#             self.accuracies['local_epochs'][device_id].append(epoch_accuracy)
    
#         # Update total time and energy for the device
#         if device_id not in self.energy_consumption['devices']:
#             self.energy_consumption['devices'][device_id] = 0.0
#         if device_id not in self.time_delays['devices']:
#             self.time_delays['devices'][device_id] = 0.0
    
#         self.energy_consumption['devices'][device_id] += total_training_energy
#         self.time_delays['devices'][device_id] += total_training_time
    
#         print(f"Device {device_id} - Total Training Time: {total_training_time:.2f}s")
#         print(f"Device {device_id} - Total Training Energy: {total_training_energy:.2f}J")
    
#         return device_id, deepcopy(model.state_dict())

#     def aggregate_models(self, models):
#         avg_model = deepcopy(models[0])
#         for key in avg_model.keys():
#             for model in models[1:]:
#                 avg_model[key] += model[key].to(self.device)
#             avg_model[key] = avg_model[key].to(self.device) / len(models)
#         return avg_model

#     def evaluate_model(self, model_state, test_loader):
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         model = self.model_class(**self.model_args).to(device)
#         model.load_state_dict(model_state)
#         model.eval()
#         total, correct = 0, 0

#         with torch.no_grad():
#             for images, labels in test_loader:
#                 images, labels = images.to(device), labels.to(device)
#                 outputs = model(images)
#                 _, predicted = torch.max(outputs, 1)
#                 total += labels.size(0)
#                 correct += (predicted == labels).sum().item()

#         return 100 * correct / total

#     def save_summary_metrics(self, summary_file="metrics_summary.json"):
#         """
#         Consolidate and save key metrics for communication costs, energy consumption, 
#         and training performance.
#         """
#         summary_metrics = {
#             "global_iterations": self.global_iterations,
#             "edge_iterations": self.edge_iterations,
#             "local_epochs": self.local_epochs,
#             "overall_energy_consumption": self.energy_consumption,
#             "overall_time_delays": self.time_delays,
#             "overall_bandwidth_usage": self.bandwidth_usage,
#             "global_model_accuracy": self.accuracies.get("global_iterations", [])
#         }

#         # Save the summary to a JSON file
#         with open(summary_file, 'w') as f:
#             json.dump(summary_metrics, f, indent=4)

#         print(f"Summary metrics saved to {summary_file}")

#     def edge_server_training(self, edge_id, edge_devices, edge_model_state, edge_iteration, global_iteration):
#         """
#         Train devices assigned to the edge server independently and update the local model after each edge iteration.
#         """
#         edge_computation_time = 0.0
#         edge_computation_energy = 0.0

#         communication_costs = []
#         communication_latencies = []
#         communication_energies = []

#         print(f"Edge Server {edge_id}: Starting training")
#         # Edge server's current model state is passed as edge_model_state

#         local_results = []  # Collect local model updates from devices

#         for device_id, train_loader in edge_devices:
#             print(f"Edge Server {edge_id} sends model to Device {device_id} for training.")
#             # Train local model starting from edge_model_state
#             local_result = self.train_local_model(
#                 device_id,
#                 edge_model_state,          # local_model_state
#                 edge_model_state,          # edge_model_state
#                 self.global_model_state,   # global_model_state
#                 train_loader,
#                 self.local_epochs,
#                 edge_iteration,
#                 global_iteration
#             )
#             local_results.append(local_result)

#             # Communication metrics: time and energy from device to edge server
#             device_info = self.devices_df.loc[self.devices_df['device_id'] == device_id].iloc[0]
#             bandwidth = device_info['bandwidth']  # in Mbps
#             communication_time = self.compute_communication_time(
#                 self.model_size,
#                 bandwidth,
#                 self.device_latency
#             )
#             communication_energy = self.compute_communication_energy(self.model_size)

#             communication_costs.append(self.model_size)
#             communication_latencies.append(communication_time)
#             communication_energies.append(communication_energy)

#             # Update device energy and time
#             self.time_delays['devices'][device_id] += communication_time
#             self.energy_consumption['devices'][device_id] += communication_energy

#             # Update total bandwidth usage
#             self.bandwidth_usage['device_to_edge'] += self.model_size

#         # Aggregate models from all devices at the edge server level
#         local_states = [state for _, state in local_results]
#         updated_edge_model_state = self.aggregate_models(local_states)
#         print(f"Edge Server {edge_id}: Aggregated models from devices")

#         # Edge server computation time for aggregation
#         num_models = len(local_states)
#         aggregation_time = num_models * 0.1  # 0.1 seconds per model
#         edge_computation_time += aggregation_time
#         edge_computation_energy += aggregation_time * self.computation_energy_rate

#         # Store edge server energy and time
#         if edge_id not in self.energy_consumption['edge_servers']:
#             self.energy_consumption['edge_servers'][edge_id] = 0.0
#         if edge_id not in self.time_delays['edge_servers']:
#             self.time_delays['edge_servers'][edge_id] = 0.0
#         self.time_delays['edge_servers'][edge_id] += edge_computation_time
#         self.energy_consumption['edge_servers'][edge_id] += edge_computation_energy

#         # Update total bandwidth usage for communication with cloud
#         self.bandwidth_usage['edge_to_cloud'] += self.model_size

#         # Evaluate aggregated edge model
#         edge_accuracy = self.evaluate_model(updated_edge_model_state, self.get_test_loader())
#         print(f"Edge Server {edge_id} - Edge Iteration {edge_iteration + 1} - Edge Model Accuracy: {edge_accuracy:.2f}%")

#         # Record edge_iteration accuracy
#         if global_iteration not in self.accuracies['edge_iterations']:
#             self.accuracies['edge_iterations'][global_iteration] = {}
#         if edge_iteration not in self.accuracies['edge_iterations'][global_iteration]:
#             self.accuracies['edge_iterations'][global_iteration][edge_iteration] = {}
#         self.accuracies['edge_iterations'][global_iteration][edge_iteration][edge_id] = edge_accuracy

#         # Return the updated model state for this edge server
#         return updated_edge_model_state

#     def recalculate_assignments(self):
#         # Re-run cluster assignment agent
#         self.cluster_env.devices_df = self.devices_df  # Update devices_df in the environment
#         cluster_state = self.cluster_env.reset()
#         cluster_action, _ = self.cluster_agent.predict(cluster_state)
#         assignments = dict(zip(range(self.cluster_env.num_clusters), cluster_action))
#         self.devices_df['assigned_servers'] = self.devices_df['cluster'].map(assignments)

#         # Re-run device scheduling agent
#         self.scheduling_env.devices_df = self.devices_df.reset_index(drop=True)  # Update devices_df in the environment
#         scheduling_state = self.scheduling_env.reset()
#         scheduling_action, _ = self.scheduling_agent.predict(scheduling_state)
#         scheduled_devices = [i for i, a in enumerate(scheduling_action) if a == 1]
#         self.devices_df['is_scheduled'] = False
#         self.devices_df.loc[scheduled_devices, 'is_scheduled'] = True

#         # Reset edge_servers
#         self.edge_servers = {}

#         # Recreate edge_servers with updated assignments, initializing model_state from global model
#         for idx, row in self.devices_df.iterrows():
#             if not row['is_scheduled']:
#                 continue  # Skip unscheduled devices

#             local_dataset = row["local_data"]
#             if not isinstance(local_dataset, LocalDataset):
#                 continue

#             images, labels = local_dataset.get_data()
#             if len(images) == 0 or len(labels) == 0:
#                 continue

#             memory = row["memory"]
#             cpu_power = row["cpu_power"]
#             batch_size = self.compute_batch_size(memory, cpu_power)

#             tensor_dataset = TensorDataset(
#                 torch.tensor(images, dtype=torch.float32),
#                 torch.tensor(labels, dtype=torch.long)
#             )
#             loader = DataLoader(tensor_dataset, batch_size=batch_size, shuffle=True)

#             edge_server_id = row["assigned_servers"]
#             if edge_server_id not in self.edge_servers:
#                 self.edge_servers[edge_server_id] = {
#                     'devices': [],
#                     'model_state': deepcopy(self.global_model_state)  # Initialize from global model
#                 }
#             self.edge_servers[edge_server_id]['devices'].append((row["device_id"], loader))

#         # Print updated device distribution across edge servers
#         print("Updated device distribution across edge servers:")
#         for edge_id, edge_info in self.edge_servers.items():
#             num_devices = len(edge_info['devices'])
#             print(f"Edge Server {edge_id}: {num_devices} devices")

#     def federated_learning(self, global_model_state, max_parallel_edge_servers=2):
#         """
#         Perform semi-synchronous federated learning with independent device training
#         and multiple edge server iterations, ensuring local model updates after each iteration.
#         Incorporates FedProx to stabilize training in heterogeneous environments.
#         """
#         cloud_computation_time = 0.0
#         cloud_computation_energy = 0.0

#         for global_iteration in range(self.global_iterations):
#             print("-" * 60)
#             print(f"Global Iteration {global_iteration + 1}/{self.global_iterations}")
#             print("-" * 60)

#             # Recalculate assignments at the beginning of each global iteration
#             self.global_model_state = deepcopy(global_model_state)  # Update global model state
#             self.recalculate_assignments()

#             # Initialize edge server models from the global model
#             for edge_id in self.edge_servers.keys():
#                 # Apply gradual mixing with the global model every m_global iterations
#                 if global_iteration % self.m_global == 0:
#                     alpha = min(self.alpha * (global_iteration + 1), 1.0)
#                     self.edge_servers[edge_id]['model_state'] = self.mix_models(
#                         self.edge_servers[edge_id]['model_state'],
#                         self.global_model_state,
#                         alpha
#                     )

#             for edge_iteration in range(self.edge_iterations):
#                 print(f"  Edge Iteration {edge_iteration + 1}/{self.edge_iterations}")

#                 edge_results = {}

#                 # Use ThreadPoolExecutor to limit concurrent edge server training
#                 with ThreadPoolExecutor(max_workers=max_parallel_edge_servers) as executor:
#                     future_to_edge = {
#                         executor.submit(
#                             self.edge_server_training,
#                             edge_id,
#                             edge_info['devices'],
#                             edge_info['model_state'],
#                             edge_iteration,
#                             global_iteration
#                         ): edge_id
#                         for edge_id, edge_info in self.edge_servers.items()
#                     }

#                     # Collect results as each edge server finishes
#                     for future in as_completed(future_to_edge):
#                         edge_id = future_to_edge[future]
#                         try:
#                             updated_state = future.result()  # Result is updated edge model state
#                             edge_results[edge_id] = updated_state
#                             # Update the edge server's model state
#                             self.edge_servers[edge_id]['model_state'] = updated_state
#                             print(f"  Edge Server {edge_id}: Updated model after Edge Iteration {edge_iteration + 1}")
#                         except Exception as e:
#                             print(f"Edge Server {edge_id} encountered an error: {e}")

#                 # Periodically update edge models with global model
#                 if (edge_iteration + 1) % self.k_edge == 0:
#                     for edge_id in self.edge_servers.keys():
#                         alpha = min(self.alpha * (edge_iteration + 1), 1.0)
#                         self.edge_servers[edge_id]['model_state'] = self.mix_models(
#                             self.edge_servers[edge_id]['model_state'],
#                             self.global_model_state,
#                             alpha
#                         )

#             # Perform global aggregation after all edge iterations
#             global_states = [edge_info['model_state'] for edge_info in self.edge_servers.values()]
#             self.global_model_state = self.aggregate_models(global_states)
#             global_model_state = deepcopy(self.global_model_state)
#             print(f"Global Iteration {global_iteration + 1}: Aggregated all edge models")

#             # Cloud server computation time for aggregation
#             num_edge_models = len(self.edge_servers)
#             aggregation_time = num_edge_models * 0.2  # 0.2 seconds per edge model
#             cloud_computation_time += aggregation_time
#             cloud_computation_energy += aggregation_time * self.computation_energy_rate

#             # Evaluate global model and store accuracy
#             global_accuracy = self.evaluate_model(global_model_state, self.get_test_loader())
#             print(f"Global Iteration {global_iteration + 1}: Global Model Accuracy: {global_accuracy:.2f}%")

#             self.accuracies['global_iterations'].append(global_accuracy)

#             # Periodically write accuracies to file
#             with open(self.metrics_file, 'w') as f:
#                 json.dump(self.accuracies, f, indent=4)

#             # Note: Assignments will be recalculated at the beginning of the next global iteration

#         # Store cloud server energy and time
#         self.energy_consumption['cloud_server'] = cloud_computation_energy
#         self.time_delays['cloud_server'] = cloud_computation_time

#     def main(self, global_iterations, edge_iterations, local_epochs):
#         self.global_iterations = global_iterations
#         self.edge_iterations = edge_iterations
#         self.local_epochs = local_epochs

#         # Initialize global model state
#         global_model = self.model_class(**self.model_args).to(self.device)
#         self.global_model_state = deepcopy(global_model.state_dict())  # Initialize global model state

#         # Run federated learning
#         self.federated_learning(self.global_model_state, max_parallel_edge_servers=1)

#         final_accuracy = self.evaluate_model(self.global_model_state, self.get_test_loader())
#         print(f"Final Global Model Accuracy: {final_accuracy:.2f}%")

#         # Save the accuracies to JSON
#         with open(self.metrics_file.replace('.json', '_accuracies.json'), 'w') as json_file:
#             json.dump(self.accuracies, json_file, indent=4)

#         self.save_summary_metrics(self.metrics_file.replace('.json', '_full_metrics.json'))

#         # After training, consolidate metrics into a single dictionary
#         metrics_summary = {
#             "Energy Consumption": self.energy_consumption,
#             "Time Delays": self.time_delays,
#             "Bandwidth Usage": self.bandwidth_usage,
#         }

#         # Save metrics to a JSON file
#         with open(self.metrics_file.replace('.json', '_summary.json'), 'w') as f:
#             json.dump(metrics_summary, f, indent=4)

#     def get_test_loader(self):
#         test_dataset = TensorDataset(
#             torch.tensor(self.test_images, dtype=torch.float32),
#             torch.tensor(self.test_labels, dtype=torch.long)
#         )
#         test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
#         return test_loader

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
import json
import numpy as np
from collections import defaultdict
from math import log
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils import parameters_to_vector
import os  # Added
from datetime import datetime  # Optional: For timestamped checkpoints

# Placeholder imports for model architectures and dataset
# Replace these with actual imports from your project
# from your_model_file import DenseNet, DLA, BasicBlock
# from your_dataset_file import LocalDataset

class FederatedLearningSystem:
    def __init__(self, devices_df, test_images, test_labels, dataset_name, metrics_file,
                 cluster_agent, scheduling_agent, cluster_env, scheduling_env, checkpoint_dir='checkpoints', **kwargs):
        self.devices_df = devices_df.reset_index(drop=True)
        self.test_images = test_images
        self.test_labels = test_labels
        self.dataset_name = dataset_name
        self.metrics_file = metrics_file
        self.edge_servers = {}
        self.accuracies = {
            'local_epochs': {},       # Store local epoch accuracies per device
            'edge_iterations': {},    # Store edge iteration accuracies per edge server
            'global_iterations': []   # Store global iteration accuracies
        }

        # Model selection based on dataset name
        if dataset_name in ['mnist', 'fashion_mnist']:
            self.model_class = DenseNet  # Assign the class, not an instance
            self.model_args = {
                'in_channels': 1,
                'growthRate': 12,
                'depth': 100,
                'reduction': 0.5,
                'nClasses': 10,
                'bottleneck': True
            }
        elif dataset_name == 'cifar10':
            self.model_class = DLA  # Replace with DLA if needed
            self.model_args = {
                'block': BasicBlock,
                'num_classes': 10
            }
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")

        # Additional parameters
        self.global_iterations = kwargs.get('global_iterations', 5)
        self.edge_iterations = kwargs.get('edge_iterations', 3)
        self.local_epochs = kwargs.get('local_epochs', 1)
        self.input_channels = kwargs.get('input_channels', 1)
        self.num_classes = kwargs.get('num_classes', 10)
        self.batch_size = kwargs.get('batch_size', 32)

        # Parameters for Scenario 3
        self.k_edge = kwargs.get('k_edge', 2)
        self.m_global = kwargs.get('m_global', 1)
        self.alpha = kwargs.get('alpha', 0.1)

        # Parameters for FedProx
        self.mu = kwargs.get('mu', 0.1)  # FedProx hyperparameter

        # Initialize parameters for energy and time calculations
        self.model_size = kwargs.get('model_size', 1.0)                # Size in MB
        self.computation_energy_rate = kwargs.get('computation_energy_rate', 0.5)   # Energy per second
        self.communication_energy_rate = kwargs.get('communication_energy_rate', 0.1)  # Energy per MB
        self.device_latency = kwargs.get('device_latency', 0.1)        # Seconds
        self.edge_latency = kwargs.get('edge_latency', 0.05)           # Seconds

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Dictionaries to store energy and time metrics
        self.energy_consumption = {
            'devices': {},
            'edge_servers': {},
            'cloud_server': 0.0
        }
        self.time_delays = {
            'devices': {},
            'edge_servers': {},
            'cloud_server': 0.0
        }
        self.bandwidth_usage = {
            'device_to_edge': 0.0,
            'edge_to_cloud': 0.0
        }

        # Store agents and environments for reuse
        self.cluster_agent = cluster_agent
        self.scheduling_agent = scheduling_agent
        self.cluster_env = cluster_env
        self.scheduling_env = scheduling_env

        # Checkpointing setup (Added)
        self.checkpoint_dir = checkpoint_dir
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
            print(f"Checkpoint directory '{self.checkpoint_dir}' created.")
        else:
            print(f"Checkpoint directory '{self.checkpoint_dir}' already exists.")

        # Verify and enforce parameter types
        self.verify_and_enforce_parameter_types()

    def verify_and_enforce_parameter_types(self):
        """
        Verify that all model parameters are floating point. If not, cast them to float.
        """
        model = self.model_class(**self.model_args).to(self.device)
        for name, param in model.named_parameters():
            if not param.dtype.is_floating_point:
                print(f"Parameter {name} is of type {param.dtype}. Casting to torch.float32.")
                param.data = param.data.float()
        # Update the global model state after enforcement
        self.global_model_state = deepcopy(model.state_dict())

    def compute_computation_time(self, num_samples, cpu_power):
        """
        Compute computation time based on number of samples and CPU power.
        Simple linear model: Time = Number of samples / CPU Power
        """
        return num_samples / cpu_power

    def compute_computation_energy(self, computation_time):
        """
        Compute computation energy based on computation time.
        Energy = Computation Time * Energy Rate
        """
        return computation_time * self.computation_energy_rate

    def compute_communication_time(self, data_size_mb, bandwidth_mbps, latency):
        """
        Compute communication time based on data size, bandwidth, and latency.
        Time = Data Size / Bandwidth + Latency
        """
        return (data_size_mb / bandwidth_mbps) + latency

    def compute_communication_energy(self, data_size_mb):
        """
        Compute communication energy based on data size.
        Energy = Data Size * Energy Rate
        """
        return data_size_mb * self.communication_energy_rate

    def compute_batch_size(self, memory, cpu_power):
        """
        Adjust batch size based on device memory and CPU power.
        """
        base_batch_size = 34
        memory_factor = memory / 2  # Assume memory ranges from 1 to 4 GB
        cpu_factor = cpu_power / 1.0  # Assume CPU power ranges from 1.0 to 2.0 GHz
        batch_size = int(base_batch_size * memory_factor * cpu_factor)
        return int(round(max(24, min(batch_size, 64))))  # Batch size capped between 24 and 64

    def mix_models(self, local_state, global_state, alpha):
        """
        Gradually mix the global/edge model into the local model using the formula:
        w_local = alpha * w_global/edge + (1 - alpha) * w_local
        """
        mixed_state = {}
        for key in local_state.keys():
            # Ensure tensors are on CPU and detached
            global_tensor = global_state[key].clone().detach().to(self.device)
            local_tensor = local_state[key].clone().detach().to(self.device)
            mixed_state[key] = alpha * global_tensor + (1 - alpha) * local_tensor
        return mixed_state

    def layer_wise_update(self, local_state, global_state, shared_layers):
        """
        Update only the shared layers in the local model with the global model.
        """
        updated_state = deepcopy(local_state)
        for key in shared_layers:
            if key in global_state:
                updated_state[key] = global_state[key]
        return updated_state

    def get_shared_layer_keys(self, model):
        """
        Get the keys of the shared layers (e.g., feature extraction layers).
        """
        shared_layers = []
        for name, param in model.named_parameters():
            if 'classifier' not in name and 'fc' not in name:
                shared_layers.append(name)
        return shared_layers

    def calculate_label_distribution(self, edge_devices):
        """
        Calculate label distribution for each edge server.

        Args:
            edge_devices (list): List of tuples containing device_id and DataLoader.

        Returns:
            dict: Mapping from label to count.
        """
        label_counts = defaultdict(int)
        for device_id, loader in edge_devices:
            for _, labels in loader:
                labels = labels.cpu().numpy()
                for label in labels:
                    label_counts[label] += 1
        return label_counts

    def save_checkpoint(self, model_state, global_iteration):
        """
        Save the global model's state_dict as a checkpoint.

        Args:
            model_state (dict): The state_dict of the global model.
            global_iteration (int): The current global iteration number.
        """
        checkpoint_filename = f"global_model_iter_{global_iteration + 1}.pth"
        checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
        
        try:
            torch.save(model_state, checkpoint_path)
            print(f"Checkpoint saved: {checkpoint_path}")
        except Exception as e:
            print(f"Error saving checkpoint at iteration {global_iteration + 1}: {e}")

    def load_checkpoint(self, checkpoint_path):
        """
        Load a model checkpoint.

        Args:
            checkpoint_path (str): Path to the checkpoint file.

        Returns:
            dict: The loaded state_dict.
        """
        if os.path.exists(checkpoint_path):
            try:
                state_dict = torch.load(checkpoint_path, map_location=self.device)
                print(f"Checkpoint loaded: {checkpoint_path}")
                return state_dict
            except Exception as e:
                print(f"Error loading checkpoint from {checkpoint_path}: {e}")
                return None
        else:
            print(f"Checkpoint file {checkpoint_path} does not exist.")
            return None

    def train_local_model(self, device_id, local_model_state, edge_model_state, global_model_state,
                          train_loader, epochs, edge_iteration, global_iteration, lr=0.0001):
        """
        Train the local model on a specific device while considering its CPU power and memory.
        Incorporates FedProx by adding a proximal term to the loss function.
        Includes ReduceLROnPlateau to dynamically adjust the learning rate.
        Returns:
            tuple: (device_id, delta_model, local_steps)
        """
        device = self.device
        model = self.model_class(**self.model_args).to(device)

        # Load the local model state
        local_state = deepcopy(local_model_state)

        # Apply gradual mixing with edge/global model periodically
        if edge_iteration % self.k_edge == 0:
            alpha = min(self.alpha * (edge_iteration + 1), 1.0)  # Increase alpha over time
            local_state = self.mix_models(local_state, edge_model_state, alpha)

        if global_iteration % self.m_global == 0:
            alpha = min(self.alpha * (global_iteration + 1), 1.0)
            local_state = self.mix_models(local_state, global_model_state, alpha)

        # Apply layer-wise updating to only update shared layers
        shared_layers = self.get_shared_layer_keys(model)
        local_state = self.layer_wise_update(local_state, local_model_state, shared_layers)

        model.load_state_dict(local_state)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)  # Added weight_decay
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
        criterion = nn.CrossEntropyLoss()

        # Fetch device-specific characteristics
        device_info = self.devices_df.loc[self.devices_df['device_id'] == device_id].iloc[0]
        cpu_power = device_info['cpu_power']  # GHz
        memory = device_info['memory']       # GB

        # Adjust batch size based on memory
        adjusted_batch_size = self.compute_batch_size(memory, cpu_power)

        # Simulate adjusted training time and energy
        num_samples = len(train_loader.dataset)
        computation_time_per_epoch = self.compute_computation_time(num_samples, cpu_power)
        computation_energy_per_epoch = self.compute_computation_energy(computation_time_per_epoch)

        # Update DataLoader with adjusted batch size
        train_loader = DataLoader(train_loader.dataset, batch_size=adjusted_batch_size, shuffle=True)

        # Store energy and time metrics
        total_training_time = 0
        total_training_energy = 0

        # Initialize local steps counter
        local_steps = 0

        # Training loop
        model.train()
        best_loss = float('inf')
        patience = 3
        trigger_times = 0

        for epoch in range(epochs):
            correct, total = 0, 0
            running_loss = 0.0

            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()

                outputs = model(images)
                loss = criterion(outputs, labels)

                # Compute FedProx proximal term in a vectorized manner
                shared_params = [param for name, param in model.named_parameters() if name in shared_layers]
                global_shared_params = [global_model_state[name].to(device) for name in shared_layers]

                # Vectorize parameters using parameters_to_vector
                local_params_vector = parameters_to_vector(shared_params)
                global_params_vector = parameters_to_vector(global_shared_params)

                proximal_loss = torch.norm(local_params_vector - global_params_vector) ** 2
                proximal_loss = (self.mu / 2) * proximal_loss

                # Total loss
                total_loss = loss + proximal_loss
                total_loss.backward()
                optimizer.step()

                running_loss += total_loss.item()

                # Increment local steps
                local_steps += 1

                # Calculate training accuracy for the current batch
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

            # Simulate time and energy per epoch
            total_training_time += computation_time_per_epoch
            total_training_energy += computation_energy_per_epoch

            # Log accuracy and loss for the epoch
            epoch_accuracy = 100 * correct / total
            epoch_loss = running_loss / len(train_loader)
            print(f"Device {device_id} - Epoch {epoch + 1}/{epochs} - Accuracy: {epoch_accuracy:.2f}%, Loss: {epoch_loss:.4f}")

            # Adjust learning rate based on loss
            scheduler.step(epoch_loss)

            # Early stopping logic
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                trigger_times = 0
            else:
                trigger_times += 1
                print(f"Device {device_id} - Early stopping trigger {trigger_times}/{patience}")
                if trigger_times >= patience:
                    print(f"Device {device_id} - Early stopping")
                    break

            # Store epoch accuracy
            if device_id not in self.accuracies['local_epochs']:
                self.accuracies['local_epochs'][device_id] = []
            self.accuracies['local_epochs'][device_id].append(epoch_accuracy)

        # Update total time and energy for the device
        if device_id not in self.energy_consumption['devices']:
            self.energy_consumption['devices'][device_id] = 0.0
        if device_id not in self.time_delays['devices']:
            self.time_delays['devices'][device_id] = 0.0

        self.energy_consumption['devices'][device_id] += total_training_energy
        self.time_delays['devices'][device_id] += total_training_time

        print(f"Device {device_id} - Total Training Time: {total_training_time:.2f}s")
        print(f"Device {device_id} - Total Training Energy: {total_training_energy:.2f}J")

        # Calculate delta (model update)
        delta_model = {}
        for key in model.state_dict().keys():
            delta_model[key] = model.state_dict()[key] - local_state[key]

        # Return the delta and effective number of local steps
        return device_id, delta_model, local_steps

    def aggregate_models_fednova(self, base_model_state, local_results, total_steps):
        """
        Aggregate local updates using FedNova.

        Args:
            base_model_state (dict): The model state before local updates (edge_model_state).
            local_results (list): List of tuples (delta_model, local_steps) from devices.
            total_steps (int): Total effective local steps from all devices.

        Returns:
            dict: Updated model state after aggregation.
        """
        aggregated_delta = {}
        for key in base_model_state.keys():
            # Initialize as float tensors
            aggregated_delta[key] = torch.zeros_like(base_model_state[key], dtype=torch.float32)

        # Aggregate normalized updates
        for delta_model, local_steps in local_results:
            scaling_factor = local_steps / total_steps
            for key in delta_model.keys():
                # Ensure delta_model[key] is float
                delta_tensor = delta_model[key].to(self.device).float()
                aggregated_delta[key] += delta_tensor * scaling_factor

        # Update the base model state
        updated_model_state = {}
        for key in base_model_state.keys():
            updated_model_state[key] = base_model_state[key].to(self.device).float() + aggregated_delta[key]

        return updated_model_state

    def aggregate_models_global_fednova(self, global_model_state, edge_deltas, edge_steps_list, total_global_steps):
        """
        Aggregate edge server updates using FedNova at the global level.

        Args:
            global_model_state (dict): The current global model state.
            edge_deltas (list): List of delta_model dicts from edge servers.
            edge_steps_list (list): List of local steps from edge servers.
            total_global_steps (int): Total steps from all edge servers.

        Returns:
            dict: Updated global model state after aggregation.
        """
        aggregated_delta = {}
        for key in global_model_state.keys():
            # Initialize as float tensors
            aggregated_delta[key] = torch.zeros_like(global_model_state[key], dtype=torch.float32)

        # Aggregate normalized updates
        for delta_model, edge_steps in zip(edge_deltas, edge_steps_list):
            scaling_factor = edge_steps / total_global_steps
            for key in delta_model.keys():
                # Ensure delta_model[key] is float
                delta_tensor = delta_model[key].to(self.device).float()
                aggregated_delta[key] += delta_tensor * scaling_factor

        # Update the global model state
        updated_global_model_state = {}
        for key in global_model_state.keys():
            updated_global_model_state[key] = global_model_state[key].to(self.device).float() + aggregated_delta[key]

        return updated_global_model_state

    def evaluate_model(self, model_state, test_loader):
        """
        Evaluate the model on the test dataset.

        Args:
            model_state (dict): Model state_dict.
            test_loader (DataLoader): DataLoader for the test dataset.

        Returns:
            float: Overall accuracy percentage.
        """
        device = self.device
        model = self.model_class(**self.model_args).to(device)
        model.load_state_dict(model_state)
        model.eval()
        total, correct = 0, 0

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        # Calculate overall accuracy
        overall_accuracy = 100 * correct / total

        # Log accuracies
        self.accuracies['global_iterations'].append(overall_accuracy)

        # Print overall accuracy
        print(f"Overall Global Model Accuracy: {overall_accuracy:.2f}%")

        return overall_accuracy

    def save_summary_metrics(self, summary_file="metrics_summary.json"):
        """
        Consolidate and save key metrics for communication costs, energy consumption, 
        and training performance.
        """
        summary_metrics = {
            "global_iterations": self.global_iterations,
            "edge_iterations": self.edge_iterations,
            "local_epochs": self.local_epochs,
            "overall_energy_consumption": self.energy_consumption,
            "overall_time_delays": self.time_delays,
            "overall_bandwidth_usage": self.bandwidth_usage,
            "global_model_accuracy": self.accuracies.get("global_iterations", [])
        }

        # Save the summary to a JSON file
        with open(summary_file, 'w') as f:
            json.dump(summary_metrics, f, indent=4)

        print(f"Summary metrics saved to {summary_file}")

    def edge_server_training(self, edge_id, edge_devices, edge_model_state, edge_iteration, global_iteration):
        """
        Train devices assigned to the edge server independently and update the local model after each edge iteration.
        """
        edge_computation_time = 0.0
        edge_computation_energy = 0.0

        communication_costs = []
        communication_latencies = []
        communication_energies = []

        print(f"Edge Server {edge_id}: Starting training")
        # Edge server's current model state is passed as edge_model_state

        local_results = []    # Collect (delta_model, local_steps) from devices
        total_edge_steps = 0  # Total effective steps at the edge server

        for device_id, train_loader in edge_devices:
            print(f"Edge Server {edge_id} sends model to Device {device_id} for training.")
            # Train local model starting from global_model_state
            device_id, delta_model, local_steps = self.train_local_model(
                device_id,
                self.global_model_state,    # local_model_state initialized with global model
                edge_model_state,           # edge_model_state
                self.global_model_state,    # global_model_state
                train_loader,
                self.local_epochs,
                edge_iteration,
                global_iteration
            )
            local_results.append((delta_model, local_steps))
            total_edge_steps += local_steps

            # Communication metrics: time and energy from device to edge server
            device_info = self.devices_df.loc[self.devices_df['device_id'] == device_id].iloc[0]
            bandwidth = device_info['bandwidth']  # in Mbps
            communication_time = self.compute_communication_time(
                self.model_size,
                bandwidth,
                self.device_latency
            )
            communication_energy = self.compute_communication_energy(self.model_size)

            communication_costs.append(self.model_size)
            communication_latencies.append(communication_time)
            communication_energies.append(communication_energy)

            # Update device energy and time
            self.time_delays['devices'][device_id] += communication_time
            self.energy_consumption['devices'][device_id] += communication_energy

            # Update total bandwidth usage
            self.bandwidth_usage['device_to_edge'] += self.model_size

        # Aggregate models from devices using FedNova
        updated_edge_model_state = self.aggregate_models_fednova(
            edge_model_state,
            local_results,
            total_edge_steps
        )
        print(f"Edge Server {edge_id}: Aggregated models from devices")

        # Edge server computation time for aggregation
        num_models = len(local_results)
        aggregation_time = num_models * 0.1  # 0.1 seconds per model
        edge_computation_time += aggregation_time
        edge_computation_energy += aggregation_time * self.computation_energy_rate

        # Store edge server energy and time
        if edge_id not in self.energy_consumption['edge_servers']:
            self.energy_consumption['edge_servers'][edge_id] = 0.0
        if edge_id not in self.time_delays['edge_servers']:
            self.time_delays['edge_servers'][edge_id] = 0.0
        self.time_delays['edge_servers'][edge_id] += edge_computation_time
        self.energy_consumption['edge_servers'][edge_id] += edge_computation_energy

        # Update total bandwidth usage for communication with cloud
        self.bandwidth_usage['edge_to_cloud'] += self.model_size

        # Evaluate aggregated edge model
        edge_accuracy = self.evaluate_model(updated_edge_model_state, self.get_test_loader())
        print(f"Edge Server {edge_id} - Edge Iteration {edge_iteration + 1} - Edge Model Accuracy: {edge_accuracy:.2f}%")

        # Record edge_iteration accuracy
        if global_iteration not in self.accuracies['edge_iterations']:
            self.accuracies['edge_iterations'][global_iteration] = {}
        if edge_iteration not in self.accuracies['edge_iterations'][global_iteration]:
            self.accuracies['edge_iterations'][global_iteration][edge_iteration] = {}
        self.accuracies['edge_iterations'][global_iteration][edge_iteration][edge_id] = edge_accuracy

        # Update edge server's total steps
        self.edge_servers[edge_id]['total_edge_steps'] = total_edge_steps

        # Return the updated model state for this edge server
        return updated_edge_model_state

    def recalculate_assignments(self):
        """
        Re-run cluster and scheduling agents to assign devices to edge servers.
        Recreate edge servers with updated assignments, initializing model_state from global model.
        Calculate label distributions and total samples for each edge server.
        """
        # Re-run cluster assignment agent
        self.cluster_env.devices_df = self.devices_df  # Update devices_df in the environment
        cluster_state = self.cluster_env.reset()
        cluster_action, _ = self.cluster_agent.predict(cluster_state)
        assignments = dict(zip(range(self.cluster_env.num_clusters), cluster_action))
        self.devices_df['assigned_servers'] = self.devices_df['cluster'].map(assignments)

        # Re-run device scheduling agent
        self.scheduling_env.devices_df = self.devices_df.reset_index(drop=True)  # Update devices_df in the environment
        scheduling_state = self.scheduling_env.reset()
        scheduling_action, _ = self.scheduling_agent.predict(scheduling_state)
        scheduled_devices = [i for i, a in enumerate(scheduling_action) if a == 1]
        self.devices_df['is_scheduled'] = False
        self.devices_df.loc[scheduled_devices, 'is_scheduled'] = True

        # Reset edge_servers
        self.edge_servers = {}

        # Recreate edge_servers with updated assignments, initializing model_state from global model
        for idx, row in self.devices_df.iterrows():
            if not row['is_scheduled']:
                continue  # Skip unscheduled devices

            local_dataset = row["local_data"]
            if not isinstance(local_dataset, LocalDataset):
                continue

            images, labels = local_dataset.get_data()
            if len(images) == 0 or len(labels) == 0:
                continue

            memory = row["memory"]
            cpu_power = row["cpu_power"]
            batch_size = self.compute_batch_size(memory, cpu_power)

            # Ensure data types are correct
            loader = self.prepare_data_loader(images, labels, batch_size)

            edge_server_id = row["assigned_servers"]
            if edge_server_id not in self.edge_servers:
                self.edge_servers[edge_server_id] = {
                    'devices': [],
                    'model_state': deepcopy(self.global_model_state),  # Initialize from global model
                    'label_distribution': defaultdict(int),             # Initialize label distribution
                    'total_samples': 0                                # Initialize sample count
                }
            self.edge_servers[edge_server_id]['devices'].append((row["device_id"], loader))

            # Update label distribution and sample count
            label_counts = self.calculate_label_distribution([(row["device_id"], loader)])
            for label, count in label_counts.items():
                self.edge_servers[edge_server_id]['label_distribution'][label] += count
                self.edge_servers[edge_server_id]['total_samples'] += count

        # Print updated device distribution across edge servers
        print("Updated device distribution across edge servers:")
        for edge_id, edge_info in self.edge_servers.items():
            num_devices = len(edge_info['devices'])
            print(f"Edge Server {edge_id}: {num_devices} devices")
            print(f"  Label Distribution: {dict(edge_info['label_distribution'])}")

    def federated_learning(self, global_model_state, max_parallel_edge_servers=2):
        """
        Perform semi-synchronous federated learning with independent device training
        and multiple edge server iterations, ensuring local model updates after each iteration.
        Incorporates FedProx and FedNova to stabilize training in heterogeneous environments.
        """
        cloud_computation_time = 0.0
        cloud_computation_energy = 0.0

        for global_iteration in range(self.global_iterations):
            print("-" * 60)
            print(f"Global Iteration {global_iteration + 1}/{self.global_iterations}")
            print("-" * 60)

            # Recalculate assignments at the beginning of each global iteration
            self.global_model_state = deepcopy(global_model_state)  # Update global model state
            self.recalculate_assignments()

            # Initialize edge server models from the global model
            for edge_id in self.edge_servers.keys():
                # Apply gradual mixing with the global model every m_global iterations
                if global_iteration % self.m_global == 0:
                    alpha = min(self.alpha * (global_iteration + 1), 1.0)
                    self.edge_servers[edge_id]['model_state'] = self.mix_models(
                        self.edge_servers[edge_id]['model_state'],
                        self.global_model_state,
                        alpha
                    )

            for edge_iteration in range(self.edge_iterations):
                print(f"  Edge Iteration {edge_iteration + 1}/{self.edge_iterations}")

                # Use ThreadPoolExecutor to limit concurrent edge server training
                with ThreadPoolExecutor(max_workers=max_parallel_edge_servers) as executor:
                    future_to_edge = {
                        executor.submit(
                            self.edge_server_training,
                            edge_id,
                            edge_info['devices'],
                            edge_info['model_state'],
                            edge_iteration,
                            global_iteration
                        ): edge_id
                        for edge_id, edge_info in self.edge_servers.items()
                    }

                    # Collect results as each edge server finishes
                    for future in as_completed(future_to_edge):
                        edge_id = future_to_edge[future]
                        try:
                            updated_state = future.result()  # Result is updated edge model state
                            # Update the edge server's model state
                            self.edge_servers[edge_id]['model_state'] = updated_state
                            print(f"  Edge Server {edge_id}: Updated model after Edge Iteration {edge_iteration + 1}")
                        except Exception as e:
                            print(f"Edge Server {edge_id} encountered an error: {e}")

            # After all edge iterations, perform global aggregation using FedNova
            edge_deltas = []
            edge_steps_list = []
            total_global_steps = 0

            for edge_id, edge_info in self.edge_servers.items():
                delta_model = {}
                for key in edge_info['model_state'].keys():
                    delta_model[key] = edge_info['model_state'][key] - self.global_model_state[key]
                edge_deltas.append(delta_model)
                edge_steps_list.append(edge_info['total_edge_steps'])
                total_global_steps += edge_info['total_edge_steps']

            # Aggregate edge server updates using FedNova
            self.global_model_state = self.aggregate_models_global_fednova(
                self.global_model_state,
                edge_deltas,
                edge_steps_list,
                total_global_steps
            )
            global_model_state = deepcopy(self.global_model_state)
            print(f"Global Iteration {global_iteration + 1}: Aggregated all edge models using FedNova")

            # Update edge server models with the new global model
            for edge_id in self.edge_servers.keys():
                self.edge_servers[edge_id]['model_state'] = deepcopy(self.global_model_state)

            # Cloud server computation time for aggregation
            num_edge_models = len(self.edge_servers)
            aggregation_time = num_edge_models * 0.2  # 0.2 seconds per edge model
            cloud_computation_time += aggregation_time
            cloud_computation_energy += aggregation_time * self.computation_energy_rate

            # Evaluate global model and store accuracy
            global_accuracy = self.evaluate_model(global_model_state, self.get_test_loader())
            print(f"Global Iteration {global_iteration + 1}: Global Model Accuracy: {global_accuracy:.2f}%")

            self.accuracies['global_iterations'].append(global_accuracy)

            # Save the checkpoint after evaluating the global model
            self.save_checkpoint(self.global_model_state, global_iteration)

            # Periodically write accuracies to file
            with open(self.metrics_file, 'w') as f:
                json.dump(self.accuracies, f, indent=4)

            # Note: Assignments will be recalculated at the beginning of the next global iteration

        # Store cloud server energy and time
        self.energy_consumption['cloud_server'] = cloud_computation_energy
        self.time_delays['cloud_server'] = cloud_computation_time

    def main(self, global_iterations, edge_iterations, local_epochs):
        """
        Initialize the global model and start the federated learning process.
        """
        self.global_iterations = global_iterations
        self.edge_iterations = edge_iterations
        self.local_epochs = local_epochs

        # Initialize global model state on CPU/GPU
        global_model = self.model_class(**self.model_args).to(self.device)
        self.global_model_state = deepcopy(global_model.state_dict())  # Initialize global model state

        print("Starting Federated Learning...")
        # Run federated learning
        self.federated_learning(self.global_model_state, max_parallel_edge_servers=1)

        # Evaluate final global model (Optional: Remove if federated_learning already does)
        # If you want to evaluate the final model again, keep it. Otherwise, remove.
        # final_accuracy = self.evaluate_model(self.global_model_state, self.get_test_loader())
        # print(f"Final Global Model Accuracy: {final_accuracy:.2f}%")

        # Save the accuracies to JSON
        with open(self.metrics_file.replace('.json', '_accuracies.json'), 'w') as json_file:
            json.dump(self.accuracies, json_file, indent=4)

        # Save summary metrics
        self.save_summary_metrics(self.metrics_file.replace('.json', '_full_metrics.json'))

        # After training, consolidate metrics into a single dictionary
        metrics_summary = {
            "Energy Consumption": self.energy_consumption,
            "Time Delays": self.time_delays,
            "Bandwidth Usage": self.bandwidth_usage,
        }

        # Save metrics to a JSON file
        with open(self.metrics_file.replace('.json', '_summary.json'), 'w') as f:
            json.dump(metrics_summary, f, indent=4)

    def get_test_loader(self):
        """
        Create a DataLoader for the test dataset.
        """
        test_dataset = TensorDataset(
            torch.tensor(self.test_images, dtype=torch.float32),
            torch.tensor(self.test_labels, dtype=torch.long)
        )
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
        return test_loader

    def prepare_data_loader(self, images, labels, batch_size):
        """
        Prepare DataLoader with correct tensor types.

        Args:
            images (numpy array or similar): Input images.
            labels (numpy array or similar): Corresponding labels.
            batch_size (int): Batch size.

        Returns:
            DataLoader: Prepared DataLoader.
        """
        tensor_dataset = TensorDataset(
            torch.tensor(images, dtype=torch.float32),  # Ensure images are float
            torch.tensor(labels, dtype=torch.long)       # Labels should be long
        )
        loader = DataLoader(tensor_dataset, batch_size=batch_size, shuffle=True)
        return loader


In [None]:
class MainAgent:
    def __init__(self, devices_df, num_edge_servers, metrics_file, cluster_to_server_map):
        self.devices_df = devices_df.reset_index(drop=True)
        self.num_edge_servers = num_edge_servers
        self.cluster_to_server_map = cluster_to_server_map  # Pass cluster-to-server map

        # Prepare data for cluster assignment
        self.cluster_bandwidth = self.devices_df.groupby('cluster')['bandwidth'].sum().values
        self.edge_server_capacities = self.generate_edge_server_capacities()

        # Create environments
        self.cluster_env = ClusterAssignmentEnv(self.cluster_bandwidth, self.edge_server_capacities, devices_df=self.devices_df)
        self.scheduling_env = DeviceSchedulingEnv(self.devices_df, self.cluster_to_server_map)  # Pass map

        # Load trained sub-agents
        self.cluster_agent = PPO.load(metrics_file.replace('.json', '') + "_cluster_assignment_agent")
        self.scheduling_agent = PPO.load(metrics_file.replace('.json', '') + "_device_scheduling_agent")

        self.accuracies = {
            'local_epochs': {},       # Store local epoch accuracies
            'edge_iterations': {},    # Store edge iteration accuracies
            'global_iterations': []   # Store global iteration accuracies
        }

    def generate_edge_server_capacities(self):
        max_cluster_bandwidth = self.cluster_bandwidth.max()
        total_bandwidth_needed = self.cluster_bandwidth.sum()
        base_capacity = max(total_bandwidth_needed / self.num_edge_servers, max_cluster_bandwidth)

        # Ensure a minimum capacity to prevent zero-capacity servers
        min_capacity = base_capacity * 0.8
        max_capacity = base_capacity * 1.6

        edge_server_capacities = np.random.uniform(
            min_capacity,
            max_capacity,
            size=self.num_edge_servers
        )
        edge_server_capacities = np.round(edge_server_capacities).astype(int)

        # Ensure no server has zero capacity
        edge_server_capacities = np.maximum(edge_server_capacities, 1)
        return edge_server_capacities

    def run(self, federated_system_params):
        """
        Executes cluster assignment and device scheduling, updates `devices_df`,
        and integrates it into the Federated Learning System.
        """
        print("Running Cluster Assignment and Device Scheduling Agents...")

        # Cluster Assignment
        cluster_state = self.cluster_env.reset()
        cluster_action, _ = self.cluster_agent.predict(cluster_state)
        assignments = dict(zip(range(self.cluster_env.num_clusters), cluster_action))
        self.devices_df['assigned_servers'] = self.devices_df['cluster'].map(assignments)

        # Evaluate Cluster Assignment
        cluster_metrics = self.cluster_env.evaluate(assignments)
        print(f"Cluster Assignment Evaluation Metrics: {cluster_metrics}")
        with open(federated_system_params['metrics_file'].replace('.json', '_cluster_assignment_evaluation.json'), 'w') as f:
            json.dump(cluster_metrics, f, indent=4)

        # Device Scheduling
        scheduling_state = self.scheduling_env.reset()
        scheduling_action, _ = self.scheduling_agent.predict(scheduling_state)
        scheduled_devices = [i for i, a in enumerate(scheduling_action) if a == 1]
        self.devices_df['is_scheduled'] = False
        self.devices_df.loc[scheduled_devices, 'is_scheduled'] = True

        # Evaluate Device Scheduling
        scheduling_metrics = self.scheduling_env.evaluate(scheduling_action)
        print(f"Device Scheduling Evaluation Metrics: {scheduling_metrics}")
        with open(federated_system_params['metrics_file'].replace('.json', '_device_scheduling_evaluation.json'), 'w') as f:
            json.dump(scheduling_metrics, f, indent=4)

        print("Updated `devices_df` with assignments and scheduling information.")

        # Prepare parameters for the FederatedLearningSystem
        filtered_params = {
            k: v for k, v in federated_system_params.items()
            if k not in ['test_images', 'test_labels', 'dataset_name', 'metrics_file']
        }

        # Create the FederatedLearningSystem
        federated_system = FederatedLearningSystem(
            devices_df=self.devices_df,
            test_images=federated_system_params['test_images'],
            test_labels=federated_system_params['test_labels'],
            dataset_name=federated_system_params['dataset_name'],
            metrics_file=federated_system_params['metrics_file'],
            cluster_agent=self.cluster_agent,
            scheduling_agent=self.scheduling_agent,
            cluster_env=self.cluster_env,
            scheduling_env=self.scheduling_env,
            **filtered_params
        )

        # Run federated learning
        print("Starting Federated Learning...")
        federated_system.main(
            global_iterations=federated_system_params['global_iterations'],
            edge_iterations=federated_system_params['edge_iterations'],
            local_epochs=federated_system_params['local_epochs']
        )
        print("Federated Learning completed.")


# Excution Part

In [None]:
# Example dynamic usage
if __name__ == "__main__":

    gc.collect()    

    # Set parameters dynamically for experimentation
    configurations = [
        {"mfactor": 7, "dataset_name": "mnist", "edge_num": 5, "num_devices": 20, "global_iterations": 10, "edge_server_iterations": 3, "local_epochs": 5, "pr_data_redist": 0.7}, 
        # {"mfactor": 5, "dataset_name": "mnist", "edge_num": 5, "num_devices": 30, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "mnist", "edge_num": 5, "num_devices": 50, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "mnist", "edge_num": 5, "num_devices": 70, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "mnist", "edge_num": 5, "num_devices": 100, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        
        {"mfactor": 7, "dataset_name": "fashion_mnist", "edge_num": 5, "num_devices": 20, "global_iterations": 10, "edge_server_iterations": 3, "local_epochs": 5, "pr_data_redist": 0.7},
        # {"mfactor": 5, "dataset_name": "fashion_mnist", "edge_num": 5, "num_devices": 30, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "fashion_mnist", "edge_num": 5, "num_devices": 50, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "fashion_mnist", "edge_num": 5, "num_devices": 70, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "fashion_mnist", "edge_num": 5, "num_devices": 100, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        
        {"mfactor": 7, "dataset_name": "cifar10", "edge_num": 5, "num_devices": 20, "global_iterations": 10, "edge_server_iterations": 3, "local_epochs": 5, "pr_data_redist": 0.7},
        # {"mfactor": 5, "dataset_name": "cifar10", "edge_num": 5, "num_devices": 30, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "cifar10", "edge_num": 5, "num_devices": 50, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "cifar10", "edge_num": 5, "num_devices": 70, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5},
        # {"mfactor": 5, "dataset_name": "cifar10", "edge_num": 5, "num_devices": 100, "global_iterations": 10, "edge_server_iterations": 5, "local_epochs": 5}           
    ]
    
    # Create a folder to store metric files
    metrics_dir = "metrics"
    os.makedirs(metrics_dir, exist_ok=True)
    print(f"Metrics will be stored in the directory: {metrics_dir}")

    for config in configurations:
        print(f"\nRunning configuration: {config}")

        # Construct the metrics file path
        metrics_file = os.path.join(
            metrics_dir,
            "{}_{}_global_{}_edge_{}_local.json".format(
                config["dataset_name"], 
                config["global_iterations"], 
                config["edge_server_iterations"], 
                config["local_epochs"]
            )
        )    

        # Initialize the data distributor with 5 edge servers
        data_distributor = GNNClustering(num_devices=config["num_devices"], 
                                        dataset_name=config["dataset_name"], 
                                        mfactor=config["mfactor"], 
                                        num_edge_servers=config["edge_num"],
                                        metrics_file=metrics_file)

        # Distribute data and perform clustering
        test_images, test_labels = data_distributor.distribute_data()

        # Distribute data and perform clustering
        devices_df = data_distributor.clustering_devices()

        # Compare clustering methods
        # data_distributor.compare_clustering_methods()

        # Initialize the hybrid data redistributor
        data_redistributor = HybridDataRedistributor(devices_df, 
                                                    dataset_name=config["dataset_name"],
                                                    metrics_file=metrics_file)

        # Perform hybrid data redistribution
        devices_df, label_presence = data_redistributor.redistribute_data(percentage_threshold=config["pr_data_redist"])

        # Prepare cluster bandwidth and edge server capacities for training the cluster assignment agent
        cluster_bandwidth_series = devices_df.groupby('cluster')['bandwidth'].sum()
        cluster_bandwidth = cluster_bandwidth_series.values
        num_edge_servers = devices_df['cluster'].nunique()

        print(f"\nNumber of Edge Servers (Clusters): {num_edge_servers}")
        print(f"Cluster Bandwidths: {cluster_bandwidth_series.to_dict()}")

        # Calculate edge server capacities based on cluster bandwidths
        max_cluster_bandwidth = cluster_bandwidth.max()
        total_bandwidth_needed = cluster_bandwidth.sum()
        base_capacity = max(total_bandwidth_needed / num_edge_servers, max_cluster_bandwidth)
        edge_server_capacities = np.random.uniform(
            base_capacity * 0.8,
            base_capacity * 1.6,
            size=num_edge_servers
        )
        edge_server_capacities = np.round(edge_server_capacities).astype(int)

        print(f"Edge Server Capacities: {edge_server_capacities}")

        # Train the cluster assignment agent and generate the cluster-to-server map
        cluster_model = train_cluster_assignment_agent(
            cluster_bandwidth, 
            edge_server_capacities, 
            devices_df, 
            timesteps=10000, 
            metrics_file=metrics_file
        )

        # Retrieve the cluster-to-server assignments
        cluster_state = ClusterAssignmentEnv(cluster_bandwidth, edge_server_capacities, devices_df).reset()
        cluster_action, _ = cluster_model.predict(cluster_state)
        cluster_to_server_map = {cluster: server for cluster, server in enumerate(cluster_action)}

        print(f"Cluster-to-Server Map: {cluster_to_server_map}")

        # Train the device scheduling agent with the cluster-to-server map
        scheduling_model = train_device_scheduling_agent(
            devices_df, 
            timesteps=10000, 
            metrics_file=metrics_file, 
            cluster_to_server_map=cluster_to_server_map
        )

        # Initialize the main agent
        main_agent = MainAgent(
            devices_df, 
            config["edge_num"], 
            metrics_file, 
            cluster_to_server_map
        )

        # Parameters for federated learning
        federated_system_params = {
            'test_images': test_images,
            'test_labels': test_labels,
            'dataset_name': config["dataset_name"],
            'metrics_file': metrics_file,
            'global_iterations': config["global_iterations"],
            'edge_iterations': config["edge_server_iterations"],
            'local_epochs': config["local_epochs"],
            'input_channels': 1,
            'num_classes': 10
        }

        # Run the main agent
        main_agent.run(federated_system_params)

        print(f"Configuration {config} completed.\n")

        gc.collect()  
