In [None]:
import os, time, copy, random
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset

import matplotlib.pyplot as plt
from matplotlib import rcParams

from tensorflow import keras
!pip install codecarbon
from codecarbon import EmissionsTracker


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)


(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

x_train = x_train.astype(np.float32) / 255.0
x_test  = x_test.astype(np.float32) / 255.0
y_train = y_train.astype(np.int64).squeeze()
y_test  = y_test.astype(np.int64).squeeze()


cifar10_mean = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
cifar10_std  = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)
x_train = (x_train - cifar10_mean) / cifar10_std
x_test  = (x_test  - cifar10_mean) / cifar10_std


x_train_t = torch.from_numpy(np.transpose(x_train, (0, 3, 1, 2))).float()
x_test_t  = torch.from_numpy(np.transpose(x_test,  (0, 3, 1, 2))).float()
y_train_t = torch.from_numpy(y_train).long()
y_test_t  = torch.from_numpy(y_test).long()

train_ds = TensorDataset(x_train_t, y_train_t)
test_ds  = TensorDataset(x_test_t,  y_test_t)


class SmallCIFAR10CNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool  = nn.MaxPool2d(2)  # 32->16->8->4
        self.fc1   = nn.Linear(128 * 4 * 4, 256)
        self.fc2   = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 32->16
        x = self.pool(F.relu(self.conv2(x)))  # 16->8
        x = self.pool(F.relu(self.conv3(x)))  # 8->4
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

def get_model_num_bytes(model: nn.Module) -> int:
    # float32 params ~4 bytes
    return sum(p.numel() for p in model.parameters()) * 4

def _conv2d_flops(conv: nn.Conv2d, out_h: int, out_w: int) -> int:
    # FLOPs per output element: (Cin/groups * K*K) muls + same adds ~ 2 * muls
    cin = conv.in_channels
    cout = conv.out_channels
    k_h, k_w = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size)
    groups = conv.groups
    muls_per_out = (cin // groups) * k_h * k_w
    flops_per_out = 2 * muls_per_out
    return int(cout * out_h * out_w * flops_per_out)

def _linear_flops(fc: nn.Linear) -> int:
    # per sample: 2 * in_features * out_features
    return int(2 * fc.in_features * fc.out_features)

@torch.no_grad()
def estimate_model_flops_per_sample(model: nn.Module, input_shape=(1, 3, 32, 32)) -> int:
    """
    Rough forward-pass FLOPs per single sample for this specific model.
    Counts Conv2d + Linear only. ReLU/Pool ignored (small vs conv).
    """
    model.eval()
    x = torch.zeros(input_shape, device=next(model.parameters()).device)

    out_h, out_w = 32, 32
    flops = 0

    flops += _conv2d_flops(model.conv1, out_h, out_w)
    # pool: 32->16
    out_h, out_w = 16, 16

    flops += _conv2d_flops(model.conv2, out_h, out_w)
    # pool: 16->8
    out_h, out_w = 8, 8

    flops += _conv2d_flops(model.conv3, out_h, out_w)
    # pool: 8->4
    out_h, out_w = 4, 4

    # flatten: 128*4*4
    flops += _linear_flops(model.fc1)
    flops += _linear_flops(model.fc2)

    return int(flops)

# Eval / FedAvg

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

def fedavg(global_model: nn.Module, client_states: List[Dict[str, torch.Tensor]], client_sizes: List[int]):
    total = float(sum(client_sizes))
    new_state = {}
    for k in client_states[0].keys():
        new_state[k] = sum(state[k].float() * (sz / total) for state, sz in zip(client_states, client_sizes))
    global_model.load_state_dict(new_state)

def train_one_client(
    base_model: nn.Module,
    train_loader: DataLoader,
    lr: float,
    local_epochs: int,
) -> Tuple[Dict[str, torch.Tensor], float, float]:
    """
    Returns:
      - state_dict after local training
      - t_comp_sec
      - E_comp_Wh measured by CodeCarbon (best effort)
    """
    model = copy.deepcopy(base_model).to(DEVICE)
    model.train()

    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()

    output_dir = "./codecarbon_logs"
    os.makedirs(output_dir, exist_ok=True)

    tracker = EmissionsTracker(
        project_name="fl_client_train",
        output_dir=output_dir,
        log_level="error",
        measure_power_secs=1,
        save_to_file=False,
    )

    start_t = time.perf_counter()
    tracker.start()

    for _ in range(local_epochs):
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()

    tracker.stop()
    end_t = time.perf_counter()

    t_comp_sec = end_t - start_t


    E_comp_kWh = None
    try:
        E_comp_kWh = tracker.final_emissions_data.energy_consumed
    except Exception:
        try:
            E_comp_kWh = tracker._last_emissions_data.energy_consumed
        except Exception:
            E_comp_kWh = None

    E_comp_Wh = float(E_comp_kWh) * 1000.0 if E_comp_kWh is not None else 0.0
    return model.state_dict(), t_comp_sec, E_comp_Wh


NUM_CLIENTS = 10
BATCH_SIZE  = 64

indices = np.arange(len(train_ds))
np.random.shuffle(indices)
splits = np.array_split(indices, NUM_CLIENTS)

client_train_loaders = []
client_num_samples = []
for cid in range(NUM_CLIENTS):
    subset = Subset(train_ds, splits[cid].tolist())
    client_train_loaders.append(DataLoader(subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True))
    client_num_samples.append(len(subset))

test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)


