In [1]:
import os
import sys
import math
import random
import torch
import numpy as np
import syft as sy
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import ujson as json
import pandas as pd
import re
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
class Arguments():
    def __init__(self):
        self.images = 48285
        self.clients = 10
        self.rounds = 5
        self.epochs = 5
        self.local_batches = 64
        self.lr = 0.01
        self.C = 0.9
        self.drop_rate = 0.1
        self.torch_seed = 0
        self.log_interval = 10
        self.iid = 'iid'
        self.split_size = int(self.images / self.clients)
        self.samples = self.split_size / self.images 
        self.use_cuda = False
        self.save_model = False

args = Arguments()

use_cuda = args.use_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [3]:
hook = sy.TorchHook(torch)
clients = []

for i in range(args.clients):
    clients.append({'hook': sy.VirtualWorker(hook, id="client{}".format(i+1))})

In [4]:
def loadDataset(path, screenName):

    users = [ f.path for f in os.scandir(path) if f.is_dir() ]
    info = pd.DataFrame(columns= ['accelometer_size', 'gyroscope_size', 'timestamp'])

    accelerometer = pd.DataFrame(columns=['x', 'y', 'z', 'screen', 'user', 'magnitude','combine_angle', 'timestamp'])
    gyroscope = pd.DataFrame(columns=['x_gyroscope', 'y_gyroscope', 'z_gyroscope', 'screen_gyroscope', 'user_gyroscope', 'magnitude_gyroscope', 'combine_angle_gyroscope', 'timestamp_gyroscope'])

    # Read sensors data from json file and save them in Dataframes
    for i in range(0, len(users)):

        json_files = [pos_json for pos_json in os.listdir(users[i]) if pos_json.endswith('.json')]

        for index, js in enumerate(json_files):
            with open(os.path.join(users[i], js)) as json_file:
                json_text = json.load(json_file)
                accSize = 0
                gyrSize = 0
                js = js.replace('.json','')
                arr = js.split('_')

                for j in json_text['accelerometer']:
                    if screenName in j['screen']:
                        x = j['x']
                        y = j['y']
                        z = j['z']
                        if x == 0 and y == 0:
                            continue
                        screen = j['screen']
                        user = arr[0]
                        m = x**2 + y**2 + z**2
                        m = np.sqrt(m)
                        ca = np.sqrt(y**2 + z**2)
                        timestamp = arr[1]
                        accSize = accSize + 1
                        df = {'x': x, 'y': y, 'z' : z, 'screen' : screen, 'user': user, 'magnitude' : m, 'combine_angle': ca, 'timestamp': timestamp}
                        accelerometer = accelerometer.append(df, ignore_index=True)
                        
                for j in json_text['gyroscope']:
                    if screenName in j['screen']:
                        x = j['x']
                        y = j['y']
                        z = j['z']
                        if x == 0 and y == 0:
                            continue
                        screen = j['screen']
                        user = arr[0]
                        m = x**2 + y**2 + z**2
                        m = np.sqrt(m)
                        ca = np.sqrt(y**2 + z**2)
                        timestamp = arr[1]
                        gyrSize =  gyrSize + 1
                        df = {'x_gyroscope': x, 'y_gyroscope': y, 'z_gyroscope' : z, 'screen_gyroscope' : screen, 'user_gyroscope': user, 'magnitude_gyroscope' : m, 'combine_angle_gyroscope': ca, 'timestamp_gyroscope': timestamp}
                        gyroscope = gyroscope.append(df, ignore_index=True)
                    
                dframe = {'accelometer_size': accSize, 'gyroscope_size': gyrSize, 'timestamp': arr[1]}
                info = info.append(dframe, ignore_index=True)

    return accelerometer, gyroscope, info, users

In [5]:
def normalize_data(data):
        var_std = data.std(dim=1, keepdim=True)
        var_mean = data.mean(dim=1, keepdim=True)
        var_list = (data - var_mean) / var_std

        return var_list, var_mean, var_std

In [6]:
path = 'C:/Users/SouthSystem/Documents/Pessoal/TCC/Impl/sensors_data'
screen = 'MathisisGame'
accelerometer, gyroscope, info, users = loadDataset(path, screen)

In [7]:
variables = ["x", "y", "z", "magnitude", "combine_angle", "user"]
var_list = []

for part in variables:
       var_list.append(list(accelerometer[part]))

# tensor_list = torch.tensor(var_list)

df = pd.DataFrame(var_list)
dataset_T = df.T
datasetData = dataset_T.iloc[:,[0,1,2,3,4]]
datasetUser = dataset_T.iloc[:,[5]]
# list(sorted(set(dataset_T[5])))

In [None]:
pd.set_option('display.max_rows', None)
print(accelerometer)
pd.reset_option('display.max_rows')

In [None]:
# data_list, mean, std = normalize_data(tensor_list)
print(df)

In [8]:
class FedDataset(Dataset):
    def __init__(self, dataset, indx):
        self.datasetData = dataset.iloc[:,[0,1,2,3,4]]
        self.datasetUser = dataset.iloc[:,[5]]
        self.indx = [int(i) for i in indx]
        self.users = list(datasetUser[5])
        
        self.user_list = list(sorted(set(self.datasetUser[5])))
        
    def __len__(self):
        return len(self.indx)
    
    def get_user_id(self, user):
        return self.user_list.index(user)
    
    def users2list(users):
        user_ids = [get_user_id(
        user) for user in users]
        return user_ids
    
    def __getitem__(self, item):
        data_transposed = self.datasetData.T
        data = data_transposed[self.indx[item]]
        user_ids = [self.user_list.index(user) for user in self.users]
        label = user_ids[item]
        return torch.tensor(data).clone().detach(), torch.tensor(label).clone().detach().unsqueeze(0)

