# Energy-Efficient Wireless Sensor Network (EE-WSN) Simulation

This notebook compares 6 clustering algorithms for WSN energy optimization:
1. **RLHC_PROPOSED** - Novel RL-driven hybrid clustering method
2. **LEACH_BASELINE** - Low-Energy Adaptive Clustering Hierarchy
3. **DEEC_BASELINE** - Distributed Energy-Efficient Clustering
4. **FUZZY_C_MEANS_BASELINE** - Fuzzy C-Means clustering
5. **PSO_BASELINE** - Particle Swarm Optimization with Genetic Algorithm
6. **ACO_BASELINE** - Ant Colony Optimization

## 1. Setup & Installation

In [None]:
# Install required packages (uncomment if needed)
# !pip install numpy matplotlib pandas scikit-fuzzy

import numpy as np
import random
import heapq
import matplotlib.pyplot as plt
import pandas as pd
import math
from collections import defaultdict
import skfuzzy as fuzz
from IPython.display import HTML
import matplotlib.animation as animation

print('All packages imported successfully!')

## 2. Configuration Constants

In [None]:
# Simulation Parameters
AREA_SIZE = 1000  # meters L, B
N_NODES = 150       # number of sensor nodes
INIT_ENERGY = 2.0   # joules per node
NUM_CLUSTERS = int(N_NODES * 0.1)
ROUNDS = 1000
PACKET_SIZE_BYTES = 128
PACKET_SIZE_BITS = PACKET_SIZE_BYTES * 8
N_HALF = N_NODES // 2
SEED = 10
SECTOR_ROWS = 3
SECTOR_COLS = 3
BASE_STATION_FIXED = (AREA_SIZE / 2, AREA_SIZE / 2)
BS_TRAJECTORY = []

# Routing Parameters
GLOBAL_ROUTING_INFO = {} # Holds RLHC routing info
ENABLE_MULTI_HOP = True

# Base/Sink Mobility
BS_MODE = "mobile_circle"   # fixed, mobile_circle, mobile_random

if BS_MODE == "mobile_circle":
    ROUNDS = 500
    EFFECTIVE_ROUNDS = 275  # BS Complete one circle in these many rounds
elif BS_MODE == "mobile_random":
    ROUNDS = 150


# Radio Energy Model (First-order)
E_ELEC = 50e-9  # J/bit (typical)
EPS_FS = 10e-12  # J/bit/m^2 (free space)
EPS_MP = 0.0013e-12  # J/bit/m^4 (multipath)
D0 = math.sqrt(EPS_FS / EPS_MP) if EPS_MP > 0 else 87.7  # threshold


# DEEC Parameters
P_OPT = NUM_CLUSTERS / N_NODES  # desired CH fraction
MAX_CH = NUM_CLUSTERS
M_FRACTION_ADV = 0.1     # fraction of advanced nodes (e.g. 10%) Heterogeneous
A_ADV_ENERGY   = 1.0     # advanced nodes have (1 + a)*E0 energy, here 2*E0
E0 = INIT_ENERGY         # you already have INIT_ENERGY


# RL Parameters (tabular Q-learning)
STATE_BINS_E = 4       # energy levels
STATE_BINS_LOAD = 3    # channel load levels
STATE_BINS_PDR = 3     # packet delivery ratio levels
EPSILON = 1.0
EPS_MIN = 0.05         # minimum epsilon
EPS_DECAY = 0.995      # epsilon decay per episode
ALPHA = 0.1            # learning rate
GAMMA = 0.95           # discount factor

# RL Actions
A_NOOP = 0          # no operation
A_REASSIGN_FEW = 1  # reassign few nodes
A_SWITCH_CH = 2     # switch cluster head
A_REDUCE_TX = 3     # reduce transmission power
A_INCREASE_TX = 4   # increase transmission power
ACTIONS = [A_NOOP, A_REASSIGN_FEW, A_SWITCH_CH, A_REDUCE_TX, A_INCREASE_TX]


# PSO parameters
W = 0.9       # inertia weight
C1 = 1.8      # cognitive weight
C2 = 1.8      # social weight
PSO_POP = 30  # number of particles
GA_GEN = 10    # number of GA generations per round


# ACO parameters
ACO_ALPHA = 1.0      # pheromone importance
ACO_BETA = 3.0       # heuristic importance
RHO = 0.3            # pheromone evaporation
ACO_Q = 50          # pheromone constant
ANTS = NUM_CLUSTERS  # number of ants


print('Configuration loaded successfully!')

## 3. Utility Functions

In [None]:
def energy_tx(bits, d):
    """Transmission energy for bits over distance d (meters)."""
    if d < D0:
        return bits * E_ELEC + bits * EPS_FS * (d**2)
    else:
        return bits * E_ELEC + bits * EPS_MP * (d**4)

def energy_rx(bits):
    return bits * E_ELEC

def distance(a, b):
    return math.hypot(a[0]-b[0], a[1]-b[1]) # Find 2D Euclidean Distance

def discretize(value, bins):
    return int(np.digitize([value], bins=bins)[0])

def build_bins():
    # 3 load levels: low (<1.2), medium (1.2–1.5), high (>1.5) imbalance
    load_bins = [1.2, 1.5]
    # 3 PDR levels: low (<0.90), medium (0.90–0.97), high (>0.97) reliability
    pdr_bins = [0.90, 0.97]
    return load_bins, pdr_bins

def state_from_metrics(avg_energy, cluster_sizes, pdr, LOAD_BINS, PDR_BINS):
    norm_avg_e = avg_energy / INIT_ENERGY
    e_bins = np.linspace(0, 1, STATE_BINS_E + 1)[1:-1]
    e_bucket = discretize(norm_avg_e, bins=e_bins)

    if len(cluster_sizes) == 0:
        load_ratio = 1.0
    else:
        load_ratio = max(cluster_sizes) / (np.mean(cluster_sizes) + 1e-12)

    load_bucket = discretize(load_ratio, bins=LOAD_BINS)
    pdr_bucket = discretize(pdr, bins=PDR_BINS)
    return (e_bucket, load_bucket, pdr_bucket)

def measure_metrics(nodes_energy, clusters, chs, successful_packets, total_packets):
    alive = int(np.sum(nodes_energy > 0))
    avg_e = float(np.mean(nodes_energy))
    pdr = successful_packets / (total_packets + 1e-12)
    cluster_sizes = [len(c) for c in clusters] if clusters else []
    return avg_e, pdr, alive, cluster_sizes

def initialize_heterogeneous_energies():
    """
    Initialize node energies for 2-level heterogeneous DEEC such that
    the *total* initial energy equals N_NODES * INIT_ENERGY.

    - Normal nodes: E0_eff
    - Advanced nodes: (1 + A_ADV_ENERGY) * E0_eff
    Fraction of advanced nodes: M_FRACTION_ADV
    """

    E0_eff = INIT_ENERGY / (1.0 + A_ADV_ENERGY * M_FRACTION_ADV)

    nodes_energy = np.array([E0_eff] * N_NODES, dtype=float)
    is_advanced = np.zeros(N_NODES, dtype=bool)

    num_adv = int(M_FRACTION_ADV * N_NODES)
    if num_adv > 0:
        adv_indices = np.random.choice(N_NODES, size=num_adv, replace=False)
        is_advanced[adv_indices] = True
        nodes_energy[adv_indices] = (1.0 + A_ADV_ENERGY) * E0_eff

    return nodes_energy, is_advanced


