In [1]:
import copy
import os
import pickle
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset

from utils.options import args_parser
from utils.train_utils import get_data, get_model
from models.Update import DatasetSplit
from models.test import test_img_local, test_img_local_all_avg

import pdb
import easydict

import sys

import random



# Seed
torch.manual_seed(1)#args.running_idx=args.seed
torch.cuda.manual_seed(1)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False
np.random.seed(1)
random.seed(1)


In [2]:
!pip install easydict


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# Initial / Personalized Accuracy(Full)-Shard

In [3]:
model = 'mobile' # cnn, mobile
dataset = 'cifar100' # cifar10, cifar100 
num_classes = 100 # 10, 100
momentum = 0.90
wd = 1e-05
personalization_epoch = 5 # fine-tuning epochs for personalization

server_data_ratio = 0.00

for shard_per_user in [10, 50, 100]:
    for frac in [0.1]:
        for local_ep in [5]:
            for local_upt_part in ['full']:
                args = easydict.EasyDict({'epochs': 320,
                                          'num_users': 100,
                                          'hetero_option': 'shard',
                                          'shard_per_user': shard_per_user,
                                          'frac': frac,
                                          'local_ep': local_ep,
                                          'local_bs': 50,
                                          'bs': 128,
                                          'lr': 5e-3,
                                          'momentum': momentum,
                                          'wd': wd,
                                          'lr_decay': 0.1,
                                          'model': model,                                          
                                          'dataset': dataset,
                                          'iid': False,
                                          'num_classes': num_classes,                                
                                          'gpu': 0,                                          
                                          'local_upt_part': local_upt_part,
                                          'seed': 0,
                                          'fn': True,
                                          'verbose': False,
                                          'feature_norm' : 1
                                          })

                # parse args
                args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
                base_dir = './save/full_and_body/{}_iid{}_num{}_C{}_le{}_m{}_wd{}/shard{}/FedAvg/'.format(
                    args.model, args.iid, args.num_users, args.frac, args.local_ep, args.momentum, args.wd, args.shard_per_user)
                algo_dir = 'fn_{}/seed_{}/local_upt_{}_lr_0.5'.format(args.fn, args.seed, args.local_upt_part)

                dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
                dict_save_path = 'dict_users_100_{}.pkl'.format(args.shard_per_user)
                with open(dict_save_path, 'rb') as handle:#기존 pretrained되었을 때 쓰였던 클라이언트 구성으로 덮어씌운다.
                    dict_users_train, dict_users_test = pickle.load(handle)

                # build model
                net_glob = get_model(args)
                net_glob.train()

                net_local_list = []
                for user_ix in range(args.num_users):
                    net_local_list.append(copy.deepcopy(net_glob))

                criterion = nn.CrossEntropyLoss()

                before_acc_results = []#pretrained 모델에서의 모든 각 클라이언트의 test acc 기록!!
                after_acc_results = []
                
                for user, net_local in enumerate(net_local_list):
                    model_save_path = os.path.join(base_dir, algo_dir, 'best_model.pt')#pretrained된 중앙모델 업로드!!
                    
                    net_local.load_state_dict(torch.load(model_save_path, map_location=args.device), strict=False)
                    acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
                    before_acc_results.append(acc_test)
                    net_local.train()
                    ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users_train[user]), batch_size=args.local_bs, shuffle=True)

                    body_params = [p for name, p in net_local.named_parameters() if 'linear' not in name]
                    head_params = [p for name, p in net_local.named_parameters() if 'linear' in name]
                    optimizer = torch.optim.SGD([{'params': body_params, 'lr': args.lr},
                                                 {'params': head_params, 'lr': args.lr}],
                                                momentum=args.momentum)#full update!!
                    
                    for iter in range(personalization_epoch):
                        for batch_idx, (images, labels) in enumerate(ldr_train):
                            images, labels = images.to(args.device), labels.to(args.device)
                            net_local.zero_grad()
                            logits = net_local(images)

                            loss = criterion(logits, labels)
                            loss.backward()
                            optimizer.step()

                    acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
                    after_acc_results.append(acc_test)#pretrain 이후의 personalized accuracy 기재!!
                print ("-----------------------------------------------------")
                print ("local update part: {}".format(local_upt_part))
                print ("shard: {}, frac: {}, local_ep: {}".format(shard_per_user, frac, local_ep))
                print ("Before min/max/mean/std of accuracy")
                print (np.min(before_acc_results), np.max(before_acc_results), np.mean(before_acc_results), round(np.std(before_acc_results), 2))
                print ("After min/max/mean/std of accuracy")
                print (np.min(after_acc_results), np.max(after_acc_results), np.mean(after_acc_results), round(np.std(after_acc_results), 2))
                print ("-----------------------------------------------------")

Files already downloaded and verified
Files already downloaded and verified
-----------------------------------------------------
local update part: full
shard: 10, frac: 0.1, local_ep: 5
Before min/max/mean/std of accuracy
27.0 72.0 47.23 8.32
After min/max/mean/std of accuracy
63.0 96.0 82.02 6.04
-----------------------------------------------------
Files already downloaded and verified
Files already downloaded and verified
-----------------------------------------------------
local update part: full
shard: 50, frac: 0.1, local_ep: 5
Before min/max/mean/std of accuracy
39.0 62.0 49.72 4.71
After min/max/mean/std of accuracy
44.0 66.0 54.72 4.96
-----------------------------------------------------
Files already downloaded and verified
Files already downloaded and verified
-----------------------------------------------------
local update part: full
shard: 100, frac: 0.1, local_ep: 5
Before min/max/mean/std of accuracy
36.0 65.0 51.33 5.37
After min/max/mean/std of accuracy
39.0 66.0

