In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
differrent_num_clients = [10]
num_clients_balanced_our = {}
num_clients_imbalanced_our = {}
num_clients_balanced_their = {}
num_clients_imbalanced_their = {}

In [3]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 28 * 28, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(-1, 16 * 28 * 28)
        return self.fc(x)

global_model = SimpleCNN().to(device)

In [4]:
def get_client_loaders(num_clients=3):
    transform = transforms.ToTensor()
    mnist_train = datasets.MNIST(root='./data_mnist', train=True, download=True, transform=transform)
    client_data = []
    samples_per_client = len(mnist_train) // num_clients
    for i in range(num_clients):
        indices = range(i * samples_per_client, (i+1) * samples_per_client)
        client_data.append(Subset(mnist_train, indices))
    
    client_loaders = [DataLoader(data, batch_size=32, shuffle=True, drop_last=True) for data in client_data]
    return client_loaders

def get_imbalanced(num_clients=12):
    transform = transforms.ToTensor()
    mnist_train = datasets.MNIST(root='./data_mnist', train=True, download=True, transform=transform)
    samples_per_client = len(mnist_train) // num_clients
    common_data, validation_data = Subset(mnist_train, range(0 * samples_per_client, (0+1) * samples_per_client)), Subset(mnist_train, range(1 * samples_per_client, (1+1) * samples_per_client))
    clients_dataset = Subset(mnist_train, range(2 * samples_per_client, (num_clients+1) * samples_per_client))

    labels = torch.tensor([label for _, label in clients_dataset])
    sorted_indices = torch.argsort(labels)
    sorted_dataset = Subset(clients_dataset, sorted_indices)
    client_data = []
    
    for i in range(num_clients - 2):
        indices = range(i * samples_per_client, (i+1) * samples_per_client)
        client_data.append(Subset(sorted_dataset, indices))

    client_data.append(common_data)
    client_data.append(validation_data)
    client_loaders = [DataLoader(data, batch_size=32, shuffle=True, drop_last=True) for data in client_data]
    return client_loaders



In [5]:
def client_train(model, loader, epochs=1):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    model.train()
    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

In [6]:
import itertools
flatten_list = lambda obj: list(itertools.chain.from_iterable(obj))

