In [1]:
import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from utils_numpy import create_batch_data
from torchvision.models import VGG16_BN_Weights, vgg16_bn

In [2]:
def collect_augment_aggregate_data(dataset, dataloader, num_of_clients, num_of_classes):
    data, labels = next(iter(dataloader))
    data, labels = data.numpy(), labels.numpy()
    targets = dataset.targets
    if not isinstance(targets, np.ndarray):
        targets = np.asarray(targets)

    different_classes_data = []
    different_classes_labels = []
    for class_id in range(num_of_classes):
        different_classes_data.append(np.array_split(data[targets == class_id], num_of_clients))
        different_classes_labels.append(np.array_split(labels[targets == class_id], num_of_clients))
    
    client_data = []
    client_labels = []
    for client_idx in range(num_of_clients):
        client_data_buffer = []
        client_labels_buffer = []
        for class_idx in range(num_of_classes):
            client_data_buffer.append(different_classes_data[class_idx][client_idx])
            client_labels_buffer.append(different_classes_labels[class_idx][client_idx])

        client_data_buffer = np.concatenate(client_data_buffer)
        for channel_idx in range(client_data_buffer.shape[1]):
            client_data_buffer[:, channel_idx, :, :] = (client_data_buffer[:, channel_idx, :, :] - np.mean(client_data_buffer[:, channel_idx, :, :])) / np.std(client_data_buffer[:, channel_idx, :, :])

        client_data.append(client_data_buffer)
        client_labels.append(np.concatenate(client_labels_buffer))
    aggregated_data = np.concatenate(client_data)
    aggregated_labels = np.concatenate(client_labels)
    randomize = np.random.permutation(aggregated_data.shape[0])
    return aggregated_data[randomize], aggregated_labels[randomize]

In [3]:
def load_all_data_apply_vgg_cifar10(num_of_clients: int = 64):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    transform = Compose([
        ToTensor(),
    ])

    # load data
    train_dataset = CIFAR10('./data', train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0], shuffle=False)
    test_dataset = CIFAR10('./data', train=False, transform=transform, download=True)
    test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0], shuffle=False)


    all_train_data, all_train_labels = collect_augment_aggregate_data(train_dataset, train_loader, num_of_clients, 10)
    all_test_data, all_test_labels = collect_augment_aggregate_data(test_dataset, test_loader, num_of_clients, 10)
    all_train_data, all_train_labels, all_test_data, all_test_labels = create_batch_data(all_train_data, all_train_labels, all_test_data, all_test_labels, 256)

    vgg_backbone = vgg16_bn(weights=VGG16_BN_Weights.DEFAULT).eval()
    vgg_backbone = torch.nn.Sequential(*(list(vgg_backbone.children())[:-1])).to(device)


    with torch.no_grad():
        train_data_all, train_label_all, test_data_all, test_label_all = [], [], [], []
        for train_data, train_label in zip(all_train_data, all_train_labels):
            train_data = torch.tensor(train_data)
            train_data = train_data.to(device)
            train_data = vgg_backbone(train_data).reshape(train_data.size(0), -1).to('cpu').numpy()

            train_data_all.append(train_data)
            train_label_all.append(train_label)

        for test_data, test_label in zip(all_test_data, all_test_labels):
            test_label_all.append(test_label)
            test_data = torch.tensor(test_data)
            test_data = test_data.to(device)
            test_data = vgg_backbone(test_data).reshape(test_data.size(0), -1).to('cpu').numpy()
            test_data_all.append(test_data)

    train_data_all, train_label_all = np.concatenate(train_data_all, axis=0), np.concatenate(train_label_all, axis=0)
    test_data_all, test_label_all = np.concatenate(test_data_all, axis=0), np.concatenate(test_label_all, axis=0)
    train_data_all, train_label_all = torch.tensor(train_data_all), torch.tensor(train_label_all)
    test_data_all, test_label_all = torch.tensor(test_data_all), torch.tensor(test_label_all)
    train_dataset = TensorDataset(train_data_all, train_label_all)
    test_dataset = TensorDataset(test_data_all, test_label_all)
    return train_dataset, test_dataset

In [4]:
train_dataset, test_dataset = load_all_data_apply_vgg_cifar10()
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class SimpleNetwork(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin_01 = nn.Linear(in_channel, hidden_channel, bias=False)
        self.lin_02 = nn.Linear(hidden_channel, out_channel, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x):
        out = self.act(self.lin_01(x))
        out = self.lin_02(out)
        return out
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleNetwork(25088, 128, 10).to(device)
optim = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [6]:
for epoch in range(300):
    model.train()
    for iter_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        optim.zero_grad()
        preds = model(data)
        loss = criterion(preds, label)
        print(loss.item())
        loss.backward()
        optim.step()

        if epoch == 0 or (epoch % 10) == 0 or epoch == 195:
            model.eval()
            tp_count = 0
            val_size = 0
            with torch.no_grad():
                for val_iter_idx, (val_data, val_label) in enumerate(test_loader):
                    val_data, val_label = val_data.to(device), val_label.to(device)
                    val_preds = model(val_data)
                    val_preds = torch.argmax(val_preds, dim=1)
                    tp_count += torch.count_nonzero(val_label == val_preds)
                    val_size += val_data.shape[0]
                print('epoch: {}, accuracy: {}'.format(epoch + 1, (tp_count / val_size).item()))
            model.train()


2.3137998580932617
epoch: 1, accuracy: 0.4414999783039093
1.7017310857772827
2.028775691986084
2.1806061267852783
1.674379825592041
1.463227391242981
1.2385197877883911
1.1776175498962402
1.227717399597168
1.2208075523376465
1.1735552549362183
epoch: 11, accuracy: 0.609499990940094
1.1273291110992432
1.064523696899414
1.0056588649749756
0.9846765398979187
0.9945262670516968
1.0076960325241089
1.0059475898742676
0.988359272480011
0.9625404477119446
0.9365968704223633
epoch: 21, accuracy: 0.6699999570846558
0.916249692440033
0.9046121835708618
0.9003856182098389
0.8983786106109619
0.8948873281478882
0.889539897441864
0.8823285698890686
0.8728592991828918
0.8618923425674438
0.8518404960632324
epoch: 31, accuracy: 0.6972999572753906
0.8449824452400208
0.8414435386657715
0.839088499546051
0.8353558778762817
0.8292747735977173
0.8221278190612793
0.8155743479728699
0.8099056482315063
0.8055111169815063
0.8026041984558105
epoch: 41, accuracy: 0.7028999924659729
0.7997113466262817
0.79549759626