In [1]:
import torch
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import vgg16_bn, VGG16_BN_Weights
import numpy as np

In [2]:
def collect_augment_aggregate(loader, num_of_clients=64, parallelization_param=8, backbone=None, mode='train', batch_size=256, number_of_iterations=300):
    data, labels = next(iter(loader))

    partite_class_data = []
    partite_class_labels = []
    for class_idx in range(10):
        mask = labels == class_idx
        len_mask = torch.count_nonzero(mask)
        remainder = len_mask % num_of_clients
        split_param = len_mask // num_of_clients
        partite_class_data.append(list(torch.split(data[mask][:-remainder], split_param)))
        partite_class_labels.append(list(torch.split(labels[mask][:-remainder], split_param)))
        for remainder_idx in range(remainder):
            partite_class_data[-1][remainder_idx] = torch.concatenate([partite_class_data[-1][remainder_idx], data[mask][-remainder].unsqueeze(0)])
            partite_class_labels[-1][remainder_idx] = torch.concatenate([partite_class_labels[-1][remainder_idx], labels[mask][-remainder].unsqueeze(0)])

    if backbone == 'vgg':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        vgg_backbone = vgg16_bn(weights=VGG16_BN_Weights.DEFAULT).eval()
        vgg_backbone = torch.nn.Sequential(*(list(vgg_backbone.children())[:-1])).to(device)

    partite_clients_data = []
    partite_clients_labels = []
    for client_idx in range(num_of_clients):
        client_data_buffer = []
        client_labels_buffer = []
        for class_idx in range(10):
            client_data_buffer.append(partite_class_data[class_idx][client_idx])
            client_labels_buffer.append(partite_class_labels[class_idx][client_idx])
        client_data_buffer = torch.concatenate(client_data_buffer)
        client_labels_buffer = torch.concatenate(client_labels_buffer)
        
        if len(client_data_buffer.shape) > 3:
            for channel in range(client_data_buffer.shape[1]):
                #normalize client data
                client_data_buffer[:, channel, :, :] = (client_data_buffer[:, channel, :, :] - torch.mean(client_data_buffer[:, channel, :, :])) / torch.std(client_data_buffer[:, channel, :, :])
        else:
            client_data_buffer = (client_data_buffer - torch.mean(client_data_buffer)) / torch.std(client_data_buffer)

        if backbone == 'vgg':
            with torch.no_grad():
                client_data_buffer = client_data_buffer.to(device)
                client_data_buffer = vgg_backbone(client_data_buffer).detach().to('cpu')
                print('client {}\'s data is passed from backbone - {}'.format(client_idx, mode))

        if mode == 'train':
            shuffle = torch.randperm(client_data_buffer.shape[0])
            client_data_buffer = client_data_buffer[shuffle]
            client_labels_buffer = client_labels_buffer[shuffle]
        client_data_buffer = client_data_buffer.reshape(client_data_buffer.shape[0], -1)

        if mode == 'train':
            starting_split = 0
            for split_idx in range(parallelization_param, client_data_buffer.shape[0], parallelization_param):
                partite_clients_data.append(client_data_buffer[starting_split:split_idx])
                partite_clients_labels.append(client_labels_buffer[starting_split:split_idx])
                starting_split = split_idx
            partite_clients_data.append(client_data_buffer[split_idx:client_data_buffer.shape[0]])
            partite_clients_labels.append(client_labels_buffer[split_idx:client_labels_buffer.shape[0]])
        else:
            partite_clients_data.append(client_data_buffer)
            partite_clients_labels.append(client_labels_buffer)

    print('data and labels are handled - {}'.format(mode))
    data_batches = []
    labels_batches = []
    if mode == 'train':
        number_of_selection = batch_size // parallelization_param
        if batch_size % parallelization_param != 0:
            number_of_selection = number_of_selection + 1
        batched_idx = [np.random.choice(len(partite_clients_data), number_of_selection) for _ in range(number_of_iterations)]
        for curr_batched_idx in batched_idx:
            data_buff = []
            labels_buff = []
            for curr_idx in curr_batched_idx:
                data_buff.append(partite_clients_data[curr_idx])
                labels_buff.append(partite_clients_labels[curr_idx])
            data_batches.append(torch.concatenate(data_buff))
            labels_batches.append(torch.concatenate(labels_buff))
    else:
        data_batches.append(torch.concatenate(partite_clients_data))
        labels_batches.append(torch.concatenate(partite_clients_labels))

    return data_batches, labels_batches

