In [1]:
from torchvision import datasets, transforms

from sampling import iid, shard, lda, lda_test

import numpy as np
import sys
import pickle

trans_mnist = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
trans_cifar10_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                               std=[0.229, 0.224, 0.225])])
trans_cifar10_val = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                             std=[0.229, 0.224, 0.225])])



trans_cifar100_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                               std=[0.267, 0.256, 0.276])])
trans_cifar100_val = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                              std=[0.267, 0.256, 0.276])])


In [2]:
dataset_train = datasets.CIFAR100('../data/cifar100', train=True, download=True, transform=trans_cifar100_train)
dataset_test = datasets.CIFAR100('../data/cifar100', train=False, download=True, transform=trans_cifar100_val)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def lda_test(dataset_train, dataset_test, num_users, dict_train): 
    all_targets_train=np.array(dataset_train.targets)
    
    net_cls_counts = {}
    
    for net_i, dataidx in dict_train.items():
        unq, unq_cnt = np.unique(all_targets_train[dataidx], return_counts=True)#전체 train data 중에 net_i번째 client가 가지고 있는 data가 어떤 label을 가지고 있는지의 정보가 unq, unq의 각 element가 몇개 들어있는지 기재하는게 unq_count이다!!
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}#tmp에는 unq가 key unq_count가 value가 되게 기재!!
        net_cls_counts[net_i] = tmp

        
    K = len(np.unique(all_targets_train))
    all_targets=np.array(dataset_test.targets)
    
    dict_users = {}
    
    idx_batch = [[] for _ in range(num_users)]
    for i in range(num_users):
        for k in range(K):
            # get a list of batch indexes which are belong to label k
            idx_k = np.where(all_targets == k)[0]
            np.random.shuffle(idx_k)
            if k in net_cls_counts[i].keys():
                data_idxs = idx_k[:net_cls_counts[i][k]]
                idx_batch[i]+=data_idxs.tolist()
    for i in range(num_users):
        np.random.shuffle(idx_batch[i])
        dict_users[i] = idx_batch[i]

    return dict_users


In [4]:
def record_net_data_stats(net_dataidx_map, all_targets):
    net_cls_counts = {}#각 client가 어떤 label을 몇개씩 가지고 있는지 통계량 기재!!

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(all_targets[dataidx], return_counts=True)#전체 train data 중에 net_i번째 client가 가지고 있는 data가 어떤 label을 가지고 있는지의 정보가 unq, unq의 각 element가 몇개 들어있는지 기재하는게 unq_count이다!!
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}#tmp에는 unq가 key unq_count가 value가 되게 기재!!
        net_cls_counts[net_i] = tmp
    return net_cls_counts #각 client가 어떤 label을 몇개씩 가지고 있는지 통계량 기재!!


In [5]:
dict_save_path = 'dict_users_lda_0.1_100.pkl'
with open(dict_save_path, 'rb') as handle:#기존 pretrained되었을 때 쓰였던 클라이언트 구성으로 덮어씌운다.
    dict_users_train, dict_users_test = pickle.load(handle)

In [6]:
print(dict_users_test)

None


In [7]:
print(dict_users_train)

