In [1]:
import os
import struct
import socket
import pickle
import time

import h5py
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from torch.utils.data import Subset
from torch.autograd import Variable
import torch.nn.init as init
from model import ResNet18, ResNet50
from utils import get_metrics_
import copy
import random
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_path = '../cifar10_data'

# Setup cpu
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = 'cuda:0'

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed) 


# Setup client order
client_order = int(0)
print('Client starts from: ', client_order)

num_train_data = 50000

# Load data
from random import shuffle

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

indices = list(range(50000))

part_tr = indices[num_train_data * client_order : num_train_data * (client_order + 1)]

train_set  = torchvision.datasets.CIFAR10(root=root_path, train=True, download=True, transform=transform)
train_set_sub = Subset(train_set, part_tr)
train_loader = torch.utils.data.DataLoader(train_set_sub, batch_size=64, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root=root_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)

x_train, y_train = next(iter(train_loader))
print(f'Train batch shape x: {x_train.size()} y: {y_train.size()}')
total_batch = len(train_loader)
print(f'Num Batch {total_batch}')


Client starts from:  0
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../cifar10_data/cifar-10-python.tar.gz


170499072it [00:23, 7259710.49it/s]                               


Extracting ../cifar10_data/cifar-10-python.tar.gz to ../cifar10_data
Files already downloaded and verified
Train batch shape x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
Num Batch 782


In [3]:
resnet_model = ResNet50(channel=3, num_classes=10).to(device)
epoch = 20
lr = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet_model.parameters(), lr = lr, momentum = 0.9)

# Start training
print("Start training @ ", time.asctime())

for epc in range(epoch):
    start_time = time.time()    
    for i, data in enumerate(tqdm(train_loader, ncols=100, desc='Centralized training', disable=False)):
        x, label = data
        x = x.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        output = resnet_model(x)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        # if (i + 1) % 100 == 0:
        #     # measure accuracy and record loss
        #     _, predicted = torch.max(output, 1)
        #     correct = (predicted == label).sum().item()
        #     accuracy = correct / len(label)
        #     print(f'Epoch: {epc+1}/{epoch}, Batch: {i+1}/{total_batch}, Train Loss: {round(loss.item(), 2)} Train Accuracy: {round(accuracy, 2)}')

        # if (i + 1) % 1000 == 0:
    test_loss, test_acc, test_auc, test_bal_acc = get_metrics_(resnet_model, test_loader, device)
    print(f'Epoch: {epc+1}/{epoch} => Test Loss: {round(test_loss, 4)} Test Accuracy: {round(test_acc, 4)} Test AUC: {round(test_auc, 4)} Test Balanced Accuracy: {round(test_bal_acc, 4)}')


    print(time.time() - start_time)

Start training @  Tue Nov 21 13:02:50 2023


Centralized training: 100%|███████████████████████████████████████| 782/782 [01:13<00:00, 10.66it/s]


Epoch: 1/20 => Test Loss: 2.0676 Test Accuracy: 0.2394 Test AUC: 0.7291 Test Balanced Accuracy: 0.2394
78.14146280288696


Centralized training:  24%|█████████▎                             | 187/782 [00:16<00:53, 11.16it/s]