In [None]:
import copy
import os
import pickle
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

from utils.options import args_parser
from utils.train_utils import get_data, get_model
from models.Update import DatasetSplit
from models.test import test_img_local, test_img_local_all, test_img_avg_all, test_img_ensemble_all

import pdb
import easydict

# Global/Personalized model analysis (Balanced Dataset)

In [None]:
model = 'mobile' # cnn, mobile
dataset = 'cifar100' # cifar10, cifar100 
num_classes = 100 # 10, 100
momentum = 0.90
wd = 0.0
personalization_epoch = 5 # fine-tuning epochs for personalization

server_data_ratio = 0.05

for shard_per_user in [10]: #[50, 10]: #, 50, 10]: # 10, 5, 2 cifar10 // 100, 50, 10 cifar100
    for frac in [0.1]: # 1.0, 0.1
        for local_ep in [10]: # 1, 4, 10
            for local_upt_part, aggr_part in [('full', 'full'), ('body', 'body')]: # [('body', 'body'), ('head', 'head'), ('full', 'body'), ('full', 'head'), ('full', 'full')]:
                args = easydict.EasyDict({'epochs': local_ep,
                                          'num_users': 100,
                                          'shard_per_user': shard_per_user,
                                          'server_data_ratio': server_data_ratio,
                                          'frac': frac,
                                          'local_ep': local_ep,
                                          'local_bs': 50,
                                          'bs': 128,
                                          'lr': 1e-3,
                                          'momentum': momentum,
                                          'wd': wd,
                                          'split': 'user',
                                          'grad_norm': False,
                                          'local_ep_pretrain': 0,
                                          'lr_decay': 1.0,
                                          'model': model,
                                          'kernul_num': 9,
                                          'kernul_sizes': '3,4,5',
                                          'norm': 'batch_norm',
                                          'num_filters': 32,
                                          'max_pool': 'True',
                                          'num_layers_keep': 1,
                                          'dataset': dataset,
                                          'iid': False,
                                          'num_classes': num_classes,
                                          'num_channels': 3,
                                          'gpu': 1,
                                          'stopping_rounds': 10,
                                          'verbose': False,
                                          'print_freq': 100,
                                          'seed': 1,
                                          'test_freq': 1,
                                          'load_fed': '',
                                          'results_save': 'run1', #run1, se1, head_se1, body_se1
                                          'start_saving': 0,
                                          'local_upt_part': local_upt_part,
                                          'aggr_part': aggr_part,
                                          'unbalanced': False
                                          })

                # parse args
                args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

                base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}_m{}_wd{}/shard{}_sdr{}/{}/'.format(
                    args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.momentum, args.wd, args.shard_per_user, args.server_data_ratio, args.results_save)
#                 base_dir = '/home/osilab7/hdd/jhoon_backup/FL_local_upt_aggr/save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
#                     args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)
                algo_dir = 'local_upt_{}_aggr_{}'.format(args.local_upt_part, args.aggr_part)

                dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
                dict_save_path = os.path.join(base_dir, algo_dir, 'dict_users.pkl')
                with open(dict_save_path, 'rb') as handle:
                    dict_users_train, dict_users_test = pickle.load(handle)

                # build model
                net_glob = get_model(args)
                net_glob.train()

                net_local_list = []
                for user_ix in range(args.num_users):
                    net_local_list.append(copy.deepcopy(net_glob))

                if args.local_upt_part == 'body':
                    body_lr = args.lr
                    head_lr = args.lr # (For personalization) 
                elif args.local_upt_part == 'head':
                    body_lr = args.lr # (For personalization)
                    head_lr = args.lr
                elif args.local_upt_part == 'full':
                    body_lr = args.lr
                    head_lr = args.lr

                criterion = nn.CrossEntropyLoss()

                before_acc_results = []
                after_acc_results = []
                
                for user, net_local in enumerate(net_local_list):
                    model_save_path = os.path.join(base_dir, algo_dir, 'best_model.pt')
#                     model_save_path = os.path.join(base_dir, algo_dir, 'best_local_{}.pt'.format(user))
                    net_local.load_state_dict(torch.load(model_save_path), strict=True)
                    acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
                    before_acc_results.append(acc_test)

                    net_local.train()
                    ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users_train[user]), batch_size=args.local_bs, shuffle=True)

                    body_params = [p for name, p in net_local.named_parameters() if 'linear' not in name]
                    head_params = [p for name, p in net_local.named_parameters() if 'linear' in name]
                    optimizer = torch.optim.SGD([{'params': body_params, 'lr': body_lr},
                                                 {'params': head_params, 'lr': head_lr}],
                                                momentum=args.momentum)
                    
                    for iter in range(personalization_epoch):
                        for batch_idx, (images, labels) in enumerate(ldr_train):
                            images, labels = images.to(args.device), labels.to(args.device)
                            net_local.zero_grad()
                            logits = net_local(images)

                            loss = criterion(logits, labels)
                            loss.backward()
                            optimizer.step()

                    acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
                    after_acc_results.append(acc_test)
                    
                print ("-----------------------------------------------------")
                print ("local update part: {}, aggregation part: {}".format(local_upt_part, aggr_part))
                print ("shard: {}, frac: {}, local_ep: {}".format(shard_per_user, frac, local_ep))
                print ("Before min/max/mean/std of accuracy")
                print (np.min(before_acc_results), np.max(before_acc_results), np.mean(before_acc_results), round(np.std(before_acc_results), 2))
                print ("After min/max/mean/std of accuracy")
                print (np.min(after_acc_results), np.max(after_acc_results), np.mean(after_acc_results), round(np.std(after_acc_results), 2))
                print ("-----------------------------------------------------")