{0: [2529, 44276, 17722, 28165, 37778, 45271, 34559, 24005, 34068, 31540, 5660, 14662, 21696, 18637, 21444, 38022, 22518, 26535, 30009, 17673, 44508, 44531, 47945, 30320, 20351, 39609, 32556, 5154, 31970, 46205, 20500, 9329, 38605, 46754, 14213, 6229, 13316, 29740, 15880, 49215, 41125, 23655, 1275, 45175, 48513, 49480, 21141, 13919, 25627, 8105, 17346, 19253, 23014, 13691, 28540, 30603, 19226, 37210, 23204, 30371, 29686, 18803, 19904, 28494, 40154, 20121, 38336, 25203, 14497, 8698, 18493, 14121, 36029, 40619, 36580, 49063, 29780, 17548, 33743, 22391, 42710, 1307, 5388, 16833, 19616, 39702, 6001, 30643, 37674, 15891, 16650, 3307, 2795, 44231, 21332, 27786, 34054, 39724, 6451, 25508, 18063, 8126, 19516, 26004, 20956, 15682, 9502, 28288, 11419, 8995, 29398, 49411, 10355, 5133, 13626, 47392, 27977, 7237, 868, 37573, 43769, 19282, 47801, 10093, 5078, 13678, 18767, 3943, 39743, 17371, 30385, 13787, 879, 29412, 3100, 25284, 47113, 17752, 45619, 35953, 2201, 38080, 4167, 21319, 47160, 37847, 2

In [8]:
dict_users_test = lda_test(dataset_train, dataset_test, 100, dict_users_train)

In [9]:
print(dict_users_test)

{0: [3658, 7474, 6743, 9466, 6861, 8292, 5225, 3671, 1527, 5806, 8311, 3423, 9739, 3139, 1968, 5134, 8271, 4521, 5474, 4117, 7461, 1556, 3170, 1736, 4800, 5110, 2191, 4012, 2487, 3501, 7868, 5095, 9124, 1747, 503, 7487, 1079, 2900, 9605, 5244, 4859, 4905, 129, 3795, 6585, 4507, 4834, 7531, 3803, 1507, 9979, 2567, 2868, 8710, 9500, 7049, 5423, 3491, 7476, 2112, 551, 6345, 745, 2999, 505, 6940, 5434, 37, 3707, 7722, 8224, 1022, 6218, 40, 7808, 8743, 182, 3534, 4772, 6116, 9170, 5376, 1790, 780, 1770, 3068, 2638, 8700, 7863, 6734, 3980, 9621, 3248, 4833, 3939, 2380, 8082, 2106, 3733, 6699, 2971, 7664, 5635, 2471, 2561, 6782, 1732, 6464, 585, 1337, 3805, 9175, 9975, 4532, 9539, 2920, 39, 3060, 5663, 4771, 1126, 3239, 7463, 7224, 9887, 7213, 2177, 9024, 3055, 9181, 7693, 1546, 3426, 7732, 4656, 4038, 9509, 3596, 9010, 9426, 1678, 1239, 7346, 918, 9189, 875, 9635, 6298, 1256, 2120, 2156, 534, 483, 2418, 5238, 1287, 1212, 6165, 154, 729, 4464, 1348, 3825, 1210, 528, 9032, 7452, 7201, 2320, 69

In [10]:
print(">>> Distributing client train data...")

traindata_cls_dict = record_net_data_stats(dict_users_train, np.array(dataset_train.targets))
print('Data statistics: %s' % str(traindata_cls_dict))



>>> Distributing client train data...
Data statistics: {0: {7: 2, 11: 28, 12: 35, 13: 3, 14: 22, 19: 38, 21: 11, 22: 22, 23: 10, 24: 1, 26: 4, 27: 2, 29: 6, 35: 19, 36: 11, 37: 20, 38: 10, 39: 23, 40: 7, 41: 116, 45: 31, 46: 7, 47: 4, 49: 1, 50: 1, 54: 2, 56: 1, 58: 9, 59: 4, 63: 39, 65: 2, 67: 2, 70: 5, 74: 2}, 1: {2: 2, 10: 2, 11: 3, 16: 10, 24: 2, 33: 84, 36: 9, 40: 19, 43: 2, 44: 25, 47: 3, 50: 1, 55: 14, 56: 1, 57: 90, 58: 13, 65: 1, 73: 2, 82: 41, 88: 10, 93: 16, 94: 2, 98: 155}, 2: {2: 1, 3: 7, 10: 8, 13: 1, 14: 20, 20: 1, 21: 4, 24: 1, 25: 37, 29: 4, 31: 26, 37: 4, 39: 1, 46: 2, 47: 1, 48: 11, 49: 5, 53: 1, 55: 19, 59: 96, 60: 4, 62: 14, 63: 1, 64: 20, 67: 1, 68: 17, 69: 51, 73: 32, 75: 1, 76: 1, 79: 2, 84: 6, 85: 1, 90: 27, 93: 6, 96: 2, 98: 9, 99: 10}, 3: {2: 3, 7: 105, 10: 8, 16: 7, 22: 27, 23: 2, 26: 13, 29: 1, 31: 3, 32: 15, 33: 2, 35: 3, 36: 5, 38: 1, 40: 1, 41: 18, 43: 2, 49: 9, 52: 21, 55: 35, 58: 21, 59: 9, 63: 1, 69: 145, 76: 2, 81: 4, 82: 20, 84: 2, 87: 31}, 4: {3: 5

In [11]:
print(">>> Distributing client test data...")    
testdata_cls_dict = record_net_data_stats(dict_users_test, np.array(dataset_test.targets))
print('Data statistics: %s' % str(testdata_cls_dict))


>>> Distributing client test data...
Data statistics: {0: {7: 2, 11: 28, 12: 35, 13: 3, 14: 22, 19: 38, 21: 11, 22: 22, 23: 10, 24: 1, 26: 4, 27: 2, 29: 6, 35: 19, 36: 11, 37: 20, 38: 10, 39: 23, 40: 7, 41: 100, 45: 31, 46: 7, 47: 4, 49: 1, 50: 1, 54: 2, 56: 1, 58: 9, 59: 4, 63: 39, 65: 2, 67: 2, 70: 5, 74: 2}, 1: {2: 2, 10: 2, 11: 3, 16: 10, 24: 2, 33: 84, 36: 9, 40: 19, 43: 2, 44: 25, 47: 3, 50: 1, 55: 14, 56: 1, 57: 90, 58: 13, 65: 1, 73: 2, 82: 41, 88: 10, 93: 16, 94: 2, 98: 100}, 2: {2: 1, 3: 7, 10: 8, 13: 1, 14: 20, 20: 1, 21: 4, 24: 1, 25: 37, 29: 4, 31: 26, 37: 4, 39: 1, 46: 2, 47: 1, 48: 11, 49: 5, 53: 1, 55: 19, 59: 96, 60: 4, 62: 14, 63: 1, 64: 20, 67: 1, 68: 17, 69: 51, 73: 32, 75: 1, 76: 1, 79: 2, 84: 6, 85: 1, 90: 27, 93: 6, 96: 2, 98: 9, 99: 10}, 3: {2: 3, 7: 100, 10: 8, 16: 7, 22: 27, 23: 2, 26: 13, 29: 1, 31: 3, 32: 15, 33: 2, 35: 3, 36: 5, 38: 1, 40: 1, 41: 18, 43: 2, 49: 9, 52: 21, 55: 35, 58: 21, 59: 9, 63: 1, 69: 100, 76: 2, 81: 4, 82: 20, 84: 2, 87: 31}, 4: {3: 5,

In [12]:
with open('dict_users_lda_0.1_100.pkl', 'wb') as handle:
    pickle.dump((dict_users_train, dict_users_test), handle)

In [13]:
with open(dict_save_path, 'rb') as handle:#기존 pretrained되었을 때 쓰였던 클라이언트 구성으로 덮어씌운다.
    dict_users_train, dict_users_test = pickle.load(handle)

In [14]:
print(dict_users_test)

{0: [3658, 7474, 6743, 9466, 6861, 8292, 5225, 3671, 1527, 5806, 8311, 3423, 9739, 3139, 1968, 5134, 8271, 4521, 5474, 4117, 7461, 1556, 3170, 1736, 4800, 5110, 2191, 4012, 2487, 3501, 7868, 5095, 9124, 1747, 503, 7487, 1079, 2900, 9605, 5244, 4859, 4905, 129, 3795, 6585, 4507, 4834, 7531, 3803, 1507, 9979, 2567, 2868, 8710, 9500, 7049, 5423, 3491, 7476, 2112, 551, 6345, 745, 2999, 505, 6940, 5434, 37, 3707, 7722, 8224, 1022, 6218, 40, 7808, 8743, 182, 3534, 4772, 6116, 9170, 5376, 1790, 780, 1770, 3068, 2638, 8700, 7863, 6734, 3980, 9621, 3248, 4833, 3939, 2380, 8082, 2106, 3733, 6699, 2971, 7664, 5635, 2471, 2561, 6782, 1732, 6464, 585, 1337, 3805, 9175, 9975, 4532, 9539, 2920, 39, 3060, 5663, 4771, 1126, 3239, 7463, 7224, 9887, 7213, 2177, 9024, 3055, 9181, 7693, 1546, 3426, 7732, 4656, 4038, 9509, 3596, 9010, 9426, 1678, 1239, 7346, 918, 9189, 875, 9635, 6298, 1256, 2120, 2156, 534, 483, 2418, 5238, 1287, 1212, 6165, 154, 729, 4464, 1348, 3825, 1210, 528, 9032, 7452, 7201, 2320, 69