In [None]:
# indeces = [32771, 6, 16393, 32780, 8206, 8209, 18, 8210, 24598, 22, 24600, 24, 8217, 8220, 8222, 16415, 24609, 24614, 24615, 24624, 8241, 24627, 32820, 16439, 24632, 8249, 24634, 16444, 62, 32831, 16448, 67, 32837, 24647, 16462, 32851, 8276, 32853, 86, 8277, 24664, 83, 32858, 24667, 32860, 32861, 16478, 91, 98, 32871, 32872, 16489, 103, 32875, 32881, 32883, 32885, 16505, 8315, 123, 16515, 16517, 8325, 16520, 24713, 141, 8335, 32914, 24725, 8343, 151, 24730, 32924, 32926, 8352, 32929, 16544, 24741, 32934, 24744, 32936, 32944, 24754, 24755, 16563, 24757, 181, 16565, 16568, 32956, 188, 194, 8388, 16581, 32965, 24775, 24782, 8398, 24784, 209, 24786, 24787, 212, 8405, 16597, 216, 32987, 16605, 16606, 8415, 32990, 32991, 24802, 226, 24806, 24807, 231, 33001, 8422, 8423, 24811, 16621, 24808, 16623, 240, 241, 243, 8436, 251, 33021, 24829, 33023, 8451, 24836, 8457, 24841, 271, 24848, 33042, 8467, 276, 281, 24860, 16670, 16671, 292, 8484, 16679, 8487, 300, 302, 24879, 33072, 8508, 24892, 8511, 24896, 321, 8512, 322, 8516, 324, 24897, 33095, 16711, 24903, 33097, 8522, 8524, 24906, 24910, 335, 8528, 16721, 338, 33110, 345, 350, 33119, 352, 353, 355, 362, 33134, 8560, 16752, 33137, 16755, 33142, 16758, 33145, 16763, 24955, 382, 8577, 387, 24963, 16772, 390, 16776, 33161, 393, 24971, 33164, 33165, 33166, 397, 8594, 24980, 8596, 406, 16788, 8602, 33180, 413, 415, 416, 424, 25001, 427, 8621, 33199, 16819, 16820, 436, 439, 16823, 441, 16833, 25026, 8642, 25027, 33218, 449, 457, 33232, 8659, 25044, 25046, 33238, 33244, 25052, 33245, 487, 16872, 8681, 25066, 8683, 16875, 16886, 33272, 33273, 16891, 16893, 8704, 513, 16899, 25094, 16903, 33293, 33294, 25104, 33300, 16924, 8733, 16928, 545, 16931, 25124, 549, 25127, 8746, 16939, 16942, 16943, 16944, 25134, 25138, 8753, 25135, 16949, 16946, 25140, 33337, 16953, 16955, 570, 33341, 25150, 25153, 25154, 8771, 577, 16966, 583, 8779, 25165, 8783, 8785, 595, 16980, 8793, 25177, 8795, 16988, 16989, 606, 16990, 16994, 16995, 8802, 8806, 33385, 617, 33390, 17008, 17009, 17012, 25205, 17018, 8827, 33404, 33405, 17022, 25212, 17020, 17025, 8834, 25219, 33411, 639, 8839, 649, 25226, 651, 17038, 8850, 17044, 33429, 662, 17046, 33432, 8855, 8858, 8859, 33436, 25245, 666, 8864, 675, 25251, 17065, 8874, 17069, 25262, 25261, 690, 25268, 25270, 695, 25271, 33465, 696, 17085, 25278, 703, 8898, 25286, 17095, 8905, 33483, 17100, 25294, 25295, 33492, 725, 25302, 25305, 17114, 25306, 734, 735, 8928, 17120, 8930, 25312, 25318, 33511, 25319, 8937, 8942, 752, 17137, 8946, 25331, 8944, 25334, 759, 17144, 17146, 17147, 8966, 33545, 17163, 8972, 780, 25361, 33554, 787, 788, 25369, 17178, 17177, 17182, 25378, 25385, 17195, 9004, 9005, 17196, 33589, 821, 828, 829, 9022, 25405, 25407, 834, 33604, 17222, 25414, 33608, 838, 33609, 17230, 33615, 17233, 9044, 17237, 854, 33623, 33627, 17245, 33632, 17249, 9061, 25447, 17258, 875, 9068, 883, 17268, 9080, 33657, 17272, 33658, 25483, 17293, 17296, 33683, 33686, 33688, 9115, 17307, 25500, 9119, 33696, 934, 937, 33706, 9133, 33716, 9142, 33721, 954, 33723, 17341, 17342, 25533, 33726, 961, 9154, 9155, 33732, 33738, 17355, 33742, 25550, 25554, 33748, 33750, 25559, 25562, 25568, 9185, 9188, 17381, 25578, 1003, 17387, 1007, 17395, 25588, 17396, 1011, 17403, 9214, 1025, 1026, 9224, 1041, 25617, 17427, 1044, 25619, 9238, 17431, 25627, 33826, 9250, 17444, 33831, 1067, 9260, 17453, 33839, 9263, 17457, 33841, 25654, 9276, 17469, 9279, 25664, 1089, 17474, 17472, 9285, 9286, 25671, 1096, 17481, 25673, 1108, 9301, 33878, 33880, 17497, 33884, 17506, 9314, 17516, 1132, 25710, 1136, 9331, 17523, 25717, 25715, 9332, 1145, 9341, 9342, 33917, 25728, 25730, 9347, 17543, 33929, 1165, 1167, 33937, 25746, 17555, 17556, 17557, 25745, 33943, 25751, 1178, 25755, 25756, 33947, 17566, 1183, 9377, 25762, 9380, 9381, 9382, 33956, 9384, 9385, 1194, 17579, 9389, 25777, 9393, 17589, 25784, 9401, 1210, 33978, 9408, 33986, 1218, 9413, 33990, 33992, 33994, 1229, 9421, 17615, 9422, 34005, 1239, 25817, 1244, 9439, 1247, 25825, 9443, 34025, 25834, 25837, 17645, 17647, 25840, 25841, 34036, 9461, 25846, 17653, 34038, 1273, 17659, 1278, 9470, 25857, 34050, 9474, 9479, 25863, 25866, 34067, 34070, 17687, 1304, 17692, 9503, 34079, 1312, 1314, 1316, 34084, 34089, 34090, 34091, 9516, 25903, 34098, 25910, 9526, 34102, 1341, 34111, 17728, 9538, 17735, 1352, 1353, 17741, 34126, 25933, 17742, 25934, 34130, 1359, 34132, 1366, 9562, 17761, 17763, 25957, 17771, 17773, 17775, 9584, 25967, 17778, 34163, 25971, 25973, 17782, 1400, 34170, 17787, 34174, 25984, 25985, 34179, 25993, 1418, 34188, 34191, 26000, 25999, 9621, 34200, 34204, 26012, 34206, 26017, 1442, 17827, 1444, 1446, 9638, 17836, 9644, 26030, 9647, 34220, 26028, 26034, 26035, 1460, 26037, 1462, 34231, 34225, 1466, 34234, 1467, 1469, 9663, 17855, 9665, 34242, 1479, 1480, 34249, 34248, 17867, 34254, 34259, 9684, 9685, 34261, 17879, 26074, 34267, 17884, 34269, 26078, 1503, 17893, 9705, 26090, 17902, 9714, 1523, 26100, 17910, 9721, 26106, 17915, 9722, 34304, 17922, 26119, 17928, 34316, 9741, 17934, 26128, 9745, 26132, 26133, 1560, 17945, 26138, 9755, 9753, 1561, 1566, 26143, 26139, 9761, 9762, 26146, 1574, 9768, 26154, 26155, 1580, 26159, 17970, 1589, 17974, 9783, 26170, 9791, 17984, 1603, 17987, 34374, 26183, 17992, 34376, 26185, 34379, 34380, 1613, 34382, 19858, 1616, 26193, 1617, 18004, 26200, 18008, 1624, 9819, 34398, 34400, 18017, 26209, 26216, 34412, 26222, 9840, 26227, 34422, 34423, 34426, 9857, 18051, 34435, 9862, 26249, 34442, 9868, 1678, 34447, 26260, 34452, 1687, 9884, 1694, 9889, 1699, 18084, 1700, 9894, 1704, 9897, 34474, 26283, 9899, 34475, 34479, 18096, 34482, 1717, 1721, 18110, 1727, 1728, 26305, 18114, 18112, 18116, 9925, 9926, 34503, 1741, 26317, 18127, 9939, 26324, 9943, 1759, 1765, 9962, 18155, 26349, 26355, 18165, 26358, 18168, 26361, 18169, 1792, 26371, 34565, 1797, 18183, 1799, 1801, 18184, 34572, 10005, 1814, 26391, 10008, 10009, 34586, 10013, 26399, 10017, 10021, 18213, 10023, 34597, 1834, 18220, 10029, 26415, 34608, 26416, 18227, 10039, 18232, 1854, 26434, 18243, 34628, 18245, 10054, 18244, 26438, 18249, 1868, 10062, 18254, 1872, 34641, 34638, 34639, 10070, 34647, 1878, 34649, 18263, 1888, 26466, 10084, 10085, 18277, 18286, 18287, 34672, 18290, 26488, 1915, 10109, 26495, 26498, 18308, 34693, 1927, 10125, 26510, 34707, 1940, 10133, 18332, 1948, 26526, 18335, 10144, 10143, 34725, 10153, 26539, 10155, 26541, 10156, 34731, 18352, 1963, 1971, 10165, 18358, 34743, 1977, 18361, 18363, 18364, 34749, 34751, 1984, 18369, 1985, 1990, 1992, 26570, 34763, 18378, 26573, 2004, 2005, 10198, 18391, 34772, 26585, 18389, 18393, 10204, 2013, 26588, 18399, 18401, 2018, 34786, 26597, 34791, 18412, 26605, 2028, 18415, 26608, 10223, 2032, 18418, 18417, 34805, 26615, 10233, 26618, 18432, 26624, 34818, 10244, 10246, 18439, 10257, 34835, 18453, 26646, 2077, 2079, 10280, 10282, 26666, 18478, 2094, 10290, 26674, 34866, 2103, 10299, 10303, 26688, 26689, 18499, 10309, 2118, 10316, 2124, 2129, 2132, 2134, 34907, 2140, 26715, 18523, 26721, 2150, 10344, 2154, 18538, 34924, 34928, 2162, 34931, 18548, 10363, 26751, 2177, 2179, 26758, 18572, 18579, 26776, 2201, 2202, 2206, 18591, 34976, 18594, 18600, 10411, 2220, 26797, 34992, 18609, 10418, 18611, 10422, 18618, 35006, 10430, 18626, 26824, 2249, 26825, 18637, 18644, 18647, 10456, 10457, 10458, 18649, 2268, 35038, 35039, 35040, 26849, 10470, 18665, 2282, 18668, 10481, 26869, 35061, 10487, 18680, 10489, 2298, 2299, 18684, 2300, 2302, 18678, 2306, 2307, 2308, 10501, 18694, 26887, 35080, 10507, 26894, 26897, 2325, 18711, 10520, 26904, 35095, 2331, 35100, 2329, 26910, 26903, 26912, 18721, 26923, 35118, 2351, 10544, 18736, 35125, 26934, 18745, 26940, 18750, 2369, 10568, 2377, 18761, 35147, 35148, 10573, 18766, 10575, 18764, 2385, 10581, 35159, 35161, 2393, 18779, 2396, 35165, 26973, 26976, 18785, 2400, 26979, 10596, 2407, 35175, 10603, 2414, 2415, 10607, 18799, 35186, 18804, 27002, 2427, 18813, 2431, 10624, 27010, 10627, 35204, 18820, 2436, 10628, 10634, 10636, 35216, 18832, 18836, 35221, 18838, 10647, 18840, 18841, 35226, 35224, 18845, 18846, 2463, 2464, 10660, 18852, 10665, 35242, 35243, 27057, 2482, 2481, 10679, 35257, 18874, 10683, 27070, 2495, 35266, 18882, 27077, 27078, 27081, 18890, 18889, 2508, 10701, 10700, 27090, 18899, 2516, 27093, 10707, 10708, 10711, 10715, 18908, 2526, 35294, 10718, 18912, 2529, 27102, 2534, 10727, 10734, 27126, 2550, 35320, 27129, 18938, 2552, 27133, 35329, 18946, 18947, 35333, 2567, 2570, 2571, 18956, 18959, 10769, 35348, 2582, 18968, 27161, 18977, 27171, 35365, 10791, 35370, 18990, 10799, 18992, 18995, 2612, 27189, 18997, 27193, 35387, 27196, 2623, 10817, 19010, 35401, 2634, 19019, 2638, 10831, 10833, 10835, 35413, 35414, 35415, 2648, 19033, 10838, 10837, 27226, 10844, 2656, 2657, 19042, 10851, 2663, 27239, 35434, 19051, 10858, 27245, 35438, 10862, 19056, 2669, 27249, 2675, 27252, 2673, 35445, 10864, 35448, 2672, 35452, 35460, 10884, 35463, 35466, 2698, 27276, 10892, 2702, 27275, 27280, 2707, 35476, 19093, 27284, 2711, 10905, 35482, 27292, 10908, 35486, 19100, 35484, 27298, 10916, 35492, 2728, 27307, 10927, 19124, 27316, 27317, 35511, 2749, 35519, 19136, 2753, 10946, 19137, 27329, 10956, 35533, 19150, 10959, 10963, 27352, 10970, 10973, 2782, 35554, 19175, 2793, 2794, 19178, 19179, 2800, 27378, 10995, 2808, 11000, 19192, 11003, 19195, 35579, 35583, 19201, 19204, 35588, 35591, 19207, 11015, 2826, 35593, 35598, 27408, 2834, 19221, 35606, 2838, 19225, 35615, 35616, 35621, 27434, 2861, 35631, 2869, 2870, 35639, 27454, 27460, 27461, 11078, 35657, 19276, 2895, 2896, 19282, 11093, 19289, 35674, 2907, 27482, 11102, 35678, 35679, 11108, 27493, 19303, 2924, 11121, 19314, 11123, 11129, 35706, 35707, 27517, 2943, 2947, 2951, 35721, 11147, 27533, 27535, 35729, 2966, 2970, 2971, 19356, 11164, 27547, 35740, 11169, 35746, 27556, 35754, 27566, 35765, 35766, 11190, 35767, 3001, 11194, 19383, 3004, 35773, 11201, 35777, 3014, 35782, 11208, 11207, 27595, 19404, 19407, 3030, 27608, 27613, 35807, 3042, 35811, 19428, 3046, 11238, 19431, 11239, 11243, 35821, 19438, 27635, 11251, 19445, 35830, 3062, 11260, 27646, 27647, 11265, 19458, 19459, 11266, 3076, 3078, 3079, 19464, 19462, 19465, 11271, 27658, 3085, 11283, 27668, 3093, 11285, 27675, 35868, 27676, 3102, 3103, 27680, 11296, 11299, 3110, 27688, 35885, 35886, 19504, 35891, 19510, 3127, 11320, 11323, 19520, 3137, 3139, 3140, 35908, 11334, 35911, 11336, 19530, 19531, 3147, 3151, 19547, 35931, 35933, 3165, 11360, 19553, 27748, 3174, 27751, 27750, 3179, 11371, 19563, 11374, 27765, 3190, 35960, 35961, 35962, 35963, 3198, 35967, 11392, 19585, 27781, 11405, 19602, 19604, 11420, 3230, 19615, 27806, 3234, 11427, 19618, 36006, 27814, 27816, 11433, 3243, 19628, 27829, 3256, 11450, 19643, 11451, 3261, 36038, 36039, 11464, 19657, 27850, 3272, 27852, 36043, 36046, 19664, 3281, 3286, 11482, 19675, 11484, 19676, 3294, 27871, 19680, 36066, 27875, 11493, 36071, 11496, 19688, 11499, 19692, 27884, 19695, 27887, 36084, 3316, 11510, 36087, 36085, 11514, 27903, 36096, 27905, 3335, 3336, 11527, 27916, 27917, 36111, 27919, 11536, 3344, 3347, 11539, 19733, 3351, 27929, 11545, 36123, 19738, 11549, 11550, 11551, 19742, 3364, 3365, 11563, 36140, 11565, 36143, 19763, 3380, 27957, 19767, 19770, 3387, 3388, 36156, 11581, 3390, 3393, 11589, 27975, 3403, 11596, 27982, 27984, 27987, 27989, 3413, 3415, 19799, 11607, 36186, 3416, 27997, 11614, 27999, 27998, 3429, 19814, 3431, 3432, 11626, 36203, 11627, 28013, 3438, 3439, 36210, 11637, 36214, 11639, 19832, 28026, 28029, 36224, 19843, 19844, 11654, 11655, 3468, 3470, 28047, 19855, 11665, 3474, 11667, 28051, 36245, 3478, 36246, 3473, 11673, 19866, 36251, 19868, 28058, 3486, 28063, 28062, 28064, 11682, 3488, 11684, 19878, 19880, 28073, 36270, 28078, 36276, 36279, 36281, 11706, 3514, 3515, 28097, 11714, 11715, 36292, 3528, 28105, 28109, 3539, 11732, 36311, 28120, 19933, 19936, 28131, 11748, 19942, 19945, 28139, 36336, 28144, 28146, 11767, 36345, 19962, 11771, 19964, 28157, 36350, 36354, 28168, 11785, 19979, 36367, 28176, 11793, 3603, 36371, 28182, 3607, 11800, 36382, 11811, 11812, 20005, 3622, 20007, 20008, 28209, 36405, 20023, 28216, 20025, 36411, 20029, 20037, 20038, 36423, 36427, 20050, 11859, 28244, 28248, 28251, 3675, 28253, 28257, 3682, 36450, 11875, 36453, 11878, 3686, 20067, 11881, 20076, 11886, 20089, 3706, 3710, 36482, 11907, 3714, 28293, 36486, 36487, 28296, 3721, 36484, 20113, 3738, 3740, 36514, 3749, 20134, 20135, 28334, 20143, 28335, 11954, 3763, 36533, 11959, 20152, 3769, 11962, 11966, 11969, 28355, 36548, 20166, 11981, 28367, 36561, 20179, 36566, 36567, 3798, 28386, 36580, 3814, 36583, 20206, 12016, 36592, 3833, 20219, 28414, 36608, 3841, 12033, 20227, 28419, 12040, 3849, 3850, 3852, 3853, 28432, 36627, 12051, 36629, 28438, 12053, 28440, 3865, 20251, 20252, 20253, 20254, 36639, 36640, 36641, 3872, 3876, 36654, 12078, 28464, 36656, 20276, 20279, 28473, 28478, 20286, 28480, 28481, 20292, 12102, 28487, 3913, 28491, 12108, 36686, 28494, 28498, 20307, 3925, 12122, 36699, 36698, 12128, 28514, 20327, 3946, 20333, 12143, 3952, 36721, 12146, 12147, 20345, 28540, 12159, 36735, 28549, 28551, 36743, 36749, 36751, 20370, 20373, 12182, 20374, 36761, 12187, 20382, 12191, 28577, 12194, 4003, 12196, 28580, 20389, 4006, 20388, 36773, 12200, 36774, 12203, 12195, 36782, 4014, 12208, 28593, 28594, 4018, 12212, 28598, 20408, 36795, 36798, 28609, 28612, 20422, 28615, 12231, 4043, 28620, 36811, 4048, 28627, 12243, 20438, 12246, 4056, 28633, 20445, 4062, 36830, 20447, 36833, 20450, 12255, 4066, 28645, 4070, 28646, 20453, 28648, 20458, 4075, 12269, 28654, 28655, 36846, 20465, 12273, 28659, 36845, 4088, 28668, 36865, 28677, 4102, 12294, 4108, 4109, 28688, 4114, 20498, 20500, 28692, 20502, 12312, 36889, 12314, 12319, 36896, 36897, 28707, 28710, 20519, 4137, 4139, 28716, 4141, 12334, 36911, 20528, 20523, 36914, 36917, 4149, 36919, 4152, 28729, 20535, 28727, 28732, 36927, 20545, 28738, 36931, 36932, 28743, 20552, 28747, 4172, 4182, 4183, 12377, 12381, 36957, 4191, 36959, 20576, 12385, 28769, 20580, 12384, 28776, 12394, 28781, 36974, 36983, 20600, 12407, 20605, 28798, 36992, 4225, 12419, 20616, 37001, 4235, 37008, 4244, 4247, 37016, 28826, 20636, 37022, 37025, 28835, 12455, 37033, 28841, 37035, 20654, 20655, 4273, 12467, 12468, 4277, 37045, 12471, 28855, 4280, 20660, 4285, 12480, 28866, 4293, 12489, 20684, 4301, 4302, 12503, 4312, 37079, 20698, 28894, 12510, 12512, 37089, 20707, 4326, 37098, 4334, 37108, 20724, 28918, 20727, 37111, 4347, 28924, 4349, 12544, 28929, 37127, 28936, 37130, 12556, 37132, 20754, 20755, 28949, 12568, 28953, 4391, 28967, 20777, 20775, 37160, 37164, 12583, 12590, 37168, 4402, 20788, 4404, 4407, 20792, 20793, 37177, 4410, 28994, 37186, 12614, 20807, 20806, 4429, 29005, 4431, 12624, 20815, 4432, 20819, 20820, 37205, 29012, 4430, 4441, 12633, 20826, 37214, 20831, 4447, 4449, 37218, 4451, 37216, 37222, 4455, 29032, 20838, 12654, 37233, 37235, 20851, 29047, 4481, 12674, 4487, 37257, 29065, 20882, 29074, 20885, 20887, 20889, 4506, 29083, 37276, 37274, 4510, 20891, 4513, 29090, 20900, 12710, 37287, 12714, 20909, 4528, 12721, 37299, 20921, 37313, 12738, 29124, 4549, 4551, 20936, 29129, 29128, 20940, 20942, 37329, 29139, 29142, 4567, 29144, 12763, 29147, 4572, 20960, 29155, 4580, 20966, 37351, 20969, 20971, 12779, 4589, 29166, 20976, 37365, 20994, 37381, 37383, 29192, 29197, 37391, 12816, 4628, 12821, 12822, 4630, 37403, 21020, 12829, 4637, 37406, 4640, 29220, 21031, 12842, 12846, 12851, 12856, 4664, 12858, 29242, 12863, 4672, 12868, 37447, 21064, 4682, 4685, 12878, 37455, 4688, 37459, 4692, 21075, 12888, 12890, 37467, 12892, 21082, 37469, 21092, 29289, 4714, 37483, 37481, 4718, 29297, 12914, 21109, 4725, 29302, 37498, 37499, 37500, 37502, 37507, 12934, 21128, 21131, 12940, 29325, 21134, 12944, 12945, 37524, 4761, 4762, 21149, 4775, 12968, 37548, 12972, 21166, 4783, 12979, 29366, 21174, 37560, 4793, 21178, 12986, 37565, 4803, 29380, 37573, 12997, 37575, 37581, 37583, 37586, 4819, 29394, 13013, 37593, 37595, 29404, 21213, 13023, 13024, 4831, 37602, 29412, 37605, 29413, 13032, 21224, 21231, 37617, 37618, 37623, 4855, 4857, 4861, 21253, 29446, 4872, 13065, 21260, 4877, 21261, 29458, 37652, 29466, 37661, 37665, 13094, 21288, 29482, 37678, 4914, 37683, 37685, 37686, 37687, 21303, 29494, 29498, 29499, 13114, 21308, 21310, 13118, 4932, 13125, 29511, 29512, 4937, 13130, 21323, 37705, 13140, 4949, 29526, 13143, 37719, 4953, 4955, 4962, 37733, 21351, 29544, 13162, 37739, 21356, 29549, 21357, 4977, 21362, 4981, 13176, 13181, 21376, 13186, 21381, 21382, 21383, 37765, 37769, 29578, 29579, 29581, 13198, 5009, 13202, 37778, 21398, 13206, 21405, 13216, 21408, 37794, 5029, 37798, 13223, 21416, 29609, 21420, 37813, 37814, 13239, 21432, 5047, 21434, 37818, 21436, 21430, 21438, 13247, 37825, 29634, 5061, 5063, 21447, 5065, 37836, 37837, 5070, 5073, 37849, 36831, 13276, 13277, 21468, 37853, 21472, 21473, 5093, 5094, 37863, 5096, 21481, 29674, 5101, 29678, 21489, 29682, 5108, 29685, 29686, 5111, 29687, 37879, 21498, 37883, 37882, 13302, 37878, 5119, 29696, 5116, 13318, 37894, 21512, 37897, 29707, 37900, 5133, 29712, 13329, 13330, 13338, 21531, 13343, 5153, 13348, 13354, 29741, 37933, 5166, 29744, 5169, 21552, 5167, 37941, 37943, 13368, 29752, 5179, 37948, 13371, 37947, 5183, 29760, 5187, 21574, 29767, 5190, 29769, 13385, 29772, 21584, 29777, 37970, 37972, 13399, 13400, 29786, 5211, 21595, 37981, 29792, 21602, 21606, 21608, 37993, 5225, 13416, 29804, 21609, 5224, 21615, 21616, 5233, 5235, 38004, 29816, 13433, 2670, 38011, 13440, 21637, 38024, 38025, 21642, 13451, 21643, 5261, 38029, 5264, 5266, 13458, 5268, 13462, 29847, 13464, 5274, 13469, 13473, 29858, 38049, 21668, 13480, 13482, 21677, 21678, 29869, 29872, 29882, 38075, 21690, 5309, 29885, 38080, 29889, 38082, 21696, 21703, 21704, 38088, 13517, 38094, 21713, 38097, 5336, 29914, 38107, 13538, 5347, 29929, 38122, 29931, 38127, 5360, 21747, 5364, 29941, 13560, 21753, 5374, 29950, 38142, 29953, 38150, 29961, 5389, 13582, 21775, 5396, 13590, 21786, 5402, 38172, 21789, 29981, 5407, 13609, 21801, 5419, 13611, 29997, 21812, 13621, 30006, 13623, 13624, 38201, 21817, 38198, 21821, 30014, 13631, 5437, 5438, 13635, 13637, 38216, 30026, 5459, 38228, 13653, 30039, 30040, 30045, 5469, 21856, 21858, 21859, 13668, 30051, 38246, 21860, 5481, 38251, 38254, 38255, 5488, 30069, 30070, 21877, 38264, 5496, 30075, 5500, 30076, 21883, 13692, 21888, 21886, 30085, 30089, 5514, 38284, 13710, 30095, 21906, 13717, 5530, 21917, 30113, 30114, 30118, 38311, 21930, 38316, 13741, 13743, 5551, 38319, 21939, 38328, 13753, 30138, 13755, 38331, 21949, 13757, 21951, 38338, 30151, 30157, 21966, 5583, 5584, 21969, 38351, 5589, 30165, 38360, 5593, 30170, 21976, 30172, 13790, 5599, 38368, 30178, 5603, 30181, 5605, 30183, 5610, 38379, 13804, 30188, 30192, 22001, 5617, 22003, 22004, 30196, 5627, 22012, 22017, 5639, 38407, 38408, 5642, 5646, 22035, 13845, 38422, 13847, 30235, 38428, 5661, 30238, 5663, 38436, 13861, 30250, 22059, 22061, 38448, 38452, 30261, 30263, 22076, 38461, 30269, 30273, 5706, 5709, 5710, 38478, 22093, 38483, 30293, 38486, 13911, 5722, 38495, 5728, 38497, 30303, 13924, 30315, 5739, 30317, 5741, 13937, 22129, 30325, 30333, 38527, 22143, 30337, 22147, 38531, 30340, 22152, 38539, 13964, 13967, 30351, 38546, 22162, 13970, 30359, 22169, 13977, 30371, 5797, 38569, 5805, 13999, 14001, 30386, 14003, 5809, 22200, 30394, 5819, 14012, 22204, 22207, 22208, 22215, 14033, 38610, 22227, 30418, 30421, 38618, 22236, 38620, 14046, 38624, 22245, 14053, 14055, 30440, 22250, 30443, 14061, 22253, 14064, 22259, 30454, 22264, 30457, 22266, 30460, 5888, 30466, 30467, 30469, 14086, 22289, 22290, 30481, 14100, 22298, 5916, 30492, 30494, 5919, 14112, 30499, 30500, 14117, 22310, 30507, 14124, 22322, 5938, 22325, 22331, 30524, 30526, 22335, 30528, 5954, 5955, 5957, 5959, 22345, 5962, 22349, 5966, 22351, 22358, 22364, 22365, 22366, 22370, 22371, 5988, 30564, 30568, 5993, 22376, 14188, 22381, 30574, 14191, 30575, 30578, 22388, 22395, 6012, 6011, 14206, 14211, 14216, 30600, 30604, 30609, 30610, 14229, 6038, 14232, 14233, 22426, 14235, 14234, 22432, 6052, 22437, 14251, 14255, 14257, 14258, 30641, 14261, 14262, 22455, 30647, 22454, 14276, 14277, 22470, 30662, 6088, 22469, 22468, 6085, 6089, 6093, 14288, 30678, 14295, 30680, 6107, 14304, 14310, 30695, 30699, 22511, 6130, 6135, 22520, 22522, 14335, 30724, 22539, 6155, 14348, 22542, 6158, 22545, 22547, 22548, 14358, 6175, 14367, 14369, 30753, 22567, 30759, 14377, 6185, 30763, 30764, 14382, 22574, 22576, 30768, 6190, 14387, 22581, 14390, 22584, 30777, 14396, 14397, 22593, 22594, 6212, 14409, 6217, 6218, 22608, 6224, 14419, 30804, 6230, 22619, 30812, 14429, 22620, 30820, 6248, 22634, 30828, 6253, 6257, 22644, 22651, 14460, 6269, 22654, 30848, 6274, 22659, 14468, 14474, 30859, 30858, 30861, 22672, 6289, 30871, 30876, 22685, 14497, 14499, 22695, 22697, 30890, 22701, 6323, 14518, 30907, 22715, 6333, 6338, 14530, 30916, 22729, 14537, 22733, 14544, 6354, 30931, 30933, 14557, 14558, 30942, 14563, 30948, 30951, 14569, 14570, 22762, 6378, 30953, 14574, 22772, 30968, 22776, 30972, 22782, 22783, 14592, 22785, 14599, 14604, 6415, 30995, 30997, 14614, 6425, 6427, 31004, 31005, 22812, 31007, 22816, 6433, 6436, 31013, 14632, 6441, 31018, 22832, 14641, 14649, 31035, 31036, 14652, 31040, 22849, 14659, 14660, 6471, 6472, 31050, 14666, 22865, 31057, 14674, 22873, 6492, 22880, 6498, 22884, 14693, 6502, 22890, 31085, 14710, 14711, 31096, 31099, 31100, 22912, 14721, 31110, 6541, 22930, 31125, 14742, 31127, 22933, 22938, 14747, 6557, 22952, 31147, 6573, 31153, 22968, 31161, 14778, 31162, 6588, 6585, 22977, 31169, 6595, 14788, 14797, 14800, 14802, 22995, 6615, 14808, 23003, 6619, 23005, 14821, 31206, 6631, 31213, 23022, 6643, 6644, 6649, 31226, 6654, 31232, 14849, 6662, 6663, 23049, 6666, 14861, 6671, 23057, 23058, 6673, 14868, 23064, 31257, 23070, 14880, 6692, 31268, 14886, 23079, 23082, 14892, 23086, 31282, 14900, 31286, 6711, 6712, 14904, 23099, 23105, 6722, 23107, 6723, 31301, 6730, 23120, 6737, 31313, 23125, 23128, 6744, 14936, 14942, 31335, 23144, 14953, 14954, 23143, 31340, 14956, 6765, 14959, 23153, 31346, 6771, 14964, 6770, 31347, 6775, 6776, 23159, 14970, 14976, 14977, 23169, 14982, 14986, 31371, 31372, 23182, 14990, 6804, 31383, 23192, 15002, 15008, 6817, 15015, 6823, 31401, 6831, 23216, 31410, 15027, 23220, 6838, 15034, 6848, 23233, 23236, 31429, 23242, 15050, 6860, 15055, 6870, 15063, 31451, 31452, 15073, 23266, 23267, 31457, 6889, 31469, 31470, 6895, 23284, 31480, 23288, 15098, 15102, 6912, 31490, 6917, 23301, 15111, 31496, 31498, 31499, 31500, 6925, 6928, 23313, 23315, 31508, 31507, 15126, 31516, 23326, 31521, 6946, 6947, 6948, 31524, 6952, 31530, 31535, 31538, 31539, 6964, 6965, 31545, 6970, 6974, 23359, 15170, 6981, 15177, 15180, 23373, 31567, 15187, 31574, 6999, 23389, 7006, 31583, 7007, 23391, 23396, 7020, 7025, 15217, 23414, 15229, 31617, 15235, 15237, 23430, 31623, 31621, 23434, 7053, 15249, 23444, 15255, 23447, 7066, 31642, 15260, 23452, 7068, 23462, 15274, 7083, 23470, 7088, 7091, 15290, 23483, 7100, 7101, 15292, 15295, 7104, 23489, 15298, 31684, 23493, 23494, 15302, 31689, 31692, 15311, 15314, 31699, 31702, 31706, 15326, 23519, 31712, 23521, 7143, 31721, 15339, 7148, 31725, 7149, 15346, 23544, 7163, 23549, 15358, 7166, 7168, 23550, 23554, 23556, 15365, 23557, 31751, 31755, 31756, 23565, 31759, 7188, 23573, 31766, 15382, 31770, 23578, 31771, 15389, 15397, 31781, 15399, 7210, 23598, 15407, 31792, 15411, 7222, 15414, 23608, 23609, 23607, 15423, 31811, 7238, 7240, 31816, 23628, 7249, 23634, 23635, 7251, 31832, 7256, 15458, 31843, 23653, 31846, 15464, 7273, 23664, 31857, 7282, 23674, 23679, 15489, 7302, 7305, 15497, 7307, 15500, 7311, 31888, 23699, 31892, 15509, 15511, 15513, 23708, 23709, 31903, 23711, 23714, 31906, 23715, 7332, 31908, 31909, 7341, 23728, 7346, 7347, 23732, 23731, 31926, 15539, 7362, 23748, 15557, 7365, 23753, 31947, 15565, 15566, 23758, 15573, 15574, 31957, 23765, 15579, 31964, 15586, 31970, 15589, 7397, 7402, 7405, 23790, 23796, 7413, 7414, 7416, 31995, 15613, 32006, 23816, 23817, 15629, 15630, 23822, 23824, 7441, 15634, 32013, 23828, 23829, 23825, 32028, 23838, 23839, 7456, 15646, 7458, 32032, 15652, 32036, 7465, 15658, 32043, 32044, 15662, 7471, 23860, 7477, 7478, 32053, 23871, 32064, 23878, 23881, 7501, 23886, 23885, 7503, 32086, 7510, 32088, 15705, 7518, 15710, 7519, 32099, 32105, 23915, 15725, 23919, 32114, 32116, 7543, 7546, 32124, 32125, 7551, 15744, 32134, 23942, 32135, 7562, 32139, 15754, 7570, 15762, 7574, 15768, 23961, 7577, 23964, 32157, 7583, 32166, 23982, 32176, 15793, 32183, 15802, 32186, 15805, 23998, 32194, 7622, 24012, 32205, 15821, 7631, 32209, 15830, 15832, 32218, 32220, 7647, 24034, 32227, 15843, 24049, 7670, 7672, 15866, 24059, 7676, 24061, 15868, 15872, 32258, 7683, 15884, 24079, 7698, 32275, 15890, 7701, 24086, 7703, 32277, 24089, 15898, 24091, 32285, 24095, 24098, 15907, 32294, 32295, 7721, 32298, 24117, 32310, 24121, 7743, 7745, 24137, 15945, 7755, 32331, 7757, 32334, 7764, 15962, 7771, 15966, 7776, 32352, 15971, 15976, 24170, 7789, 7791, 15989, 32379, 32381, 7806, 32391, 7817, 17837, 24210, 24211, 7838, 24222, 16032, 7841, 7842, 16039, 7849, 24234, 32427, 24241, 16052, 24245, 32440, 16057, 16067, 7879, 16081, 24273, 16085, 32473, 24282, 16092, 24286, 7903, 24295, 7920, 32501, 32514, 24322, 16134, 16135, 7944, 24330, 32524, 32528, 24337, 16144, 32529, 7958, 7969, 16163, 24356, 32547, 7975, 24361, 32555, 7982, 16175, 16174, 16181, 7994, 7996, 32574, 7999, 24384, 8001, 32578, 8002, 24388, 24386, 8004, 16200, 8009, 16202, 16206, 8015, 32595, 32598, 32600, 32601, 24409, 8028, 8035, 24423, 8046, 8047, 16238, 16244, 32629, 16245, 8056, 16248, 8059, 24454, 32649, 16265, 32651, 24461, 32664, 24473, 16281, 32669, 16286, 24484, 24488, 8107, 32686, 16304, 8115, 24499, 8119, 16313, 8122, 16317, 32703, 16319, 8127, 8130, 24517, 8135, 8143, 24537, 32729, 24540, 32737, 8163, 16356, 8166, 32744, 24556, 8173, 8172, 16367, 8177, 16375, 8184, 8185, 16381, 24575]
indeces = [32771, 6, 16393, 32780, 8206, 8209, 18, 8210, 24598, 22, 24600, 24, 8217, 8220, 8222, 16415, 24609, 24614, 24615, 24624, 8241, 24627, 32820, 16439, 24632, 8249, 24634, 16444, 62, 32831, 16448, 67, 32837, 24647, 16462, 32851, 8276, 32853, 86, 8277, 24664, 83, 32858, 24667, 32860, 32861, 16478, 91, 98, 32871, 32872, 16489, 103, 32875, 32881, 32883, 32885, 16505, 8315, 123, 16515, 16517, 8325, 16520, 24713, 141, 8335, 32914, 24725, 8343, 151, 24730, 32924, 32926, 8352, 32929, 16544, 24741, 32934, 24744, 32936, 32944, 24754, 24755, 16563, 24757, 181, 16565, 16568, 32956, 188, 194, 8388, 16581, 32965, 24775, 24782, 8398, 24784, 209, 24786, 24787, 212, 8405, 16597, 216, 32987, 16605, 16606, 8415, 32990, 32991, 24802, 226, 24806, 24807, 231, 33001, 8422, 8423, 24811, 16621, 24808, 16623, 240, 241, 243, 8436, 251, 33021, 24829, 33023, 8451, 24836, 8457, 24841, 271, 24848, 33042, 8467, 276, 281, 24860, 16670, 16671, 292, 8484, 16679, 8487, 300, 302, 24879, 33072, 8508, 24892, 8511, 24896, 321, 8512, 322, 8516, 324, 24897, 33095, 16711, 24903, 33097, 8522, 8524, 24906, 24910, 335, 8528, 16721, 338, 33110]
# print(list(dataset_T[5]))
# client['set'] = DataLoader(FedDataset(train_dataset_t, indeces), batch_size=64, shuffle=True)
user_list = list(sorted(set(datasetUser[5])))
# print(user_list.index('00qhiaz'))
users = list(datasetUser[5])
# print(user_list)
user_ids = [user_list.index(user) for user in users]
print(user_ids)
# for data in enumerate(client['set']):
#     print(data)

