In [9]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

In [20]:
def collect_augment_aggregate(loader, num_of_clients):
    data, labels = next(iter(loader))

    partite_class_data = []
    partite_class_labels = []
    for class_idx in range(10):
        mask = labels == class_idx
        len_mask = torch.count_nonzero(mask)
        remainder = len_mask % num_of_clients
        split_param = len_mask // num_of_clients
        partite_class_data.append(list(torch.split(data[mask][:-remainder], split_param)))
        partite_class_labels.append(list(torch.split(labels[mask][:-remainder], split_param)))
        for remainder_idx in range(remainder):
            partite_class_data[-1][remainder_idx] = torch.concatenate([partite_class_data[-1][remainder_idx], data[mask][-remainder].unsqueeze(0)])
            partite_class_labels[-1][remainder_idx] = torch.concatenate([partite_class_labels[-1][remainder_idx], labels[mask][-remainder].unsqueeze(0)])

    partite_clients_data = []
    partite_clients_labels = []
    for client_idx in range(num_of_clients):
        client_data_buffer = []
        client_labels_buffer = []
        for class_idx in range(10):
            client_data_buffer.append(partite_class_data[class_idx][client_idx])
            client_labels_buffer.append(partite_class_labels[class_idx][client_idx])
        client_data_buffer = torch.concatenate(client_data_buffer)
        client_labels_buffer = torch.concatenate(client_labels_buffer)
        
        #normalize client data
        client_data_buffer = (client_data_buffer - torch.mean(client_data_buffer)) / torch.std(client_data_buffer)
        partite_clients_data.append(client_data_buffer.reshape(client_data_buffer.shape[0], -1))
        partite_clients_labels.append(client_labels_buffer)
    permute_data = torch.randperm(data.shape[0])
    data = torch.concatenate(partite_clients_data)[permute_data]
    labels = torch.concatenate(partite_clients_labels)[permute_data]
    dataset = TensorDataset(data, labels)
    return dataset

In [21]:
num_of_clients = 64
transform = ToTensor()

train_dataset = MNIST('./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0])
train_dataset = collect_augment_aggregate(train_loader, num_of_clients)
train_loader = DataLoader(train_dataset, batch_size=256)

test_dataset = MNIST('./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0])
test_dataset = collect_augment_aggregate(test_loader, num_of_clients)
test_loader = DataLoader(test_dataset, batch_size=256)

In [22]:
class SimpleNetwork(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin_01 = nn.Linear(in_channel, hidden_channel, bias=False)
        self.lin_02 = nn.Linear(hidden_channel, out_channel, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x):
        out = self.act(self.lin_01(x))
        out = self.lin_02(out)
        return out
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleNetwork(784, 128, 10).to(device)
optim = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [23]:
model.train()
for iter_idx, (data, label) in enumerate(train_loader):
    data, label = data.to(device), label.to(device)
    optim.zero_grad()
    preds = model(data)
    loss = criterion(preds, label)
    print(loss.item())
    loss.backward()
    optim.step()

    if iter_idx == 0 or (iter_idx % 10) == 0 or iter_idx == (len(train_loader) - 1):
        model.eval()
        tp_count = 0
        val_size = 0
        with torch.no_grad():
            for val_iter_idx, (val_data, val_label) in enumerate(test_loader):
                val_data, val_label = val_data.to(device), val_label.to(device)
                val_preds = model(val_data)
                val_preds = torch.argmax(val_preds, dim=1)
                tp_count += torch.count_nonzero(val_label == val_preds)
                val_size += val_data.shape[0]
            print('accuracy: {}'.format((tp_count / val_size).item()))
        model.train()

2.3138599395751953
accuracy: 0.3732999861240387
2.1472933292388916
1.9934217929840088
1.8154860734939575
1.6793084144592285
1.5445636510849
1.4322694540023804
1.241387128829956
1.1057384014129639
1.045878529548645
0.9724445343017578
accuracy: 0.802899956703186
0.9450538754463196
0.8578817844390869
0.7757722735404968
0.7151757478713989
0.6618653535842896
0.6178776025772095
0.5667965412139893
0.6775199174880981
0.5910530090332031
0.5643151998519897
accuracy: 0.8718000054359436
0.566311776638031
0.4710420072078705
0.5314793586730957
0.4957160949707031
0.5266109108924866
0.5244624018669128
0.49615031480789185
0.3854353427886963
0.43622592091560364
0.47242966294288635
accuracy: 0.8848999738693237
0.4319853186607361
0.4906245768070221
0.3570694625377655
0.4795605540275574
0.4790980815887451
0.3911655843257904
0.481953501701355
0.37130844593048096
0.4266935884952545
0.44516521692276
accuracy: 0.8956999778747559
0.4594241976737976
0.3859485685825348
0.400579035282135
0.3426133990287781
0.37199