In [80]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt
import numpy as np
from admm.agents import FedConsensus
from admm.servers import EventADMM
from admm.models import FCNet, CNN
from admm.utils import average_params, split_dataset
import seaborn as sns
sns.set_theme()

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [81]:
cifar_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar_trainset = datasets.CIFAR10(
    root='./data/cifar10', train=True,
    download=False, transform=cifar_transform
)
cifar_testset = datasets.CIFAR10(
    root='./data/cifar10', train=False,
    download=False, transform=cifar_transform
)

mnist_transform = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Lambda(lambda x: torch.flatten(x))
])

mnist_trainset = datasets.MNIST(
    root='./data/mnist_data', train=True,
    download=False, transform=mnist_transform
)
mnist_testset = datasets.MNIST(
    root='./data/mnist_data', train=False,
    download=False, transform=mnist_transform
)

In [82]:
from datset_preperation import _partition_data

train_dataset, val_dataset, _ = split_dataset(dataset=mnist_trainset, train_ratio=0.8, val_ratio=0.2)

trainsets = _partition_data(
    num_clients=10,
    iid=False,
    balance=False,
    power_law=True,
    seed=42,
    trainset=train_dataset.dataset
)

for i, dataset in enumerate(trainsets):
    labels = np.zeros(10)
    loader = DataLoader(dataset, batch_size=1)
    for data, target in loader:
        labels[target.item()] += 1
    print(f'Dataset {i} distribution: {labels} - num_samples = {labels.sum()}')

labels = np.zeros(10)
loader = DataLoader(val_dataset, batch_size=1)
for data, target in loader:
    labels[target.item()] += 1
print(f'Validation dataset {i} distribution: {labels} - num_samples = {labels.sum()}')

oader = DataLoader(val_dataset, batch_size=10)
for data, target in loader:
    print(data.shape)
    break

Dataset 0 distribution: [  21.  273.   72. 1951.   34.  607. 4135.  168.  204.  387.] - num_samples = 7852.0
Dataset 1 distribution: [719. 530. 445. 127.  20.  40.   2.   1.  74. 231.] - num_samples = 2189.0
Dataset 2 distribution: [  84.   32.  288.  266.    5. 2204.   32.  520.  903.  114.] - num_samples = 4448.0
Dataset 3 distribution: [  10. 1275.   77.  646.   51.    9.   11.   54. 1657.  274.] - num_samples = 4064.0
Dataset 4 distribution: [ 189.  226. 3712.   87.    6.  203.  118.   43.  218.   80.] - num_samples = 4882.0
Dataset 5 distribution: [  29. 3156.   57.   87.    9.  505.  649. 3178.  171. 3505.] - num_samples = 11346.0
Dataset 6 distribution: [  28.  357. 1217. 1295. 4325.  325.   40.   20.  180. 1019.] - num_samples = 8806.0
Dataset 7 distribution: [ 462.  136.    5.   14.   14.  137.    5. 1956.  771.  213.] - num_samples = 3713.0
Dataset 8 distribution: [4.309e+03 4.210e+02 3.000e+00 5.440e+02 1.126e+03 6.900e+01 2.000e+00
 7.300e+01 1.638e+03 4.800e+01] - num_samp