In [None]:
datasetData = dataset_T.iloc[:,[0,1,2,3,4]]
datasetUser = dataset_T.iloc[:,[5]]
indx = [int(i) for i in indeces]
users = list(datasetUser[5])
user_list = list(sorted(set(datasetUser[5])))
data_transposed = datasetData.T
data = data_transposed[indx[0]]
type(data)


In [None]:
train_loader = DataLoader(FedDataset(dataset_T, indeces), batch_size=64, shuffle=True)
fed = FedDataset(dataset_T, indeces)
print(fed[0][0].shape)
# for data, target in train_loader:
#    print(data, target)

In [9]:
def mnistIID(dataIID, num_users, ind):
    images = int(len(dataIID)/num_users)
    # [i for i in range(len(dataIID))]
    users_dict, indeces = {}, ind
    for i in range(num_users):
        np.random.seed(i)
        users_dict[i] = set(np.random.choice(indeces, images, replace=False))
        indeces = list(set(indeces) - users_dict[i])
    return users_dict

In [10]:
def getActualImgs(datasetImg, indeces, batch_size):
    return DataLoader(FedDataset(datasetImg, indeces), batch_size=batch_size, shuffle=True)

In [11]:
# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_dataset = dataset_T.sample(frac=0.8, random_state=25)
test_dataset = dataset_T.drop(train_dataset.index)

