In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm, trange
import os
import csv

# ------------------------------
# 1. Setup
# ------------------------------
transform = transforms.Compose([transforms.ToTensor()])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.split_server = nn.Linear(64, 64)
        self.client = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.split_server(x))
        return self.client(x)

class Encoder(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.encoder = nn.Linear(28*28, latent_dim)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.encoder(x)

encoder = Encoder(latent_dim=64).to(device)
encoder.load_state_dict(torch.load("./encoder.pth"))
encoder.to(device)
encoder.eval()
# ------------------------------
# 2. Train + Eval function
# ------------------------------
def train_and_eval(client_id, num_epochs=100, lr=0.001):
    data_path = f"../simulate/volumes2/client{client_id}"

    # load train
    train_data = torch.load(os.path.join(data_path, "train.pt"))
    train_x = torch.stack([d[0] for d in train_data]).float() / 255.0
    train_y = torch.tensor([d[1] for d in train_data], dtype=torch.long)

    # load test
    test_data = torch.load(os.path.join(data_path, "test.pt"))
    test_x = torch.stack([d[0] for d in test_data]).float() / 255.0
    test_y = torch.tensor([d[1] for d in test_data], dtype=torch.long)

    train_dataset = torch.utils.data.TensorDataset(train_x, train_y)
    test_dataset  = torch.utils.data.TensorDataset(test_x, test_y)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

    model = SimpleNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    results = []
    for epoch in trange(num_epochs, desc=f"Client {client_id} Training"):
        model.train()
        running_loss = 0

        for data, target in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False):
            data, target = data.to(device), target.to(device)
            with torch.no_grad():
                z = encoder(data)

            output = model(z)
            loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * data.size(0)

        avg_loss = running_loss / len(train_loader.dataset)

        # --- eval ---
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                with torch.no_grad():
                    z = encoder(data)

                output = model(z)
                pred = output.argmax(dim=1)
                correct += (pred == target).sum().item()
                total += target.size(0)

        accuracy = correct / total
        tqdm.write(f"[Client {client_id}] Epoch {epoch+1} | Loss {avg_loss:.4f} | Acc {accuracy*100:.2f}%")

        results.append([epoch+1, avg_loss, accuracy])

    # save csv
    os.makedirs(f"../results/locals_imbalance", exist_ok=True)
    with open(f"../results/locals_imbalance/client{client_id}.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "loss", "accuracy"])
        writer.writerows(results)

# ------------------------------
# 3. Run all clients
# ------------------------------
for cid in range(1, 21):
    train_and_eval(cid, num_epochs=100, lr=0.001)