In [11]:
import copy
import pickle
import tqdm

import torch
from torch import nn
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset

import numpy as np
import random

from models import ResNet50
from models_CCVR import feature_extractor, classifier
from update import LocalUpdate, test_inference

In [12]:
# parameters
iid = 1 # if the data is i.i.d or not
unbalanced = 0 # in non i.i.d. setting split the data between clients equally or not
num_users = 100 # number of client
frac = 0.1 # fraction of the clients to be used for federated updates
n_epochs = 100
gpu = 0
optimizer = "sgd" #sgd or adam
local_batch_size = 10 # batch size of local updates in each user
lr = 0.001 # learning rate
local_epochs = 10
loss_function = "CrossEntropyLoss"

# percentage = 90  # percentage of strugglers

num_groups = 0  # 0 for BatchNorm, > 0 for GroupNorm
if num_groups == 0:
    normalization_type = "BatchNorm"
else:
    normalization_type = "GroupNorm"

In [13]:
if iid:
    from utils_v2 import get_dataset, average_weights, weighted_average_weights, exp_details
else:
    from utils import get_dataset, average_weights, weighted_average_weights, exp_details

In [14]:
# for REPRODUCIBILITY https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(0)

generator = torch.Generator()
generator.manual_seed(0)

np.random.seed(0)

In [15]:
class DatasetSplit(Dataset):
    """
    An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)

In [16]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [17]:
def get_dataloader(dataset, idxs):
    trainloader = DataLoader(DatasetSplit(dataset, idxs),
                             batch_size=None, shuffle=True, generator=generator,
                             worker_init_fn=seed_worker)

    return trainloader

In [18]:
train_dataset, test_dataset, user_groups = get_dataset(iid=iid, unbalanced=unbalanced,
                                                       num_users=num_users)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
model = ResNet50(n_type=normalization_type)
# model = CNNCifar()

train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
    device = torch.device("cpu")
    gpu = 0
else:
    print('CUDA is available!  Training on GPU ...')
    device = torch.device("cuda")
    gpu = 1

model.to(device)

CUDA is available!  Training on GPU ...


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (

In [20]:
# filename_pt = "fedAVG_results/weighted_average/ResNet50_100_sgd_lr_[0.001]_C[0.1]_iid[0]_unbalanced[1]_E[1]_B[10]_BatchNorm_numGroups[0].pt"

# filename_pt = "fedAVG_results/ResNet50_100_sgd_lr_[0.001]_C[0.1]_iid[0]_unbalanced[0]_E[1]_B[10]_BatchNorm_numGroups[0].pt"

filename_pt = "fedAVG_results/ResNet50_100_sgd_lr_[0.001]_C[0.1]_iid[1]_unbalanced[0]_E[1]_B[10]_BatchNorm_numGroups[0].pt"

# load saved model (i.e. the one with the smallest validation loss)
model.load_state_dict(torch.load(filename_pt))

<All keys matched successfully>

In [21]:
f = feature_extractor()
f.to(device)

f.load_state_dict(torch.load(filename_pt))

<All keys matched successfully>

In [22]:
# def test_inference(model, test_dataset, gpu=1, local_batch_size=10, loss_function="NLLLoss"):
#     """
#     Returns the test accuracy and loss.
#     """
#
#     model.eval()
#     test_loss = 0.0
#     class_correct = list(0. for i in range(10))
#     class_total = list(0. for i in range(10))
#
#     device = 'cuda' if gpu else 'cpu'
#     if loss_function == "NLLLoss":
#         criterion = nn.NLLLoss()
#     if loss_function == "CrossEntropyLoss":
#         criterion = nn.CrossEntropyLoss()
#
#     testloader = DataLoader(test_dataset, batch_size=local_batch_size,
#                             shuffle=False, generator=generator)
#
#     for images, labels in testloader:
#         images, labels = images.to(device), labels.to(device)
#
#         # Inference
#         output = model(images)
#         loss = criterion(output, labels)
#         test_loss += (loss.data.item() * images.shape[0])
#
#         # Prediction
#         # convert output probabilities to predicted class
#         _, pred = torch.max(output, 1)
#         # compare predictions to true label
#         correct_tensor = pred.eq(labels.data.view_as(pred))
#         correct = np.squeeze(correct_tensor.numpy()) if not gpu else np.squeeze(correct_tensor.cpu().numpy())
#
#         for i in range(len(images)):
#             label = labels.data[i]
#             class_correct[label] += correct[i].item()
#             class_total[label] += 1
#
#     # average test loss
#     test_loss = test_loss / len(testloader.dataset)
#
#     accuracy = np.sum(class_correct) / np.sum(class_total)
#
#     return accuracy, test_loss
#
#
# # test the trained model
#
# test_acc, test_loss = test_inference(model=model, test_dataset=test_dataset, gpu=gpu,
#                                      loss_function=loss_function)
#
# print(f'\nResults after {n_epochs} global rounds of training:')
# print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

