In [1]:
import os
import struct
import socket
import pickle
import time

import h5py
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from torch.utils.data import Subset
from torch.autograd import Variable
import torch.nn.init as init
from resnet import ResNet18, ResNet50
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score

import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_path = '../../models/cifar10_data'

# Setup cpu
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)

# Setup client order
client_order = int(0)
print('Client starts from: ', client_order)

num_train_data = 50000

# Load data
from random import shuffle

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

indices = list(range(50000))

part_tr = indices[num_train_data * client_order : num_train_data * (client_order + 1)]

train_set  = torchvision.datasets.CIFAR10(root=root_path, train=True, download=True, transform=transform)
train_set_sub = Subset(train_set, part_tr)
train_loader = torch.utils.data.DataLoader(train_set_sub, batch_size=8, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root=root_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=8, shuffle=False, num_workers=2)

x_train, y_train = next(iter(train_loader))
print(f'Train batch shape x: {x_train.size()} y: {y_train.size()}')
total_batch = len(train_loader)
print(f'Num Batch {total_batch}')


Client starts from:  0
Files already downloaded and verified
Files already downloaded and verified
Train batch shape x: torch.Size([8, 3, 32, 32]) y: torch.Size([8])
Num Batch 6250


In [3]:
def get_metrics(net, eval_loader, device):
    net.eval()
    criterion = nn.CrossEntropyLoss(reduction='sum').to(device)
    
    with torch.no_grad():
        logits_all, targets_all = torch.tensor([], device='cpu'), torch.tensor([], dtype=torch.int, device='cpu')
        for x, y in eval_loader:
            x, y = x.to(device), y.to(device)
            logits = net(x)
            logits_all = torch.cat((logits_all, logits.detach().cpu()),dim=0)
            targets_all = torch.cat((targets_all, y.cpu()), dim=0)
    
    pred = F.log_softmax(logits_all, dim=1)
    loss = criterion(pred, targets_all)/len(eval_loader.dataset) # validation loss
    
    output = pred.argmax(dim=1) # predicated/output label
    prob = F.softmax(logits_all, dim=1) # probabilities

    acc = accuracy_score(y_pred=output.numpy(), y_true=targets_all.numpy())
    bal_acc = balanced_accuracy_score(y_pred=output.numpy(), y_true=targets_all.numpy())
    auc = roc_auc_score(targets_all.numpy(), prob.numpy(), multi_class='ovr')

    return loss.item(), acc, auc, bal_acc

In [4]:
resnet_model = ResNet50(channel=3, num_classes=10).to(device)
epoch = 1
lr = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet_model.parameters(), lr = lr, momentum = 0.9)

# Start training
start_time = time.time()
print("Start training @ ", time.asctime())

for epc in range(epoch):

    for i, data in enumerate(tqdm(train_loader, ncols=100, desc='Centralized training', disable=True)):
        x, label = data
        x = x.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        output = resnet_model(x)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:

            # measure accuracy and record loss
            _, predicted = torch.max(output, 1)
            correct = (predicted == label).sum().item()
            accuracy = correct / len(label)
            print(f'Epoch: {epc+1}/{epoch}, Batch: {i+1}/{total_batch}, Train Loss: {round(loss.item(), 2)} Train Accuracy: {round(accuracy, 2)}')

            test_loss, test_acc, test_auc, test_bal_acc = get_metrics(resnet_model, test_loader, device)
            print(f'                                           Test Loss: {round(test_loss, 2)} Test Accuracy: {round(test_acc, 2)} Test AUC: {round(test_auc, 2)} Test Balanced Accuracy: {round(test_bal_acc, 2)}')



Start training @  Thu Oct  5 09:59:48 2023
Epoch: 1/1, Batch: 1/6250, Train Loss: 2.4 Train Accuracy: 0.0
              Test Loss: 2.41 Test Accuracy: 0.1 Test AUC: 0.49 Test Balanced Accuracy: 0.1
Epoch: 1/1, Batch: 101/6250, Train Loss: 2.56 Train Accuracy: 0.0
              Test Loss: 2.57 Test Accuracy: 0.1 Test AUC: 0.57 Test Balanced Accuracy: 0.1
Epoch: 1/1, Batch: 201/6250, Train Loss: 2.63 Train Accuracy: 0.12
              Test Loss: 2.51 Test Accuracy: 0.1 Test AUC: 0.58 Test Balanced Accuracy: 0.1
Epoch: 1/1, Batch: 301/6250, Train Loss: 2.77 Train Accuracy: 0.0
              Test Loss: 2.59 Test Accuracy: 0.11 Test AUC: 0.58 Test Balanced Accuracy: 0.11
Epoch: 1/1, Batch: 401/6250, Train Loss: 2.05 Train Accuracy: 0.38
              Test Loss: 2.56 Test Accuracy: 0.1 Test AUC: 0.59 Test Balanced Accuracy: 0.1
Epoch: 1/1, Batch: 501/6250, Train Loss: 2.75 Train Accuracy: 0.0
              Test Loss: 2.57 Test Accuracy: 0.14 Test AUC: 0.6 Test Balanced Accuracy: 0.14
Epoch: 