train_dataset_t = train_dataset.T
test_dataset_t = test_dataset.T

In [None]:
print(dataset_T)

In [12]:
index_train = []
for col in train_dataset_t.columns:
    index_train.append(col)

index_test = []
for col in test_dataset_t.columns:
    index_test.append(col)

In [13]:
global_group = mnistIID(dataset_T, 10, [i for i in range(len(dataset_T))])
train_group = mnistIID(train_dataset, 10, index_train)
test_group = mnistIID(test_dataset, 10, index_test)

In [29]:
print(train_dataset.iloc[:,[0,1,2,3,4]])

              0         1         2         3         4
14397   0.01355  0.931519  0.060913  0.933606  0.933508
5634   0.167947  0.964855  0.009789  0.979412  0.964905
13880  0.028198   0.91687  0.215698  0.942322    0.9419
12346  0.066384   0.53117  0.838921  0.995156   0.99294
3087  -0.043458  -0.05298  0.993194  0.995555  0.994606
...         ...       ...       ...       ...       ...
13977  0.004272  0.911011  0.216187   0.93632   0.93631
16562 -0.269971  0.922315 -0.142272  0.971489  0.933223
17977 -0.194212  0.713387  0.676226  1.001959  0.982957
7106   0.062256  0.819702  0.576294  1.003943  1.002011
2051   0.086914  0.755859  0.662109  1.008596  1.004844