def server_update(global_model, client_logits, client_inputs, lr=0.01):
    client_logits = flatten_list(client_logits)
    client_inputs = flatten_list(client_inputs)
    optimizer = optim.SGD(global_model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()
    global_model.train()
    for x, y in zip(client_inputs, client_logits):
        x = x.to(device)
        optimizer.zero_grad()
        server_logits = global_model(x)
        loss = mse_loss(server_logits, y)
        loss.backward()
        optimizer.step()

In [7]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            _, predicted = torch.max(logits, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

In [8]:
differrent_num_clients = [10]
num_clients_balanced_our = {}
num_clients_imbalanced_our = {}
num_clients_balanced_their = {}
num_clients_imbalanced_their = {}

In [9]:
for num_clients in differrent_num_clients:
    client_loaders = get_imbalanced(num_clients + 2)
    common_loader = client_loaders[-2]
    validation_loder = client_loaders[-1]
    client_models = [SimpleCNN().to(device) for _ in range(num_clients)]


## Logit averaging

In [10]:
client_models = [SimpleCNN().to(device) for _ in range(num_clients)]
node_accuracy = [evaluate(client_models[i], validation_loder) * 100 for i in range(num_clients)]
sum(node_accuracy) / len(node_accuracy)

9.963942307692307

In [11]:
client_loaders = get_client_loaders(num_clients + 2)
common_loader = client_loaders[-2]
validation_loder = client_loaders[-1]
client_models = [SimpleCNN().to(device) for _ in range(num_clients)]

In [None]:
def comp_rand_k(logits, k = 3):
    seed = random.randint(0, 1000)
    torch.manual_seed(seed)
    size = logits.size()
    new_size = size[:-1] + (k,)
    sparse_logits = torch.zeros(new_size, dtype= logits.dtype)
    for i in range(size[0]):
        for j in range(size[1]):
            indices = torch.randperm(size[-1], device=logits.device)[:k]
            indices, _ = torch.sort(indices)
            sparse_logits[i, j] = logits[i, j, indices]
    sparse_logits *= size[-1] / k
    return sparse_logits, seed, size

def decomp_rand_k(sparse_logits, seed, size, k = 3):
    torch.manual_seed(seed)
    decomp_logits  = torch.full(size, float('-inf') ,dtype= sparse_logits.dtype)
    for i in range(size[0]):
        for j in range(size[1]):
            indices = torch.randperm(size[-1], device=sparse_logits.device)[:k]
            indices, _ = torch.sort(indices)
            decomp_logits[i, j, indices] = sparse_logits[i, j]
    return decomp_logits

In [None]:
def comp_top_k(logits, k = 3):
    size = logits.size()
    _, indices = torch.topk(logits, k, dim= -1)
    sparse_logits = torch.gather(logits, dim=-1, index=indices).to(logits.device)
    return sparse_logits, indices, size


def decomp_top_k(sparse_logits, indices, size):
    decomp_logits  = torch.full(size, float('-inf'), dtype= sparse_logits.dtype, device= sparse_logits.device)
    decomp_logits.scatter_(-1, indices, sparse_logits)
    return decomp_logits

In [None]:
def dithering(logits, b = 2, s = 8, p = 2):
    norm_logits = torch.norm(logits.float(), p, dim=-1, keepdim=True)
    normalized_logits = logits.abs() / norm_logits
    levels = b ** -torch.arange(1, s + 1, dtype=logits.dtype, device=logits.device)  
    levels = torch.cat([torch.tensor([1.0], device=logits.device, dtype=logits.dtype), levels, torch.tensor([0.0], device=logits.device, dtype=logits.dtype)])
    u = torch.sum(normalized_logits.unsqueeze(-1) <= levels.unsqueeze(0), dim=-1, dtype= float) 
    lower = b ** -u
    upper = b ** -(u - 1)
    lower = torch.where(u == 10, 0.0, lower)
    probs_upper = (normalized_logits - lower) / (upper - lower)
    bern = torch.bernoulli(probs_upper)
    quantized = torch.where(bern == 1, upper, lower)
    result = norm_logits * torch.sign(logits) * quantized
    return result

In [12]:
class ClientState:
    def __init__(self, model, lmbd = 0.1, gamma = 0.1):
        self.weights = []
        self.state = [torch.zeros_like(p) for p in model.parameters()]
        self.prev_state = [torch.zeros_like(p) for p in model.parameters()]
        self.lmbd = lmbd
        self.gamma = gamma
    
    def local_step(self, model):
        for idx, param in enumerate(model.parameters()):
            self.state[idx] += param.grad - self.prev_state[idx]
            self.prev_state[idx] = param.grad
            param.grad = torch.zeros_like(param.grad)
    
    def global_step(self, model):
        for idx, param in enumerate(model.parameters()):
            self.state[idx] = param.grad
            self.prev_state[idx] = param.grad
            param.grad = torch.zeros_like(param.grad)

    def set_weights(self, model):
        self.weights = [p for p in model.parameters()]

    def get_reg_term(self, model):
        for idx, param in enumerate(model.parameters()):
            self.state[idx] = param.grad
            self.prev_state[idx] = param.grad
            param.grad = torch.zeros_like(param.grad)

        l2_regularization = torch.tensor(0., requires_grad=True)
        for idx, (param, x, g) in enumerate(zip(model.parameters(), self.weights, self.state)):
            l2_regularization = l2_regularization + torch.norm(param - (x - self.gamma * self.lmbd * g), p=2)
        return l2_regularization


In [13]:
from torch.utils.tensorboard import SummaryWriter
import torch.utils.tensorboard as tensorboard

writer = SummaryWriter()

2025-08-06 20:19:05.919815: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-06 20:19:05.938882: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754511545.957599  516159 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754511545.963474  516159 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754511545.978085  516159 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [14]:
import random

clients_states = [ClientState(model) for model in client_models]
client_optimizers = [
    optim.Adam(model.parameters())
    for model in client_models
]
criterion = nn.CrossEntropyLoss()

def sync_clients(batch):
    inputs, target = batch
    inputs = inputs.to(device)
    target = target.to(device)
    logits = []
    for model in client_models:
        logits.append(model(inputs))
    # print(logits[0].shape)
    logits = torch.stack(logits)
    sparse_logits = comp_rand_k(logits)
    sparse_logits = decomp_top_k(*sparse_logits)
    logits = sparse_logits.mean(0)
    loss = criterion(logits, target)
    loss.backward()

def update_state():
    for model_idx, model in enumerate(client_models):
        for idx, param in enumerate(model.parameters):
            clients_states[model_idx][idx] += param.grad

def compute_accuracy():
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in validation_loder:
            x, y = x.to(device), y.to(device)
            logits = torch.stack([model(x) for model in client_models]).mean(0)
            _, predicted = torch.max(logits, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

def run_step(p, iter_idx):
    if random.random() > p:
        # TODO: 
        batch = next(iter(common_loader))
        sync_clients(batch)
        for model, state in zip(client_models, clients_states):
            state.global_step(model)
    else:
        for model, state, loader in zip(client_models, clients_states, client_loaders):
            inputs, target = next(iter(loader))
            inputs = inputs.to(device)
            target = target.to(device)
            # print(inputs)
            logits = model(inputs)
            loss = criterion(logits, target)
            loss.backward()
            state.local_step(model)
    
    for model, state, loader, optimizer in zip(client_models, clients_states, client_loaders, client_optimizers):
        state.set_weights(model)
        for _ in range(3):
            inputs, target = next(iter(loader))
            inputs = inputs.to(device)
            target = target.to(device)
            # for inputs, target in loader:
            logits = model(inputs)
            loss = criterion(logits, target)
            reg = state.get_reg_term(model)
            loss += reg
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    global_acc = compute_accuracy()
    local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
    writer.add_scalars(
        main_tag='Accuracy',
        tag_scalar_dict={'local': local_accuracy.mean(), 'ensembled': global_acc},
        global_step=iter_idx
    )
    print(f"Accuracy = {global_acc} Local accuracy = {local_accuracy.mean()} +- {local_accuracy.std()}")

In [15]:
for i in range(100):
    run_step(0.2, i)

Accuracy = 0.6137820512820513 Local accuracy = 0.2643830180168152 +- 0.08397138118743896
Accuracy = 0.5308493589743589 Local accuracy = 0.3551081717014313 +- 0.12408051639795303
Accuracy = 0.7684294871794872 Local accuracy = 0.4624399244785309 +- 0.13731791079044342
Accuracy = 0.7419871794871795 Local accuracy = 0.5452324151992798 +- 0.07956598699092865
Accuracy = 0.7672275641025641 Local accuracy = 0.6132612228393555 +- 0.08989808708429337
Accuracy = 0.84375 Local accuracy = 0.6686698198318481 +- 0.07870286703109741
Accuracy = 0.8439503205128205 Local accuracy = 0.7227363586425781 +- 0.0580935962498188
Accuracy = 0.8537660256410257 Local accuracy = 0.7209335565567017 +- 0.08218414336442947
Accuracy = 0.8699919871794872 Local accuracy = 0.7552484273910522 +- 0.043463435024023056
Accuracy = 0.8635817307692307 Local accuracy = 0.7574719190597534 +- 0.04214155673980713
Accuracy = 0.8840144230769231 Local accuracy = 0.7871193885803223 +- 0.04737617075443268
Accuracy = 0.8860176282051282 Lo

In [16]:
writer.close()