In [1]:
from torchvision import datasets, transforms

from sampling import iid, shard, lda, lda_test

import numpy as np
import sys
import pickle

trans_mnist = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
trans_cifar10_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                               std=[0.229, 0.224, 0.225])])
trans_cifar10_val = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                             std=[0.229, 0.224, 0.225])])



trans_cifar100_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                               std=[0.267, 0.256, 0.276])])
trans_cifar100_val = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                              std=[0.267, 0.256, 0.276])])


In [2]:
dataset_train = datasets.CIFAR100('../data/cifar100', train=True, download=True, transform=trans_cifar100_train)
dataset_test = datasets.CIFAR100('../data/cifar100', train=False, download=True, transform=trans_cifar100_val)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def lda_test(dataset_train, dataset_test, num_users, dict_train): 
    all_targets_train=np.array(dataset_train.targets)
    
    net_cls_counts = {}
    
    for net_i, dataidx in dict_train.items():
        unq, unq_cnt = np.unique(all_targets_train[dataidx], return_counts=True)#전체 train data 중에 net_i번째 client가 가지고 있는 data가 어떤 label을 가지고 있는지의 정보가 unq, unq의 각 element가 몇개 들어있는지 기재하는게 unq_count이다!!
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}#tmp에는 unq가 key unq_count가 value가 되게 기재!!
        net_cls_counts[net_i] = tmp

        
    K = len(np.unique(all_targets_train))
    all_targets=np.array(dataset_test.targets)
    
    dict_users = {}
    
    idx_batch = [[] for _ in range(num_users)]
    for i in range(num_users):
        for k in range(K):
            # get a list of batch indexes which are belong to label k
            idx_k = np.where(all_targets == k)[0]
            np.random.shuffle(idx_k)
            if k in net_cls_counts[i].keys():
                data_idxs = idx_k[:net_cls_counts[i][k]]
                idx_batch[i]+=data_idxs.tolist()
    for i in range(num_users):
        np.random.shuffle(idx_batch[i])
        dict_users[i] = idx_batch[i]

    return dict_users


In [4]:
def record_net_data_stats(net_dataidx_map, all_targets):
    net_cls_counts = {}#각 client가 어떤 label을 몇개씩 가지고 있는지 통계량 기재!!

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(all_targets[dataidx], return_counts=True)#전체 train data 중에 net_i번째 client가 가지고 있는 data가 어떤 label을 가지고 있는지의 정보가 unq, unq의 각 element가 몇개 들어있는지 기재하는게 unq_count이다!!
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}#tmp에는 unq가 key unq_count가 value가 되게 기재!!
        net_cls_counts[net_i] = tmp
    return net_cls_counts #각 client가 어떤 label을 몇개씩 가지고 있는지 통계량 기재!!


In [5]:
dict_save_path = 'dict_users_lda_0.5_100.pkl'
with open(dict_save_path, 'rb') as handle:#기존 pretrained되었을 때 쓰였던 클라이언트 구성으로 덮어씌운다.
    dict_users_train, dict_users_test = pickle.load(handle)

In [6]:
print(dict_users_test)

None


In [7]:
print(dict_users_train)

