In [1]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

In [2]:
def collect_augment_aggregate(loader, num_of_clients):
    data, labels = next(iter(loader))

    partite_class_data = []
    partite_class_labels = []
    for class_idx in range(10):
        mask = labels == class_idx
        len_mask = torch.count_nonzero(mask)
        remainder = len_mask % num_of_clients
        split_param = len_mask // num_of_clients
        partite_class_data.append(list(torch.split(data[mask][:-remainder], split_param)))
        partite_class_labels.append(list(torch.split(labels[mask][:-remainder], split_param)))
        for remainder_idx in range(remainder):
            partite_class_data[-1][remainder_idx] = torch.concatenate([partite_class_data[-1][remainder_idx], data[mask][-remainder].unsqueeze(0)])
            partite_class_labels[-1][remainder_idx] = torch.concatenate([partite_class_labels[-1][remainder_idx], labels[mask][-remainder].unsqueeze(0)])

    partite_clients_data = []
    partite_clients_labels = []
    for client_idx in range(num_of_clients):
        client_data_buffer = []
        client_labels_buffer = []
        for class_idx in range(10):
            client_data_buffer.append(partite_class_data[class_idx][client_idx])
            client_labels_buffer.append(partite_class_labels[class_idx][client_idx])
        client_data_buffer = torch.concatenate(client_data_buffer)
        client_labels_buffer = torch.concatenate(client_labels_buffer)
        
        #normalize client data
        client_data_buffer = (client_data_buffer - torch.mean(client_data_buffer)) / torch.std(client_data_buffer)
        partite_clients_data.append(client_data_buffer.reshape(client_data_buffer.shape[0], -1))
        partite_clients_labels.append(client_labels_buffer)
    permute_data = torch.randperm(data.shape[0])
    data = torch.concatenate(partite_clients_data)[permute_data]
    labels = torch.concatenate(partite_clients_labels)[permute_data]
    dataset = TensorDataset(data, labels)
    return dataset

In [3]:
num_of_clients = 64
transform = ToTensor()

train_dataset = MNIST('./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0])
train_dataset = collect_augment_aggregate(train_loader, num_of_clients)
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))

test_dataset = MNIST('./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0])
test_dataset = collect_augment_aggregate(test_loader, num_of_clients)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

In [4]:
class SimpleNetwork(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin_01 = nn.Linear(in_channel, hidden_channel, bias=False)
        self.lin_02 = nn.Linear(hidden_channel, out_channel, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x):
        out = self.act(self.lin_01(x))
        out = self.lin_02(out)
        return out
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleNetwork(784, 128, 10).to(device)
optim = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [5]:
for epoch in range(300):
    model.train()
    for iter_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        optim.zero_grad()
        preds = model(data)
        loss = criterion(preds, label)
        print(loss.item())
        loss.backward()
        optim.step()

        if epoch == 0 or (epoch % 10) == 0 or epoch == 234:
            model.eval()
            tp_count = 0
            val_size = 0
            with torch.no_grad():
                for val_iter_idx, (val_data, val_label) in enumerate(test_loader):
                    val_data, val_label = val_data.to(device), val_label.to(device)
                    val_preds = model(val_data)
                    val_preds = torch.argmax(val_preds, dim=1)
                    tp_count += torch.count_nonzero(val_label == val_preds)
                    val_size += val_data.shape[0]
                print('epoch: {}, accuracy: {}'.format(epoch+1, (tp_count / val_size).item()))
            model.train()

2.337559938430786
epoch: 1, accuracy: 0.41429999470710754
2.0882413387298584
1.8813409805297852
1.6887352466583252
1.5054283142089844
1.3369925022125244
1.1879539489746094
1.0588058233261108
0.9487942457199097
0.8575913310050964
0.7834738492965698
epoch: 11, accuracy: 0.8463000059127808
0.722058117389679
0.6693609356880188
0.6244329810142517
0.5872538685798645
0.5566978454589844
0.5309362411499023
0.5084362030029297
0.4884193539619446
0.4707747995853424
0.4555896818637848
epoch: 21, accuracy: 0.88919997215271
0.44261205196380615
0.4310884475708008
0.4203794002532959
0.4104326069355011
0.4014269709587097
0.3933083713054657
0.3858363926410675
0.37886616587638855
0.3723679482936859
0.3662857115268707
epoch: 31, accuracy: 0.9060999751091003
0.3605315089225769
0.35508760809898376
0.34997889399528503
0.34518036246299744
0.34063535928726196
0.3363131582736969
0.3321903347969055
0.32822850346565247
0.3244074583053589
0.3207532465457916
epoch: 41, accuracy: 0.9172999858856201
0.3172879219055176