In [1]:
from torchvision import datasets, transforms

from sampling import iid, shard, lda, lda_test

import numpy as np
import sys
import pickle

trans_mnist = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
trans_cifar10_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                               std=[0.229, 0.224, 0.225])])
trans_cifar10_val = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                             std=[0.229, 0.224, 0.225])])



trans_cifar100_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                               std=[0.267, 0.256, 0.276])])
trans_cifar100_val = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                              std=[0.267, 0.256, 0.276])])


In [2]:
dataset_train = datasets.CIFAR100('../data/cifar100', train=True, download=True, transform=trans_cifar100_train)
dataset_test = datasets.CIFAR100('../data/cifar100', train=False, download=True, transform=trans_cifar100_val)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def lda_test(dataset_train, dataset_test, num_users, dict_train): 
    all_targets_train=np.array(dataset_train.targets)
    
    net_cls_counts = {}
    
    for net_i, dataidx in dict_train.items():
        unq, unq_cnt = np.unique(all_targets_train[dataidx], return_counts=True)#전체 train data 중에 net_i번째 client가 가지고 있는 data가 어떤 label을 가지고 있는지의 정보가 unq, unq의 각 element가 몇개 들어있는지 기재하는게 unq_count이다!!
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}#tmp에는 unq가 key unq_count가 value가 되게 기재!!
        net_cls_counts[net_i] = tmp

        
    K = len(np.unique(all_targets_train))
    all_targets=np.array(dataset_test.targets)
    
    dict_users = {}
    
    idx_batch = [[] for _ in range(num_users)]
    for i in range(num_users):
        for k in range(K):
            # get a list of batch indexes which are belong to label k
            idx_k = np.where(all_targets == k)[0]
            np.random.shuffle(idx_k)
            if k in net_cls_counts[i].keys():
                data_idxs = idx_k[:net_cls_counts[i][k]]
                idx_batch[i]+=data_idxs.tolist()
    for i in range(num_users):
        np.random.shuffle(idx_batch[i])
        dict_users[i] = idx_batch[i]

    return dict_users


In [4]:
def record_net_data_stats(net_dataidx_map, all_targets):
    net_cls_counts = {}#각 client가 어떤 label을 몇개씩 가지고 있는지 통계량 기재!!

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(all_targets[dataidx], return_counts=True)#전체 train data 중에 net_i번째 client가 가지고 있는 data가 어떤 label을 가지고 있는지의 정보가 unq, unq의 각 element가 몇개 들어있는지 기재하는게 unq_count이다!!
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}#tmp에는 unq가 key unq_count가 value가 되게 기재!!
        net_cls_counts[net_i] = tmp
    return net_cls_counts #각 client가 어떤 label을 몇개씩 가지고 있는지 통계량 기재!!


In [5]:
dict_save_path = 'dict_users_lda_0.3_100.pkl'
with open(dict_save_path, 'rb') as handle:#기존 pretrained되었을 때 쓰였던 클라이언트 구성으로 덮어씌운다.
    dict_users_train, dict_users_test = pickle.load(handle)

In [6]:
print(dict_users_test)

None


In [7]:
print(dict_users_train)