@dataclass
class ClientMeta:
    bandwidth_mbps: float
    tx_power_w: float
    battery_Wh: float

client_meta: Dict[int, ClientMeta] = {}
for i in range(NUM_CLIENTS):
    bandwidth = random.uniform(2.0, 20.0)   # Mbps
    tx_power  = random.uniform(1.0, 3.0)    # W
    battery   = random.uniform(10.0, 40.0)  # Wh
    client_meta[i] = ClientMeta(bandwidth, tx_power, battery)

def comm_time_sec(model_bytes: int, bandwidth_mbps: float) -> float:
    bits = model_bytes * 8.0
    bps  = bandwidth_mbps * 1e6
    return bits / max(bps, 1.0)

def upload_energy_Wh(t_upload_sec: float, tx_power_w: float) -> float:
    return tx_power_w * (t_upload_sec / 3600.0)


def greedy_select_clients(
    I: List[int],
    t_comp: Dict[int, float],
    t_up: Dict[int, float],
    E_comp: Dict[int, float],
    E_up: Dict[int, float],
    E_batt: Dict[int, float],
    tau_sec: float,
    alpha: float
) -> List[int]:
    # Feasibility filter (energy + time)
    I_prime = []
    for i in I:
        if (E_comp[i] + E_up[i] <= E_batt[i]) and (t_comp[i] + t_up[i] <= tau_sec):
            I_prime.append(i)

    def score(i):
        return (alpha * E_comp[i] + E_up[i]) / (1.0 + alpha)

    I_prime_sorted = sorted(I_prime, key=score)

    x = {i: 0 for i in I_prime_sorted}
    J: List[int] = []

    for i in I_prime_sorted:
        x[i] = 1
        J = [k for k in I_prime_sorted if x[k] == 1]
        if len(J) == 0:
            continue

        # synchronous FL waiting time
        round_time = max(t_comp[k] + t_up[k] for k in J)
        if round_time > tau_sec:
            x[i] = 0
            J = [k for k in I_prime_sorted if x[k] == 1]

    return J


ROUNDS = 5
LR = 0.05
LOCAL_EPOCHS_MIN = 1
LOCAL_EPOCHS_MAX = 1

TAU_SEC = 20.0
ALPHA   = 0.7

global_model = SmallCIFAR10CNN().to(DEVICE)


round_ids = []
round_acc = []
round_params = []
round_flops = []
round_avg_train_time = []
round_max_train_time = []
round_selected_k = []


history = []