In [84]:
batch_size = 32
train_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in trainsets]
test_loader = DataLoader(mnist_testset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [86]:
deltas = [0.001, 0.01, 0.1, 1, 10]
rho = 0.01
t_max = 100
# loaders = [digit_1_train_loader_fc, digit_2_train_loader_fc]
loaders = train_loaders
device = 'cpu'

acc_per_delta = np.zeros((len(deltas), t_max))
rate_per_delta = np.zeros((len(deltas), t_max))
loads = []
test_accs = []

print(f'N = {len(loaders)}')
for i, delta in enumerate(deltas):
    
    agents = [
        FedConsensus(
            N=len(loaders),
            delta=0,
            rho=delta,
            model=FCNet(in_channels=784, hidden1=200, hidden2=None, out_channels=10),
            loss=nn.CrossEntropyLoss(),
            train_loader=loader,
            classification=True,
            epochs=2,
            device=device
        ) for loader in loaders
    ]

    # Broadcast average to all agents and check if equal
    for agent in agents:
        agent.primal_avg = average_params([agent.model.parameters() for agent in agents])
    for param1, param2 in zip(agents[0].primal_avg, agents[1].primal_avg):
        if not torch.equal(param1, param2): raise ValueError("Averaged params aren't equal")

    # Run consensus algorithm
    server = EventADMM(clients=agents, t_max=t_max)
    server.spin(loader=val_loader)
    
    # For plotting purposes
    acc_per_delta[i,:] = server.val_accs
    rate_per_delta[i,:] = server.rates
    loads.append(server.comm)
    accs = server.validate(loader=test_loader)
    test_accs.append(sum(accs)/len(accs))

N = 10


Comm frequency: 1.000, agent 0: 0.65, agent 1: 0.69, agent 2: 0.70, agent 3: 0.69, agent 4: 0.68, agent 5: 0.69, agent 6: 0.69, agent 7: 0.69, agent 8: 0.67, agent 9: 0.68:  27%|██▋       | 27/100 [09:48<24:30, 20.15s/it] 

In [None]:
T = range(t_max)
# Plot accuracies
for acc, delta in zip(acc_per_delta, deltas):
    plt.plot(T, acc, label=f'rho={delta}')
plt.legend(loc='center right', bbox_to_anchor=(1.3, 0.5))
plt.xlabel('Time Step')
plt.ylabel('Accuracy')
plt.title('Validation Set Accuracy - Fully Connected - Learning Rate = 0.001')
plt.show()

In [None]:
for rate, delta in zip(rate_per_delta, deltas):
    plt.plot(T, rate, label=f'delta={delta}')
plt.legend(loc='center right', bbox_to_anchor=(1.3, 0.5))
plt.xlabel('Time Step')
plt.ylabel('Rate')
plt.title('Communication Rate - Fully Connected')
plt.show()

In [None]:
for load, acc, delta in zip(loads, test_accs, deltas):
    plt.plot(acc, load, label=f'delta={delta}', marker='o')
plt.legend(loc='center right', bbox_to_anchor=(1.3, 0.5))
plt.xlabel('Test Error')
plt.ylabel('Communication Load')
plt.title('Fully Connected')
plt.show()

In [40]:
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import ExponentialPartitioner, NaturalIdPartitioner, LinearPartitioner

nodes = 3
fds = FederatedDataset(dataset='cifar10', partitioners={'train': nodes, 'test': 1})
partitions = [fds.load_partition(node_id=node, split='train') for node in range(nodes)]

transforms = ToTensor()
def apply_transforms(batch):
  batch["img"] = [transforms(img) for img in batch["img"]]
  return batch

partitions_torch = [partition.with_transform(apply_transforms) for partition in partitions]

print('Training partitions')
for i, partition in enumerate(partitions):
    print(f'\nPartition {i}')
    b= []
    for i in range(10):
        a = [1 for label in partition['label'] if label == i]
        b.append(sum(a))
        a = []
    print(b)

print('\nTest partition\n')
partition = fds.load_partition(node_id=0, split='test')
b= []
for i in range(10):
    a = [1 for label in partition['label'] if label == i]
    b.append(sum(a))
    a = []
print(b)

Training partitions

Partition 0
[1705, 1605, 1682, 1628, 1669, 1683, 1674, 1692, 1643, 1686]

Partition 1
[1599, 1723, 1659, 1729, 1670, 1647, 1679, 1644, 1653, 1664]

Partition 2
[1696, 1672, 1659, 1643, 1661, 1670, 1647, 1664, 1704, 1650]

Test partition

[1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]

Validation partition



ValueError: The given split: 'val' is not present in the dataset's splits: '['train', 'test']'.