def assign_static_sectors(nodes_pos, area_size=1000, rows=3, cols=3):
    """
    nodes_pos: list or array of shape (N_NODES, 2) with (x, y) positions
    area_size: field is [0, area_size] x [0, area_size]
    rows, cols: grid sectors (3x3 => 9 sectors)
    returns: sector_ids[i] = sector index (0 .. rows*cols-1) for node i
    """
    sector_ids = []
    cell_w = area_size / cols   # ~333.33 m
    cell_h = area_size / rows   # ~333.33 m

    for (x, y) in nodes_pos:
        # column index 0,1,2
        c = int(x // cell_w)
        if c >= cols:
            c = cols - 1  # handle x == area_size

        # row index 0,1,2
        r = int(y // cell_h)
        if r >= rows:
            r = rows - 1  # handle y == area_size

        # sector index: row-major
        sector_idx = r * cols + c
        sector_ids.append(sector_idx)

    return np.array(sector_ids, dtype=int)


import numpy as np

def get_bs_position(rnd=None):
    global BS_TRAJECTORY

    # -------- Scenario 1: fixed BS --------
    if BS_MODE == "fixed":
        return BASE_STATION_FIXED

    # -------- Scenario 2: circular mobile BS (your current logic) --------
    if BS_MODE == "mobile_circle":
        if rnd is None:
            if BS_TRAJECTORY:
                return BS_TRAJECTORY[-1]
            return BASE_STATION_FIXED

        if len(BS_TRAJECTORY) >= rnd:
            return BS_TRAJECTORY[rnd - 1]

        cx, cy = AREA_SIZE / 2, AREA_SIZE / 2
        radius_ratio = 0.3
        radius = radius_ratio * AREA_SIZE

        step = 2 * np.pi / EFFECTIVE_ROUNDS

        for r in range(len(BS_TRAJECTORY) + 1, rnd + 1):
            angle = step * r
            x = cx + radius * np.cos(angle)
            y = cy + radius * np.sin(angle)
            BS_TRAJECTORY.append((x, y))

        return BS_TRAJECTORY[rnd - 1]

    # -------- Scenario 3: random position each round --------
    if BS_MODE == "mobile_random":
        if rnd is None:
            # If no round given, return last if exists, else random
            if BS_TRAJECTORY:
                return BS_TRAJECTORY[-1]
            x = np.random.uniform(0, AREA_SIZE)
            y = np.random.uniform(0, AREA_SIZE)
            BS_TRAJECTORY.append((x, y))
            return BS_TRAJECTORY[-1]

        # If already computed for this round, reuse it (to keep reproducibility)
        if len(BS_TRAJECTORY) >= rnd:
            return BS_TRAJECTORY[rnd - 1]

        # Generate random positions up to 'rnd'
        for r in range(len(BS_TRAJECTORY) + 1, rnd + 1):
            x = np.random.uniform(0, AREA_SIZE)
            y = np.random.uniform(0, AREA_SIZE)
            BS_TRAJECTORY.append((x, y))

        return BS_TRAJECTORY[rnd - 1]

    # Fallback
    return BASE_STATION_FIXED


# To store routing info
def init_method_routing(method_name):
    if method_name not in GLOBAL_ROUTING_INFO:
        GLOBAL_ROUTING_INFO[method_name] = {
            "chs": {},
            "clusters": {},
            "links": {}
        }

def log_link(method_name, rnd, src, dst):
    init_method_routing(method_name)
    links = GLOBAL_ROUTING_INFO[method_name]["links"]
    if rnd not in links:
        links[rnd] = []
    links[rnd].append((src, dst))

def log_chs_and_clusters(method_name, rnd, chs, clusters):
    """
    chs:      list of CH indices
    clusters: list of lists (cluster membership)
    """
    init_method_routing(method_name)
    GLOBAL_ROUTING_INFO[method_name]["chs"][rnd] = list(chs)
    GLOBAL_ROUTING_INFO[method_name]["clusters"][rnd] = [list(c) for c in clusters]


print('Utility functions defined successfully!')

## 4. Communication Module

In [None]:
def radio_comm(clusters, chs, nodes_energy, nodes_pos, base_station, tx_power_factor, successful_packets=0, total_packets=0):
    """
    End-to-end PDR:
    - Each member node generates one packet.
    - A packet is counted successful ONLY if it finally reaches the base station.
    - Currently: single-hop CH->BS, but structure is compatible with future multi-hop.
    """

    num_nodes = len(nodes_energy)

    # Track which member packets reach their CH
    # aggregated_packets[cidx] = number of member packets successfully stored at CH of cluster cidx
    aggregated_packets = [0] * len(clusters)

    # 1) Member -> CH phase
    for cidx, members in enumerate(clusters):
        ch = chs[cidx]

        for node in members:
            if nodes_energy[node] <= 0:
                # dead node cannot generate/transmit a packet
                continue

            # Each (alive) member attempts to send ONE packet this round
            total_packets += 1

            # If CH is dead, packet cannot be delivered
            if nodes_energy[ch] <= 0:
                # Optional: decide whether node still spends energy trying to transmit
                # For now, model that node tries to transmit even if CH is dead.
                d = distance(nodes_pos[node], nodes_pos[ch])
                etx = energy_tx(PACKET_SIZE_BITS, d) * tx_power_factor
                if nodes_energy[node] >= etx:
                    nodes_energy[node] -= etx
                else:
                    nodes_energy[node] = 0
                # No success counted here (end-to-end)
                continue

            # CH is alive: simulate TX and RX
            d = distance(nodes_pos[node], nodes_pos[ch])
            etx = energy_tx(PACKET_SIZE_BITS, d) * tx_power_factor
            erx = energy_rx(PACKET_SIZE_BITS)

            # Check if both sides have enough energy for this communication
            if nodes_energy[node] >= etx and nodes_energy[ch] >= erx:
                nodes_energy[node] -= etx
                nodes_energy[ch] -= erx
                # Packet is now successfully at CH (but not yet at BS)
                aggregated_packets[cidx] += 1
            else:
                # Not enough energy to complete TX/RX
                if nodes_energy[node] >= etx:
                    nodes_energy[node] -= etx   # node still spends energy trying
                elif nodes_energy[node] > 0:
                    nodes_energy[node] = 0

                # Optional: you may also burn some CH energy even on failed RX
                # if nodes_energy[ch] >= erx: nodes_energy[ch] -= erx

                # Do NOT count success here (end-to-end)
                continue

    # 2) CH -> BS phase
    # For end-to-end PDR, a member's packet is successful only if its CH
    # successfully delivers to BS.
    for cidx, ch in enumerate(chs):
        cluster_packet_count = aggregated_packets[cidx]
        if cluster_packet_count == 0:
            # No data to send from this cluster
            continue

        if nodes_energy[ch] <= 0:
            # CH dead: all aggregated packets from this cluster are lost
            continue

        # One transmission from CH to BS that represents all aggregated packets.
        # If you want per-packet transmissions, you can loop cluster_packet_count times.
        d = distance(nodes_pos[ch], base_station)
        etx = energy_tx(PACKET_SIZE_BITS, d) * tx_power_factor

        if nodes_energy[ch] >= etx:
            # CH successfully sends aggregated data to BS
            nodes_energy[ch] -= etx

            # All aggregated member packets from this cluster are now delivered to BS
            successful_packets += cluster_packet_count
        else:
            # Not enough energy to transmit => all packets from this cluster are lost
            # Optional: still burn remaining energy attempt
            nodes_energy[ch] = max(0, nodes_energy[ch] - etx)

    return nodes_energy, successful_packets, total_packets


## 5. Proposed Method (RLHC) - Layer 1: DEEC

In [None]:
def deec_select(nodes_energy, p_opt, min_chs=1, max_chs_factor=3.0):
    """
    Layer-1: DEEC/DEECP-based CH selection.

    - Uses P_i = p_opt * (E_i / E_avg)
    - Heterogeneity is captured via nodes_energy (advanced nodes have higher E)
    - Returns a set of CHs according to DEECP logic.
    """

    n_nodes = len(nodes_energy)

    E_avg = np.mean(nodes_energy)
    if E_avg <= 0:
        return []  # network dead

    # DEECP probability
    P_i = p_opt * (nodes_energy / E_avg)
    P_i = np.clip(P_i, 0.0, 1.0)

    # Probabilistic selection
    candidates = []
    for i, Pi in enumerate(P_i):
        if random.random() <= Pi:
            candidates.append(i)

    # Control number of CHs
    target_chs = p_opt * n_nodes
    max_chs = int(max_chs_factor * target_chs)

    # Too many CHs -> keep highest P_i
    if len(candidates) > max_chs and max_chs > 0:
        candidates = sorted(candidates, key=lambda idx: P_i[idx], reverse=True)
        candidates = candidates[:max_chs]

    # Too few CHs -> enforce at least min_chs highest P_i
    if len(candidates) < min_chs:
        top_indices = np.argsort(-P_i)[:min_chs]
        candidates = list(top_indices)

    return candidates

print('L1: DEECP selection defined!')


## 5. Proposed Method (RLHC) - Layer 2: EEKA

In [None]:
import numpy as np

def eeka_select_global(
    candidates,
    nodes_energy,
    nodes_pos,
    total_desired_k,
    alpha=0.7,
    beta=0.3
):
    """
    Simple global EEKA-like CH selection.

    - For each node i:
        w_e(i) = E_i / (mean(E_all) + eps)
        c(i)   = 1 / (avg distance to all other nodes + eps)
      Normalize both across all nodes, then:
        U(i) = alpha * e_norm + beta * c_norm

    - Candidate-first: first pick among 'candidates' by U,
      if not enough, fill remaining CHs from non-candidates by U.
    """
    # Make sure these are numpy arrays
    nodes_energy = np.array(nodes_energy, dtype=float)
    nodes_pos = np.array(nodes_pos, dtype=float)

    n_nodes = len(nodes_energy)
    if n_nodes == 0 or total_desired_k <= 0:
        return []

    # 1) Energy weight
    eps = 1e-6
    E_mean = float(nodes_energy.mean())
    w_e = nodes_energy / (E_mean + eps)

    # 2) Centrality (global)
    centrality_vals = []
    for i in range(n_nodes):
        others = [j for j in range(n_nodes) if j != i]
        if len(others) == 0:
            centrality_vals.append(0.0)
            continue
        dists = [np.linalg.norm(nodes_pos[i] - nodes_pos[j]) for j in others]
        avgd = float(np.mean(dists)) if len(dists) > 0 else 1e6
        centrality_vals.append(1.0 / (avgd + eps))
    centrality_vals = np.array(centrality_vals, dtype=float)

    # 3) Normalize helper
    def normalize(arr):
        a_min = float(arr.min())
        a_max = float(arr.max())
        if a_max > a_min:
            return (arr - a_min) / (a_max - a_min)
        else:
            return np.zeros_like(arr)

    e_norm = normalize(w_e)
    c_norm = normalize(centrality_vals)

    # 4) Utility
    utilities = alpha * e_norm + beta * c_norm
    utility_dict = {i: float(utilities[i]) for i in range(n_nodes)}  # <-- use n_nodes

    selected = []

    # Candidate-first
    if candidates is not None and len(candidates) > 0:
        cand_sorted = sorted(candidates, key=lambda i: utility_dict[i], reverse=True)
        selected.extend(cand_sorted[:total_desired_k])

    # Fill with non-candidates if needed
    if len(selected) < total_desired_k:
        remaining_slots = total_desired_k - len(selected)
        non_candidates = [i for i in range(n_nodes) if i not in selected]
        noncand_sorted = sorted(non_candidates, key=lambda i: utility_dict[i], reverse=True)
        selected.extend(noncand_sorted[:remaining_slots])

    return selected[:total_desired_k]

print('L2: EEKA filtering defined!')


## 5. Proposed Method (RLHC) - Layer 3: K-Means

In [None]:
def kmeans_layer(nodes_pos, k, init_ch_indices=None, max_iter=20):
    """
    Layer 3: K-means clustering using Layer-2 CHs as initial centroids.

    Parameters
    ----------
    nodes_pos : array-like of shape (N, 2)
        Positions (x, y) of all (alive) nodes.
    k : int
        Number of clusters (should match number of CHs).
    init_ch_indices : list of int or None
        Indices of CHs selected by Layer 2 (EEKA).
        If provided, these are used as initial centroids.
    max_iter : int
        Maximum number of K-means iterations.

    Returns
    -------
    clusters : list of lists
        clusters[j] is the list of node indices assigned to cluster j.
    centroids : list of (x, y)
        Final centroid positions for each cluster.
    """
    nodes_pos = np.array(nodes_pos, dtype=float)
    n_nodes = len(nodes_pos)

    if k <= 0 or n_nodes == 0:
        return [], []

    # Initialize centroids
    if init_ch_indices is not None and len(init_ch_indices) >= k:
        # Use positions of the first k CHs from Layer 2
        centroids = [tuple(nodes_pos[i]) for i in init_ch_indices[:k]]
    else:
        # Fallback: random initialization
        chosen = np.random.choice(range(n_nodes), k, replace=False)
        centroids = [tuple(nodes_pos[i]) for i in chosen]

    for _ in range(max_iter):
        # Assignment step
        clusters = [[] for _ in range(k)]
        for i, p in enumerate(nodes_pos):
            dists = [distance(p, c) for c in centroids]
            idx = int(np.argmin(dists))
            clusters[idx].append(i)

        # Update step
        new_centroids = []
        for c in clusters:
            if len(c) == 0:
                # Empty cluster: reinitialize centroid randomly
                new_centroids.append(tuple(nodes_pos[np.random.randint(0, n_nodes)]))
            else:
                xs = [nodes_pos[i][0] for i in c]
                ys = [nodes_pos[i][1] for i in c]
                new_centroids.append((float(np.mean(xs)), float(np.mean(ys))))

        # Convergence check
        if all(distance(centroids[i], new_centroids[i]) < 1e-3 for i in range(k)):
            centroids = new_centroids
            break

        centroids = new_centroids

    return clusters, centroids

print("L3: K-means clustering (Layer 3) defined!")


## 5. Proposed Method (RLHC) - Layer 4: RL Q-Learning

In [None]:
def choose_action(state, epsilon, Q):
    if random.random() < epsilon:
        return random.choice(ACTIONS)
    else:
        qs = Q[state]
        maxv = np.max(qs)
        maxacts = [a for a,v in enumerate(qs) if abs(v-maxv) < 1e-9]
        return random.choice(maxacts)

def compute_reward(avg_e_before, avg_e_after, pdr_after, action_cost, target_pdr=0.95):
    """
    Reward function for RL:
    - Encourages slower energy depletion (higher avg_e_after).
    - Encourages PDR >= target_pdr.
    - Penalizes large action_cost.
    """
    delta_E   = avg_e_after - avg_e_before     # usually small negative
    delta_PDR = pdr_after - target_pdr         # positive if PDR > target

    R_alpha = 50.0   # weight for energy
    R_beta  = 20.0   # weight for PDR
    R_gamma = 1.0    # weight for action cost

    reward = R_alpha * delta_E + R_beta * delta_PDR - R_gamma * action_cost
    return reward


def update_q(Q, state, action, reward, avg_e_after, cluster_sizes, pdr_after, epsilon, LOAD_BINS, PDR_BINS):
    next_state = state_from_metrics(avg_e_after, cluster_sizes, pdr_after, LOAD_BINS, PDR_BINS)
    s = state
    a = action
    oldQ = Q[s][a]
    Q[s][a] = oldQ + ALPHA * (reward + GAMMA * np.max(Q[next_state]) - oldQ)
    epsilon = max(EPS_MIN, epsilon * EPS_DECAY)
    return Q, epsilon

print('L4: Q-Learning module defined!')

## 5. Proposed Method (RLHC) - Layer 5: Communication

In [None]:
def link_cost_ch_to_ch(i, j, bits, chs, ch_positions, nodes_energy,
                       alpha_energy=0.1, relay_min_energy=0.07):
    """
    Cost for link CH_i -> CH_j (indices in chs list).

    - If CH_j energy < relay_min_energy: we avoid using it as relay.
    - Otherwise: tx energy + penalty based on relay's energy.
    """
    relay_node_id = chs[j]
    relay_energy = nodes_energy[relay_node_id] + 1e-9  # avoid div-by-zero

    # HARD FILTER: do not use almost-dead CHs as relays
    if relay_energy < relay_min_energy:
        return 1e12  # huge cost: Dijkstra will avoid this edge

    from_pos = ch_positions[i]
    to_pos   = ch_positions[j]
    d = distance(from_pos, to_pos)
    E_tx = energy_tx(bits, d)

    # SOFT penalty: prefer high-energy relays
    penalty = alpha_energy * (1.0 / relay_energy)

    return E_tx + penalty


def link_cost_ch_to_bs(i, bits, ch_positions, bs_pos):
    """
    Cost for link CH_i -> BS.
    """
    d = distance(ch_positions[i], bs_pos)
    return energy_tx(bits, d)


def dijkstra_ch_to_bs(start_idx, neighbors, chs, ch_positions,
                      nodes_energy, bs_pos, bits,
                      alpha_energy=0.1, relay_min_energy=0.07):
    """
    Compute min-cost path from CH[start_idx] to BS.
    Returns:
        path_idx: list of indices (0..len(chs)-1) plus -1 representing BS
        best_cost: total cost along this path
    """
    N = len(chs)
    INF = 1e18
    dist_arr = [INF] * N
    prev = [None] * N

    dist_arr[start_idx] = 0.0
    pq = [(0.0, start_idx)]

    while pq:
        cur_d, u = heapq.heappop(pq)
        if cur_d > dist_arr[u]:
            continue

        for v in neighbors[u]:
            c_uv = link_cost_ch_to_ch(i=u, j=v, bits=bits, chs=chs, ch_positions=ch_positions, nodes_energy=nodes_energy, alpha_energy=alpha_energy, relay_min_energy=relay_min_energy)
            nd = cur_d + c_uv
            if nd < dist_arr[v]:
                dist_arr[v] = nd
                prev[v] = u
                heapq.heappush(pq, (nd, v))

    # Best option: either direct start->BS or via some CH k -> BS
    best_cost = link_cost_ch_to_bs(start_idx, bits, ch_positions, bs_pos)
    best_last = None  # None => direct to BS

    for k in range(N):
        if dist_arr[k] >= INF:
            continue
        cost_k_bs = link_cost_ch_to_bs(k, bits, ch_positions, bs_pos)
        total_cost = dist_arr[k] + cost_k_bs
        if total_cost < best_cost:
            best_cost = total_cost
            best_last = k

    # Reconstruct path
    if best_last is None:
        # Best is direct: [start_idx, BS]
        return [start_idx, -1], best_cost

    # If best_last has no predecessor, path is [start_idx (=best_last), BS]
    if prev[best_last] is None and best_last == start_idx:
        return [start_idx, -1], best_cost

    # General case: follow predecessors back to start
    path_nodes = [best_last]
    cur = best_last
    while prev[cur] is not None:
        cur = prev[cur]
        if cur not in path_nodes:
            path_nodes.append(cur)
        else:
            # safety break in case of any weird cycle (should not happen)
            break

    # Ensure start_idx is present
    if start_idx not in path_nodes:
        path_nodes.append(start_idx)

    path_nodes.reverse()  # from start_idx ... best_last
    path_idx = path_nodes + [-1]  # append BS as -1

    return path_idx, best_cost


def multi_hop_comm_layer5(clusters, chs, nodes_energy, nodes_pos, base_station, tx_power_factor, pkt_size, max_hop_dist, alpha_energy=0.1,
    relay_min_energy=0.7,
    max_hops_per_route=2,
    forward_limit=3
):
    """
    Layer 5 communication:
      1) members -> CH (single-hop)
      2) CHs -> BS (multi-hop, optimized routes)
    Returns:
      nodes_energy (updated),
      successful_packets,
      total_packets,
      links_this_round  # NEW: list of (src, dst) hops for this round
    """
    successful_packets = 0
    total_packets = 0

    # NEW: collect all hops of this round
    links_this_round = []

    # -----------------------------
    # Build CH graph & routes
    # -----------------------------
    ch_positions = [nodes_pos[ch] for ch in chs]
    N_ch = len(chs)

    neighbors = [[] for _ in range(N_ch)]
    for i in range(N_ch):
        for j in range(N_ch):
            if i == j:
                continue
            dij = distance(ch_positions[i], ch_positions[j])
            if dij <= max_hop_dist:
                neighbors[i].append(j)

    ch_routes = {}
    ch_index = {ch_id: idx for idx, ch_id in enumerate(chs)}

    for idx, ch_node_id in enumerate(chs):
        if nodes_energy[ch_node_id] <= 0:
            ch_routes[ch_node_id] = [ch_node_id, "BS"]
            continue

        path_idx, _ = dijkstra_ch_to_bs(
            start_idx=idx,
            neighbors=neighbors,
            chs=chs,
            ch_positions=ch_positions,
            nodes_energy=nodes_energy,
            bs_pos=base_station,
            bits=pkt_size,
            alpha_energy=alpha_energy,
            relay_min_energy=relay_min_energy
        )

        node_path = []
        for p in path_idx:
            if p == -1:
                node_path.append("BS")
            else:
                node_path.append(chs[p])

        hops = len(node_path) - 1
        if hops > max_hops_per_route:
            node_path = [ch_node_id, "BS"]

        ch_routes[ch_node_id] = node_path

    # -----------------------------
    # 1) Intra-cluster: members -> CH
    # -----------------------------
    for cidx, cluster_nodes in enumerate(clusters):
        if not cluster_nodes:
            continue
        ch_node = chs[cidx]
        if nodes_energy[ch_node] <= 0:
            continue

        ch_pos = nodes_pos[ch_node]
        for node in cluster_nodes:
            if node == ch_node:
                continue
            if nodes_energy[node] <= 0:
                continue

            d = distance(nodes_pos[node], ch_pos)
            E_tx = energy_tx(pkt_size, d) * tx_power_factor
            E_rx = energy_rx(pkt_size)

            nodes_energy[node]    -= E_tx
            nodes_energy[ch_node] -= E_rx

            # log member -> CH hop
            links_this_round.append((node, ch_node))

            total_packets += 1
            successful_packets += 1  # assuming reliable intra-cluster

    # -----------------------------
    # 2) Inter-cluster: CHs -> BS (multi-hop)
    # -----------------------------
    forward_count = {ch: 0 for ch in chs}

    for ch_node in chs:
        if nodes_energy[ch_node] <= 0:
            continue

        route = ch_routes[ch_node]  # e.g. [ch, relay1, ..., "BS"]
        if len(route) < 2:
            continue

        first_hop_counted = False
        sender = route[0]

        for hop_idx in range(len(route) - 1):
            receiver = route[hop_idx + 1]

            if sender == "BS":
                break
            if nodes_energy[sender] <= 0:
                break  # sender dead, packet lost

            if receiver == "BS":
                # last hop to BS
                d = distance(nodes_pos[sender], base_station)
                E_tx = energy_tx(pkt_size, d) * tx_power_factor
                nodes_energy[sender] -= E_tx

                # log CH -> BS hop as (sender, -1) so we can handle BS specially
                links_this_round.append((sender, -1))

                if not first_hop_counted:
                    total_packets += 1
                    successful_packets += 1
                    first_hop_counted = True

                break

            else:
                # relay CH
                if nodes_energy[receiver] <= 0:
                    break  # relay dead, packet lost

                if forward_count[receiver] >= forward_limit:
                    break  # relay overloaded

                d = distance(nodes_pos[sender], nodes_pos[receiver])
                E_tx = energy_tx(pkt_size, d) * tx_power_factor
                E_rx = energy_rx(pkt_size)

                nodes_energy[sender]   -= E_tx
                nodes_energy[receiver] -= E_rx

                forward_count[receiver] += 1

                # log CH -> relay hop
                links_this_round.append((sender, receiver))

                if not first_hop_counted:
                    total_packets += 1
                    successful_packets += 1
                    first_hop_counted = True

                sender = receiver  # move along the path

    return nodes_energy, successful_packets, total_packets, links_this_round


## 5. Proposed Method (RLHC) - Main Algorithm

In [None]:
def run_proposed():
    np.random.seed(SEED)
    random.seed(SEED)

    avg_energy_history = []
    pdr_history = []
    alive_nodes_history = []
    alive_nodes_per_round = []
    num_ch_history = []
    reward_history = []
    epsilon_history = []
    throughput_history = []
    pdr_percent_history = []
    total_energy_history = []
    first_dead_round = None
    half_dead_round = None
    last_dead_round = None

    # RL state-discretization bins and Q-table
    LOAD_BINS, PDR_BINS = build_bins()
    Q = defaultdict(lambda: np.zeros(len(ACTIONS)))
    epsilon = EPSILON   # e.g., 1.0

    # Network initialization
    nodes_pos = [(random.uniform(0, AREA_SIZE), random.uniform(0, AREA_SIZE)) for _ in range(N_NODES)]
    sector_ids = assign_static_sectors(nodes_pos, AREA_SIZE, SECTOR_ROWS, SECTOR_COLS)


    # Heterogeneous initial energies ONCE (not every round)
    nodes_energy, is_advanced = initialize_heterogeneous_energies()
    nodes_energy = np.array(nodes_energy, dtype=float)

    tx_power_factor = 1.0
    TARGET_PDR = 0.9   # desired PDR target for reward shaping

    for rnd in range(1, ROUNDS + 1):
        base_station = get_bs_position(rnd)
        # -------------------------
        # Layer 1: DEEC/DEECP → candidates
        # -------------------------
        candidates = deec_select(nodes_energy, p_opt=P_OPT)

        # -------------------------
        # Layer 2: EEKA based → preliminary CHs
        # -------------------------
        final_CHs = eeka_select_global(
            candidates=candidates,
            nodes_energy=nodes_energy,
            nodes_pos=nodes_pos,
            total_desired_k=NUM_CLUSTERS,
            alpha=0.7,
            beta=0.3
        )

        # -------------------------
        # Layer 3: K-means clustering around Layer-2 CHs
        # -------------------------
        clusters, centroids = kmeans_layer(
            nodes_pos=nodes_pos,
            k=len(final_CHs),           # usually == NUM_CLUSTERS
            init_ch_indices=final_CHs,  # Layer-2 CH indices as init points
            max_iter=20
        )

        # Final CH refinement per cluster
        final_chs = []
        for cidx, cluster_nodes in enumerate(clusters):
            if len(cluster_nodes) == 0:
                # Empty cluster: choose the global highest-energy node as CH (fallback)
                final_chs.append(int(np.argmax(nodes_energy)))
                continue
            centroid = centroids[cidx]

            # CHs that EEKA (Layer 2) selected within this cluster:
            ee_ka_in_cluster = [ch for ch in final_CHs if ch in cluster_nodes]
            if ee_ka_in_cluster:
                nearest_ch = min(ee_ka_in_cluster, key=lambda ch: distance(nodes_pos[ch], centroid))
                final_chs.append(nearest_ch)
            else:
                # If no EEKA CH in this cluster, pick highest-energy member
                highest_energy_node = max(cluster_nodes, key=lambda n: nodes_energy[n])
                final_chs.append(highest_energy_node)

        chs = final_chs

        # -------------------------
        # Layer 4: Q-learning based control
        # -------------------------
        avg_e_before = np.mean(nodes_energy)

        # Define RL state from current metrics (before action)
        cluster_sizes = [len(c) for c in clusters] if clusters else [0]
        state = state_from_metrics(
            avg_e_before,
            cluster_sizes,
            TARGET_PDR,   # using target PDR in state representation as well
            LOAD_BINS,
            PDR_BINS
        )

        # Choose action via epsilon-greedy
        action = choose_action(state, epsilon, Q)

        # Check if reassignment is worth doing (load imbalance condition)
        max_cluster = max([len(c) for c in clusters]) if clusters else 0
        mean_cluster = np.mean([len(c) for c in clusters]) if clusters else 0
        allow_reassign = (mean_cluster > 0 and max_cluster > 1.5 * mean_cluster)
        if action == A_REASSIGN_FEW and not allow_reassign:
            action = A_NOOP

        # Apply action
        action_cost = 0.0
        if action == A_NOOP:
            pass

        elif action == A_REASSIGN_FEW:
            largest_idx = int(np.argmax([len(c) for c in clusters]))
            if len(clusters[largest_idx]) > 0:
                m = max(1, int(0.05 * len(clusters[largest_idx])))
                members = clusters[largest_idx].copy()
                # Sort members: farthest from CH first
                members.sort(
                    key=lambda n: distance(nodes_pos[n], nodes_pos[chs[largest_idx]]),
                    reverse=True
                )
                moved = 0
                for node in members:
                    if moved >= m:
                        break
                    # Candidate clusters: those less loaded than mean
                    candidates_ch = [
                        i for i in range(len(clusters))
                        if len(clusters[i]) < mean_cluster
                    ]
                    if not candidates_ch:
                        break
                    best = min(
                        candidates_ch,
                        key=lambda cidx: distance(nodes_pos[node], nodes_pos[chs[cidx]])
                    )
                    clusters[largest_idx].remove(node)
                    clusters[best].append(node)
                    moved += 1
                    action_cost += 0.5   # slightly higher cost for reassignment

        elif action == A_SWITCH_CH:
            largest_idx = int(np.argmax([len(c) for c in clusters]))
            if len(clusters[largest_idx]) > 1:
                candidate = max(
                    [n for n in clusters[largest_idx] if n != chs[largest_idx]],
                    key=lambda x: nodes_energy[x]
                )
                old_ch = chs[largest_idx]
                chs[largest_idx] = candidate
                action_cost += 0.3

        elif action == A_REDUCE_TX:
            tx_power_factor = max(0.5, tx_power_factor * 0.9)
            action_cost += 0.05

        elif action == A_INCREASE_TX:
            tx_power_factor = min(1.5, tx_power_factor * 1.1)
            action_cost += 0.05

        # -------------------------
        # Radio communication & metrics
        # -------------------------
        successful_packets = 0
        total_packets = 0

        # -------------------------
        # Layer 5: Multi-hop communication & metrics
        # -------------------------
        METHOD_NAME_RLHC = "RLHC_PROPOSED"
        links_this_round = []
        
        if ENABLE_MULTI_HOP:
            MAX_HOP_DIST = AREA_SIZE * 0.4  # tune as needed
            nodes_energy, successful_packets, total_packets, links_this_round = multi_hop_comm_layer5(
                clusters=clusters,
                chs=chs,
                nodes_energy=nodes_energy,
                nodes_pos=nodes_pos,
                base_station=base_station,
                tx_power_factor=tx_power_factor,
                pkt_size=PACKET_SIZE_BITS,
                max_hop_dist=MAX_HOP_DIST,
                alpha_energy=0.2,
                relay_min_energy=0.2,    # try 20% of E_INIT
                max_hops_per_route=3,     # try 2, 3, 4
                forward_limit=3            # try 2, 4
            )
        else:
            nodes_energy, successful_packets, total_packets = radio_comm( clusters, chs, nodes_energy, nodes_pos, base_station, tx_power_factor, successful_packets, total_packets)
            
            # For logging/visualization, define links_this_round as direct CH -> BS hops
            for cidx, ch in enumerate(chs):
                if len(clusters[cidx]) == 0:
                    continue
                # Log one logical link CH -> BS (dst = -1 convention)
                links_this_round.append((ch, -1))
            
        # must match the string used in sector_life_table["Method"]
        for (src, dst) in links_this_round:
            log_link(METHOD_NAME_RLHC, rnd, src, dst)
            
        log_chs_and_clusters(METHOD_NAME_RLHC, rnd, chs, clusters)

        # -------------------------
        # Metrics & logging
        # -------------------------
        nodes_energy = np.maximum(nodes_energy, 0.0)

        avg_e_after, pdr_after, alive_after, cluster_sizes = measure_metrics(
            nodes_energy,
            clusters,
            chs,
            successful_packets,
            total_packets
        )

        # -------------------------
        # Reward & Q update
        # -------------------------
        reward = compute_reward(
            avg_e_before,
            avg_e_after,
            pdr_after,
            action_cost,
            target_pdr=TARGET_PDR
        )

        # Stronger penalty if some new deaths occurred
        if any(nodes_energy <= 0):
            reward -= 10.0

        Q, epsilon = update_q(
            Q,
            state,
            action,
            reward,
            avg_e_after,
            cluster_sizes,
            pdr_after,
            epsilon,
            LOAD_BINS,
            PDR_BINS
        )

        # -------------------------
        # Lifetime tracking
        # -------------------------
        if first_dead_round is None and alive_after < N_NODES:
            first_dead_round = rnd
        if half_dead_round is None and alive_after <= N_HALF:
            half_dead_round = rnd
        if alive_after == 0 and last_dead_round is None:
            last_dead_round = rnd

        # -------------------------
        # Logging
        # -------------------------
        alive_nodes = nodes_energy > 0
        alive_nodes_indices = np.where(alive_nodes)[0].tolist()
        alive_nodes_per_round.append(alive_nodes_indices)
        avg_energy_history.append(avg_e_after)
        total_energy_history.append(float(np.sum(nodes_energy)))
        pdr_history.append(pdr_after)
        alive_nodes_history.append(alive_after)
        num_ch_history.append(len(chs))
        reward_history.append(reward)
        epsilon_history.append(epsilon)
        throughput_history.append(successful_packets)
        pdr_percent_history.append(
            100.0 * successful_packets / (total_packets + 1e-12)
        )

        if alive_after == 0:
            print(f"[RLHC_PROPOSED] All nodes died at round {rnd}")
            break

    return (
        avg_energy_history,
        pdr_history,
        alive_nodes_history,
        alive_nodes_per_round,
        num_ch_history,
        reward_history,
        epsilon_history,
        throughput_history,
        pdr_percent_history,
        first_dead_round,
        half_dead_round,
        last_dead_round,
        sector_ids,
        nodes_pos,
        total_energy_history
    )

print('Proposed RLHC algorithm defined (updated with tuned reward & fixes)!')


## 6. Baseline Algorithm: LEACH

In [None]:
def select_CHs_leach(alive_nodes, round_num, last_ch_round, p_opt=P_OPT):
    """
    Standard LEACH CH selection:
    - Constant probability p_opt for all nodes
    - Epoch = 1 / p_opt
    - G-set based on last CH round
    """

    is_CH = np.zeros(N_NODES, dtype=bool)

    # Epoch length
    if p_opt <= 0:
        return is_CH
    epoch = int(1.0 / p_opt)
    epoch = max(epoch, 1)  # avoid zero

    for i in range(N_NODES):
        if not alive_nodes[i]:
            continue

        # Check if node i is in G (eligible this round)
        rounds_since_last_CH = round_num - last_ch_round[i]
        in_G = rounds_since_last_CH >= epoch

        if not in_G:
            continue

        # LEACH threshold
        r_mod = round_num % epoch
        denominator = 1.0 - p_opt * r_mod
        if denominator <= 0:
            threshold = 0.0
        else:
            threshold = p_opt / denominator

        if np.random.rand() < threshold:
            is_CH[i] = True
            last_ch_round[i] = round_num

    # Safety: if no CH selected, pick one alive node randomly
    if not is_CH.any():
        alive_idx = np.where(alive_nodes)[0]
        if len(alive_idx) > 0:
            chosen = np.random.choice(alive_idx)
            is_CH[chosen] = True
            last_ch_round[chosen] = round_num

    return is_CH


def run_leach():
    np.random.seed(SEED)
    random.seed(SEED)

    avg_energy_history = []
    pdr_history = []
    alive_nodes_history = []
    alive_nodes_per_round = []
    num_ch_history = []
    reward_history = []
    epsilon_history = []
    throughput_history = []
    pdr_percent_history = []
    total_energy_history = []
    first_dead_round = None
    half_dead_round = None
    last_dead_round = None

    # Node positions
    nodes_pos = [(random.uniform(0, AREA_SIZE), random.uniform(0, AREA_SIZE))for _ in range(N_NODES)]
    sector_ids = assign_static_sectors(nodes_pos, AREA_SIZE, SECTOR_ROWS, SECTOR_COLS)

    nodes_energy = np.array([INIT_ENERGY] * N_NODES, dtype=float)

    # Track last CH round per node (for G-set)
    last_ch_round = np.full(N_NODES, -1_000_000, dtype=int)

    for rnd in range(1, ROUNDS + 1):
        base_station = get_bs_position(rnd)
        alive_nodes = nodes_energy > 0

        # LEACH CH selection (literature-style)
        is_CH = select_CHs_leach(
            alive_nodes=alive_nodes,
            round_num=rnd,
            last_ch_round=last_ch_round,
            p_opt=P_OPT
        )
        chs = np.where(is_CH)[0].tolist()

        if len(chs) == 0:
            # Extra safety (should be rare due to fallback)
            alive_idx = np.where(alive_nodes)[0]
            if len(alive_idx) > 0:
                chosen = random.choice(alive_idx.tolist())
                chs = [chosen]

        # Clustering: each alive node joins nearest CH
        clusters = [[] for _ in range(len(chs))]
        for i in range(N_NODES):
            if nodes_energy[i] <= 0:
                continue
            nearest_ch = min(
                range(len(chs)),
                key=lambda j: distance(nodes_pos[i], nodes_pos[chs[j]])
            )
            clusters[nearest_ch].append(i)

        # Communication (end-to-end PDR)
        nodes_energy, successful_packets, total_packets = radio_comm(
            clusters, chs, nodes_energy, nodes_pos, base_station,
            tx_power_factor=1.0,
            successful_packets=0,
            total_packets=0
        )

        nodes_energy = np.maximum(nodes_energy, 0.0)

        avg_e_after, pdr_after, alive_after, cluster_sizes = measure_metrics(
            nodes_energy, clusters, chs, successful_packets, total_packets
        )

        if first_dead_round is None and alive_after < N_NODES:
            first_dead_round = rnd
        if half_dead_round is None and alive_after <= N_HALF:
            half_dead_round = rnd
        if alive_after == 0 and last_dead_round is None:
            last_dead_round = rnd

        alive_nodes_indices = np.where(alive_nodes)[0].tolist()
        alive_nodes_per_round.append(alive_nodes_indices)
        avg_energy_history.append(avg_e_after)
        total_energy_history.append(float(np.sum(nodes_energy)))
        pdr_history.append(pdr_after)
        alive_nodes_history.append(alive_after)
        num_ch_history.append(len(chs))
        reward_history.append(0.0)
        epsilon_history.append(0.0)
        throughput_history.append(successful_packets)
        pdr_percent_history.append(
            100.0 * successful_packets / (total_packets + 1e-12)
        )

        if alive_after == 0:
            print(f"[LEACH_BASELINE] All nodes died at round {rnd}")
            break

    return (
        avg_energy_history, pdr_history, alive_nodes_history, alive_nodes_per_round, num_ch_history,
        reward_history, epsilon_history, throughput_history, pdr_percent_history,
        first_dead_round, half_dead_round, last_dead_round, sector_ids, nodes_pos, total_energy_history
    )

print('LEACH baseline algorithm (with standard threshold) defined!')


## 6. Baseline Algorithm: DEEC

In [None]:
def select_CHs_deec(nodes_energy,
                    alive_nodes,
                    round_num,
                    last_ch_round,
                    p_opt=P_OPT,
                    m_fraction_adv=M_FRACTION_ADV,
                    a_adv=A_ADV_ENERGY):
    """
    DEEC CH selection with:
    - Heterogeneous energy (normal & advanced nodes)
    - Explicit G-set using per-node last CH round.

    nodes_energy   : array of current residual energies
    alive_nodes    : boolean array, True if node is alive
    round_num      : current round index (1-based)
    last_ch_round  : array storing the last round when node i was CH
    """

    is_CH = np.zeros(N_NODES, dtype=bool)

    # Current average residual energy of alive nodes
    total_energy_alive = np.sum(nodes_energy[alive_nodes])
    num_alive = np.sum(alive_nodes)
    if num_alive <= 0 or total_energy_alive <= 0:
        return is_CH  # no alive nodes, no CHs

    E_avg = total_energy_alive / float(num_alive)

    # Average energy factor due to heterogeneity
    # E_bar0 = E0 * (1 + a * m)
    hetero_factor = 1.0 + a_adv * m_fraction_adv  # used in DEEC literature

    for i in range(N_NODES):
        if not alive_nodes[i]:
            continue
        if nodes_energy[i] <= 0:
            continue

        # Node-specific probability Pi
        Pi = p_opt * (nodes_energy[i] / (E_avg * hetero_factor))

        # Clamp to avoid extreme values
        Pi = max(1e-6, min(Pi, 0.9))

        # Epoch for this node
        epoch_i = max(1, int(1.0 / Pi))

        # G-set: node can be CH only if it hasn't been CH in this epoch
        rounds_since_last_CH = round_num - last_ch_round[i]
        in_G = rounds_since_last_CH >= epoch_i

        if not in_G:
            continue

        # LEACH/DEEC threshold
        r_mod = round_num % epoch_i
        denominator = 1.0 - Pi * r_mod
        if denominator <= 0:
            threshold = 0.0
        else:
            threshold = Pi / denominator

        # Random selection
        if np.random.rand() < threshold:
            is_CH[i] = True
            last_ch_round[i] = round_num  # update CH history

    # Safety: if no CH selected, pick the highest-energy alive node
    if not is_CH.any():
        alive_idx = np.where(alive_nodes)[0]
        if len(alive_idx) > 0:
            best = alive_idx[np.argmax(nodes_energy[alive_idx])]
            is_CH[best] = True
            last_ch_round[best] = round_num

    return is_CH


def run_deec():
    np.random.seed(SEED)
    random.seed(SEED)

    avg_energy_history = []
    pdr_history = []
    alive_nodes_history = []
    alive_nodes_per_round = []
    num_ch_history = []
    reward_history = []
    epsilon_history = []
    throughput_history = []
    pdr_percent_history = []
    total_energy_history = []
    first_dead_round = None
    half_dead_round = None
    last_dead_round = None

    # Node positions
    nodes_pos = [
        (random.uniform(0, AREA_SIZE), random.uniform(0, AREA_SIZE))
        for _ in range(N_NODES)
    ]
    sector_ids = assign_static_sectors(nodes_pos, AREA_SIZE, SECTOR_ROWS, SECTOR_COLS)

    # Heterogeneous energy initialization (normal & advanced)
    nodes_energy, is_advanced = initialize_heterogeneous_energies()

    # Track last CH round for each node (for G-set logic)
    last_ch_round = np.full(N_NODES, -1_000_000, dtype=int)

    for rnd in range(1, ROUNDS + 1):
        base_station = get_bs_position(rnd)

        alive_nodes = nodes_energy > 0
        alive_energy = nodes_energy[alive_nodes]
        avg_energy = np.mean(alive_energy) if alive_energy.size > 0 else 0.0

        # DEEC CH selection with heterogeneity + G-set
        is_CH = select_CHs_deec(
            nodes_energy=nodes_energy,
            alive_nodes=alive_nodes,
            round_num=rnd,
            last_ch_round=last_ch_round,
            p_opt=P_OPT,
            m_fraction_adv=M_FRACTION_ADV,
            a_adv=A_ADV_ENERGY
        )
        chs = np.where(is_CH)[0].tolist()

        if len(chs) == 0:
            print(f"[DEEC_BASELINE] No CHs at round {rnd}, skipping communication.")
            avg_energy_history.append(avg_energy)
            pdr_history.append(0.0)
            alive_nodes_history.append(np.sum(alive_nodes))
            num_ch_history.append(0)
            reward_history.append(0.0)
            epsilon_history.append(0.0)
            throughput_history.append(0)
            pdr_percent_history.append(0.0)
            continue

        # Clustering: each alive node joins nearest CH
        clusters = [[] for _ in range(len(chs))]
        for i in range(N_NODES):
            if nodes_energy[i] <= 0:
                continue
            nearest_ch_idx = min(
                range(len(chs)),
                key=lambda j: distance(nodes_pos[i], nodes_pos[chs[j]])
            )
            clusters[nearest_ch_idx].append(i)

        # Communication (end-to-end PDR)
        nodes_energy, successful_packets, total_packets = radio_comm(
            clusters, chs, nodes_energy, nodes_pos, base_station,
            tx_power_factor=1.0,
            successful_packets=0,
            total_packets=0
        )

        nodes_energy = np.maximum(nodes_energy, 0.0)

        avg_e_after, pdr_after, alive_after, cluster_sizes = measure_metrics(
            nodes_energy, clusters, chs, successful_packets, total_packets
        )

        if first_dead_round is None and alive_after < N_NODES:
            first_dead_round = rnd
        if half_dead_round is None and alive_after <= N_HALF:
            half_dead_round = rnd
        if alive_after == 0 and last_dead_round is None:
            last_dead_round = rnd

        alive_nodes_indices = np.where(alive_nodes)[0].tolist()
        alive_nodes_per_round.append(alive_nodes_indices)

        avg_energy_history.append(avg_e_after)
        total_energy_history.append(float(np.sum(nodes_energy)))
        pdr_history.append(pdr_after)
        alive_nodes_history.append(alive_after)
        num_ch_history.append(len(chs))
        reward_history.append(0.0)
        epsilon_history.append(0.0)
        throughput_history.append(successful_packets)
        pdr_percent_history.append(
            100.0 * successful_packets / (total_packets + 1e-12)
        )

        if alive_after == 0:
            print(f"[DEEC_BASELINE] All nodes died at round {rnd}")
            break

    return (
        avg_energy_history, pdr_history, alive_nodes_history, alive_nodes_per_round, num_ch_history,
        reward_history, epsilon_history, throughput_history, pdr_percent_history,
        first_dead_round, half_dead_round, last_dead_round, sector_ids, nodes_pos, total_energy_history
    )

print('DEEC baseline algorithm (heterogeneous + G-set) defined!')


## 6. Baseline Algorithm: Fuzzy C-Means

In [None]:
def run_fuzzy():
    np.random.seed(SEED)
    random.seed(SEED)

    avg_energy_history = []
    total_energy_history = []
    pdr_history = []
    alive_nodes_history = []
    alive_nodes_per_round = []
    num_ch_history = []
    reward_history = []
    epsilon_history = []
    throughput_history = []
    pdr_percent_history = []
    first_dead_round = None
    half_dead_round = None
    last_dead_round = None

    nodes_pos = np.array([(random.uniform(0, AREA_SIZE), random.uniform(0, AREA_SIZE)) for _ in range(N_NODES)])
    sector_ids = assign_static_sectors(nodes_pos, AREA_SIZE, SECTOR_ROWS, SECTOR_COLS)

    nodes_energy = np.array([INIT_ENERGY] * N_NODES)

    m = 2.0
    error = 1e-5
    maxiter = 1000

    for rnd in range(1, ROUNDS + 1):
        base_station = get_bs_position(rnd)

        cntr, u, u0, d, jm, p, fpc = fuzz.cluster.cmeans(
            data=nodes_pos.T,
            c=NUM_CLUSTERS,
            m=m,
            error=error,
            maxiter=maxiter,
            init=None,
            seed=SEED
        )

        labels = np.argmax(u, axis=0)
        clusters = [[] for _ in range(NUM_CLUSTERS)]
        for i, lbl in enumerate(labels):
            if nodes_energy[i] > 0:
                clusters[lbl].append(i)

        chs = [None] * NUM_CLUSTERS
        for cidx, members in enumerate(clusters):
            alive_members = [n for n in members if nodes_energy[n] > 0]
            if not alive_members:
                continue
            energies = np.array([nodes_energy[n] for n in alive_members])
            if energies.sum() == 0:
                chosen = random.choice(alive_members)
            else:
                probs = energies / energies.sum()
                chosen = np.random.choice(alive_members, p=probs)
            chs[cidx] = chosen

        clusters = [c for c in clusters if c]
        chs = [ch for ch in chs if ch is not None]

        if len(chs) == 0:
            alive_nodes = [i for i in range(N_NODES) if nodes_energy[i] > 0]
            if alive_nodes:
                chosen = random.choice(alive_nodes)
                chs = [chosen]
                clusters = [[chosen]]

        nodes_energy, successful_packets, total_packets = radio_comm(
            clusters, chs, nodes_energy, nodes_pos, base_station, tx_power_factor=1.0,
            successful_packets=0, total_packets=0
        )
        nodes_energy = np.maximum(nodes_energy, 0.0)
        avg_e_after, pdr_after, alive_after, cluster_sizes = measure_metrics(
            nodes_energy, clusters, chs, successful_packets, total_packets
        )

        if first_dead_round is None and alive_after < N_NODES:
            first_dead_round = rnd
        if half_dead_round is None and alive_after <= N_HALF:
            half_dead_round = rnd
        if alive_after == 0 and last_dead_round is None:
            last_dead_round = rnd

        alive_nodes = nodes_energy > 0
        alive_nodes_indices = np.where(alive_nodes)[0].tolist()
        alive_nodes_per_round.append(alive_nodes_indices)
        avg_energy_history.append(avg_e_after)
        total_energy_history.append(float(np.sum(nodes_energy)))
        pdr_history.append(pdr_after)
        alive_nodes_history.append(alive_after)
        num_ch_history.append(len(chs))
        reward_history.append(0.0)
        epsilon_history.append(0.0)
        throughput_history.append(successful_packets)
        pdr_percent_history.append(100.0 * successful_packets / (total_packets + 1e-12))

        if alive_after == 0:
            print(f"[FUZZY_C_MEANS_BASELINE] All nodes died at round {rnd}")
            break

    return (
        avg_energy_history, pdr_history, alive_nodes_history, alive_nodes_per_round, num_ch_history,
        reward_history, epsilon_history, throughput_history, pdr_percent_history,
        first_dead_round, half_dead_round, last_dead_round, sector_ids, nodes_pos, total_energy_history
    )

print('Fuzzy C-Means baseline algorithm defined!')

## 6. Baseline Algorithm: PSO + GA

In [None]:
def run_pso():
    np.random.seed(SEED)
    random.seed(SEED)

    # Histories
    avg_energy_history = []
    total_energy_history = []
    pdr_history = []
    alive_nodes_history = []
    alive_nodes_per_round = []
    num_ch_history = []
    reward_history = []
    epsilon_history = []
    throughput_history = []
    pdr_percent_history = []
    first_dead_round = None
    half_dead_round = None
    last_dead_round = None

    # Network initialization
    nodes_pos = np.array([
        (random.uniform(0, AREA_SIZE), random.uniform(0, AREA_SIZE))
        for _ in range(N_NODES)
    ])
    sector_ids = assign_static_sectors(nodes_pos, AREA_SIZE, SECTOR_ROWS, SECTOR_COLS)

    nodes_energy = np.array([INIT_ENERGY] * N_NODES, dtype=float)

    # PSO+GA hyperparameters (tune as needed)
    dim = 2 * NUM_CLUSTERS              # each cluster has (x, y)
    POP = PSO_POP                       # population size
    W_inertia = W                       # inertia weight (e.g., 0.7)
    C1_pbest = C1                       # cognitive weight (e.g., 1.5)
    C2_gbest = C2                       # social weight (e.g., 1.5)
    GA_CROSSOVER_RATE = 0.7
    GA_MUTATION_RATE = 0.1
    GA_MUTATION_STD = AREA_SIZE * 0.02  # mutation step size (2% of area)

    # Initialize particles: random centroids in field
    particles = np.random.uniform(
        low=0.0, high=AREA_SIZE,
        size=(POP, dim)  # shape: (POP, 2*K)
    )
    velocities = np.zeros_like(particles)

    # Personal bests
    pbest = particles.copy()
    pbest_scores = np.full(POP, np.inf)

    # Global best
    gbest = None
    gbest_score = np.inf

    def decode_particle_to_chs(centroids, nodes_pos, nodes_energy, base_station):
        """
        From a particle (centroids), build clusters and pick CHs:
        - Assign each alive node to nearest centroid.
        - For each cluster, choose CH as node with highest energy;
          tie-break by closeness to centroid.
        """
        K = NUM_CLUSTERS
        centroids = centroids.reshape(K, 2)

        # Assign nodes to nearest centroid
        clusters = [[] for _ in range(K)]
        for i in range(N_NODES):
            if nodes_energy[i] <= 0:
                continue
            dists = np.linalg.norm(nodes_pos[i] - centroids, axis=1)
            nearest = int(np.argmin(dists))
            clusters[nearest].append(i)

        chs = []
        for k in range(K):
            cluster_nodes = clusters[k]
            if not cluster_nodes:
                # fallback: pick the highest-energy node globally
                chs.append(int(np.argmax(nodes_energy)))
                continue
            # pick highest-energy node in this cluster
            max_e = max(nodes_energy[n] for n in cluster_nodes)
            cand = [n for n in cluster_nodes if nodes_energy[n] == max_e]
            if len(cand) > 1:
                # tie-break: closest to centroid
                ck = centroids[k]
                best = min(cand, key=lambda n: np.linalg.norm(nodes_pos[n] - ck))
                chs.append(best)
            else:
                chs.append(cand[0])

        return clusters, chs

    def fitness_function(centroids, nodes_pos, nodes_energy, base_station):
        """
        Multi-criteria fitness:
        1) Intra-cluster distance (sum of distances node-CH)
        2) CH energy penalty (prefer high-energy CHs → penalize low avg CH energy)
        3) Distance of CHs to base station
        Lower is better.
        """
        clusters, chs = decode_particle_to_chs(
            centroids, nodes_pos, nodes_energy, base_station
        )

        # 1) Intra-cluster distance
        intra_dist = 0.0
        for cidx, cluster_nodes in enumerate(clusters):
            ch = chs[cidx]
            for node in cluster_nodes:
                intra_dist += np.linalg.norm(nodes_pos[node] - nodes_pos[ch])

        # 2) CH energy penalty (we want CH energies large)
        ch_energies = nodes_energy[chs]
        avg_ch_energy = np.mean(ch_energies) if len(ch_energies) > 0 else 0.0
        # penalty is inverse of energy (avoid division by zero)
        energy_penalty = 1.0 / (avg_ch_energy + 1e-9)

        # 3) CH to BS distance
        bs_dist = 0.0
        bs_pos = np.array(base_station)
        for ch in chs:
            bs_dist += np.linalg.norm(nodes_pos[ch] - bs_pos)

        # Weights (tuneable)
        w1 = 1.0     # intra-cluster distance
        w2 = 100.0   # CH energy penalty
        w3 = 0.1     # CH-BS distance

        fitness = w1 * intra_dist + w2 * energy_penalty + w3 * bs_dist
        return fitness

    for rnd in range(1, ROUNDS + 1):
        base_station = get_bs_position(rnd)

        # ---------------
        # PSO Evaluation
        # ---------------
        scores = np.zeros(POP)
        for pidx in range(POP):
            # Make sure centroids stay in field
            particles[pidx] = np.clip(particles[pidx], 0.0, AREA_SIZE)
            score = fitness_function(
                particles[pidx], nodes_pos, nodes_energy, base_station
            )
            scores[pidx] = score

            # Update personal best
            if score < pbest_scores[pidx]:
                pbest_scores[pidx] = score
                pbest[pidx] = particles[pidx].copy()

            # Update global best
            if score < gbest_score:
                gbest_score = score
                gbest = particles[pidx].copy()

        # ---------------
        # PSO Update
        # ---------------
        for pidx in range(POP):
            r1 = np.random.rand(dim)
            r2 = np.random.rand(dim)
            velocities[pidx] = (
                W_inertia * velocities[pidx]
                + C1_pbest * r1 * (pbest[pidx] - particles[pidx])
                + C2_gbest * r2 * (gbest - particles[pidx])
            )
            particles[pidx] += velocities[pidx]
            particles[pidx] = np.clip(particles[pidx], 0.0, AREA_SIZE)

        # ---------------
        # GA: Crossover & Mutation on the best half
        # ---------------
        # Select best half by fitness
        best_indices = np.argsort(scores)[:POP // 2]

        # Crossover pairs
        for i in range(0, len(best_indices), 2):
            if i + 1 >= len(best_indices):
                break
            idx1 = best_indices[i]
            idx2 = best_indices[i + 1]

            parent1 = particles[idx1].copy()
            parent2 = particles[idx2].copy()

            if random.random() < GA_CROSSOVER_RATE:
                point = random.randint(1, dim - 1)
                child1 = np.concatenate([parent1[:point], parent2[point:]])
                child2 = np.concatenate([parent2[:point], parent1[point:]])
            else:
                child1, child2 = parent1, parent2

            # Mutation: small Gaussian noise on some dimensions
            for child in (child1, child2):
                if random.random() < GA_MUTATION_RATE:
                    mut_dims = np.random.rand(dim) < 0.1  # 10% of dims
                    noise = np.random.normal(0.0, GA_MUTATION_STD, size=dim)
                    child[mut_dims] += noise[mut_dims]
                    child[:] = np.clip(child, 0.0, AREA_SIZE)

            # Replace worst individuals with children
            worst_indices = np.argsort(scores)[-2:]
            particles[worst_indices[0]] = child1
            particles[worst_indices[1]] = child2

        # ---------------
        # Use global best centroids to define actual CHs and clusters
        # ---------------
        if gbest is None:
            # just in case, but should not happen
            gbest = particles[0].copy()

        clusters, chs = decode_particle_to_chs(
            gbest, nodes_pos, nodes_energy, base_station
        )

        # ---------------
        # Radio communication & metrics
        # ---------------
        nodes_energy, successful_packets, total_packets = radio_comm(
            clusters, chs, nodes_energy, nodes_pos, base_station, tx_power_factor=1.0
        )
        nodes_energy = np.maximum(nodes_energy, 0.0)

        avg_e_after, pdr_after, alive_after, cluster_sizes = measure_metrics(
            nodes_energy, clusters, chs, successful_packets, total_packets
        )

        # ---------------
        # Lifetime tracking
        # ---------------
        if first_dead_round is None and alive_after < N_NODES:
            first_dead_round = rnd
        if half_dead_round is None and alive_after <= N_HALF:
            half_dead_round = rnd
        if alive_after == 0 and last_dead_round is None:
            last_dead_round = rnd

        # ---------------
        # Logging
        # ---------------
        alive_nodes = nodes_energy > 0
        alive_nodes_indices = np.where(alive_nodes)[0].tolist()
        alive_nodes_per_round.append(alive_nodes_indices)
        avg_energy_history.append(avg_e_after)
        total_energy_history.append(float(np.sum(nodes_energy)))
        pdr_history.append(pdr_after)
        alive_nodes_history.append(alive_after)
        num_ch_history.append(len(chs))
        reward_history.append(0.0)      # no RL reward here
        epsilon_history.append(0.0)     # no epsilon-greedy here
        throughput_history.append(successful_packets)
        pdr_percent_history.append(
            100.0 * successful_packets / (total_packets + 1e-12)
        )

        if alive_after == 0:
            print(f"[PSO_GA] All nodes died at round {rnd}")
            break

    return (
        avg_energy_history, pdr_history, alive_nodes_history, alive_nodes_per_round, num_ch_history,
        reward_history, epsilon_history, throughput_history, pdr_percent_history,
        first_dead_round, half_dead_round, last_dead_round, sector_ids, nodes_pos, total_energy_history
    )

print('PSO+GA clustering baseline (literature-style) defined!')


## 6. Baseline Algorithm: ACO

In [None]:
def run_aco():
    np.random.seed(SEED)
    random.seed(SEED)

    # Histories
    avg_energy_history, pdr_history, alive_nodes_history = [], [], []
    num_ch_history, reward_history, epsilon_history = [], [], []
    throughput_history, pdr_percent_history = [], []
    alive_nodes_per_round = []
    total_energy_history = []
    first_dead_round = half_dead_round = last_dead_round = None

    # Network initialization
    nodes_pos = np.array([
        (random.uniform(0, AREA_SIZE), random.uniform(0, AREA_SIZE))
        for _ in range(N_NODES)
    ])
    sector_ids = assign_static_sectors(nodes_pos, AREA_SIZE, SECTOR_ROWS, SECTOR_COLS)

    nodes_energy = np.array([INIT_ENERGY] * N_NODES, dtype=float)

    # ACO parameters (assumed already defined globally)
    # ACO_ALPHA, ACO_BETA, RHO, ACO_Q, ANTS, NUM_CLUSTERS

    # Initial pheromone and heuristic
    pheromone = np.ones(N_NODES, dtype=float)
    heuristic = np.zeros(N_NODES, dtype=float)

    for rnd in range(1, ROUNDS + 1):
        base_station = get_bs_position(rnd)
        # -------------------------
        # Alive nodes
        # -------------------------
        alive_nodes = [i for i in range(N_NODES) if nodes_energy[i] > 0]
        if not alive_nodes:
            print(f"[ACO_BASELINE] All nodes died at round {rnd}")
            break

        alive_arr = np.array(alive_nodes, dtype=int)

        # -------------------------
        # Heuristic update
        # heuristic[i] = E_i / d(i, BS)
        # -------------------------
        for i in alive_arr:
            d_bs = np.linalg.norm(nodes_pos[i] - base_station)
            heuristic[i] = nodes_energy[i] / (d_bs + 1e-6)

        all_solutions = []
        fitness_scores = []

        # -------------------------
        # Ant colony construction
        # -------------------------
        for _ in range(ANTS):
            # Raw probabilities over alive nodes
            probs = (pheromone[alive_arr] ** ACO_ALPHA) * (heuristic[alive_arr] ** ACO_BETA)
            probs_sum = probs.sum()

            if probs_sum <= 0:
                # All zero → use uniform probability
                probs = np.ones_like(probs, dtype=float) / len(probs)
            else:
                probs = probs / probs_sum

            num_chs = min(NUM_CLUSTERS, len(alive_arr))

            if len(alive_arr) <= NUM_CLUSTERS:
                ch_indices = alive_arr.copy()
            else:
                # Ensure enough non-zero probabilities for sampling without replacement
                positive_mask = probs > 0
                num_positive = int(np.sum(positive_mask))

                if num_positive < num_chs:
                    # Fallback: uniform random choice among all alive nodes
                    ch_indices = np.random.choice(
                        alive_arr, size=num_chs, replace=False
                    )
                else:
                    # Restrict to positive-probability candidates
                    alive_pos = alive_arr[positive_mask]
                    probs_pos = probs[positive_mask]
                    probs_pos = probs_pos / probs_pos.sum()  # renormalize

                    ch_indices = np.random.choice(
                        alive_pos, size=num_chs, replace=False, p=probs_pos
                    )

            # -------------------------
            # Build clusters: assign each alive node to nearest CH
            # -------------------------
            clusters = [[] for _ in range(len(ch_indices))]
            ch_positions = nodes_pos[ch_indices]

            for i in alive_arr:
                dists = np.linalg.norm(ch_positions - nodes_pos[i], axis=1)
                closest = int(np.argmin(dists))
                clusters[closest].append(i)

            # -------------------------
            # Fitness (literature-style)
            # 1) Intra-cluster distance
            # 2) CH energy penalty (prefer higher-energy CHs)
            # 3) CH–BS distance
            # -------------------------
            total_dist = 0.0
            for i in alive_arr:
                dists = np.linalg.norm(ch_positions - nodes_pos[i], axis=1)
                total_dist += np.min(dists)

            # CH energy
            ch_energies = nodes_energy[ch_indices]
            avg_ch_energy = np.mean(ch_energies) if len(ch_energies) > 0 else 0.0
            energy_penalty = 1.0 / (avg_ch_energy + 1e-6)

            # CH–BS distance
            bs_dists = np.linalg.norm(ch_positions - base_station, axis=1)
            ch_bs_dist = np.sum(bs_dists)

            # Weights (tune as needed)
            w1, w2, w3 = 1.0, 100.0, 0.1
            fitness = w1 * total_dist + w2 * energy_penalty + w3 * ch_bs_dist

            all_solutions.append((clusters, ch_indices))
            fitness_scores.append(fitness)

        # -------------------------
        # Select best solution for this round
        # -------------------------
        best_idx = int(np.argmin(fitness_scores))
        best_clusters, best_chs = all_solutions[best_idx]

        # Remove empty clusters for safety
        valid_clusters, valid_chs = [], []
        for ci, cluster in enumerate(best_clusters):
            if cluster:
                valid_clusters.append(cluster)
                if ci < len(best_chs):
                    valid_chs.append(int(best_chs[ci]))

        if not valid_chs:
            # Fallback: choose one random alive node as CH
            chosen = random.choice(alive_nodes)
            valid_chs = [chosen]
            valid_clusters = [[chosen]]

        best_clusters, best_chs = valid_clusters, valid_chs

        # -------------------------
        # Pheromone update
        # -------------------------
        pheromone *= (1 - RHO)  # evaporation
        best_fit = fitness_scores[best_idx]
        for ch in best_chs:
            pheromone[ch] += ACO_Q / (best_fit + 1e-6)

        # -------------------------
        # Radio communication
        # -------------------------
        nodes_energy, successful_packets, total_packets = radio_comm(
            best_clusters,
            best_chs,
            nodes_energy,
            nodes_pos,
            base_station,
            tx_power_factor=1.0,
            successful_packets=0,
            total_packets=0
        )
        nodes_energy = np.maximum(nodes_energy, 0.0)

        avg_e_after, pdr_after, alive_after, _ = measure_metrics(
            nodes_energy, best_clusters, best_chs, successful_packets, total_packets
        )

        # -------------------------
        # Lifetime tracking
        # -------------------------
        if first_dead_round is None and alive_after < N_NODES:
            first_dead_round = rnd
        if half_dead_round is None and alive_after <= N_HALF:
            half_dead_round = rnd
        if alive_after == 0 and last_dead_round is None:
            last_dead_round = rnd

        # -------------------------
        # Logging
        # -------------------------
        alive_nodes = nodes_energy > 0
        alive_nodes_indices = np.where(alive_nodes)[0].tolist()
        alive_nodes_per_round.append(alive_nodes_indices)
        avg_energy_history.append(avg_e_after)
        total_energy_history.append(float(np.sum(nodes_energy)))
        pdr_history.append(pdr_after)
        alive_nodes_history.append(alive_after)
        num_ch_history.append(len(best_chs))
        reward_history.append(0.0)      # no RL here
        epsilon_history.append(0.0)     # no RL here
        throughput_history.append(successful_packets)
        pdr_percent_history.append(
            100.0 * successful_packets / (total_packets + 1e-12)
        )

        if alive_after == 0:
            print(f"[ACO_BASELINE] All nodes died at round {rnd}")
            break

    return (
        avg_energy_history,
        pdr_history,
        alive_nodes_history,
        alive_nodes_per_round,
        num_ch_history,
        reward_history,
        epsilon_history,
        throughput_history,
        pdr_percent_history,
        first_dead_round,
        half_dead_round,
        last_dead_round,
        sector_ids,
        nodes_pos,
        total_energy_history
    )

print('ACO baseline algorithm defined (robust & literature-style)!')


## 7. Experiment Runner

Run all 6 algorithms and collect results.

In [None]:
METHODS = {
    'LEACH_BASELINE': run_leach,
    'DEEC_BASELINE': run_deec,
    'FUZZY_C_MEANS_BASELINE': run_fuzzy,
    'PSO_BASELINE': run_pso,
    'ACO_BASELINE': run_aco,
    'RLHC_PROPOSED': run_proposed,
}

colors = ['blue', 'red', 'green', 'gold', 'orange', 'black']
results = {}

print(f"Starting experiments...\nSimulate for the rounds: {ROUNDS} \n")
for method_name, method_func in METHODS.items():
    print(f"Running method: {method_name}")
    (
        avg_energy_history, pdr_history, alive_nodes_history, alive_nodes_per_round, num_ch_history,
        reward_history, epsilon_history, throughput_history, pdr_percent_history,
        first_dead_round, half_dead_round, last_dead_round, sector_ids, nodes_pos, total_energy_history
    ) = method_func()

    results[method_name] = {
        "avg_energy_history": avg_energy_history,
        "pdr_history": pdr_history,
        "alive_nodes_history": alive_nodes_history,
        "alive_nodes_per_round": alive_nodes_per_round,
        "num_ch_history": num_ch_history,
        "reward_history": reward_history,
        "epsilon_history": epsilon_history,
        "throughput_history": throughput_history,
        "pdr_percent_history": pdr_percent_history,
        "first_dead_round": first_dead_round,
        "half_dead_round": half_dead_round,
        "sector_ids": sector_ids,
        "nodes_pos": nodes_pos,
        "total_energy_history": total_energy_history,
        "last_dead_round": ROUNDS
    }

print("\nAll experiments completed!")

## 8. Visualization

### Plot 1: Dead Nodes vs Number of Rounds

In [None]:
import pandas as pd
import numpy as np

def compute_sector_lifetimes(alive_nodes_per_round, sector_ids, num_sectors):
    lifetimes = [None] * num_sectors  # last round sector had at least 1 alive node

    for rnd, alive_nodes in enumerate(alive_nodes_per_round, start=1):
        if not alive_nodes:
            continue
        alive_arr = np.array(alive_nodes, dtype=int)
        alive_sectors = np.unique(sector_ids[alive_arr])
        for s in alive_sectors:
            lifetimes[s] = rnd  # update last active round
    return lifetimes

rows_list = []
for method_name, res in results.items():
    alive_nodes_per_round = res["alive_nodes_per_round"]

    # get these from results instead of globals
    sector_ids = res["sector_ids"]
    num_sectors = len(np.unique(sector_ids))   # or set num_sectors = 9 earlier and store it

    lifetimes = compute_sector_lifetimes(alive_nodes_per_round, sector_ids, num_sectors)

    row = {"Method": method_name}
    for s, life in enumerate(lifetimes):
        row[f"Sector {s} lifetime (round)"] = life
    rows_list.append(row)

sector_life_table = pd.DataFrame(rows_list)
print("\nSector lifetime summary table:")
print(sector_life_table.to_string(index=False))


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# If you already have sector_life_table from your code:
# sector_life_table = pd.DataFrame(rows_list)

# --- Basic bar chart for each method ---
for idx, row in sector_life_table.iterrows():
    method_name = row["Method"]

    # Extract only the sector columns
    sector_cols = [c for c in sector_life_table.columns if c.startswith("Sector")]
    lifetimes = row[sector_cols].values.astype(float)

    # Create x labels: Sector 0, Sector 1, ...
    x = np.arange(len(sector_cols))

    plt.figure(figsize=(10, 5))
    plt.bar(x, lifetimes, color="skyblue", edgecolor="black")
    plt.xticks(x, sector_cols, rotation=45, ha="right")
    plt.ylabel("Lifetime (round)")
    plt.title(f"Sector Lifetime per Sector - {method_name}")
    plt.tight_layout()
    plt.grid(axis="y", alpha=0.3)
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle

for idx, row in sector_life_table.iterrows():
    method_name = row["Method"]
    res = results[method_name]

    nodes_pos = np.array(res["nodes_pos"])
    sector_ids = np.array(res["sector_ids"])

    # get sector lifetimes from sector_life_table
    row = sector_life_table[sector_life_table["Method"] == method_name].iloc[0]
    sector_cols = [c for c in sector_life_table.columns if c.startswith("Sector")]
    sector_lifetimes = row[sector_cols].values.astype(float)

    # dimensions
    cell_w = AREA_SIZE / SECTOR_COLS
    cell_h = AREA_SIZE / SECTOR_ROWS

    plt.figure(figsize=(7, 7))
    ax = plt.gca()

    # -----------------------------
    # Shade sectors that died before last round
    # -----------------------------
    sector_index = 0
    for r in range(SECTOR_ROWS):
        for c in range(SECTOR_COLS):
            lifetime = sector_lifetimes[sector_index]

            # if sector did NOT survive till last round, shade it grey
            if lifetime < ROUNDS:
                rect = Rectangle(
                    (c * cell_w, r * cell_h),  # bottom-left corner
                    cell_w,
                    cell_h,
                    facecolor="lightgrey",
                    alpha=0.5,
                    edgecolor="none",
                    zorder=0  # behind other elements
                )
                ax.add_patch(rect)

            sector_index += 1

    # -----------------------------
    # Scatter plot of nodes colored by sector
    # -----------------------------
    sc = plt.scatter(
        nodes_pos[:, 0],
        nodes_pos[:, 1],
        c=sector_ids,
        cmap="tab10",
        s=40,
        edgecolor="black",
        zorder=2
    )

    # -----------------------------
    # Base Station
    # -----------------------------
    if BS_MODE != "fixed":
        # Make sure BS_TRAJECTORY is a numpy array
        bs_traj = np.array(BS_TRAJECTORY)

        if BS_MODE == "mobile_circle":
            # Plot BS path as a line
            plt.plot(
                bs_traj[:, 0],
                bs_traj[:, 1],
                linestyle="-",
                color="red",
                linewidth=1.5,
                label="BS path",
                zorder=3
            )

        if BS_MODE == "mobile_random":
            # Plot BS path as a line
            plt.plot(
                bs_traj[:, 0],
                bs_traj[:, 1],
                linestyle="--",
                color="red",
                linewidth=0.2,
                label="BS path",
                zorder=2
            )

        # if BS_MODE == "mobile_random":
        #     # Only draw points (no connecting line)
        #     plt.scatter(
        #         bs_traj[:, 0],
        #         bs_traj[:, 1],
        #         marker="*",
        #         s=120,
        #         color="red",
        #         edgecolor="black",
        #         label="BS positions",
        #         zorder=1
        #     )

    else:
        # Fixed BS for
        base_station = BASE_STATION_FIXED
        plt.scatter(
            base_station[0], base_station[1],
            marker="^",
            s=200,
            color="red",
            edgecolor="black",
            label="Base Station",
            zorder=3
        )

    plt.colorbar(sc, label="Sector ID")

    # -----------------------------
    # Draw grid lines
    # -----------------------------
    for c in range(SECTOR_COLS + 1):
        x = c * cell_w
        plt.axvline(x=x, color="gray", linestyle="--", linewidth=0.7, zorder=1)

    for r in range(SECTOR_ROWS + 1):
        y = r * cell_h
        plt.axhline(y=y, color="gray", linestyle="--", linewidth=0.7, zorder=1)

    # -----------------------------
    # Annotate each sector with lifetime
    # -----------------------------
    sector_index = 0
    for r in range(SECTOR_ROWS):
        for c in range(SECTOR_COLS):
            x_center = c * cell_w + cell_w / 2
            y_center = r * cell_h + cell_h / 2

            lifetime = sector_lifetimes[sector_index]
            text = f"{lifetime:.0f} rounds"

            plt.text(
                x_center, y_center,
                text,
                ha="center", va="center",
                fontsize=9,
                fontweight="bold",
                color="black",
                bbox=dict(facecolor="white", alpha=0.6, edgecolor="none"),
                zorder=4
            )

            sector_index += 1

    # -----------------------------
    # Styling
    # -----------------------------
    plt.xlim(0, AREA_SIZE)
    plt.ylim(0, AREA_SIZE)
    plt.xlabel("X position")
    plt.ylabel("Y position")
    plt.title(f"Sector Lifetimes : {method_name}")
    ax.set_aspect("equal", adjustable="box")
    plt.legend(loc="upper right")
    plt.tight_layout()
    plt.show()


In [None]:
plt.figure(figsize=(10,6))
for method_name, res in results.items():
    rounds = range(1, res["last_dead_round"] + 1)
    dead_nodes = [N_NODES - a for a in res["alive_nodes_history"][:res["last_dead_round"]]]
    plt.plot(rounds, dead_nodes, label=method_name, linestyle="--", color=colors[list(METHODS.keys()).index(method_name)])
plt.xlabel("Number of Rounds")
plt.ylabel("Dead Nodes")
plt.title("Dead Nodes vs Number of Rounds")
plt.legend()
plt.grid(True)
plt.show()

### Plot 2: Alive Nodes vs Number of Rounds

In [None]:
plt.figure(figsize=(10,6))
for method_name, res in results.items():
    rounds = range(1, res["last_dead_round"] + 1)
    alive_nodes = res["alive_nodes_history"][:res["last_dead_round"]]
    plt.plot(rounds, alive_nodes, label=method_name, linestyle="--", color=colors[list(METHODS.keys()).index(method_name)])
plt.xlabel("Number of Rounds")
plt.ylabel("Alive Nodes")
plt.title("Alive Nodes vs Number of Rounds")
plt.legend()
plt.grid(True)
plt.show()

### Plot 3: Cumulative Throughput vs Number of Rounds

In [None]:
plt.figure(figsize=(10,6))
for method_name, res in results.items():
    rounds = range(1, res["last_dead_round"] + 1)
    cumulative_throughput = np.cumsum(res["throughput_history"][:res["last_dead_round"]])
    plt.plot(rounds, cumulative_throughput, label=method_name, linestyle="--", color=colors[list(METHODS.keys()).index(method_name)])
plt.xlabel("Number of Rounds")
plt.ylabel("Cumulative Throughput")
plt.title("Cumulative Throughput vs Number of Rounds")
plt.legend()
plt.grid(True)
plt.show()

### Plot 4: Total Energy Consumed vs Number of Rounds

In [None]:
plt.figure(figsize=(10, 6))

for i, (method_name, res) in enumerate(results.items()):
    # total_energy_history[r] must be: sum of remaining energy of ALL nodes at round r
    total_E = np.array(res["total_energy_history"][:res["last_dead_round"]])

    initial_total_energy = N_NODES * INIT_ENERGY  # sum of all initial energies

    # Total energy consumed by all nodes up to each round
    total_energy_consumed = initial_total_energy - total_E

    rounds = np.arange(1, len(total_energy_consumed) + 1)
    plt.plot(
        rounds,
        total_energy_consumed,
        label=method_name,
        linestyle="--",
        color=colors[i],
    )

plt.xlabel("Number of Rounds")
plt.ylabel("Total Energy Consumed (J)")
plt.title("Total Energy Consumption vs Number of Rounds")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


### Plot 5: Total PDR (%) Comparison

In [None]:
total_pdr = {method_name: np.mean(res["pdr_percent_history"][:res["last_dead_round"]])
             for method_name, res in results.items()}
plt.figure(figsize=(8,5))
plt.bar(total_pdr.keys(), total_pdr.values(), color=colors[:len(total_pdr)])
plt.ylabel("Total PDR (%)")
plt.title("Total Packet Delivery Ratio Comparison")
plt.xticks(rotation=45, ha='right')
plt.grid(True)
plt.tight_layout()
plt.show()

print("\nTotal PDR per method:")
print("-" * 40)
print(f"{'Method':<25} {'PDR (%)':>10}")
print("-" * 40)
for method, pdr in total_pdr.items():
    print(f"{method:<25} {pdr:>10.2f}")
print("-" * 40)

## 9. Results Summary

### Node Death Summary Table

In [None]:
death_table = pd.DataFrame({
    "Method": list(results.keys()),
    "First Node Death": [res["first_dead_round"] for res in results.values()],
    "Half Nodes Death": [res["half_dead_round"] for res in results.values()],
    "Last Node Death": [res["last_dead_round"] for res in results.values()]
})
print("\nNode Death Summary Table:")
display(death_table)

In [None]:
death_table = pd.DataFrame({
    "Method": list(results.keys()),
    "First Node Death": [res["first_dead_round"] for res in results.values()],
    "Half Nodes Death": [res["half_dead_round"] for res in results.values()],
    "Last Node Death": [res["last_dead_round"] for res in results.values()]
})

# Stable region length (rounds 1 to FND-1; if FND is None, use total ROUNDS)
stable_region = []
unstable_region = []

for method_name, res in results.items():
    fnd = res["first_dead_round"]
    lnd = res["last_dead_round"]

    # If FND not reached, stable region = total simulated rounds
    if fnd is None:
        stable_len = ROUNDS
    else:
        stable_len = max(0, fnd - 1)

    # Unstable region = from FND to LND
    if fnd is None or lnd is None:
        unstable_len = 0
    else:
        unstable_len = max(0, lnd - fnd + 1)

    stable_region.append(stable_len)
    unstable_region.append(unstable_len)

death_table["Stable Region (rounds)"] = stable_region
death_table["Unstable Region (rounds)"] = unstable_region

print("\nNode Death & Stability Summary Table:")
display(death_table)


In [None]:
def lifetime_until_fraction_dead(alive_history, fraction_dead, total_nodes):
    # fraction_dead = 0.75 for 75% dead
    threshold_alive = int(np.ceil((1 - fraction_dead) * total_nodes))
    for rnd, alive in enumerate(alive_history, start=1):
        if alive <= threshold_alive:
            return rnd
    return None  # not reached within simulated rounds


In [None]:
lifetime_75_dead = []

for method_name, res in results.items():
    alive_hist = res["alive_nodes_history"]  # ensure you stored this
    round_75 = lifetime_until_fraction_dead(alive_hist, 0.75, N_NODES)
    lifetime_75_dead.append(round_75)

death_table["75% Nodes Dead (round)"] = lifetime_75_dead

display(death_table)


In [None]:
def plot_routing_snapshot_round_global(
    method_name,
    rnd,
    nodes_pos,
    sector_ids,
    base_station,
    area_size,
    sector_rows,
    sector_cols
):

    info = GLOBAL_ROUTING_INFO.get(method_name, {})
    chs_dict = info.get("chs", {})
    clusters_dict = info.get("clusters", {})
    links_dict = info.get("links", {})

    chs = chs_dict.get(rnd, [])
    clusters = clusters_dict.get(rnd, [])
    links = links_dict.get(rnd, [])

    fig, ax = plt.subplots(figsize=(7, 7))

    # 1) Sector grid
    cell_w = area_size / sector_cols
    cell_h = area_size / sector_rows
    for r in range(sector_rows):
        for c in range(sector_cols):
            x0 = c * cell_w
            y0 = r * cell_h
            ax.add_patch(
                plt.Rectangle(
                    (x0, y0),
                    cell_w,
                    cell_h,
                    fill=False,
                    edgecolor="lightgray",
                    linewidth=0.7,
                    zorder=0
                )
            )

    # 2) Nodes by sector
    nodes_pos_arr = np.array(nodes_pos)
    ax.scatter(
        nodes_pos_arr[:, 0],
        nodes_pos_arr[:, 1],
        c=sector_ids,
        cmap="tab10",
        s=40,
        edgecolor="black",
        zorder=2,
        label="Nodes"
    )

    # 3) CHs
    if len(chs) > 0:
        ch_positions = nodes_pos_arr[chs]
        ax.scatter(
            ch_positions[:, 0],
            ch_positions[:, 1],
            s=120,
            marker="s",
            color="yellow",
            edgecolor="red",
            linewidth=1.2,
            zorder=3,
            label="CHs"
        )

    # 4) Member -> CH links (from clusters & chs)
    for cidx, cluster_nodes in enumerate(clusters):
        if not cluster_nodes:
            continue
        if cidx >= len(chs):
            continue
        ch_node = chs[cidx]
        cx, cy = nodes_pos[ch_node]
        for node in cluster_nodes:
            if node == ch_node:
                continue
            nx, ny = nodes_pos[node]
            ax.plot(
                [nx, cx],
                [ny, cy],
                color="lightgray",
                linewidth=0.6,
                alpha=0.9,
                zorder=1
            )

    # 5) Routing links (CH-CH, CH-BS, member-CH as logged)
    ch_set = set(chs)
    bx, by = base_station

    for (src, dst) in links:
        if src < 0 or src >= len(nodes_pos):
            continue

        x1, y1 = nodes_pos[src]

        # Decide color and target
        if dst == -1:
            # CH -> BS
            x2, y2 = bx, by
            color = "green"   # CH–BS
            z = 3.0
        else:
            if dst < 0 or dst >= len(nodes_pos):
                continue
            x2, y2 = nodes_pos[dst]

            if src in ch_set and dst in ch_set:
                # CH -> CH (inter-cluster hop)
                color = "red"
                z = 3.0
            else:
                # other routing link (e.g., member -> CH logged, etc.)
                color = "blue"
                z = 2.5

        ax.plot(
            [x1, x2],
            [y1, y2],
            color=color,
            linewidth=1.0,
            alpha=0.8,
            zorder=z
        )

    # 6) Base station
    ax.scatter(
        [bx],
        [by],
        s=180,
        marker="*",
        color="red",
        edgecolor="black",
        linewidth=1.2,
        zorder=4,
        label="BS"
    )

    # 7) Aesthetics
    ax.set_xlim(0, area_size)
    ax.set_ylim(0, area_size)
    ax.set_aspect("equal", adjustable="box")
    ax.set_title(f"{method_name} – Routing & Clusters – Round {rnd}")
    ax.set_xlabel("X (m)")
    ax.set_ylabel("Y (m)")

    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc="upper right")

    plt.tight_layout()
    plt.show()


In [None]:
# After run_proposed() and after simulation:
round_to_show = 50
method_name = "RLHC_PROPOSED"
base_station = get_bs_position(round_to_show)

plot_routing_snapshot_round_global(
    method_name=method_name,
    rnd=round_to_show,
    nodes_pos=nodes_pos,
    sector_ids=sector_ids,
    base_station=base_station,
    area_size=AREA_SIZE,
    sector_rows=SECTOR_ROWS,
    sector_cols=SECTOR_COLS
)


In [None]:
def animate_routing_all_rounds(
    method_name,
    nodes_pos,
    sector_ids,
    area_size,
    sector_rows,
    sector_cols,
    rounds,
    interval=300  # ms between frames
):
    """
    Animate routing and clusters across all rounds for a given method
    using GLOBAL_ROUTING_INFO.
    """
    info = GLOBAL_ROUTING_INFO.get(method_name, {})
    chs_dict = info.get("chs", {})
    clusters_dict = info.get("clusters", {})
    links_dict = info.get("links", {})

    nodes_pos_arr = np.array(nodes_pos)

    fig, ax = plt.subplots(figsize=(7, 7))

    # -----------------------------
    # 1) Draw sector grid (static)
    # -----------------------------
    cell_w = area_size / sector_cols
    cell_h = area_size / sector_rows
    for r in range(sector_rows):
        for c in range(sector_cols):
            x0 = c * cell_w
            y0 = r * cell_h
            ax.add_patch(
                plt.Rectangle(
                    (x0, y0),
                    cell_w,
                    cell_h,
                    fill=False,
                    edgecolor="lightgray",
                    linewidth=0.7,
                    zorder=0
                )
            )

    # -----------------------------
    # 2) Static node scatter (nodes do not move)
    # -----------------------------
    node_scatter = ax.scatter(
        nodes_pos_arr[:, 0],
        nodes_pos_arr[:, 1],
        c=sector_ids,
        cmap="tab10",
        s=40,
        edgecolor="black",
        zorder=2,
        label="Nodes"
    )

    # -----------------------------
    # 3) CH markers (updated per frame)
    # -----------------------------
    ch_scatter = ax.scatter(
        [], [],  # will be updated
        s=120,
        marker="s",
        color="yellow",
        edgecolor="red",
        linewidth=1.2,
        zorder=3,
        label="CHs"
    )

    # -----------------------------
    # 4) BS marker (updated per frame)
    # -----------------------------
    bs_scatter = ax.scatter(
        [], [],
        s=180,
        marker="*",
        color="red",
        edgecolor="black",
        linewidth=1.2,
        zorder=4,
        label="BS"
    )

    # -----------------------------
    # 5) Line containers for:
    #    - member→CH links (intra-cluster)
    #    - routing links (CH-CH, CH-BS, member-CH as logged)
    # -----------------------------
    intra_lines = []   # member -> CH (light gray)
    route_lines = []   # routing links (blue)

    # Setup axes
    ax.set_xlim(0, area_size)
    ax.set_ylim(0, area_size)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("X (m)")
    ax.set_ylabel("Y (m)")
    title_text = ax.set_title("")

    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc="upper right")

    # -----------------------------
    # 6) Update function for each frame
    # -----------------------------
    def init():
        # nothing else to init; lines will be created in first update
        return node_scatter, ch_scatter, bs_scatter

    def update(frame_idx):
        rnd = frame_idx + 1  # assuming rounds start at 1
        if rnd > rounds:
            rnd = rounds

        # Get data for this round
        chs = chs_dict.get(rnd, [])
        clusters = clusters_dict.get(rnd, [])
        links = links_dict.get(rnd, [])
        print(f"Updating Frame={frame_idx}, Round={rnd}, CHs={len(chs)}, Links={len(links)}")

        # Update title
        title_text.set_text(f"{method_name} – Routing & Clusters – Round {rnd}")

        # Update CH scatter
        if len(chs) > 0:
            ch_pos = nodes_pos_arr[chs]
            ch_scatter.set_offsets(ch_pos)
        else:
            ch_scatter.set_offsets(np.empty((0, 2)))

        # Update BS position
        bx, by = get_bs_position(rnd)
        bs_scatter.set_offsets(np.array([[bx, by]]))

        # Clear old lines
        for ln in intra_lines:
            ln.remove()
        intra_lines.clear()

        for ln in route_lines:
            ln.remove()
        route_lines.clear()

        # Draw member -> CH (from clusters & chs) in light gray
        for cidx, cluster_nodes in enumerate(clusters):
            if not cluster_nodes:
                continue
            if cidx >= len(chs):
                continue
            ch_node = chs[cidx]
            cx, cy = nodes_pos[ch_node]
            for node in cluster_nodes:
                if node == ch_node:
                    continue
                nx, ny = nodes_pos[node]
                ln, = ax.plot(
                    [nx, cx],
                    [ny, cy],
                    color="lightgray",
                    linewidth=0.6,
                    alpha=0.9,
                    zorder=1
                )
                intra_lines.append(ln)

        # ---------------------------------------
        # Draw routing links with different colors
        #   - CH -> BS : green
        #   - CH -> CH : red (inter-cluster)
        #   - others   : blue (default)
        # ---------------------------------------
        ch_set = set(chs)

        for (src, dst) in links:
            if src < 0 or src >= len(nodes_pos):
                continue
            x1, y1 = nodes_pos[src]

            # Determine target position and link type
            if dst == -1:
                # CH -> BS
                x2, y2 = bx, by
                color = "green"
                z = 3.0
            else:
                if dst < 0 or dst >= len(nodes_pos):
                    continue
                x2, y2 = nodes_pos[dst]

                if src in ch_set and dst in ch_set:
                    # CH -> CH (inter-cluster hop)
                    color = "red"
                    z = 3.0
                else:
                    # other routing link (e.g., member -> CH logged, etc.)
                    color = "blue"
                    z = 2.5

            ln, = ax.plot(
                [x1, x2],
                [y1, y2],
                color=color,
                linewidth=1.2,
                alpha=0.9,
                zorder=z
            )
            route_lines.append(ln)

        return node_scatter, ch_scatter, bs_scatter, title_text


    # -----------------------------
    # 7) Create animation
    # -----------------------------
    ani = animation.FuncAnimation(
        fig,
        update,
        frames=rounds,
        init_func=init,
        interval=interval,
        blit=False,  # easier with multiple artists
        repeat=True
    )

    plt.tight_layout()
    # plt.show()

    return ani


In [None]:
import matplotlib as mpl
mpl.rcParams['animation.embed_limit'] = 200  # in MB

ani = animate_routing_all_rounds(
    method_name="RLHC_PROPOSED",
    nodes_pos=nodes_pos,
    sector_ids=sector_ids,
    area_size=AREA_SIZE,
    sector_rows=SECTOR_ROWS,
    sector_cols=SECTOR_COLS,
    rounds=ROUNDS,
    interval=300
)

HTML(ani.to_jshtml())
