In [None]:

!pip -q install torch torchvision torchaudio codecarbon pandas numpy scikit-learn >/dev/null 2>&1 || true

import os, time, random, numpy as np, pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)


from tensorflow import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
y_train = y_train.squeeze().astype(np.int64)
y_test  = y_test.squeeze().astype(np.int64)


x_train = (x_train.astype(np.float32) / 255.0).transpose(0, 3, 1, 2)  # (N,3,32,32)
x_test  = (x_test.astype(np.float32) / 255.0).transpose(0, 3, 1, 2)

class CIFAR10NP(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = X
        self.y = y
    def __len__(self): return len(self.y)
    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), int(self.y[i])

full_train = CIFAR10NP(x_train, y_train)
full_test  = CIFAR10NP(x_test, y_test)


def dirichlet_split_noniid_with_size(labels, n_clients=10, alpha=0.3, size_skew=True, min_size=50):
    labels = np.array(labels)
    n_classes = int(labels.max()) + 1
    while True:
        per_client = [[] for _ in range(n_clients)]
        for c in range(n_classes):
            idx = np.where(labels == c)[0]
            np.random.shuffle(idx)
            props = np.random.dirichlet([alpha] * n_clients)
            if size_skew:
                w = np.random.lognormal(mean=0.0, sigma=0.5, size=n_clients)
                props = props * w
                props = props / props.sum()
            splits = (np.cumsum(props) * len(idx)).astype(int)[:-1]
            chunks = np.split(idx, splits)
            for i, ch in enumerate(chunks):
                per_client[i].extend(ch.tolist())

        ok = all(len(x) >= min_size for x in per_client)
        if ok:
            for i in range(n_clients):
                random.shuffle(per_client[i])
            return per_client

client_indices = dirichlet_split_noniid_with_size(
    labels=y_train, n_clients=10, alpha=0.3, size_skew=True, min_size=50
)


from codecarbon import EmissionsTracker

class CCTracker:
    """CodeCarbon tracker returning energy in Wh."""
    def __init__(self, name="section"):
        self.name = name
        self.tracker = None
    def start(self):

        self.tracker = EmissionsTracker(project_name=self.name, measure_power_secs=1)
        self.tracker.start()
    def stop_wh(self) -> float:
        try:
            self.tracker.stop()
        finally:
            pass
        kwh = 0.0

        if hasattr(self.tracker, "_total_energy") and hasattr(self.tracker._total_energy, "kWh"):
            kwh = float(self.tracker._total_energy.kWh)
        elif hasattr(self.tracker, "final_emissions_data") and self.tracker.final_emissions_data:

            kwh = float(self.tracker.final_emissions_data.energy_consumed or 0.0)
        return 1000.0 * kwh


# Homogeneous CNN search space

