In [2]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

In [3]:
def collect_augment_aggregate(loader, num_of_clients):
    data, labels = next(iter(loader))

    partite_class_data = []
    partite_class_labels = []
    for class_idx in range(10):
        mask = labels == class_idx
        len_mask = torch.count_nonzero(mask)
        remainder = len_mask % num_of_clients
        split_param = len_mask // num_of_clients
        partite_class_data.append(list(torch.split(data[mask][:-remainder], split_param)))
        partite_class_labels.append(list(torch.split(labels[mask][:-remainder], split_param)))
        for remainder_idx in range(remainder):
            partite_class_data[-1][remainder_idx] = torch.concatenate([partite_class_data[-1][remainder_idx], data[mask][-remainder].unsqueeze(0)])
            partite_class_labels[-1][remainder_idx] = torch.concatenate([partite_class_labels[-1][remainder_idx], labels[mask][-remainder].unsqueeze(0)])

    partite_clients_data = []
    partite_clients_labels = []
    for client_idx in range(num_of_clients):
        client_data_buffer = []
        client_labels_buffer = []
        for class_idx in range(10):
            client_data_buffer.append(partite_class_data[class_idx][client_idx])
            client_labels_buffer.append(partite_class_labels[class_idx][client_idx])
        client_data_buffer = torch.concatenate(client_data_buffer)
        client_labels_buffer = torch.concatenate(client_labels_buffer)
        
        #normalize client data
        client_data_buffer = (client_data_buffer - torch.mean(client_data_buffer)) / torch.std(client_data_buffer)
        partite_clients_data.append(client_data_buffer.reshape(client_data_buffer.shape[0], -1))
        partite_clients_labels.append(client_labels_buffer)
    permute_data = torch.randperm(data.shape[0])
    data = torch.concatenate(partite_clients_data)[permute_data]
    labels = torch.concatenate(partite_clients_labels)[permute_data]
    dataset = TensorDataset(data, labels)
    return dataset

In [4]:
num_of_clients = 64
transform = ToTensor()

train_dataset = MNIST('./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0])
train_dataset = collect_augment_aggregate(train_loader, num_of_clients)
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))

test_dataset = MNIST('./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0])
test_dataset = collect_augment_aggregate(test_loader, num_of_clients)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

In [5]:
class SimpleNetwork(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin_01 = nn.Linear(in_channel, hidden_channel, bias=False)
        self.lin_02 = nn.Linear(hidden_channel, out_channel, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x):
        out = self.act(self.lin_01(x))
        out = self.lin_02(out)
        return out
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleNetwork(784, 128, 10).to(device)
optim = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [6]:
for epoch in range(235):
    model.train()
    for iter_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        optim.zero_grad()
        preds = model(data)
        loss = criterion(preds, label)
        print(loss.item())
        loss.backward()
        optim.step()

        if epoch == 0 or (epoch % 10) == 0 or epoch == 234:
            model.eval()
            tp_count = 0
            val_size = 0
            with torch.no_grad():
                for val_iter_idx, (val_data, val_label) in enumerate(test_loader):
                    val_data, val_label = val_data.to(device), val_label.to(device)
                    val_preds = model(val_data)
                    val_preds = torch.argmax(val_preds, dim=1)
                    tp_count += torch.count_nonzero(val_label == val_preds)
                    val_size += val_data.shape[0]
                print('epoch: {}, accuracy: {}'.format(epoch+1, (tp_count / val_size).item()))
            model.train()

2.314077615737915
epoch: 1, accuracy: 0.4007999897003174
2.0720551013946533
1.8705565929412842
1.6868267059326172
1.5117104053497314
1.3476401567459106
1.2000890970230103
1.0719810724258423
0.9629901647567749
0.8716666102409363
0.7954801321029663
epoch: 11, accuracy: 0.8425999879837036
0.7314024567604065
0.6774482131004333
0.6322057247161865
0.5940883755683899
0.5618354082107544
0.5345187187194824
0.5110664963722229
0.49079862236976624
0.47356918454170227
0.45871326327323914
epoch: 21, accuracy: 0.8876000046730042
0.4452778100967407
0.4330521523952484
0.42214709520339966
0.4123145043849945
0.40324607491493225
0.39489027857780457
0.3871544897556305
0.3799211382865906
0.3732284903526306
0.36706581711769104
epoch: 31, accuracy: 0.906499981880188
0.3612806499004364
0.35580405592918396
0.35067427158355713
0.3458377718925476
0.34121036529541016
0.3368082642555237
0.33264192938804626
0.3286502957344055
0.3248051702976227
0.32112884521484375
epoch: 41, accuracy: 0.9162999987602234
0.3176092207