In [None]:
import os, random, copy, glob
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from tensorflow import keras

!pip install -qqq codecarbon
from codecarbon import EmissionsTracker


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Config
# =========================
DATA_CFG = {
    "img_size": (32, 32, 3),
    "num_classes": 10,
}

FL_CFG = {
    "num_clients": 10,
    "rounds": 5,
    "client_participation": 0.8,
    "batch_size": 64,
    "micro_epochs_min": 1,
    "micro_epochs_max": 3,
    "lr": 1e-3,
    "candidates_per_client": 5,
    "top_k": 2,
    "dirichlet_alpha": 0.5,
    "max_mutants": 2,
}

AGG_CFG = {"kd_T": 2.0, "kd_alpha": 0.7}

ENERGY_CFG = {
    "save_dir": "runs_energy",
    "log_level": "warning",
}

SEL_CFG = {
    "lambda_div": 0.2,
    "epsilon_explore": 0.2
}

os.makedirs(ENERGY_CFG["save_dir"], exist_ok=True)

# =========================
# Data Loading
# =========================
def load_keras_data(dataset="cifar10"):
    if dataset.lower() == "cifar10":
        (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    else:
        raise ValueError("Unsupported dataset")

    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0

    # Convert to torch tensors
    Xtr = torch.tensor(x_train).permute(0,3,1,2)  # N,H,W,C â†’ N,C,H,W
    ytr = torch.tensor(y_train.squeeze(), dtype=torch.long)
    Xte = torch.tensor(x_test).permute(0,3,1,2)
    yte = torch.tensor(y_test.squeeze(), dtype=torch.long)
    return Xtr, ytr, Xte, yte

def make_loader(X, y, batch_size=64, shuffle=True):
    return DataLoader(TensorDataset(X, y), batch_size=batch_size, shuffle=shuffle)

def dirichlet_partition(y: torch.Tensor, num_clients:int, alpha:float) -> List[np.ndarray]:
    y_np = y.numpy()
    classes = np.unique(y_np)
    idx_per_class = [np.where(y_np == c)[0] for c in classes]
    client_indices = [[] for _ in range(num_clients)]
    for cls_idxs in idx_per_class:
        np.random.shuffle(cls_idxs)
        proportions = np.random.dirichlet([alpha]*num_clients)
        cuts = (np.cumsum(proportions) * len(cls_idxs)).astype(int)[:-1]
        splits = np.split(cls_idxs, cuts)
        for i, split in enumerate(splits):
            client_indices[i].extend(split.tolist())
    for i in range(num_clients):
        random.shuffle(client_indices[i])
    return [np.array(ci, dtype=int) for ci in client_indices]

# =========================
# NAS Search Space & Model
# =========================
ACTS = {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU}
OPSET = [
    {"type": "conv3", "k": 3, "sep": False, "pool": None},
    {"type": "conv5", "k": 5, "sep": False, "pool": None},
    {"type": "sep3",  "k": 3, "sep": True,  "pool": None},
    {"type": "max2",  "k": 2, "sep": False, "pool": "max"},
]
WIDTHS = [8,16,32,64]
DEPTHS = [2,3,4,5]
ACT_CHOICES = ["relu","gelu","silu"]

def random_arch():
    depth = random.choice(DEPTHS)
    act = random.choice(ACT_CHOICES)
    chs = [random.choice(WIDTHS) for _ in range(depth)]
    ops = [random.choice(OPSET) for _ in range(depth)]
    return {"depth": depth, "channels": chs, "ops": ops, "act": act}

class SepConv(nn.Module):
    def __init__(self, c_in, c_out, k):
        super().__init__()
        pad = k//2
        self.depth = nn.Conv2d(c_in, c_in, k, padding=pad, groups=c_in, bias=False)
        self.point = nn.Conv2d(c_in, c_out, 1, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
    def forward(self, x):
        x = self.depth(x)
        x = self.point(x)
        return self.bn(x)

def make_block(c_in, c_out, op, act_name):
    if op["pool"]=="max":
        return nn.Sequential(
            nn.Conv2d(c_in, c_out, 1, bias=False), nn.BatchNorm2d(c_out),
            ACTS[act_name](), nn.MaxPool2d(op["k"])
        )
    if op["sep"]:
        conv = SepConv(c_in, c_out, op["k"])
    else:
        conv = nn.Conv2d(c_in, c_out, op["k"], padding=op["k"]//2, bias=False)
    return nn.Sequential(conv, nn.BatchNorm2d(c_out), ACTS[act_name]())

class SubNet(nn.Module):
    def __init__(self, arch, num_classes):
        super().__init__()
        self.arch = copy.deepcopy(arch)
        layers, c_in = [], 3
        for i in range(arch["depth"]):
            c_out = arch["channels"][i]
            layers.append(make_block(c_in, c_out, arch["ops"][i], arch["act"]))
            c_in = c_out
        self.features = nn.Sequential(*layers)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(c_in, num_classes)
    def forward(self, x, return_feats=False):
        x = self.features(x)
        f = self.pool(x).view(x.size(0), -1)
        logits = self.head(f)
        return (logits,f) if return_feats else logits

# =========================
# Training / KD / Aggregation / Mutation
# =========================
def mgm_proxy(model: nn.Module, loader: DataLoader, num_batches:int=1) -> float:
    model.train()
    crit = nn.CrossEntropyLoss()
    vals, seen = [], 0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        xb.requires_grad_(True)
        logits, feats = model(xb, return_feats=True)
        loss = crit(logits, yb)
        loss.backward(retain_graph=True)
        G = feats @ feats.t()
        score = G.mean().detach().item()
        gnorm = sum((p.grad.data.norm(2).item() for p in model.parameters() if p.grad is not None), 0.0)
        vals.append(score / (1.0 + gnorm))
        model.zero_grad(set_to_none=True)
        xb.requires_grad_(False)
        seen += 1
        if seen >= num_batches:
            break
    return float(np.mean(vals)) if vals else 0.0

def train_epochs(model, loader, epochs=1, lr=1e-3):
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss()
    for _ in range(epochs):
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(); loss = crit(model(xb), yb)
            loss.backward(); opt.step()

def evaluate(model, loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    crit = nn.CrossEntropyLoss(reduction="sum")
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb)
            loss_sum += crit(logits, yb).item()
            pred = logits.argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    return correct/total, loss_sum/total

def kd_train(student, teacher_logits_fn, loader, epochs=1, lr=1e-3, T=2.0, alpha=0.7):
    student.train()
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()
    for _ in range(epochs):
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            with torch.no_grad():
                tlogits = teacher_logits_fn(xb)
            slogits = student(xb)
            kd = F.kl_div(F.log_softmax(slogits/T, dim=1), F.softmax(tlogits/T, dim=1), reduction="batchmean")*(T*T)
            loss = alpha*kd + (1-alpha)*ce(slogits, yb)
            opt.zero_grad(); loss.backward(); opt.step()

def kd_from_hetero(global_model, teachers: List[nn.Module], loader, epochs=1, T=2.0, alpha=0.7, lr=1e-3):
    def teacher_logits_fn(x):
        outs = []
        for t in teachers:
            t.eval()
            with torch.no_grad():
                outs.append(t(x))
        return torch.stack(outs, dim=0).mean(0)
    kd_train(global_model, teacher_logits_fn, loader, epochs=epochs, lr=lr, T=T, alpha=alpha)

def aligned_average(global_model: nn.Module, client_models: List[nn.Module]):
    gstate = global_model.state_dict()
    sums = {k: torch.zeros_like(v) for k, v in gstate.items()}
    counts = {k: 0 for k in gstate.keys()}
    for m in client_models:
        s = m.state_dict()
        for k, v in s.items():
            if k in sums and v.shape == sums[k].shape:
                sums[k] += v; counts[k] += 1
    for k in gstate.keys():
        if counts[k] > 0:
            gstate[k] = sums[k]/counts[k]
    global_model.load_state_dict(gstate, strict=False)

def project_align_and_avg(global_model, client_models: List[nn.Module], loader=None, epochs=1, lr=1e-3):
    if loader is not None:
        kd_from_hetero(global_model, client_models, loader, epochs=epochs, T=2.0, alpha=0.7, lr=lr)

def mutate_arch(arch):
    a = copy.deepcopy(arch)
    ch = random.choice(["op", "width", "act", "depth"])
    if ch == "op":
        i = random.randrange(a["depth"]); a["ops"][i] = random.choice(OPSET)
    elif ch == "width":
        i = random.randrange(a["depth"]); a["channels"][i] = random.choice(WIDTHS)
    elif ch == "act":
        a["act"] = random.choice(ACT_CHOICES)
    else:
        if a["depth"] < max(DEPTHS) and random.random() < 0.5:
            a["depth"] += 1; a["channels"].append(random.choice(WIDTHS)); a["ops"].append(random.choice(OPSET))
        elif a["depth"] > min(DEPTHS):
            a["depth"] -= 1; a["channels"].pop(); a["ops"].pop()
    return a

# =========================
# CodeCarbon
# =========================
def start_tracker(name):
    tracker = EmissionsTracker(
        project_name=name,
        output_dir=ENERGY_CFG["save_dir"],
        save_to_file=True,
        measure_power_secs=1,
        log_level=ENERGY_CFG["log_level"],
        emissions_endpoint=None
    )
    tracker.start()
    return tracker

def stop_tracker(tracker):
    try:
        tracker.stop()
    except Exception:
        pass

def last_run_energy_Wh():
    files = sorted(glob.glob(os.path.join(ENERGY_CFG["save_dir"], "*.csv")), key=os.path.getmtime)
    if not files:
        return 0.0
    try:
        df = pd.read_csv(files[-1])
        if "energy_consumed" in df.columns:
            kwh = float(df["energy_consumed"].iloc[-1])
            return kwh * 1000.0  # Wh
    except Exception:
        pass
    return 0.0

# =========================
# Diversity & Client Selection
# =========================
def arch_signature(arch):
    depth = arch["depth"]
    mean_w = int(np.mean(arch["channels"]))
    act = arch["act"]
    ops = tuple(o["type"] for o in arch["ops"])
    return (depth, mean_w, act, ops)

def arch_distance(sig_a, sig_b):
    if sig_a is None or sig_b is None:
        return 0.0
    d = 0.0
    d += abs(sig_a[0] - sig_b[0]) / 3.0
    d += abs(sig_a[1] - sig_b[1]) / 64.0
    d += 0.5 * (0 if sig_a[2] == sig_b[2] else 1)
    ops_a, ops_b = sig_a[3], sig_b[3]
    L = max(len(ops_a), len(ops_b))
    if L > 0:
        match = sum(1 for i in range(min(len(ops_a), len(ops_b)))
                    if ops_a[i] == ops_b[i])
        d += (1 - match / L)
    return d

def energy_aware_select(client_hist: Dict[int, dict], m: int, num_clients: int) -> List[int]:
    all_ids = list(range(num_clients))
    selected, selected_sigs = [], []
    known = [i for i in all_ids if i in client_hist]
    unknown = [i for i in all_ids if i not in client_hist]

    while len(selected) < m:
        best_id, best_score = None, -1e9
        if unknown and random.random() < SEL_CFG["epsilon_explore"]:
            pick = random.choice(unknown)
            selected.append(pick); selected_sigs.append(None)
            unknown.remove(pick)
            continue

        for cid in known:
            if cid in selected:
                continue
            h = client_hist[cid]
            acc = max(0.0, h.get("local_acc", 0.0))
            energy = max(1e-6, h.get("energy_Wh", 0.0))
            base = acc / energy
            sig = arch_signature(h.get("arch")) if h.get("arch") else None
            if selected_sigs:
                div = np.mean([arch_distance(sig, s) for s in selected_sigs])
            else:
                div = 0.0
            score = base + SEL_CFG["lambda_div"] * div
            if score > best_score:
                best_score, best_id = score, cid

        if best_id is not None:
            selected.append(best_id)
            s_sig = arch_signature(client_hist[best_id].get("arch")) if client_hist[best_id].get("arch") else None
            selected_sigs.append(s_sig)
            known.remove(best_id)
        elif unknown:
            pick = random.choice(unknown)
            selected.append(pick); selected_sigs.append(None)
            unknown.remove(pick)
        else:
            break

    pool = [i for i in all_ids if i not in selected]
    if len(selected) < m and pool:
        selected += random.sample(pool, m - len(selected))
    return sorted(selected)

# =========================
# Improved Experiment Loop
# =========================
def run_one_experiment(num_clients: int):
    FL_CFG["num_clients"] = num_clients

    Xtr, ytr, Xte, yte = load_keras_data("cifar10")
    test_loader = make_loader(Xte, yte, batch_size=128, shuffle=False)

    idxs = dirichlet_partition(ytr, num_clients, FL_CFG["dirichlet_alpha"])
    client_loaders = [make_loader(Xtr[i], ytr[i], batch_size=FL_CFG["batch_size"], shuffle=True) for i in idxs]

    global_arch = random_arch()
    global_model = SubNet(global_arch, DATA_CFG["num_classes"]).to(DEVICE)

    client_hist = {}
    client_rows, global_rows = [], []

    print(f"\n=== Running experiment with {num_clients} clients ===")

    for rnd in range(1, FL_CFG["rounds"] + 1):
        print(f"\n--- Round {rnd} ---")
        m = max(1, int(FL_CFG["client_participation"] * num_clients))
        if rnd == 1:
            selected = sorted(random.sample(range(num_clients), m))
        else:
            selected = energy_aware_select(client_hist, m, num_clients)

        client_models = []
        for cid in selected:
            dl = client_loaders[cid]
            trk = start_tracker(f"client_{cid}_round_{rnd}")

            # KNAS candidates
            candidates = [random_arch() for _ in range(FL_CFG["candidates_per_client"])]
            scored = [(mgm_proxy(SubNet(a, DATA_CFG["num_classes"]).to(DEVICE), dl, num_batches=2), a) for a in candidates]
            scored.sort(key=lambda x: x[0], reverse=True)
            top_arches = [a for _, a in scored[:FL_CFG["top_k"]]]

            best_acc, best_model, best_arch = -1.0, None, None
            for a in top_arches:
                model = SubNet(a, DATA_CFG["num_classes"]).to(DEVICE)
                micro_epochs = random.randint(FL_CFG["micro_epochs_min"], FL_CFG["micro_epochs_max"])
                train_epochs(model, dl, epochs=micro_epochs, lr=FL_CFG["lr"])
                acc, _ = evaluate(model, dl)
                if acc > best_acc:
                    best_acc, best_model, best_arch = acc, model, a

            stop_tracker(trk)
            energy_Wh = last_run_energy_Wh()

            client_hist[cid] = {
                "local_acc": float(best_acc),
                "energy_Wh": float(energy_Wh),
                "arch": best_arch
            }
            client_models.append(best_model)
            print(f"Client {cid}: acc={best_acc:.3f}, energy={energy_Wh:.3f}Wh")

        # Aggregation + KD
        mutants = [SubNet(mutate_arch(random.choice([client_hist[cid]["arch"] for cid in selected])), DATA_CFG["num_classes"]).to(DEVICE)
                   for _ in range(FL_CFG["max_mutants"])]
        aligned_average(global_model, client_models + mutants)
        project_align_and_avg(global_model, client_models + mutants, loader=test_loader, epochs=1, lr=FL_CFG["lr"])

        g_acc, g_loss = evaluate(global_model, test_loader)
        print(f"Global model test acc: {g_acc:.3f}")

        global_rows.append({"round": rnd, "global_acc": g_acc, "global_loss": g_loss})
        for cid in selected:
            client_rows.append({"round": rnd, "client": cid,
                                "local_acc": client_hist[cid]["local_acc"],
                                "energy_Wh": client_hist[cid]["energy_Wh"]})

    print("\n=== Experiment done ===")
    return client_rows, global_rows, client_hist

def main():
    client_settings = [3, 5]
    for nc in client_settings:
        run_one_experiment(nc)

if __name__=="__main__":
    main()