class HomoCNN(nn.Module):
    """
    CNN defined by (depth, width, kernel).
    depth âˆˆ {2,3,4}, width âˆˆ {16,32,64}, kernel âˆˆ {3,5}
    """
    def __init__(self, num_classes=10, depth=3, width=32, kernel=3):
        super().__init__()
        self._arch = {"depth": depth, "width": width, "kernel": kernel}
        k = kernel
        chs = [width, width*2, width*2, width*2][:depth]
        layers = []
        in_ch = 3
        for i, out_ch in enumerate(chs):
            layers += [
                nn.Conv2d(in_ch, out_ch, k, padding=k//2, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            ]
            if i < 2:
                layers += [nn.MaxPool2d(2)]
            in_ch = out_ch
        self.features = nn.Sequential(*layers)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(in_ch, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.head(x)

def build_model(arch: Dict):
    m = HomoCNN(depth=arch["depth"], width=arch["width"], kernel=arch["kernel"]).to(DEVICE)
    m._arch = arch.copy()
    return m

def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

def _conv2d_flops(conv: nn.Conv2d, out_h: int, out_w: int) -> int:
    cin = conv.in_channels
    cout = conv.out_channels
    k_h, k_w = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size)
    groups = conv.groups
    muls_per_out = (cin // groups) * k_h * k_w
    # multiply + add â‰ˆ 2*muls
    return int(cout * out_h * out_w * (2 * muls_per_out))

def _linear_flops(fc: nn.Linear) -> int:
    return int(2 * fc.in_features * fc.out_features)

@torch.no_grad()
def estimate_homocnn_flops_per_sample(model: HomoCNN, input_hw=(32, 32)) -> int:
    """
    Forward FLOPs per sample for HomoCNN (Conv+Linear only).
    Ignores BN/ReLU/Pool FLOPs (small vs conv).
    """
    H, W = input_hw
    flops = 0

    for layer in model.features:
        if isinstance(layer, nn.Conv2d):
            flops += _conv2d_flops(layer, H, W)
        elif isinstance(layer, nn.MaxPool2d):
            H //= 2
            W //= 2

    for layer in model.head:
        if isinstance(layer, nn.Linear):
            flops += _linear_flops(layer)

    return int(flops)


def evaluate(model: nn.Module, dataset: Dataset, batch_size=256):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    model.eval(); correct=0; total=0
    with torch.no_grad():
        for xb,yb in loader:
            xb,yb = xb.to(DEVICE), yb.to(DEVICE)
            pred = model(xb).argmax(1)
            correct += (pred==yb).sum().item(); total += yb.size(0)
    return correct/total

def evaluate_loader(model: nn.Module, loader: DataLoader):
    model.eval(); correct=0; total=0
    with torch.no_grad():
        for xb,yb in loader:
            xb,yb = xb.to(DEVICE), yb.to(DEVICE)
            pred = model(xb).argmax(1)
            correct += (pred==yb).sum().item(); total += yb.size(0)
    return (correct/total) if total>0 else 0.0


class Client:
    def __init__(self, cid: int, indices: List[int], batch_size=32):
        self.cid = cid
        self.dataset = Subset(full_train, indices)
        self.loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        self.eval_loader = DataLoader(self.dataset, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

    def local_train(self, global_model: nn.Module, epochs=1, lr=1e-3, weight_decay=1e-4):

        model = build_model(global_model._arch)
        model.load_state_dict(global_model.state_dict(), strict=True)
        model.train()
        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        cc = CCTracker(name=f"client_{self.cid}")
        start_t = time.perf_counter()
        cc.start()

        for _ in range(epochs):
            for xb,yb in self.loader:
                xb,yb = xb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad()
                logits = model(xb)
                loss = F.cross_entropy(logits, yb)
                loss.backward(); opt.step()

        e_wh = cc.stop_wh()
        end_t = time.perf_counter()
        train_time_sec = float(end_t - start_t)


        local_acc = evaluate_loader(model, self.eval_loader)


        return {k: v.detach().cpu() for k,v in model.state_dict().items()}, len(self.dataset), e_wh, local_acc, train_time_sec


def fedavg(global_model: nn.Module, client_payloads: List[Tuple[dict,int,float,float,float]]):
    total = sum(n for _,n,_,_,_ in client_payloads)
    new_sd = {}
    keys = client_payloads[0][0].keys()
    for k in keys:
        new_sd[k] = sum(sd[k]*(n/total) for sd,n,_,_,_ in client_payloads)
    global_model.load_state_dict(new_sd, strict=True)
    return global_model


class RLController(nn.Module):
    """
    Controller over depth {2,3,4}, width {16,32,64}, kernel {3,5}
    Each global epoch: sample an arch, run one FL round, update controller via reward.
    """
    def __init__(self):
        super().__init__()
        self.depth_logits  = nn.Parameter(torch.zeros(3))
        self.width_logits  = nn.Parameter(torch.zeros(3))
        self.kernel_logits = nn.Parameter(torch.zeros(2))

    def sample_arch(self):
        d_cat = torch.distributions.Categorical(logits=self.depth_logits)
        w_cat = torch.distributions.Categorical(logits=self.width_logits)
        k_cat = torch.distributions.Categorical(logits=self.kernel_logits)
        d_i, w_i, k_i = d_cat.sample(), w_cat.sample(), k_cat.sample()
        arch = {"depth": [2,3,4][int(d_i)],
                "width": [16,32,64][int(w_i)],
                "kernel": [3,5][int(k_i)]}
        logp = d_cat.log_prob(d_i) + w_cat.log_prob(w_i) + k_cat.log_prob(k_i)
        return arch, logp

    def parameters_list(self):
        return [self.depth_logits, self.width_logits, self.kernel_logits]


def run_fednas_5epochs(
    epochs=5,
    local_epochs=1,
    clients_per_round=10, # all clients participate each epoch
    lr=1e-3, weight_decay=1e-4, batch_size=32,
    beta_energy=0.02      # trade-off weight for energy in reward
):
    print(f"DEVICE: {DEVICE}")

    # Build clients
    clients = [Client(i, client_indices[i], batch_size=batch_size) for i in range(10)]
    controller = RLController().to(DEVICE)
    opt_ctrl = torch.optim.Adam(controller.parameters_list(), lr=5e-2)

    global_rows = []   # one row per epoch
    client_rows = []   # one row per client per epoch

    def init_weights(m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        if isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    for ep in range(1, epochs+1):
        print("\n==============================")
        print(f" Global Epoch {ep}/{epochs} (NAS sample + 1 FL round)")
        arch, logp = controller.sample_arch()
        print(f"  Sampled arch: depth={arch['depth']}, width={arch['width']}, kernel={arch['kernel']}")

        # fresh global model for this epoch's architecture
        global_model = build_model(arch)
        global_model.apply(init_weights)

        # one FL round
        selected = sorted(random.sample(range(10), clients_per_round))
        client_payloads = []
        total_local_energy_wh = 0.0
        train_times_sec = []

        print("  Client metrics:")
        for cid in selected:
            sd, n, eWh, lacc, tsec = clients[cid].local_train(
                global_model, epochs=local_epochs, lr=lr, weight_decay=weight_decay
            )
            client_payloads.append((sd, n, eWh, lacc, tsec))
            total_local_energy_wh += eWh
            train_times_sec.append(tsec)

            print(f"    â€¢ Client {cid:02d} | samples={n:5d} | local_acc={lacc*100:6.2f}% | "
                  f"energy={eWh:.4f} Wh | time={tsec:.3f} s")

            client_rows.append({
                "epoch": ep,
                "client_id": cid,
                "samples": n,
                "local_accuracy": float(lacc),
                "local_energy_wh": float(eWh),
                "train_time_sec": float(tsec),
                "depth": arch["depth"],
                "width": arch["width"],
                "kernel": arch["kernel"],
            })


        agg_cc = CCTracker(name=f"aggregation_ep{ep}")
        agg_cc.start()
        global_model = fedavg(global_model, client_payloads)
        agg_energy_wh = agg_cc.stop_wh()


        global_acc = evaluate(global_model, full_test, batch_size=256)
        total_energy_wh = total_local_energy_wh + agg_energy_wh


        params_cnt = count_params(global_model)
        flops_ps   = estimate_homocnn_flops_per_sample(global_model, input_hw=(32,32))


        avg_tsec   = float(np.mean(train_times_sec)) if len(train_times_sec) else 0.0
        max_tsec   = float(np.max(train_times_sec))  if len(train_times_sec) else 0.0

        print(f"  Aggregation energy: {agg_energy_wh:.6f} Wh")
        print(f"  Model params: {params_cnt:,} | FLOPs/sample: {flops_ps:,}")
        print(f"  Avg client train time: {avg_tsec:.3f}s | Max client train time: {max_tsec:.3f}s")
        print(f"  â†’ GLOBAL after aggregation: Acc={global_acc*100:.2f}% | TotalEnergy={total_energy_wh:.4f} Wh")

        # REINFORCE update (reward: accuracy - beta*total_energy)
        reward = float(global_acc) - float(beta_energy) * float(total_energy_wh)
        opt_ctrl.zero_grad()
        loss = -logp.to(DEVICE) * torch.tensor(reward, dtype=torch.float32, device=DEVICE)
        loss.backward()
        opt_ctrl.step()

        global_rows.append({
            "epoch": ep,
            "depth": arch["depth"],
            "width": arch["width"],
            "kernel": arch["kernel"],
            "params": int(params_cnt),
            "flops_per_sample": int(flops_ps),
            "avg_client_train_time_sec": float(avg_tsec),
            "max_client_train_time_sec": float(max_tsec),
            "global_accuracy": float(global_acc),
            "total_local_energy_wh": float(total_local_energy_wh),
            "aggregation_energy_wh": float(agg_energy_wh),
            "total_energy_wh": float(total_energy_wh),
            "reward": float(reward),
            "k_selected": int(len(selected)),
        })


    df_global = pd.DataFrame(global_rows)
    df_clients = pd.DataFrame(client_rows)
    df_global.to_csv("results_global.csv", index=False)
    df_clients.to_csv("results_clients.csv", index=False)

    print("\n==============================")
    print(" Per-epoch GLOBAL results:")
    for r in global_rows:
        print(f"  Epoch {r['epoch']}: Acc={r['global_accuracy']*100:.2f}%, "
              f"Params={r['params']}, FLOPs={r['flops_per_sample']}, "
              f"AvgT={r['avg_client_train_time_sec']:.3f}s, "
              f"LocalEnergy={r['total_local_energy_wh']:.4f} Wh, "
              f"AggEnergy={r['aggregation_energy_wh']:.6f} Wh, "
              f"Total={r['total_energy_wh']:.4f} Wh")

    print("\nðŸ“„ Saved CSVs: results_global.csv, results_clients.csv")


    import matplotlib.pyplot as plt
    from matplotlib import rcParams

    rcParams.update({
        "font.family": "serif",
        "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
        "axes.linewidth": 0.8,
        "axes.labelsize": 11,
        "axes.titlesize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "figure.dpi": 300,
    })

    E = df_global["epoch"].to_numpy()
    acc = (df_global["global_accuracy"].to_numpy() * 100.0)
    params = df_global["params"].to_numpy()
    flops = df_global["flops_per_sample"].to_numpy()
    tavg = df_global["avg_client_train_time_sec"].to_numpy()
    ksel = df_global["k_selected"].to_numpy()


    plt.figure(figsize=(5.2, 3.4))
    plt.plot(params, acc, marker="o")
    for i in range(len(E)):
        plt.annotate(f"R{int(E[i])}", (params[i], acc[i]),
                     textcoords="offset points", xytext=(5, 4), fontsize=8)
    plt.xlabel("Model Parameters (count)")
    plt.ylabel("Test Accuracy (%)")
    plt.title("Experiment 4A â€” Accuracy vs Params (per round)")
    plt.grid(True, linestyle="--", alpha=0.35)
    plt.tight_layout()
    plt.show()


    plt.figure(figsize=(5.2, 3.4))
    plt.plot(flops, acc, marker="o")
    for i in range(len(E)):
        plt.annotate(f"R{int(E[i])}", (flops[i], acc[i]),
                     textcoords="offset points", xytext=(5, 4), fontsize=8)
    plt.xlabel("Model FLOPs per Sample (forward pass)")
    plt.ylabel("Test Accuracy (%)")
    plt.title("Experiment 4B â€” Accuracy vs FLOPs (per round)")
    plt.grid(True, linestyle="--", alpha=0.35)
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(5.2, 3.4))
    plt.plot(tavg, acc, marker="o")
    for i in range(len(E)):
        plt.annotate(f"R{int(E[i])} (k={int(ksel[i])})", (tavg[i], acc[i]),
                     textcoords="offset points", xytext=(5, 4), fontsize=8)
    plt.xlabel("Avg Client Training Time per Round (s)")
    plt.ylabel("Test Accuracy (%)")
    plt.title("Experiment 5C â€” Accuracy vs Client Training Time (per round)")
    plt.grid(True, linestyle="--", alpha=0.35)
    plt.tight_layout()
    plt.show()

    return df_global, df_clients


df_global, df_clients = run_fednas_5epochs(
    epochs=5,
    local_epochs=1,
    clients_per_round=10,
    lr=0.001,
    weight_decay=0.0001,
    batch_size=32,
    beta_energy=0.02
)

df_global.head(), df_clients.head()