In [None]:
num_of_clients = 64
parallelization_param = 8
backbone = None
batch_size = 256
number_of_iterations = 300

transform = ToTensor()

train_dataset = MNIST('./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0])
train_data, train_label = collect_augment_aggregate(train_loader, num_of_clients=num_of_clients, parallelization_param=parallelization_param, backbone=backbone,
                                                    batch_size=batch_size, number_of_iterations=number_of_iterations, mode='train')

test_dataset = MNIST('./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0])
test_data, test_label = collect_augment_aggregate(test_loader, num_of_clients=num_of_clients, mode='test')

In [3]:
num_of_clients = 64
parallelization_param = 8
backbone = 'vgg'
batch_size = 256
number_of_iterations = 300

transform = ToTensor()

train_dataset = CIFAR10('./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0])
train_data, train_label = collect_augment_aggregate(train_loader, num_of_clients=num_of_clients, parallelization_param=parallelization_param, backbone=backbone,
                                                    batch_size=batch_size, number_of_iterations=number_of_iterations, mode='train')

test_dataset = CIFAR10('./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0])
test_data, test_label = collect_augment_aggregate(test_loader, num_of_clients=num_of_clients, mode='test', backbone=backbone)

Files already downloaded and verified
client 0's data is passed from backbone - train
client 1's data is passed from backbone - train
client 2's data is passed from backbone - train
client 3's data is passed from backbone - train
client 4's data is passed from backbone - train
client 5's data is passed from backbone - train
client 6's data is passed from backbone - train
client 7's data is passed from backbone - train
client 8's data is passed from backbone - train
client 9's data is passed from backbone - train
client 10's data is passed from backbone - train
client 11's data is passed from backbone - train
client 12's data is passed from backbone - train
client 13's data is passed from backbone - train
client 14's data is passed from backbone - train
client 15's data is passed from backbone - train
client 16's data is passed from backbone - train
client 17's data is passed from backbone - train
client 18's data is passed from backbone - train
client 19's data is passed from backbone 

In [6]:
class SimpleNetwork(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin_01 = nn.Linear(in_channel, hidden_channel, bias=False)
        self.lin_02 = nn.Linear(hidden_channel, out_channel, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x):
        out = self.act(self.lin_01(x))
        out = self.lin_02(out)
        return out
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleNetwork(25088, 128, 10).to(device)
optim = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [7]:
model.train()
for iter_idx, (data, label) in enumerate(zip(train_data, train_label)):
    data, label = data.to(device), label.to(device)
    optim.zero_grad()
    preds = model(data)
    loss = criterion(preds, label)
    print(loss.item())
    loss.backward()
    optim.step()

    if iter_idx == 0 or (iter_idx + 1) % 10 == 0 or (iter_idx + 1) == len(train_data):
        model.eval()
        tp_count = 0
        val_size = 0
        with torch.no_grad():
            for val_iter_idx, (val_data, val_label) in enumerate(zip(test_data, test_label)):
                val_data, val_label = val_data.to(device), val_label.to(device)
                val_preds = model(val_data)
                val_preds = torch.argmax(val_preds, dim=1)
                tp_count += torch.count_nonzero(val_label == val_preds)
                val_size += val_data.shape[0]
            print('iter: {}, accuracy: {}'.format(iter_idx + 1, (tp_count / val_size).item()))
        model.train()

2.3127899169921875
iter: 1, accuracy: 0.31219998002052307
2.118112325668335
1.7688369750976562
1.6250340938568115
1.719810962677002
1.473921537399292
1.2691431045532227
1.2311407327651978
1.219807744026184
1.3524233102798462
iter: 10, accuracy: 0.5724999904632568
1.0901464223861694
1.2750219106674194
1.1673383712768555
1.1840304136276245
1.090572476387024
1.0408058166503906
1.057960867881775
1.1159541606903076
1.0559642314910889
0.9705296754837036
iter: 20, accuracy: 0.6312999725341797
1.0131561756134033
1.086552619934082
1.0458146333694458
1.0292799472808838
1.0525169372558594
0.9937734603881836
1.1092076301574707
0.932860791683197
0.9111469984054565
1.0466974973678589
iter: 30, accuracy: 0.6464999914169312
0.9647875428199768
0.8693063855171204
0.9390324950218201
1.0867986679077148
0.997886061668396
0.9652329683303833
0.8694192171096802
0.9365697503089905
0.8729681968688965
1.0148671865463257
iter: 40, accuracy: 0.6567999720573425
1.0224374532699585
0.9648373126983643
0.88242757320404