In [1]:
import torch
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

In [2]:
def collect_augment_aggregate(loader, num_of_clients):
    data, labels = next(iter(loader))

    partite_class_data = []
    partite_class_labels = []
    for class_idx in range(10):
        mask = labels == class_idx
        len_mask = torch.count_nonzero(mask)
        remainder = len_mask % num_of_clients
        split_param = (len_mask // num_of_clients + 1) if remainder != 0 else (len_mask / num_of_clients)
        partite_class_data.append(torch.split(data[mask], split_param))
        partite_class_labels.append(torch.split(labels[mask], split_param))

    partite_clients_data = []
    partite_clients_labels = []
    for client_idx in range(num_of_clients):
        client_data_buffer = []
        client_labels_buffer = []
        for class_idx in range(10):
            client_data_buffer.append(partite_class_data[class_idx][client_idx])
            client_labels_buffer.append(partite_class_labels[class_idx][client_idx])
        client_data_buffer = torch.concatenate(client_data_buffer)
        client_labels_buffer = torch.concatenate(client_labels_buffer)
        
        #normalize client data
        client_data_buffer = (client_data_buffer - torch.mean(client_data_buffer)) / torch.std(client_data_buffer)
        partite_clients_data.append(client_data_buffer.reshape(client_data_buffer.shape[0], -1))
        partite_clients_labels.append(client_labels_buffer)
    permute_data = torch.randperm(data.shape[0])
    data = torch.concatenate(partite_clients_data)[permute_data]
    labels = torch.concatenate(partite_clients_labels)[permute_data]
    dataset = TensorDataset(data, labels)
    return dataset

In [3]:
num_of_clients = 10
transform = ToTensor()

train_dataset = FashionMNIST('./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0])
train_dataset = collect_augment_aggregate(train_loader, num_of_clients)
train_loader = DataLoader(train_dataset, batch_size=256)

test_dataset = FashionMNIST('./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0])
test_dataset = collect_augment_aggregate(test_loader, num_of_clients)
test_loader = DataLoader(test_dataset, batch_size=256)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



In [4]:
class SimpleNetwork(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin_01 = nn.Linear(in_channel, hidden_channel, bias=False)
        self.lin_02 = nn.Linear(hidden_channel, out_channel, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x):
        out = self.act(self.lin_01(x))
        out = self.lin_02(out)
        return out
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleNetwork(784, 128, 10).to(device)
optim = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [5]:
model.train()
for iter_idx, (data, label) in enumerate(train_loader):
    data, label = data.to(device), label.to(device)
    optim.zero_grad()
    preds = model(data)
    loss = criterion(preds, label)
    print(loss.item())
    loss.backward()
    optim.step()

    if iter_idx == 0 or (iter_idx % 10) == 0 or iter_idx == (len(train_loader) - 1):
        model.eval()
        tp_count = 0
        val_size = 0
        with torch.no_grad():
            for val_iter_idx, (val_data, val_label) in enumerate(test_loader):
                val_data, val_label = val_data.to(device), val_label.to(device)
                val_preds = model(val_data)
                val_preds = torch.argmax(val_preds, dim=1)
                tp_count += torch.count_nonzero(val_label == val_preds)
                val_size += val_data.shape[0]
            print('accuracy: {}'.format((tp_count / val_size).item()))
        model.train()

2.4070942401885986
accuracy: 0.45969998836517334
2.002516269683838
1.7913089990615845
1.5901436805725098
1.4349408149719238
1.3121137619018555
1.0941272974014282
1.106682538986206
0.9620117545127869
0.9954046607017517
0.8277041912078857
accuracy: 0.705299973487854
0.8617054224014282
0.8327913284301758
0.8126285076141357
0.7945008873939514
0.6554155349731445
0.7608761191368103
0.6622105836868286
0.7318330407142639
0.702953577041626
0.6856306195259094
accuracy: 0.7428999543190002
0.7380772233009338
0.7523606419563293
0.6958063244819641
0.6586880087852478
0.6085900664329529
0.7472221255302429
0.5331209897994995
0.5522488951683044
0.5804660320281982
0.5950310826301575
accuracy: 0.774899959564209
0.5883497595787048
0.6424744129180908
0.6575775742530823
0.537688672542572
0.689792275428772
0.6021755337715149
0.5667117834091187
0.6010506749153137
0.660700798034668
0.5945851802825928
accuracy: 0.7888000011444092
0.5673952102661133
0.6250730156898499
0.5214952826499939
0.5491254925727844
0.54485