for r in range(1, ROUNDS + 1):
    print(f"\n================ Round {r}/{ROUNDS} ================")

    t_comp, t_up, E_comp, E_up, E_batt = {}, {}, {}, {}, {}
    model_bytes = get_model_num_bytes(global_model)
    local_epochs = {i: random.randint(LOCAL_EPOCHS_MIN, LOCAL_EPOCHS_MAX) for i in range(NUM_CLIENTS)}

    client_trained_states = {}

    for i in range(NUM_CLIENTS):
        meta = client_meta[i]
        t_up[i] = comm_time_sec(model_bytes, meta.bandwidth_mbps)
        E_up[i] = upload_energy_Wh(t_up[i], meta.tx_power_w)
        E_batt[i] = meta.battery_Wh

        state_i, t_i, E_i = train_one_client(
            base_model=global_model,
            train_loader=client_train_loaders[i],
            lr=LR,
            local_epochs=local_epochs[i],
        )
        client_trained_states[i] = state_i
        t_comp[i] = t_i
        E_comp[i] = E_i

    # Select clients
    selected = greedy_select_clients(
        I=list(range(NUM_CLIENTS)),
        t_comp=t_comp,
        t_up=t_up,
        E_comp=E_comp,
        E_up=E_up,
        E_batt=E_batt,
        tau_sec=TAU_SEC,
        alpha=ALPHA
    )

    if len(selected) == 0:
        print("No clients satisfy constraints. Relax TAU_SEC or increase battery_Wh.")
        acc = evaluate(global_model, test_loader)
        history.append((r, 0, acc, 0.0, 0.0))
        print(f"Test accuracy: {acc:.4f}")


        round_ids.append(r)
        round_selected_k.append(0)
        round_acc.append(acc)
        round_params.append(count_params(global_model))
        round_flops.append(estimate_model_flops_per_sample(global_model, (1, 3, 32, 32)))
        round_avg_train_time.append(0.0)
        round_max_train_time.append(0.0)
        continue

    # Aggregate
    selected_states = [client_trained_states[i] for i in selected]
    selected_sizes  = [client_num_samples[i] for i in selected]
    fedavg(global_model, selected_states, selected_sizes)

    # Evaluate
    acc = evaluate(global_model, test_loader)

    # Energy summary
    round_Ecomp_Wh = float(sum(E_comp[i] for i in selected))
    round_Eup_Wh   = float(sum(E_up[i]   for i in selected))

    # Training time summary (selected clients)
    sel_times = [t_comp[i] for i in selected]
    avg_t = float(np.mean(sel_times)) if len(sel_times) else 0.0
    max_t = float(np.max(sel_times))  if len(sel_times) else 0.0

 # logging (per round) ----
    params_r = count_params(global_model)
    flops_r  = estimate_model_flops_per_sample(global_model, (1, 3, 32, 32))

    round_ids.append(r)
    round_selected_k.append(len(selected))
    round_acc.append(acc)
    round_params.append(params_r)
    round_flops.append(flops_r)
    round_avg_train_time.append(avg_t)
    round_max_train_time.append(max_t)

    print("Selected clients:", selected)
    print(f"Test accuracy: {acc:.4f}")
    print(f"Round E_comp (Wh) [CodeCarbon]: {round_Ecomp_Wh:.4f}")
    print(f"Round E_upload (Wh) [model]:   {round_Eup_Wh:.6f}")
    print(f"Round total energy (Wh):       {round_Ecomp_Wh + round_Eup_Wh:.4f}")
    print(f"Avg selected client train time (s): {avg_t:.3f}")
    print(f"Max selected client train time (s): {max_t:.3f}")
    print(f"Model params: {params_r:,}")
    print(f"Model FLOPs/sample (forward): {flops_r:,}")

    history.append((r, len(selected), acc, round_Ecomp_Wh, round_Eup_Wh))

print("\n================ Summary ================")
print("Round | #Selected | TestAcc | E_comp_Wh | E_upload_Wh | Total_Wh")
for (r, k, acc, ec, eu) in history:
    print(f"{r:5d} | {k:9d} | {acc:6.4f} | {ec:9.4f} | {eu:11.6f} | {ec+eu:8.4f}")


rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
    "axes.linewidth": 0.8,
    "axes.labelsize": 11,
    "axes.titlesize": 12,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "figure.dpi": 300,
})

round_ids = np.array(round_ids)
round_acc = np.array(round_acc) * 100.0  # to %
round_params = np.array(round_params)
round_flops = np.array(round_flops)
round_avg_train_time = np.array(round_avg_train_time)
round_selected_k = np.array(round_selected_k)


plt.figure(figsize=(5.2, 3.4))
plt.plot(round_params, round_acc, marker="o")
for i, rid in enumerate(round_ids):
    plt.annotate(f"R{rid}", (round_params[i], round_acc[i]), textcoords="offset points", xytext=(5, 4), fontsize=8)
plt.xlabel("Model Parameters (count)")
plt.ylabel("Test Accuracy (%)")
plt.title("Experiment 4A — Accuracy vs Params (per round)")
plt.grid(True, linestyle="--", alpha=0.35)
plt.tight_layout()
plt.show()


plt.figure(figsize=(5.2, 3.4))
plt.plot(round_flops, round_acc, marker="o")
for i, rid in enumerate(round_ids):
    plt.annotate(f"R{rid}", (round_flops[i], round_acc[i]), textcoords="offset points", xytext=(5, 4), fontsize=8)
plt.xlabel("Model FLOPs per Sample (forward pass)")
plt.ylabel("Test Accuracy (%)")
plt.title("Experiment 4B — Accuracy vs FLOPs (per round)")
plt.grid(True, linestyle="--", alpha=0.35)
plt.tight_layout()
plt.show()


plt.figure(figsize=(5.2, 3.4))
plt.plot(round_avg_train_time, round_acc, marker="o")
for i, rid in enumerate(round_ids):
    plt.annotate(f"R{rid} (k={round_selected_k[i]})", (round_avg_train_time[i], round_acc[i]),
                 textcoords="offset points", xytext=(5, 4), fontsize=8)
plt.xlabel("Avg Selected Client Training Time per Round (s)")
plt.ylabel("Test Accuracy (%)")
plt.title("Experiment 5C — Accuracy vs Client Training Time (per round)")
plt.grid(True, linestyle="--", alpha=0.35)
plt.tight_layout()
plt.show()