In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(device)

cuda:2


In [2]:
differrent_num_clients = [10]
num_clients_balanced_our = {}
num_clients_imbalanced_our = {}
num_clients_balanced_their = {}
num_clients_imbalanced_their = {}

In [3]:
from parallel.model import ResNet18

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 28 * 28, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(-1, 16 * 28 * 28)
        return self.fc(x)

global_model = ResNet18(10).to(device)# SimpleCNN().to(device)

In [4]:
from collections import defaultdict

def get_client_loaders(num_clients=3, train=True):
    transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            ]
    )
    mnist_train = datasets.CIFAR10(root='./data_cifar10', train=train, download=True, transform=transform)
    client_data = []
    samples_per_client = len(mnist_train) // num_clients
    for i in range(num_clients):
        indices = range(i * samples_per_client, (i+1) * samples_per_client)
        client_data.append(Subset(mnist_train, indices))
    
    client_loaders = [DataLoader(data, batch_size=32, shuffle=True, drop_last=True) for data in client_data]
    return client_loaders

def get_imbalanced(num_clients=12):
    transform = transforms.ToTensor()
    mnist_train = datasets.CIFAR10(root='./data_cifar10', train=True, download=True, transform=transform)
    samples_per_client = len(mnist_train) // num_clients
    common_data, validation_data = Subset(mnist_train, range(0 * samples_per_client, (0+1) * samples_per_client)), Subset(mnist_train, range(1 * samples_per_client, (1+1) * samples_per_client))
    clients_dataset = Subset(mnist_train, range(2 * samples_per_client, (num_clients+1) * samples_per_client))

    labels = torch.tensor([label for _, label in clients_dataset])
    sorted_indices = torch.argsort(labels)
    sorted_dataset = Subset(clients_dataset, sorted_indices)
    client_data = []
    
    for i in range(num_clients - 2):
        indices = range(i * samples_per_client, (i+1) * samples_per_client)
        client_data.append(Subset(sorted_dataset, indices))

    client_data.append(common_data)
    client_data.append(validation_data)
    client_loaders = [DataLoader(data, batch_size=32, shuffle=True, drop_last=True) for data in client_data]
    return client_loaders

def create_probabilistic_client_loaders(
    num_clients=20,
    p=0.7,  # Probability a client gets a class
    batch_size=32,
):
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    dataset = datasets.CIFAR10(root='./data_cifar10', train=True, download=True, transform=transform)
    
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    for c in class_indices:
        np.random.shuffle(class_indices[c])
    
    client_indices = [[] for _ in range(num_clients)]
    
    for class_idx in range(10):
        # Determine which clients get this class
        client_mask = np.random.rand(num_clients) < p
        selected_clients = np.where(client_mask)[0]
        
        if len(selected_clients) > 0:
            # Split class indices among selected clients
            splits = np.array_split(class_indices[class_idx], len(selected_clients))
            for client_id, split in zip(selected_clients, splits):
                client_indices[client_id].extend(split)
    
    # Create DataLoaders
    client_loaders = []
    for indices in client_indices:
        if not indices:
            indices = [0]
        loader = DataLoader(
            Subset(dataset, indices),
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )
        client_loaders.append(loader)
    
    return client_loaders

In [5]:
from tqdm import tqdm, trange