In [23]:
# import torch.nn.functional as F
#
# f = nn.Sequential(
#     # stop at conv4
#     *list(model.children())[:-1]
# )

In [24]:
summary(f, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]           4,096
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
            Conv2d-7          [-1, 256, 32, 32]          16,384
       BatchNorm2d-8          [-1, 256, 32, 32]             512
            Conv2d-9          [-1, 256, 32, 32]          16,384
      BatchNorm2d-10          [-1, 256, 32, 32]             512
       Bottleneck-11          [-1, 256, 32, 32]               0
           Conv2d-12           [-1, 64, 32, 32]          16,384
      BatchNorm2d-13           [-1, 64, 32, 32]             128
           Conv2d-14           [-1, 64,

In [25]:
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]           4,096
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
            Conv2d-7          [-1, 256, 32, 32]          16,384
       BatchNorm2d-8          [-1, 256, 32, 32]             512
            Conv2d-9          [-1, 256, 32, 32]          16,384
      BatchNorm2d-10          [-1, 256, 32, 32]             512
       Bottleneck-11          [-1, 256, 32, 32]               0
           Conv2d-12           [-1, 64, 32, 32]          16,384
      BatchNorm2d-13           [-1, 64, 32, 32]             128
           Conv2d-14           [-1, 64,

In [26]:
# from torchvision.models.feature_extraction import get_graph_node_names
#
# train_nodes, eval_nodes = get_graph_node_names(model)

In [27]:
# return_nodes_featureExtractor = {
#     # node_name: user-specified key for output dict
#     'layer1.2.relu_2': 'layer1',
#     'layer2.3.relu_2': 'layer2',
#     'layer3.5.relu_2': 'layer3',
#     'layer4.2.relu_2': 'layer4'
# }
#
# return_nodes_Classifier = {
#     'avg_pool2d' : 'avg_pool',
#     'size' : 'size',
#     'view' : 'view',
#     'linear' : 'linear'
# }

In [28]:
# from torchvision.models.feature_extraction import create_feature_extractor
#
# f = create_feature_extractor(model, train_return_nodes=train_nodes[:-3], eval_return_nodes=eval_nodes[:-3])
# # f = create_feature_extractor(model, return_nodes_featureExtractor)
# g = create_feature_extractor(model, return_nodes_Classifier)

In [29]:
# f(torch.rand((1, 3, 32, 32)).to(device))["avg_pool2d"].reshape(-1).shape

In [30]:
# g(torch.rand((1, 3, 32, 32)).to(device))

In [31]:
def clientUpdate(f, dataset, idxs, device):

    trainloader = get_dataloader(dataset, idxs)

    d = {}
    sum_ = 0

    # extract features by category
    for batch_idx, (image, label) in enumerate(trainloader):
        sum_ += 1

        image = image.to(device)
        label = int(label.cpu())

        # feature = (f(image.reshape(1, 3, 32, 32))["avg_pool2d"].reshape(-1)).cpu().detach()
        # feature = (f(image.reshape(1, 3, 32, 32))["layer4.2.relu_2"].reshape(-1)).cpu().detach()
        feature = (f(image.reshape(1, 3, 32, 32)).reshape(-1)).cpu().detach()

        if label in d.keys():
            d[label].append(feature)
        else:
            d[label] = [feature]

    # mu, sigma
    upload_d = {}

    # for k, v in tqdm.tqdm(d.items()):
    #     v_item = torch.stack(v).detach().cpu()
    #
    #     # consider the case where the sample size is too small to upload
    #     if len(v_item) < 10:
    #         continue
    #
    #     mu, sigma = v_item.mean(dim=0), v_item.var(dim=0)
    #     upload_d[k] = {"mu": mu, "sigma": sigma, "N": len(v)}

    for k, v in tqdm.tqdm(d.items()):
        v_item = torch.stack(v).detach().cpu()
        # consider the case where the sample size is too small to upload
        if len(v_item) < 10:
            continue

        N = len(v)

        mu = v_item.mean(dim=0)

        sigma = torch.zeros((2048, 2048))
        for t in v_item:
            x = (t - mu).reshape(1, 2048)
            sigma = torch.add(torch.mul(x, torch.transpose(x, 1, 0)), sigma)
        sigma *= 1/(N-1)

        upload_d[k] = {"mu": mu, "sigma": sigma, "N": N}

    return upload_d

In [32]:
upload_d = clientUpdate(f, train_dataset, user_groups[0], device)

  return torch.tensor(image), torch.tensor(label)
100%|██████████| 10/10 [00:03<00:00,  3.26it/s]


In [33]:
upload_d.keys()

dict_keys([9, 5, 3, 4, 8, 7, 6, 0, 1, 2])

In [34]:
upload_d[0]["mu"].shape

torch.Size([2048])

In [35]:
upload_d[0]["sigma"].shape

torch.Size([2048, 2048])

In [36]:
m = max(int(frac * num_users), 1) # number of users to be used for federated updates, at least 1
# idxs_users = np.random.choice(range(num_users), m, replace=False) # choose randomly m users

idxs_users = range(num_users)

upload_d_list = [ clientUpdate(f, train_dataset, user_groups[idx], device) for idx in idxs_users ]

  return torch.tensor(image), torch.tensor(label)
100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
100%|██████████| 10/10 [00:02<00:00,  3.53it/s]
100%|██████████| 10/10 [00:02<00:00,  3.50it/s]
100%|██████████| 10/10 [00:03<00:00,  3.20it/s]
100%|██████████| 10/10 [00:02<00:00,  3.46it/s]
100%|██████████| 10/10 [00:02<00:00,  3.52it/s]
100%|██████████| 10/10 [00:03<00:00,  3.20it/s]
100%|██████████| 10/10 [00:03<00:00,  3.29it/s]
100%|██████████| 10/10 [00:02<00:00,  3.48it/s]
100%|██████████| 10/10 [00:03<00:00,  3.27it/s]
100%|██████████| 10/10 [00:03<00:00,  3.22it/s]
100%|██████████| 10/10 [00:03<00:00,  2.97it/s]
100%|██████████| 10/10 [00:03<00:00,  3.04it/s]
100%|██████████| 10/10 [00:03<00:00,  3.22it/s]
100%|██████████| 10/10 [00:03<00:00,  3.30it/s]
100%|██████████| 10/10 [00:03<00:00,  3.16it/s]
100%|██████████| 10/10 [00:03<00:00,  3.21it/s]
100%|██████████| 10/10 [00:03<00:00,  3.30it/s]
100%|██████████| 10/10 [00:03<00:00,  3.06it/s]
100%|██████████| 10/10 [00:03<00:00,  

In [37]:
print(len(upload_d_list))
upload_d_list[5]

100


{5: {'mu': tensor([0.8084, 0.8075, 0.8563,  ..., 0.8329, 0.9784, 0.8318]),
  'sigma': tensor([[ 0.0117,  0.0036,  0.0019,  ...,  0.0023,  0.0024,  0.0064],
          [ 0.0036,  0.0065,  0.0024,  ...,  0.0020,  0.0025,  0.0032],
          [ 0.0019,  0.0024,  0.0112,  ..., -0.0017,  0.0030,  0.0040],
          ...,
          [ 0.0023,  0.0020, -0.0017,  ...,  0.0124,  0.0002,  0.0019],
          [ 0.0024,  0.0025,  0.0030,  ...,  0.0002,  0.0084,  0.0036],
          [ 0.0064,  0.0032,  0.0040,  ...,  0.0019,  0.0036,  0.0169]]),
  'N': 50},
 6: {'mu': tensor([0.8137, 0.8196, 0.8324,  ..., 0.8494, 0.9741, 0.8529]),
  'sigma': tensor([[ 0.0127,  0.0034,  0.0010,  ...,  0.0010,  0.0043,  0.0022],
          [ 0.0034,  0.0069,  0.0011,  ...,  0.0032,  0.0030,  0.0028],
          [ 0.0010,  0.0011,  0.0115,  ..., -0.0016,  0.0005,  0.0049],
          ...,
          [ 0.0010,  0.0032, -0.0016,  ...,  0.0110,  0.0042,  0.0033],
          [ 0.0043,  0.0030,  0.0005,  ...,  0.0042,  0.0107,  0.005

In [38]:
def server_aggregate_stat(upload_d_list):
    # statistical feature distribution for each label
    fd_d = {}

    for l in range(10):  # for each label

        print("label : ", l)

        # clients do not necessarily have all tags, heterogeneous
        labeled_fd_lst = [ x for x in upload_d_list if l in x.keys() ]
        sum_n = sum(x[l]["N"] for x in labeled_fd_lst)

        mu_lst = [fd[l]["mu"] * fd[l]["N"] / sum_n for fd in labeled_fd_lst]
        mu = torch.stack(mu_lst).sum(dim=0)
        # print(mu.shape)
        #
        # sigma1 = torch.stack(
        #     [fd[l]["mu"] * (fd[l]["N"] - 1) / (sum_n - 1) for fd in labeled_fd_lst]
        # ).sum(dim=0)
        #
        # sigma2 = torch.stack(
        #     [
        #         fd[l]["mu"] * fd[l]["mu"].T * fd[l]["N"] / (sum_n - 1)
        #         for fd in labeled_fd_lst
        #     ]
        # ).sum(dim=0)
        #
        # sigma = sigma1 + sigma2 - sum_n / (sum_n - 1) * mu * mu.T

        sigma1 = torch.stack(
            [ (fd[l]["N"] - 1) / (sum_n - 1) * fd[l]["sigma"] for fd in labeled_fd_lst ]
        ).sum(dim=0)
        # print(sigma1.shape)

        sigma2 = torch.stack(
            [
                fd[l]["mu"].reshape(1, 2048) * torch.transpose(fd[l]["mu"].reshape(1, 2048), 1, 0) * fd[l]["N"] / (sum_n - 1)
                for fd in labeled_fd_lst
            ]
        ).sum(dim=0)
        # print(sigma2.shape)

        sigma3 = sum_n / (sum_n - 1) * mu.reshape(1, 2048) * torch.transpose(mu.reshape(1, 2048), 1, 0)
        # print(sigma3.shape)

        sigma = sigma1 + sigma2 - sigma3
        # print(sigma.shape)


        virtual_samples = np.random.default_rng().multivariate_normal(mu, sigma, check_valid='ignore', size=1000, tol=1e-6, method='eigh')
        fd_d[l] = torch.tensor(virtual_samples)


        # generate data samples with batchsize of 1k, there are 10 categories in total, so it is 10k samples
        # dist_c = np.random.normal(mu, sigma, size=(1000, mu.size()[0]))
        # fd_d[l] = torch.tensor(dist_c)

    return fd_d

In [39]:
# fd_d = {}
# for l in range(10):  # for each label
#     print("label : ", l)
#
#     # clients do not necessarily have all tags, heterogeneous
#     labeled_fd_lst = [ x for x in upload_d_list if l in x.keys() ]
#     sum_n = sum(x[l]["N"] for x in labeled_fd_lst)
#
#     mu_lst = [fd[l]["mu"] * fd[l]["N"] / sum_n for fd in labeled_fd_lst]
#     mu = torch.stack(mu_lst).sum(dim=0)
#     # print(mu.shape)
#
#     sigma1 = torch.stack(
#         [ (fd[l]["N"] - 1) / (sum_n - 1) * fd[l]["sigma"] for fd in labeled_fd_lst ]
#     ).sum(dim=0)
#     # print(sigma1.shape)
#
#     sigma2 = torch.stack(
#         [
#             fd[l]["mu"].reshape(1, 2048) * torch.transpose(fd[l]["mu"].reshape(1, 2048), 1, 0) * fd[l]["N"] / (sum_n - 1)
#             for fd in labeled_fd_lst
#         ]
#     ).sum(dim=0)
#     # print(sigma2.shape)
#
#     sigma3 = sum_n / (sum_n - 1) * mu.reshape(1, 2048) * torch.transpose(mu.reshape(1, 2048), 1, 0)
#     # print(sigma3.shape)
#
#     sigma = sigma1 + sigma2 - sum_n / (sum_n - 1) * mu * mu.T
#     # print(sigma.shape)
#
#
#     virtual_samples = np.random.default_rng().multivariate_normal(mu, sigma, check_valid='ignore', size=100, tol=1e-6, method='eigh')
#     fd_d[l] = torch.tensor(virtual_samples)


In [None]:
fd_d = server_aggregate_stat(upload_d_list)

label :  0
label :  1
label :  2
label :  3


In [None]:
# for k, v in fd_d.items():
#     print("label", k)
#     print(v.shape)

In [None]:
filename_pikle = "CCVR_results/extracted_features_iid[{}]_unbalanced[{}]_.pickle".format(iid, unbalanced)

with open(filename_pikle, 'wb') as handle:
    pickle.dump(fd_d, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open(filename_pikle, 'rb') as handle:
    fd_d = pickle.load(handle)

In [None]:
class DictDataset(Dataset):
    def __init__(self, label_data_d):
        "Initialization"
        self.data, self.labels = [], []
        for label, data in label_data_d.items():
            self.data.append(data)
            self.labels.append(torch.tensor([label] * len(data)))

        self.data = torch.cat(self.data).type(torch.float32)
        self.labels = torch.cat(self.labels).type(torch.long)

    def __len__(self):
        "Denotes the total number of samples"
        return len(self.data)

    def __getitem__(self, index):
        "Generates one sample of data"
        return self.data[index], self.labels[index]

In [None]:
def get_dataloader(trainset, testset, batch_size, num_workers=0, pin_memory=False):
    trainloader = DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    testloader = DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return trainloader, testloader

In [None]:
# import torch.nn.functional as F
#
# class classifier(nn.Module):
#     def __init__(self):
#         super(classifier, self).__init__()
#
#         # network ends a 10-way fully-connected layer
#         expansion = 4
#         num_classes = 10
#
#         self.linear = nn.Linear(512 * expansion, num_classes)
#
#     def forward(self, x):
#         # out = F.avg_pool2d(x, 4)  # average pooling before fully connected layer
#         out = x.view(x.size(0), -1)
#         out = self.linear(out)
#         return out

In [None]:
g = classifier()
g.to(device)
summary(g, (2048, 1, 1))

In [None]:
# fd_d[0][0].shape

In [None]:
# fd_d[0][0].reshape((2048, 1, 1)).shape

In [None]:
# x = fd_d[0].reshape((1000, 2048, 1, 1)).type(torch.float).to(device)
# x.shape

In [None]:
# g(x)

In [None]:
from torch.optim import lr_scheduler

# optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.SGD(g.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

criterion = nn.CrossEntropyLoss()

# Decay LR by a factor of 0.1 every 7 epochs
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# prepare datasets, models, optimizers, and more
trainset = DictDataset(fd_d)
train_loader, _ = get_dataloader(trainset, trainset, batch_size=64)

In [None]:
# for epoch in range(1, n_epochs+1):
for epoch in range(1, 100):
    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    correct_train = 0.0
    correct_valid = 0.0

    ###################
    # train the model #
    ###################
    g.train()

    for k, v in fd_d.items():

        data, target = v.type(torch.float).cuda(), torch.full((1000,), k).cuda()

        # clear the gradients of all optimized variables
        optimizer.zero_grad()

        # forward pass: compute predicted outputs by passing inputs to the model
        output = g(data.reshape((1000, 2048, 1, 1)))

        # calculate the batch loss
        loss = criterion(output, target)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update training loss
        train_loss += (loss.data.item() * data.shape[0])
        # print('outputs on which to apply torch.max ', prediction)
        # find the maximum along the rows, use dim=1 to torch.max()
        _, predicted_outputs = torch.max(output.data, 1)
        # Update the running corrects
        correct_train += (predicted_outputs == target).float().sum().item()

    scheduler.step()

    # calculate average losses
    train_loss = train_loss / len(train_loader.sampler)
    # calculate accuracies
    train_acc =  correct_train / len(train_loader.sampler)

    print('Epoch: {} \tTraining Loss: {:.3f} \tTraining Accuracy: {:.3f}'.format(
        epoch, train_loss, train_acc))

In [None]:
def test_inference(f, g, test_dataset, gpu=1, local_batch_size=10, loss_function="CrossEntropyLoss"):
    """
    Returns the test accuracy and loss.
    """

    f.eval()
    g.eval()

    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    device = 'cuda' if gpu else 'cpu'
    if loss_function == "NLLLoss":
        criterion = nn.NLLLoss()
    if loss_function == "CrossEntropyLoss":
        criterion = nn.CrossEntropyLoss()

    testloader = DataLoader(test_dataset, batch_size=10,
                            shuffle=False, generator=generator)

    for images, labels in testloader:
    # for k, v in fd_d.items():
        images, labels = images.to(device), labels.to(device)
        # images, labels = v.type(torch.float).cuda().reshape(1000, 2048, 1, 1), torch.full((1000,), k).cuda()

        # Inference
        # output_f = f(images)["avg_pool2d"]
        output_f = f(images.cuda())
        output_g = g(output_f)

        loss = criterion(output_g, labels)
        test_loss += (loss.data.item() * images.shape[0])

        # Prediction
        # convert output probabilities to predicted class
        _, pred = torch.max(output_g, 1)
        # compare predictions to true label
        correct_tensor = pred.eq(labels.data.view_as(pred))
        correct = np.squeeze(correct_tensor.numpy()) if not gpu else np.squeeze(correct_tensor.cpu().numpy())

        for i in range(len(images)):
            label = labels.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    # average test loss
    test_loss = test_loss / len(testloader.dataset)

    accuracy = np.sum(class_correct) / np.sum(class_total)

    return accuracy, test_loss

In [None]:
# test the trained model

test_acc, test_loss = test_inference(f, g, test_dataset=test_dataset, gpu=gpu,
                                     loss_function=loss_function)

print(f'\nResults after {n_epochs} global rounds of training:')
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))