[17958 rows x 5 columns]


In [14]:
for inx, client in enumerate(clients):
    trainset_ind_list = list(train_group[inx])
    client['trainset'] = getActualImgs(train_dataset, trainset_ind_list, args.local_batches)
    client['testset'] = getActualImgs(test_dataset, list(test_group[inx]), args.local_batches)
    client['samples'] = len(trainset_ind_list) / args.images

In [34]:
x_data = train_dataset.iloc[:,[0,1,2,3,4]]
data_array = np.asarray(x_data)
x_data = data_array
print(x_data[0, :])

[0.013549805618822575 0.9315185546875 0.0609130859375 0.9336063511983121
 0.9335080191222346]


In [21]:
X = None
x_data = train_dataset.iloc[:,[0,1,2,3,4]]
data_array = np.asarray(x_data)
x_data = data_array
for i in range(len(x_data)):
    row = np.asarray(x_data[i, :])
    row = row.reshape(5, 1).T
    if X is None:
        X = np.zeros((len(x_data), 1, 5))
    X[i] = row
print(X.shape)


(17958, 1, 5)


In [15]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [16]:
def ClientUpdate(args, device, client):
    client['model'].train()
    client['model'].send(client['hook'])
    
    for epoch in range(1, args.epochs + 1):
        for batch_idx, (data, target) in enumerate(client['trainset']):
            data = data.send(client['hook'])
            target = target.send(client['hook'])
            
            data, target = data.to(device), target.to(device)
            client['optim'].zero_grad()
            output = client['model'](data)
            loss = F.nll_loss(output, target)
            loss.backward()
            client['optim'].step()
            
            if batch_idx % args.log_interval == 0:
                loss = loss.get() 
                print('Model {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    client['hook'].id,
                    epoch, batch_idx * args.local_batches, len(client['trainset']) * args.local_batches, 
                    100. * batch_idx / len(client['trainset']), loss))
                
    client['model'].get()