# Global/Personalized model analysis (Unbalanced Dataset / IID)

In [None]:
model = 'mobile' # mobile
dataset = 'cifar100' # cifar100 
num_classes = 100 # 100

for shard_per_user in [2]: # 100, 50, 10 cifar100
    for frac in [0.1]: # 1.0, 0.1
        for local_ep in [4]: # 4
            for local_upt_part, aggr_part in [('body', 'body'), ('full', 'full')]:
                args = easydict.EasyDict({'epochs': local_ep,
                                          'num_users': 100,
                                          'shard_per_user': shard_per_user,
                                          'frac': frac,
                                          'local_ep': local_ep,
                                          'local_bs': 50,
                                          'bs': 128,
                                          'lr': 1e-3,
                                          'momentum': 0.9,
                                          'split': 'user',
                                          'grad_norm': False,
                                          'local_ep_pretrain': 0,
                                          'lr_decay': 1.0,
                                          'model': model,
                                          'kernul_num': 9,
                                          'kernul_sizes': '3,4,5',
                                          'norm': 'batch_norm',
                                          'num_filters': 32,
                                          'max_pool': 'True',
                                          'num_layers_keep': 1,
                                          'dataset': dataset,
                                          'iid': True,
                                          'num_classes': num_classes,
                                          'num_channels': 3,
                                          'gpu': 1,
                                          'stopping_rounds': 10,
                                          'verbose': False,
                                          'print_freq': 100,
                                          'seed': 1,
                                          'test_freq': 1,
                                          'load_fed': '',
                                          'results_save': 'run1',
                                          'start_saving': 0,
                                          'local_upt_part': local_upt_part,
                                          'aggr_part': aggr_part,
                                          'unbalanced': True,
                                          'num_batch_users': 25,
                                          'moved_data_size': 200,
                                          })

                # parse args
                args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

#                 base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}_unbalanced_bu{}_md{}/{}/'.format(
#                     args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.num_batch_users, args.moved_data_size, args.results_save)
                base_dir = '/home/osilab7/hdd/jhoon_backup/FL_local_upt_aggr/save/{}/{}_iid{}_num{}_C{}_le{}/shard{}_unbalanced_bu{}_md{}/{}/'.format(
                    args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.num_batch_users, args.moved_data_size, args.results_save)
                algo_dir = 'local_upt_{}_aggr_{}'.format(args.local_upt_part, args.aggr_part)

                dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
                dict_save_path = os.path.join(base_dir, algo_dir, 'dict_users.pkl')
                with open(dict_save_path, 'rb') as handle:
                    dict_users_train, dict_users_test = pickle.load(handle)

                # build model
                net_glob = get_model(args)
                net_glob.train()

                net_local_list = []
                for user_ix in range(args.num_users):
                    net_local_list.append(copy.deepcopy(net_glob))

                if args.local_upt_part == 'body':
                    body_lr = args.lr
                    head_lr = args.lr # (For personalization) 
                elif args.local_upt_part == 'head':
                    body_lr = args.lr # (For personalization)
                    head_lr = args.lr
                elif args.local_upt_part == 'full':
                    body_lr = args.lr
                    head_lr = args.lr

                criterion = nn.CrossEntropyLoss()

                before_acc_results = []
                after_acc_results = []

                for user, net_local in enumerate(net_local_list):
                    model_save_path = os.path.join(base_dir, algo_dir, 'best_local_{}.pt'.format(user))
                    net_local.load_state_dict(torch.load(model_save_path), strict=True)
                    acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
                    before_acc_results.append(acc_test)

                    net_local.train()
                    ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users_train[user]), batch_size=args.local_bs, shuffle=True)

                    body_params = [p for name, p in net_local.named_parameters() if 'linear' not in name]
                    head_params = [p for name, p in net_local.named_parameters() if 'linear' in name]
                    optimizer = torch.optim.SGD([{'params': body_params, 'lr': body_lr},
                                                 {'params': head_params, 'lr': head_lr}],
                                                momentum=0.9)
                    
                    for iter in range(args.epochs):
                        for batch_idx, (images, labels) in enumerate(ldr_train):
                            images, labels = images.to(args.device), labels.to(args.device)
                            net_local.zero_grad()
                            logits = net_local(images)

                            loss = criterion(logits, labels)
                            loss.backward()
                            optimizer.step()

                    acc_test, loss_test = test_img_local(net_local, dataset_test, args, user_idx=user, idxs=dict_users_test[user])
                    after_acc_results.append(acc_test)
                    
                print ("-----------------------------------------------------")
                print ("shard: {}, frac: {}, local_ep: {}".format(shard_per_user, frac, local_ep))
                print (np.mean(before_acc_results), np.mean(after_acc_results))
                print ("-----------------------------------------------------")