def client_train(model, loader, epochs=1):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    model.train()
    for _ in trange(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

In [6]:
import itertools
flatten_list = lambda obj: list(itertools.chain.from_iterable(obj))

def server_update(global_model, client_logits, client_inputs, lr=0.01):
    client_logits = flatten_list(client_logits)
    client_inputs = flatten_list(client_inputs)
    optimizer = optim.SGD(global_model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()
    global_model.train()
    for x, y in zip(client_inputs, client_logits):
        x = x.to(device)
        optimizer.zero_grad()
        server_logits = global_model(x)
        loss = mse_loss(server_logits, y)
        loss.backward()
        optimizer.step()

In [7]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            _, predicted = torch.max(logits, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

In [8]:
differrent_num_clients = [10]
num_clients_balanced_our = {}
num_clients_imbalanced_our = {}
num_clients_balanced_their = {}
num_clients_imbalanced_their = {}

In [9]:
# for num_clients in differrent_num_clients:
#     client_loaders = get_imbalanced(num_clients + 2)
#     common_loader = client_loaders[-2]
#     validation_loder = client_loaders[-1]
#     client_models = [SimpleCNN().to(device) for _ in range(num_clients)]


## Logit averaging

In [10]:
# client_models = [ResNet18(num_clients).to(device) for _ in range(num_clients)]
# node_accuracy = [evaluate(client_models[i], validation_loder) * 100 for i in range(num_clients)]
# sum(node_accuracy) / len(node_accuracy)

In [11]:
# client_train(global_model, client_loaders[0], 10)

In [12]:
# evaluate(global_model, validation_loder)

In [13]:
device

device(type='cuda', index=2)

In [14]:
num_clients = 10
client_loaders = create_probabilistic_client_loaders(num_clients, p=0.85)
common_loader, validation_loder =  get_client_loaders(2, train=False) #client_loaders[-2]
# validation_loder = client_loaders[-1]
client_models = [ResNet18(num_clients).to(device) for _ in range(num_clients)]

In [15]:
class ClientState:
    def __init__(self, model, lmbd = 0.1, gamma = 0.1):
        self.weights = []
        self.state = [torch.zeros_like(p) for p in model.parameters()]
        self.prev_state = [torch.zeros_like(p) for p in model.parameters()]
        self.lmbd = lmbd
        self.gamma = gamma
    
    def local_step(self, model):
        for idx, param in enumerate(model.parameters()):
            self.state[idx] += param.grad - self.prev_state[idx]
            self.prev_state[idx] = param.grad
            param.grad = torch.zeros_like(param.grad)
    
    def global_step(self, model):
        for idx, param in enumerate(model.parameters()):
            self.state[idx] = param.grad
            self.prev_state[idx] = param.grad
            param.grad = torch.zeros_like(param.grad)

    def set_weights(self, model):
        self.weights = [p for p in model.parameters()]

    def get_reg_term(self, model):
        for idx, param in enumerate(model.parameters()):
            self.state[idx] = param.grad
            self.prev_state[idx] = param.grad
            param.grad = torch.zeros_like(param.grad)

        l2_regularization = torch.tensor(0., requires_grad=True)
        for idx, (param, x, g) in enumerate(zip(model.parameters(), self.weights, self.state)):
            l2_regularization = l2_regularization + torch.norm(param - (x - self.gamma * self.lmbd * g), p=2)
        return l2_regularization


In [16]:
from torch.utils.tensorboard import SummaryWriter
import torch.utils.tensorboard as tensorboard

writer = SummaryWriter()

2025-08-08 06:48:35.627159: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-08 06:48:35.640220: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754635715.654593 2987764 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754635715.658661 2987764 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754635715.669695 2987764 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
import random

gamma = 1e-3
lmbd = 0.1

clients_states = [ClientState(model, lmbd, gamma) for model in client_models]
client_optimizers = [
    optim.Adam(model.parameters(), lr=1e-5)
    for model in client_models
]


# for state in clients_states:
#     state.gamma = gamma
#     state.lmbd = lmbd


def sync_clients(batch):
    inputs, target = batch
    inputs = inputs.to(device)
    target = target.to(device)
    logits = []
    for model in client_models:
        logits.append(model(inputs))

    # with torch.no_grad():
    #     check_params_grad = [torch.tensor(param.grad) for param in client_models[0].parameters()]
        # print(f"Before {check_params_grad=}")
    # print(logits[0].shape)
    logits = torch.stack(logits)
    # print(logits.shape)
    logits = logits.mean(0)
    loss = criterion(logits, target)
    loss.backward()
    # with torch.no_grad():
        # check_params_grad = [(torch.tensor(param.grad) - check).mean() for param, check in zip(client_models[0].parameters(), check_params_grad)]
    # for param in client_models[0].parameters():
    #     assert param.grad is not None

    #     check_params_grad = [(torch.tensor(param.grad) - check).mean() for param, check in zip(client_models[0].parameters(), check_params_grad)]
    #     check_params_grad = torch.tensor(check_params_grad) #.mean()
    #     print(f"Diff {check_params_grad=}")


def compute_accuracy():
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in validation_loder:
            x, y = x.to(device), y.to(device)
            logits = torch.stack([model(x) for model in client_models]).mean(0)
            _, predicted = torch.max(logits, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

def run_step(p, iter_idx):
    # if random.random() > p:
        # TODO: 
    batch = next(iter(common_loader))
    sync_clients(batch)
    for model, state in zip(client_models, clients_states):
        state.global_step(model)
    # else:
    #     for model, state, loader in zip(client_models, clients_states, client_loaders):
    #         inputs, target = next(iter(loader))
    #         inputs = inputs.to(device)
    #         target = target.to(device)
    #         # print(inputs)
    #         logits = model(inputs)
    #         loss = criterion(logits, target)
    #         loss.backward()
    #         state.local_step(model)
    
    for model, state, loader, optimizer in zip(client_models, clients_states, client_loaders, client_optimizers):
        state.set_weights(model)
        for _ in range(10):
            inputs, target = next(iter(loader))
            inputs = inputs.to(device)
            target = target.to(device)
            # for inputs, target in loader:
            logits = model(inputs)
            loss = criterion(logits, target)
            reg = state.get_reg_term(model)
            loss = loss * gamma + reg
            # loss = gamma * loss + reg
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    # global_acc = compute_accuracy()
    # # local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
    # local_accuracy = torch.tensor([evaluate(model, validation_loder) for model, loader in zip(client_models, client_loaders)])
    # writer.add_scalars(
    #     main_tag='Accuracy',
    #     tag_scalar_dict={'local': local_accuracy.mean(), 'ensembled': global_acc},
    #     global_step=iter_idx
    # )
    # print(f"Accuracy = {global_acc} Local accuracy = {local_accuracy.mean()} +- {local_accuracy.std()}")

In [22]:
for i in range(1000):
    run_step(0.8, i)
    if i % 10 == 0:
        global_acc = compute_accuracy()
        # local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
        local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
        single_accuracy = torch.tensor([evaluate(model, validation_loder) for model, loader in zip(client_models, client_loaders)])
        # writer.add_scalars(
        #     main_tag='Accuracy',
        #     tag_scalar_dict={'local': local_accuracy.mean(), 'ensembled': global_acc},
        #     global_step=i // 10,
        # )
        print(f"Accuracy = {global_acc:.3f} Single accuracy = {single_accuracy.mean():.3f} +- {single_accuracy.std():.3f} Local accuracy = {local_accuracy.mean():.3f} +- {local_accuracy.std():.3f}")

Accuracy = 0.156 Single accuracy = 0.128 +- 0.017 Local accuracy = 0.598 +- 0.067


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.163 Single accuracy = 0.143 +- 0.013 Local accuracy = 0.642 +- 0.053
Accuracy = 0.131 Single accuracy = 0.134 +- 0.022 Local accuracy = 0.664 +- 0.066
Accuracy = 0.144 Single accuracy = 0.133 +- 0.022 Local accuracy = 0.684 +- 0.092


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.166 Single accuracy = 0.133 +- 0.010 Local accuracy = 0.709 +- 0.099
Accuracy = 0.163 Single accuracy = 0.127 +- 0.010 Local accuracy = 0.743 +- 0.079
Accuracy = 0.141 Single accuracy = 0.131 +- 0.014 Local accuracy = 0.769 +- 0.075


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.155 Single accuracy = 0.134 +- 0.009 Local accuracy = 0.791 +- 0.081
Accuracy = 0.153 Single accuracy = 0.130 +- 0.019 Local accuracy = 0.812 +- 0.078
Accuracy = 0.166 Single accuracy = 0.134 +- 0.016 Local accuracy = 0.840 +- 0.075


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.148 Single accuracy = 0.140 +- 0.015 Local accuracy = 0.871 +- 0.054
Accuracy = 0.162 Single accuracy = 0.128 +- 0.008 Local accuracy = 0.865 +- 0.067
Accuracy = 0.137 Single accuracy = 0.135 +- 0.011 Local accuracy = 0.895 +- 0.051
Accuracy = 0.158 Single accuracy = 0.139 +- 0.023 Local accuracy = 0.922 +- 0.044
Accuracy = 0.164 Single accuracy = 0.133 +- 0.011 Local accuracy = 0.932 +- 0.033
Accuracy = 0.155 Single accuracy = 0.136 +- 0.010 Local accuracy = 0.930 +- 0.040


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.164 Single accuracy = 0.142 +- 0.012 Local accuracy = 0.941 +- 0.032
Accuracy = 0.146 Single accuracy = 0.136 +- 0.013 Local accuracy = 0.962 +- 0.027


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.158 Single accuracy = 0.135 +- 0.012 Local accuracy = 0.965 +- 0.027
Accuracy = 0.164 Single accuracy = 0.136 +- 0.016 Local accuracy = 0.969 +- 0.032
Accuracy = 0.150 Single accuracy = 0.132 +- 0.010 Local accuracy = 0.956 +- 0.048


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.147 Single accuracy = 0.141 +- 0.015 Local accuracy = 0.988 +- 0.008
Accuracy = 0.162 Single accuracy = 0.144 +- 0.014 Local accuracy = 0.973 +- 0.030
Accuracy = 0.150 Single accuracy = 0.134 +- 0.010 Local accuracy = 0.986 +- 0.022
Accuracy = 0.149 Single accuracy = 0.141 +- 0.018 Local accuracy = 0.988 +- 0.013


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.146 Single accuracy = 0.137 +- 0.014 Local accuracy = 0.994 +- 0.007
Accuracy = 0.158 Single accuracy = 0.138 +- 0.014 Local accuracy = 0.996 +- 0.005
Accuracy = 0.160 Single accuracy = 0.140 +- 0.013 Local accuracy = 0.997 +- 0.004


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.157 Single accuracy = 0.134 +- 0.011 Local accuracy = 0.994 +- 0.011
Accuracy = 0.147 Single accuracy = 0.135 +- 0.016 Local accuracy = 0.997 +- 0.004
Accuracy = 0.158 Single accuracy = 0.136 +- 0.014 Local accuracy = 0.974 +- 0.060


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.150 Single accuracy = 0.133 +- 0.013 Local accuracy = 0.999 +- 0.002


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.147 Single accuracy = 0.135 +- 0.016 Local accuracy = 0.999 +- 0.002


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.146 Single accuracy = 0.137 +- 0.014 Local accuracy = 1.000 +- 0.001
Accuracy = 0.146 Single accuracy = 0.135 +- 0.013 Local accuracy = 0.996 +- 0.014


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.139 Single accuracy = 0.134 +- 0.015 Local accuracy = 0.995 +- 0.017


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.147 Single accuracy = 0.136 +- 0.014 Local accuracy = 1.000 +- 0.001


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.142 Single accuracy = 0.136 +- 0.014 Local accuracy = 1.000 +- 0.000


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.149 Single accuracy = 0.137 +- 0.017 Local accuracy = 0.999 +- 0.001


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.145 Single accuracy = 0.136 +- 0.017 Local accuracy = 0.999 +- 0.003


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.148 Single accuracy = 0.136 +- 0.014 Local accuracy = 1.000 +- 0.001


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.152 Single accuracy = 0.133 +- 0.012 Local accuracy = 0.996 +- 0.012
Accuracy = 0.143 Single accuracy = 0.136 +- 0.014 Local accuracy = 1.000 +- 0.001


KeyboardInterrupt: 

In [None]:
gamma = 1e-3
lmbd = 0.1
for state in clients_states:
    state.gamma = gamma
    state.lmbd = lmbd

In [None]:
for i in range(1000):
    run_step(0.5, i)
    if i % 10 == 0:
        global_acc = compute_accuracy()
        # local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
        local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
        single_accuracy = torch.tensor([evaluate(model, validation_loder) for model, loader in zip(client_models, client_loaders)])
        # writer.add_scalars(
        #     main_tag='Accuracy',
        #     tag_scalar_dict={'local': local_accuracy.mean(), 'ensembled': global_acc},
        #     global_step=i // 10,
        # )
        print(f"Accuracy = {global_acc:.3f} Single accuracy = {single_accuracy.mean():.3f} +- {single_accuracy.std():.3f} Local accuracy = {local_accuracy.mean():.3f} +- {local_accuracy.std():.3f}")

Accuracy = 0.118 Single accuracy = 0.112 +- 0.009 Local accuracy = 0.402 +- 0.044


Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empt

Accuracy = 0.109 Single accuracy = 0.115 +- 0.013 Local accuracy = 0.445 +- 0.041
Accuracy = 0.104 Single accuracy = 0.116 +- 0.016 Local accuracy = 0.483 +- 0.054
Accuracy = 0.116 Single accuracy = 0.127 +- 0.023 Local accuracy = 0.536 +- 0.056
Accuracy = 0.120 Single accuracy = 0.125 +- 0.014 Local accuracy = 0.567 +- 0.069
Accuracy = 0.161 Single accuracy = 0.134 +- 0.017 Local accuracy = 0.574 +- 0.063


Exception in thread Thread-7978 (_pin_memory_loop):
Traceback (most recent call last):
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/uskovev/FL/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/uskovev/FL/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 59, in _pin_memory_loop
    do_one_step()
  File "/home/uskovev/FL/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 35, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uskovev/.local/share/uv/python/cpython-3.11.11-linux

KeyboardInterrupt: 

In [None]:
for i in range(1000):
    run_step(0.5, i)
    if i % 10 == 0:
        global_acc = compute_accuracy()
        # local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
        local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
        single_accuracy = torch.tensor([evaluate(model, validation_loder) for model, loader in zip(client_models, client_loaders)])
        writer.add_scalars(
            main_tag='Accuracy',
            tag_scalar_dict={'local': local_accuracy.mean(), 'ensembled': global_acc},
            global_step=i // 10,
        )
        print(f"Accuracy = {global_acc:.3f} Single accuracy = {single_accuracy.mean():.3f} +- {single_accuracy.std():.3f} Local accuracy = {local_accuracy.mean():.3f} +- {local_accuracy.std():.3f}")

Accuracy = 0.231 Single accuracy = 0.124 +- 0.009 Local accuracy = 0.905 +- 0.021
Accuracy = 0.237 Single accuracy = 0.123 +- 0.007 Local accuracy = 0.911 +- 0.012
Accuracy = 0.231 Single accuracy = 0.124 +- 0.008 Local accuracy = 0.920 +- 0.014
Accuracy = 0.229 Single accuracy = 0.122 +- 0.009 Local accuracy = 0.919 +- 0.016
Accuracy = 0.241 Single accuracy = 0.124 +- 0.009 Local accuracy = 0.933 +- 0.014
Accuracy = 0.238 Single accuracy = 0.125 +- 0.011 Local accuracy = 0.936 +- 0.014
Accuracy = 0.241 Single accuracy = 0.124 +- 0.010 Local accuracy = 0.934 +- 0.014
Accuracy = 0.231 Single accuracy = 0.122 +- 0.009 Local accuracy = 0.927 +- 0.012


KeyboardInterrupt: 

In [None]:
writer.close()