In [17]:
def test(args, model, device, test_loader, name):
    model.eval()   
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss for {} model: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        name, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [18]:
def averageModels(global_model, clients):
    client_models = [clients[i]['model'] for i in range(len(clients))]
    samples = [clients[i]['samples'] for i in range(len(clients))]
    global_dict = global_model.state_dict()
    
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() * samples[i] for i in range(len(client_models))], 0).sum(0)
            
    global_model.load_state_dict(global_dict)
    return global_model

In [20]:
torch.manual_seed(args.torch_seed)
global_model = Net()

In [19]:
torch.manual_seed(args.torch_seed)
global_model = Net()

for client in clients:
    torch.manual_seed(args.torch_seed)
    client['model'] = Net().to(device)
    client['optim'] = optim.SGD(client['model'].parameters(), lr=args.lr)

for fed_round in range(args.rounds):
    
#     uncomment if you want a randome fraction for C every round
#     args.C = float(format(np.random.random(), '.1f'))
    
    # number of selected clients
    m = int(max(args.C * args.clients, 1))

    # Selected devices
    np.random.seed(fed_round)
    selected_clients_inds = np.random.choice(range(len(clients)), m, replace=False)
    selected_clients = [clients[i] for i in selected_clients_inds]
    
    # Active devices
    np.random.seed(fed_round)
    active_clients_inds = np.random.choice(selected_clients_inds, int((1-args.drop_rate) * m), replace=False)
    active_clients = [clients[i] for i in active_clients_inds]
    
    # Training 
    for client in active_clients:
        ClientUpdate(args, device, client)
    
#     # Testing 
#     for client in active_clients:
#         test(args, client['model'], device, client['testset'], client['hook'].id)
    
    # Averaging 
    global_model = averageModels(global_model, active_clients)
    
    # Testing the average model
    test(args, global_model, device, global_test_loader, 'Global')
            
    # Share the global model with the clients
    for client in clients:
        client['model'].load_state_dict(global_model.state_dict())

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [20, 1, 5, 5], but got 2-dimensional input of size [64, 5] instead

In [None]:
trainset_ind_list = list(train_group[inx])
set = FedDataset(train_dataset_t, trainset_ind_list)
train_loader = DataLoader(dataset=set, batch_size=64, shuffle=True)
for data, target in train_loader:
    print(data, target)

In [23]:
for batch_idx, (data, target) in enumerate(client['testset']):
    print(data.shape)

torch.Size([64, 5])
torch.Size([64, 5])
torch.Size([64, 5])


KeyboardInterrupt: 