{0: [33314, 43201, 17156, 31405, 16637, 21216, 42959, 31306, 17383, 35197, 23023, 31618, 23116, 32652, 8803, 24256, 49976, 18992, 17201, 8080, 17314, 6598, 29559, 30318, 35883, 24681, 20168, 28890, 26384, 37534, 1590, 46876, 45842, 49296, 12838, 20342, 8584, 10502, 16483, 212, 28712, 45154, 25241, 11274, 7273, 38465, 46926, 7464, 33330, 24696, 12734, 28901, 43009, 28837, 25584, 35178, 33462, 19146, 7032, 30131, 15494, 7007, 40869, 27614, 39179, 18021, 11066, 1657, 42471, 36097, 30718, 33468, 45369, 38549, 16203, 29163, 1891, 667, 34942, 20534, 43082, 30866, 21709, 9732, 32960, 47053, 6456, 29743, 26649, 36509, 46532, 14477, 22236, 49201, 42179, 38746, 6926, 3509, 30247, 38320, 311, 26224, 42634, 7894, 49703, 4330, 23167, 45002, 12213, 38073, 27789, 24203, 48030, 30047, 9581, 22880, 17760, 15678, 32365, 36697, 16399, 32337, 7981, 15478, 48040, 22937, 38392, 9445, 40568, 1561, 42275, 14174, 38858, 20110, 15475, 1080, 22798, 5229, 2586, 30041, 11242, 13336, 1492, 26162, 9112, 23055, 7881,

In [8]:
dict_users_test = lda_test(dataset_train, dataset_test, 100, dict_users_train)

In [9]:
print(dict_users_test)

{0: [1171, 4578, 8244, 7296, 6823, 5480, 6381, 4527, 5095, 2664, 5622, 3297, 8439, 4713, 474, 8360, 4549, 3525, 8928, 4382, 2752, 515, 4501, 6184, 8206, 7865, 7591, 2915, 3688, 4558, 2235, 4784, 7449, 6811, 7792, 3066, 9725, 3411, 8305, 9535, 9959, 2666, 2495, 3216, 9712, 584, 2464, 9219, 6801, 7171, 5536, 2206, 2115, 4300, 9361, 8197, 4481, 3259, 9310, 4522, 5965, 4755, 6526, 7092, 2158, 9041, 6485, 8585, 1945, 4859, 6948, 3246, 4499, 1560, 9375, 6957, 9073, 3565, 1611, 9283, 7665, 4207, 2197, 8236, 2904, 222, 7089, 3015, 6451, 8182, 6738, 9086, 42, 3805, 9573, 7694, 1845, 1243, 1633, 7056, 9879, 1626, 4708, 8380, 7718, 5124, 2287, 5610, 8377, 8619, 220, 3202, 3750, 9714, 8813, 6730, 5080, 2025, 4013, 2764, 7730, 1180, 1376, 4947, 6704, 5828, 5281, 5524, 355, 6921, 4215, 5910, 9945, 5816, 8025, 8752, 7828, 4466, 3492, 7488, 2964, 6855, 3027, 2557, 6914, 8239, 5701, 1096, 4390, 8192, 4910, 8879, 9721, 9523, 9787, 4712, 9232, 5609, 8767, 1042, 2787, 8300, 3670, 2564, 8443, 133, 5954, 18

In [10]:
print(">>> Distributing client train data...")

traindata_cls_dict = record_net_data_stats(dict_users_train, np.array(dataset_train.targets))
print('Data statistics: %s' % str(traindata_cls_dict))



>>> Distributing client train data...
Data statistics: {0: {0: 9, 1: 3, 2: 6, 3: 4, 4: 2, 5: 5, 6: 1, 7: 3, 9: 8, 10: 1, 11: 2, 12: 1, 13: 3, 14: 1, 17: 2, 20: 2, 21: 4, 22: 2, 23: 7, 24: 1, 25: 6, 26: 2, 27: 2, 28: 11, 29: 5, 30: 1, 32: 5, 33: 11, 34: 1, 36: 15, 37: 14, 40: 6, 41: 10, 42: 3, 43: 6, 44: 1, 45: 4, 46: 2, 48: 19, 49: 6, 52: 3, 53: 5, 54: 3, 56: 1, 57: 10, 58: 1, 60: 21, 61: 2, 62: 23, 63: 12, 64: 13, 66: 25, 70: 2, 72: 12, 73: 13, 74: 8, 75: 9, 76: 1, 78: 1, 79: 1, 80: 21, 81: 2, 83: 12, 84: 18, 87: 3, 89: 1, 91: 13, 92: 10, 93: 1, 94: 41, 95: 8, 96: 17}, 1: {0: 2, 1: 8, 2: 1, 3: 2, 4: 5, 6: 3, 7: 5, 8: 1, 10: 9, 12: 2, 13: 1, 15: 6, 18: 20, 19: 3, 20: 22, 21: 1, 22: 1, 23: 6, 24: 23, 26: 1, 28: 1, 29: 15, 30: 16, 31: 1, 32: 8, 34: 3, 35: 2, 36: 12, 37: 2, 38: 18, 41: 1, 42: 1, 43: 1, 45: 4, 47: 1, 48: 10, 49: 1, 50: 3, 52: 12, 53: 14, 57: 1, 58: 15, 59: 3, 60: 1, 61: 8, 62: 1, 63: 10, 65: 2, 66: 2, 67: 3, 71: 8, 72: 3, 74: 6, 77: 1, 79: 8, 81: 1, 83: 53, 84: 2, 85: 5, 8

In [11]:
print(">>> Distributing client test data...")    
testdata_cls_dict = record_net_data_stats(dict_users_test, np.array(dataset_test.targets))
print('Data statistics: %s' % str(testdata_cls_dict))


>>> Distributing client test data...
Data statistics: {0: {0: 9, 1: 3, 2: 6, 3: 4, 4: 2, 5: 5, 6: 1, 7: 3, 9: 8, 10: 1, 11: 2, 12: 1, 13: 3, 14: 1, 17: 2, 20: 2, 21: 4, 22: 2, 23: 7, 24: 1, 25: 6, 26: 2, 27: 2, 28: 11, 29: 5, 30: 1, 32: 5, 33: 11, 34: 1, 36: 15, 37: 14, 40: 6, 41: 10, 42: 3, 43: 6, 44: 1, 45: 4, 46: 2, 48: 19, 49: 6, 52: 3, 53: 5, 54: 3, 56: 1, 57: 10, 58: 1, 60: 21, 61: 2, 62: 23, 63: 12, 64: 13, 66: 25, 70: 2, 72: 12, 73: 13, 74: 8, 75: 9, 76: 1, 78: 1, 79: 1, 80: 21, 81: 2, 83: 12, 84: 18, 87: 3, 89: 1, 91: 13, 92: 10, 93: 1, 94: 41, 95: 8, 96: 17}, 1: {0: 2, 1: 8, 2: 1, 3: 2, 4: 5, 6: 3, 7: 5, 8: 1, 10: 9, 12: 2, 13: 1, 15: 6, 18: 20, 19: 3, 20: 22, 21: 1, 22: 1, 23: 6, 24: 23, 26: 1, 28: 1, 29: 15, 30: 16, 31: 1, 32: 8, 34: 3, 35: 2, 36: 12, 37: 2, 38: 18, 41: 1, 42: 1, 43: 1, 45: 4, 47: 1, 48: 10, 49: 1, 50: 3, 52: 12, 53: 14, 57: 1, 58: 15, 59: 3, 60: 1, 61: 8, 62: 1, 63: 10, 65: 2, 66: 2, 67: 3, 71: 8, 72: 3, 74: 6, 77: 1, 79: 8, 81: 1, 83: 53, 84: 2, 85: 5, 86

In [12]:
with open('dict_users_lda_0.5_100.pkl', 'wb') as handle:
    pickle.dump((dict_users_train, dict_users_test), handle)

In [13]:
with open(dict_save_path, 'rb') as handle:#기존 pretrained되었을 때 쓰였던 클라이언트 구성으로 덮어씌운다.
    dict_users_train, dict_users_test = pickle.load(handle)

In [14]:
print(dict_users_test)

{0: [1171, 4578, 8244, 7296, 6823, 5480, 6381, 4527, 5095, 2664, 5622, 3297, 8439, 4713, 474, 8360, 4549, 3525, 8928, 4382, 2752, 515, 4501, 6184, 8206, 7865, 7591, 2915, 3688, 4558, 2235, 4784, 7449, 6811, 7792, 3066, 9725, 3411, 8305, 9535, 9959, 2666, 2495, 3216, 9712, 584, 2464, 9219, 6801, 7171, 5536, 2206, 2115, 4300, 9361, 8197, 4481, 3259, 9310, 4522, 5965, 4755, 6526, 7092, 2158, 9041, 6485, 8585, 1945, 4859, 6948, 3246, 4499, 1560, 9375, 6957, 9073, 3565, 1611, 9283, 7665, 4207, 2197, 8236, 2904, 222, 7089, 3015, 6451, 8182, 6738, 9086, 42, 3805, 9573, 7694, 1845, 1243, 1633, 7056, 9879, 1626, 4708, 8380, 7718, 5124, 2287, 5610, 8377, 8619, 220, 3202, 3750, 9714, 8813, 6730, 5080, 2025, 4013, 2764, 7730, 1180, 1376, 4947, 6704, 5828, 5281, 5524, 355, 6921, 4215, 5910, 9945, 5816, 8025, 8752, 7828, 4466, 3492, 7488, 2964, 6855, 3027, 2557, 6914, 8239, 5701, 1096, 4390, 8192, 4910, 8879, 9721, 9523, 9787, 4712, 9232, 5609, 8767, 1042, 2787, 8300, 3670, 2564, 8443, 133, 5954, 18