{0: [33844, 22155, 26626, 4629, 30255, 25420, 31114, 8868, 9726, 10557, 22923, 31102, 8456, 22853, 17392, 26381, 29903, 9038, 31915, 33557, 4127, 6229, 47400, 48541, 32634, 7659, 19905, 10564, 16115, 44569, 26550, 43331, 41214, 8768, 6570, 28393, 24291, 15997, 37868, 32169, 15182, 39434, 25559, 16705, 45986, 42062, 46997, 33354, 9871, 18718, 5177, 42937, 31919, 46243, 32685, 48420, 21707, 38824, 47148, 39975, 34817, 41326, 14719, 14950, 14217, 34077, 18787, 15844, 28569, 39287, 17156, 35099, 27101, 36199, 19401, 39045, 44903, 34928, 30238, 23496, 24909, 12170, 2070, 11493, 265, 34839, 37277, 1080, 47130, 32492, 34916, 32206, 12211, 8423, 28712, 32275, 49775, 42551, 15804, 45847, 24736, 39040, 31566, 814, 20698, 35978, 697, 5414, 3697, 27574, 40176, 6913, 16224, 40488, 8765, 6440, 17101, 37789, 6674, 39855, 29947, 4682, 3464, 34308, 42979, 16727, 47420, 32931, 42428, 27197, 1557, 30031, 24327, 22188, 41718, 10665, 19422, 6010, 44171, 40892, 44142, 44639, 26310, 34359, 40874, 43006, 3454

In [8]:
dict_users_test = lda_test(dataset_train, dataset_test, 100, dict_users_train)

In [9]:
print(dict_users_test)

{0: [113, 991, 4520, 6298, 4618, 906, 9929, 4454, 3163, 424, 9190, 1912, 5133, 7569, 2016, 3662, 7187, 354, 9852, 1561, 3438, 1000, 6707, 9561, 6456, 318, 3051, 6355, 26, 5236, 685, 882, 361, 3765, 4234, 7487, 6948, 4077, 6117, 4152, 37, 173, 6433, 9635, 696, 765, 9730, 7051, 6656, 5873, 1405, 4147, 2701, 7845, 6794, 77, 7161, 482, 260, 7455, 3801, 7153, 1784, 2481, 5910, 5716, 2680, 3111, 2289, 5780, 3297, 5579, 4284, 2084, 583, 5703, 6392, 5417, 6286, 9249, 9198, 7464, 923, 2377, 3997, 6686, 6772, 740, 9727, 7405, 2223, 2305, 375, 4948, 6073, 3855, 9491, 511, 997, 9183, 9240, 3728, 3369, 5609, 8985, 9325, 5051, 4986, 1595, 6559, 3054, 6247, 8254, 8238, 9248, 1726, 5127, 9844, 4842, 4544, 8590, 5220, 8472, 8761, 876, 4769, 7765, 6968, 5021, 752, 1666, 6419, 84, 2416, 1296, 2747, 9517, 2819, 5382, 2510, 3669, 2919, 6831, 7374, 5144, 337, 9728, 8577, 9563, 4386, 7249, 9703, 5656, 8798, 9165, 999, 327, 2789, 2449, 8754, 6543, 5364, 3816, 8626, 902, 8422, 4058, 9492, 5996, 5100, 5303, 337

In [10]:
print(">>> Distributing client train data...")

traindata_cls_dict = record_net_data_stats(dict_users_train, np.array(dataset_train.targets))
print('Data statistics: %s' % str(traindata_cls_dict))



>>> Distributing client train data...
Data statistics: {0: {0: 1, 2: 5, 3: 7, 4: 2, 5: 1, 9: 7, 10: 41, 11: 27, 14: 1, 15: 5, 17: 1, 20: 3, 22: 16, 24: 3, 26: 3, 29: 14, 31: 13, 32: 1, 36: 6, 37: 17, 39: 18, 40: 2, 41: 4, 45: 16, 48: 22, 49: 1, 50: 5, 51: 6, 52: 1, 55: 1, 57: 5, 58: 1, 62: 7, 65: 8, 66: 5, 67: 3, 68: 21, 70: 1, 71: 12, 72: 24, 73: 1, 74: 4, 76: 21, 78: 4, 79: 1, 80: 25, 82: 4, 83: 8, 85: 3, 86: 11, 87: 12, 91: 32, 92: 1, 93: 8, 95: 1, 96: 8, 97: 2, 98: 2}, 1: {0: 2, 1: 1, 3: 1, 4: 8, 5: 1, 8: 46, 11: 1, 13: 7, 18: 23, 19: 2, 21: 12, 22: 21, 23: 1, 26: 4, 27: 1, 28: 3, 30: 47, 31: 9, 32: 3, 33: 1, 35: 20, 36: 2, 38: 1, 39: 3, 44: 16, 45: 6, 48: 4, 49: 22, 51: 2, 57: 1, 58: 41, 59: 24, 60: 8, 62: 3, 63: 5, 64: 4, 66: 19, 67: 1, 70: 4, 71: 1, 72: 1, 73: 1, 75: 1, 76: 7, 77: 10, 78: 9, 81: 7, 82: 3, 85: 3, 86: 2, 87: 40, 88: 4, 89: 6, 90: 11, 92: 7, 93: 26}, 2: {0: 1, 1: 1, 5: 13, 6: 1, 14: 10, 16: 3, 17: 7, 18: 34, 19: 8, 20: 4, 21: 6, 22: 4, 24: 1, 25: 1, 26: 1, 28: 10, 

In [11]:
print(">>> Distributing client test data...")    
testdata_cls_dict = record_net_data_stats(dict_users_test, np.array(dataset_test.targets))
print('Data statistics: %s' % str(testdata_cls_dict))


>>> Distributing client test data...
Data statistics: {0: {0: 1, 2: 5, 3: 7, 4: 2, 5: 1, 9: 7, 10: 41, 11: 27, 14: 1, 15: 5, 17: 1, 20: 3, 22: 16, 24: 3, 26: 3, 29: 14, 31: 13, 32: 1, 36: 6, 37: 17, 39: 18, 40: 2, 41: 4, 45: 16, 48: 22, 49: 1, 50: 5, 51: 6, 52: 1, 55: 1, 57: 5, 58: 1, 62: 7, 65: 8, 66: 5, 67: 3, 68: 21, 70: 1, 71: 12, 72: 24, 73: 1, 74: 4, 76: 21, 78: 4, 79: 1, 80: 25, 82: 4, 83: 8, 85: 3, 86: 11, 87: 12, 91: 32, 92: 1, 93: 8, 95: 1, 96: 8, 97: 2, 98: 2}, 1: {0: 2, 1: 1, 3: 1, 4: 8, 5: 1, 8: 46, 11: 1, 13: 7, 18: 23, 19: 2, 21: 12, 22: 21, 23: 1, 26: 4, 27: 1, 28: 3, 30: 47, 31: 9, 32: 3, 33: 1, 35: 20, 36: 2, 38: 1, 39: 3, 44: 16, 45: 6, 48: 4, 49: 22, 51: 2, 57: 1, 58: 41, 59: 24, 60: 8, 62: 3, 63: 5, 64: 4, 66: 19, 67: 1, 70: 4, 71: 1, 72: 1, 73: 1, 75: 1, 76: 7, 77: 10, 78: 9, 81: 7, 82: 3, 85: 3, 86: 2, 87: 40, 88: 4, 89: 6, 90: 11, 92: 7, 93: 26}, 2: {0: 1, 1: 1, 5: 13, 6: 1, 14: 10, 16: 3, 17: 7, 18: 34, 19: 8, 20: 4, 21: 6, 22: 4, 24: 1, 25: 1, 26: 1, 28: 10, 2

In [12]:
with open('dict_users_lda_0.3_100.pkl', 'wb') as handle:
    pickle.dump((dict_users_train, dict_users_test), handle)

In [13]:
with open(dict_save_path, 'rb') as handle:#기존 pretrained되었을 때 쓰였던 클라이언트 구성으로 덮어씌운다.
    dict_users_train, dict_users_test = pickle.load(handle)

In [14]:
print(dict_users_test)

{0: [113, 991, 4520, 6298, 4618, 906, 9929, 4454, 3163, 424, 9190, 1912, 5133, 7569, 2016, 3662, 7187, 354, 9852, 1561, 3438, 1000, 6707, 9561, 6456, 318, 3051, 6355, 26, 5236, 685, 882, 361, 3765, 4234, 7487, 6948, 4077, 6117, 4152, 37, 173, 6433, 9635, 696, 765, 9730, 7051, 6656, 5873, 1405, 4147, 2701, 7845, 6794, 77, 7161, 482, 260, 7455, 3801, 7153, 1784, 2481, 5910, 5716, 2680, 3111, 2289, 5780, 3297, 5579, 4284, 2084, 583, 5703, 6392, 5417, 6286, 9249, 9198, 7464, 923, 2377, 3997, 6686, 6772, 740, 9727, 7405, 2223, 2305, 375, 4948, 6073, 3855, 9491, 511, 997, 9183, 9240, 3728, 3369, 5609, 8985, 9325, 5051, 4986, 1595, 6559, 3054, 6247, 8254, 8238, 9248, 1726, 5127, 9844, 4842, 4544, 8590, 5220, 8472, 8761, 876, 4769, 7765, 6968, 5021, 752, 1666, 6419, 84, 2416, 1296, 2747, 9517, 2819, 5382, 2510, 3669, 2919, 6831, 7374, 5144, 337, 9728, 8577, 9563, 4386, 7249, 9703, 5656, 8798, 9165, 999, 327, 2789, 2449, 8754, 6543, 5364, 3816, 8626, 902, 8422, 4058, 9492, 5996, 5100, 5303, 337