In [1]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

In [2]:
def collect_augment_aggregate(loader, num_of_clients):
    data, labels = next(iter(loader))

    partite_class_data = []
    partite_class_labels = []
    for class_idx in range(10):
        mask = labels == class_idx
        len_mask = torch.count_nonzero(mask)
        remainder = len_mask % num_of_clients
        split_param = (len_mask // num_of_clients + 1) if remainder != 0 else (len_mask / num_of_clients)
        partite_class_data.append(torch.split(data[mask], split_param))
        partite_class_labels.append(torch.split(labels[mask], split_param))

    partite_clients_data = []
    partite_clients_labels = []
    for client_idx in range(num_of_clients):
        client_data_buffer = []
        client_labels_buffer = []
        for class_idx in range(10):
            client_data_buffer.append(partite_class_data[class_idx][client_idx])
            client_labels_buffer.append(partite_class_labels[class_idx][client_idx])
        client_data_buffer = torch.concatenate(client_data_buffer)
        client_labels_buffer = torch.concatenate(client_labels_buffer)
        
        #normalize client data
        client_data_buffer = (client_data_buffer - torch.mean(client_data_buffer)) / torch.std(client_data_buffer)
        partite_clients_data.append(client_data_buffer.reshape(client_data_buffer.shape[0], -1))
        partite_clients_labels.append(client_labels_buffer)
    permute_data = torch.randperm(data.shape[0])
    data = torch.concatenate(partite_clients_data)[permute_data]
    labels = torch.concatenate(partite_clients_labels)[permute_data]
    dataset = TensorDataset(data, labels)
    return dataset

In [3]:
num_of_clients = 10
transform = ToTensor()

train_dataset = MNIST('./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0])
train_dataset = collect_augment_aggregate(train_loader, num_of_clients)
train_loader = DataLoader(train_dataset, batch_size=256)

test_dataset = MNIST('./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0])
test_dataset = collect_augment_aggregate(test_loader, num_of_clients)
test_loader = DataLoader(test_dataset, batch_size=256)

In [4]:
class SimpleNetwork(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin_01 = nn.Linear(in_channel, hidden_channel, bias=False)
        self.lin_02 = nn.Linear(hidden_channel, out_channel, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x):
        out = self.act(self.lin_01(x))
        out = self.lin_02(out)
        return out
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleNetwork(784, 128, 10).to(device)
optim = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [5]:
model.train()
for iter_idx, (data, label) in enumerate(train_loader):
    data, label = data.to(device), label.to(device)
    optim.zero_grad()
    preds = model(data)
    loss = criterion(preds, label)
    print(loss.item())
    loss.backward()
    optim.step()

    if iter_idx == 0 or (iter_idx % 10) == 0:
        model.eval()
        tp_count = 0
        val_size = 0
        with torch.no_grad():
            for val_iter_idx, (val_data, val_label) in enumerate(test_loader):
                val_data, val_label = val_data.to(device), val_label.to(device)
                val_preds = model(val_data)
                val_preds = torch.argmax(val_preds, dim=1)
                tp_count += torch.count_nonzero(val_label == val_preds)
                val_size += val_data.shape[0]
            print('accuracy: {}'.format((tp_count / val_size).item()))
        model.train()

2.338310480117798
accuracy: 0.2565999925136566
2.1451141834259033
1.9621187448501587
1.8334674835205078
1.7286545038223267
1.5275076627731323
1.3906042575836182
1.31075918674469
1.1577472686767578
1.0699412822723389
0.9675514698028564
accuracy: 0.7976999878883362
0.8969026803970337
0.8791893124580383
0.7743867635726929
0.891302764415741
0.7087811231613159
0.7663559317588806
0.7447514533996582
0.6389551758766174
0.5624587535858154
0.6413402557373047
accuracy: 0.8646000027656555
0.5258104205131531
0.4942413866519928
0.5191867351531982
0.4634852111339569
0.5492941737174988
0.42594635486602783
0.46482303738594055
0.41044968366622925
0.4072994589805603
0.4142930507659912
accuracy: 0.8809999823570251
0.4920409917831421
0.45844316482543945
0.45901286602020264
0.44256994128227234
0.4438053071498871
0.4610747992992401
0.3446796238422394
0.3227398991584778
0.4465421736240387
0.39527902007102966
accuracy: 0.8892999887466431
0.3857544958591461
0.3420315086841583
0.3558228611946106
0.33703279495239