In [1]:
import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from utils_numpy import create_batch_data
from torchvision.models import VGG16_BN_Weights, vgg16_bn

In [2]:
def collect_augment_aggregate_data(dataset, dataloader, num_of_clients, num_of_classes):
    data, labels = next(iter(dataloader))
    data, labels = data.numpy(), labels.numpy()
    targets = dataset.targets
    if not isinstance(targets, np.ndarray):
        targets = np.asarray(targets)

    different_classes_data = []
    different_classes_labels = []
    for class_id in range(num_of_classes):
        different_classes_data.append(np.array_split(data[targets == class_id], num_of_clients))
        different_classes_labels.append(np.array_split(labels[targets == class_id], num_of_clients))
    
    client_data = []
    client_labels = []
    for client_idx in range(num_of_clients):
        client_data_buffer = []
        client_labels_buffer = []
        for class_idx in range(num_of_classes):
            client_data_buffer.append(different_classes_data[class_idx][client_idx])
            client_labels_buffer.append(different_classes_labels[class_idx][client_idx])

        client_data_buffer = np.concatenate(client_data_buffer)
        for channel_idx in range(client_data_buffer.shape[1]):
            client_data_buffer[:, channel_idx, :, :] = (client_data_buffer[:, channel_idx, :, :] - np.mean(client_data_buffer[:, channel_idx, :, :])) / np.std(client_data_buffer[:, channel_idx, :, :])

        client_data.append(client_data_buffer)
        client_labels.append(np.concatenate(client_labels_buffer))
    aggregated_data = np.concatenate(client_data)
    aggregated_labels = np.concatenate(client_labels)
    randomize = np.random.permutation(aggregated_data.shape[0])
    return aggregated_data[randomize], aggregated_labels[randomize]

In [3]:
def load_all_data_apply_vgg_cifar10(num_of_clients: int = 64):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    transform = Compose([
        ToTensor(),
    ])

    # load data
    train_dataset = CIFAR10('./data', train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=train_dataset.data.shape[0], shuffle=False)
    test_dataset = CIFAR10('./data', train=False, transform=transform, download=True)
    test_loader = DataLoader(test_dataset, batch_size=test_dataset.data.shape[0], shuffle=False)


    all_train_data, all_train_labels = collect_augment_aggregate_data(train_dataset, train_loader, num_of_clients, 10)
    all_test_data, all_test_labels = collect_augment_aggregate_data(test_dataset, test_loader, num_of_clients, 10)
    all_train_data, all_train_labels, all_test_data, all_test_labels = create_batch_data(all_train_data, all_train_labels, all_test_data, all_test_labels, 256)

    vgg_backbone = vgg16_bn(weights=VGG16_BN_Weights.DEFAULT).eval()
    vgg_backbone = torch.nn.Sequential(*(list(vgg_backbone.children())[:-1])).to(device)


    with torch.no_grad():
        train_data_all, train_label_all, test_data_all, test_label_all = [], [], [], []
        for train_data, train_label in zip(all_train_data, all_train_labels):
            train_data = torch.tensor(train_data)
            train_data = train_data.to(device)
            train_data = vgg_backbone(train_data).reshape(train_data.size(0), -1).to('cpu').numpy()

            train_data_all.append(train_data)
            train_label_all.append(train_label)

        for test_data, test_label in zip(all_test_data, all_test_labels):
            test_label_all.append(test_label)
            test_data = torch.tensor(test_data)
            test_data = test_data.to(device)
            test_data = vgg_backbone(test_data).reshape(test_data.size(0), -1).to('cpu').numpy()
            test_data_all.append(test_data)

    train_data_all, train_label_all = np.concatenate(train_data_all, axis=0), np.concatenate(train_label_all, axis=0)
    test_data_all, test_label_all = np.concatenate(test_data_all, axis=0), np.concatenate(test_label_all, axis=0)
    train_data_all, train_label_all = torch.tensor(train_data_all), torch.tensor(train_label_all)
    test_data_all, test_label_all = torch.tensor(test_data_all), torch.tensor(test_label_all)
    train_dataset = TensorDataset(train_data_all, train_label_all)
    test_dataset = TensorDataset(test_data_all, test_label_all)
    return train_dataset, test_dataset

In [4]:
train_dataset, test_dataset = load_all_data_apply_vgg_cifar10()
train_loader = DataLoader(train_dataset, batch_size=256)
test_loader = DataLoader(test_dataset, batch_size=256)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class SimpleNetwork(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lin_01 = nn.Linear(in_channel, hidden_channel, bias=False)
        self.lin_02 = nn.Linear(hidden_channel, out_channel, bias=False)
        self.act = nn.ReLU()
    
    def forward(self, x):
        out = self.act(self.lin_01(x))
        out = self.lin_02(out)
        return out
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleNetwork(25088, 128, 10).to(device)
optim = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [6]:
model.train()
for iter_idx, (data, label) in enumerate(train_loader):
    data, label = data.to(device), label.to(device)
    optim.zero_grad()
    preds = model(data)
    loss = criterion(preds, label)
    print(loss.item())
    loss.backward()
    optim.step()

    if iter_idx == 0 or (iter_idx % 10) == 0 or iter_idx == (len(train_loader) - 1):
        model.eval()
        tp_count = 0
        val_size = 0
        with torch.no_grad():
            for val_iter_idx, (val_data, val_label) in enumerate(test_loader):
                val_data, val_label = val_data.to(device), val_label.to(device)
                val_preds = model(val_data)
                val_preds = torch.argmax(val_preds, dim=1)
                tp_count += torch.count_nonzero(val_label == val_preds)
                val_size += val_data.shape[0]
            print('accuracy: {}'.format((tp_count / val_size).item()))
        model.train()

2.3148081302642822
accuracy: 0.2500999867916107
2.710838794708252
2.044719696044922
1.8755098581314087
1.5792943239212036
1.3938508033752441
1.5692689418792725
1.2746644020080566
1.2006922960281372
1.2303580045700073
1.1494613885879517
accuracy: 0.5781999826431274
1.2506020069122314
1.1496819257736206
1.196250081062317
1.0836868286132812
1.1545140743255615
1.1584733724594116
1.0421581268310547
1.0416914224624634
1.0890171527862549
1.1498979330062866
accuracy: 0.6280999779701233
0.9315188527107239
0.9759005308151245
1.1548951864242554
0.9148009419441223
1.0554624795913696
1.0986195802688599
0.971846878528595
0.9741944074630737
0.8796263337135315
0.9802176356315613
accuracy: 0.6552000045776367
0.9796242713928223
1.0421265363693237
1.0232962369918823
1.049885869026184
1.0244455337524414
1.0467958450317383
0.8718031048774719
1.0104317665100098
0.9148290157318115
1.0242151021957397
accuracy: 0.6635000109672546
0.9040927290916443
0.9184916615486145
0.8869460821151733
0.9979096055030823
0.927