In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os, shutil, csv
from tqdm import tqdm, trange

# ------------------------------
# 1. Global setup
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.split_server = nn.Linear(64, 64)
        self.client = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.split_server(x))
        return self.client(x)

class Encoder(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.encoder = nn.Linear(28*28, latent_dim)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.encoder(x)

encoder = Encoder(latent_dim=64).to(device)
encoder.load_state_dict(torch.load("./encoder.pth"))
encoder.to(device)
encoder.eval()
# ------------------------------
# 2. Client local train + eval
# ------------------------------
def client_update(client_id, global_model, epochs=1, lr=0.001):
    data_path = f"../simulate/volumes/client{client_id}"

    # load data
    train_data = torch.load(os.path.join(data_path, "train.pt"))
    test_data  = torch.load(os.path.join(data_path, "test.pt"))

    train_x = torch.stack([d[0] for d in train_data]).float() / 255.0
    train_y = torch.tensor([d[1] for d in train_data], dtype=torch.long)
    test_x  = torch.stack([d[0] for d in test_data]).float() / 255.0
    test_y  = torch.tensor([d[1] for d in test_data], dtype=torch.long)

    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(train_x, train_y),
        batch_size=64, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(test_x, test_y),
        batch_size=1000, shuffle=False
    )

    model = SimpleNet().to(device)
    model.load_state_dict(global_model.state_dict())  # copy global
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # --- train ---
    model.train()
    running_loss = 0
    total = 0
    for _ in range(epochs):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                z = encoder(x)
            loss = criterion(model(z), y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * x.size(0)
            total += x.size(0)
    avg_loss = running_loss / total

    # --- eval ---
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                z = encoder(x)
            pred = model(z).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    acc = correct / total if total > 0 else 0.0

    # save model tmp
    os.makedirs("model_tmp", exist_ok=True)
    torch.save(model.state_dict(), f"model_tmp/client{client_id}.pt")

    return avg_loss, acc

# ------------------------------
# 3. FedAvg
# ------------------------------
def average_models(global_model, num_clients):
    state_dicts = []
    for cid in range(1, num_clients+1):
        state = torch.load(f"model_tmp/client{cid}.pt", map_location=device)
        state_dicts.append(state)

    new_state = {k: torch.zeros_like(v, dtype=torch.float32) for k,v in state_dicts[0].items()}
    for state in state_dicts:
        for k in new_state:
            new_state[k] += state[k]
    for k in new_state:
        new_state[k] /= len(state_dicts)

    global_model.load_state_dict(new_state)

    shutil.rmtree("model_tmp")
    return global_model

# ------------------------------
# 4. Federated Training Loop
# ------------------------------
def federated_train(num_clients=5, num_rounds=10, local_epochs=1, lr=0.01):
    global_model = SimpleNet().to(device)
    results = []

    for r in trange(1, num_rounds+1, desc="Federated Rounds"):
        round_losses, round_accs = [], []

        for cid in tqdm(range(1, num_clients+1), desc=f"Round {r} Clients", leave=False):
            avg_loss, acc = client_update(cid, global_model, epochs=local_epochs, lr=lr)
            round_losses.append(avg_loss)
            round_accs.append(acc)
            results.append([r, cid, avg_loss, acc])

        # aggregate
        global_model = average_models(global_model, num_clients)

        print(f"[Round {r}] Avg Loss: {sum(round_losses)/num_clients:.4f} | "
              f"Avg Acc: {100*sum(round_accs)/num_clients:.2f}%")

    # save all results
    os.makedirs("../results", exist_ok=True)
    with open("../results/result.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["round", "client_id", "loss", "accuracy"])
        writer.writerows(results)

    return global_model

# ------------------------------
# Run training
# ------------------------------
model = federated_train(num_clients=20, num_rounds=100, local_epochs=1, lr=0.001)