# Initial / Personalized Accuracy(Full)-LDA

In [4]:
model = 'mobile' # cnn, mobile
dataset = 'cifar100' # cifar10, cifar100 
num_classes = 100 # 10, 100
momentum = 0.90
wd = 1e-05
personalization_epoch = 5 # fine-tuning epochs for personalization

server_data_ratio = 0.00

for alpha in [0.1, 0.5, 1.0]:
    for frac in [0.1]:
        for local_ep in [5]:
            for local_upt_part in ['full']:
                args = easydict.EasyDict({'epochs': 320,
                                          'num_users': 100,
                                          'hetero_option': 'lda',
                                          'alpha': alpha,
                                          'frac': frac,
                                          'local_ep': local_ep,
                                          'local_bs': 50,
                                          'bs': 128,
                                          'lr': 5e-3,
                                          'momentum': momentum,
                                          'wd': wd,
                                          'lr_decay': 0.1,
                                          'model': model,                                          
                                          'dataset': dataset,
                                          'iid': False,
                                          'num_classes': num_classes,                                
                                          'gpu': 0,                                          
                                          'local_upt_part': local_upt_part,
                                          'seed': 0,
                                          'fn': True,
                                          'verbose': False,
                                          'feature_norm' : 1
                                          })

                # parse args
                args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
                base_dir = './save/full_and_body/{}_iid{}_num{}_C{}_le{}_m{}_wd{}/alpha{}/FedAvg/'.format(
                    args.model, args.iid, args.num_users, args.frac, args.local_ep, args.momentum, args.wd, args.alpha)
                algo_dir = 'fn_{}/seed_{}/local_upt_{}_lr_0.5'.format(args.fn, args.seed, args.local_upt_part)

                dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
                dict_save_path = 'dict_users_lda_{}_100_pfl.pkl'.format(args.alpha)
                with open(dict_save_path, 'rb') as handle:#기존 pretrained되었을 때 쓰였던 클라이언트 구성으로 덮어씌운다.
                    dict_users_train, dict_users_test = pickle.load(handle)

                # build model
                net_glob = get_model(args)
                net_glob.train()

                net_local_list = []
                for user_ix in range(args.num_users):
                    net_local_list.append(copy.deepcopy(net_glob))

                criterion = nn.CrossEntropyLoss()

                before_acc_results = []#pretrained 모델에서의 모든 각 클라이언트의 test acc 기록!!
                after_acc_results = []
                
                for user, net_local in enumerate(net_local_list):
                    model_save_path = os.path.join(base_dir, algo_dir, 'best_model.pt')#pretrained된 중앙모델 업로드!!
                    
                    net_local.load_state_dict(torch.load(model_save_path, map_location=args.device), strict=False)
                    acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
                    before_acc_results.append(acc_test)
                    net_local.train()
                    ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users_train[user]), batch_size=args.local_bs, shuffle=True)

                    body_params = [p for name, p in net_local.named_parameters() if 'linear' not in name]
                    head_params = [p for name, p in net_local.named_parameters() if 'linear' in name]
                    optimizer = torch.optim.SGD([{'params': body_params, 'lr': args.lr},
                                                 {'params': head_params, 'lr': args.lr}],
                                                momentum=args.momentum)#full update!!
                    
                    for iter in range(personalization_epoch):
                        for batch_idx, (images, labels) in enumerate(ldr_train):
                            images, labels = images.to(args.device), labels.to(args.device)
                            net_local.zero_grad()
                            logits = net_local(images)

                            loss = criterion(logits, labels)
                            loss.backward()
                            optimizer.step()

                    acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
                    after_acc_results.append(acc_test)#pretrain 이후의 personalized accuracy 기재!!
                print ("-----------------------------------------------------")
                print ("local update part: {}".format(local_upt_part))
                print ("LDA: {}, frac: {}, local_ep: {}".format(alpha, frac, local_ep))
                print ("Before min/max/mean/std of accuracy")
                print (np.min(before_acc_results), np.max(before_acc_results), np.mean(before_acc_results), round(np.std(before_acc_results), 2))
                print ("After min/max/mean/std of accuracy")
                print (np.min(after_acc_results), np.max(after_acc_results), np.mean(after_acc_results), round(np.std(after_acc_results), 2))
                print ("-----------------------------------------------------")

Files already downloaded and verified
Files already downloaded and verified
-----------------------------------------------------
local update part: full
LDA: 0.1, frac: 0.1, local_ep: 5
Before min/max/mean/std of accuracy
27.751196172248804 53.04740406320542 41.30332542022291 5.93
After min/max/mean/std of accuracy
58.65384615384615 80.75396825396825 69.50300894340913 5.28
-----------------------------------------------------
Files already downloaded and verified
Files already downloaded and verified
-----------------------------------------------------
local update part: full
LDA: 0.5, frac: 0.1, local_ep: 5
Before min/max/mean/std of accuracy
39.321357285429144 51.663405088062625 45.730643443517465 3.03
After min/max/mean/std of accuracy
45.211581291759465 57.95677799607073 51.759796349072296 3.14
-----------------------------------------------------
Files already downloaded and verified
Files already downloaded and verified
-----------------------------------------------------
loca