# Feature Engineering with SHAP values Experiment 2

## Google Colab

In [None]:
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive', force_remount=True)

import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks')
sys.path.append('/content/drive/My Drive/Colab Notebooks/federated_learning')

!pip install shap==0.40.0

In [None]:
import sklearn

## Experimental Setup

In [25]:
from federated_learning.utils import SHAPUtil, experiment_util, Visualizer
from federated_learning import ClientPlane, Configuration, ObserverConfiguration
from federated_learning.server import Server
from datetime import datetime

In [None]:
def cos_similarity_values(s_client, s_server):
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(shap_subtract):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(round(np.sum(image.flatten()), 3))

    print(np.matrix(cos_similarity))

In [None]:
from scipy import spatial
import numpy

In [None]:
scipy.__version__


In [None]:
numpy.__version__

In [11]:
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity_server = [[] for i in range(10)]
    cos_similarity_client = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_server):
        for img_idx, image in enumerate(row):
                cos_similarity_server[row_idx].append(np.sum(image.flatten()))
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity_client[row_idx].append(np.sum(image.flatten()))
    spatial.distance.cosine(np.array(cos_similarity_server).flatten(), np.array(cos_similarity_client).flatten())
    return spatial.distance.cosine(np.array(cos_similarity_server).flatten(), np.array(cos_similarity_client).flatten())

In [None]:
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    similarity_sum = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(spatial.distance.cosine(image.flatten(),s_server[row_idx][img_idx].flatten()))
                
    return np.sum(cos_similarity)
    


In [133]:
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    differences_sum = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(spatial.distance.cosine(image.flatten(),s_server[row_idx][img_idx].flatten()))
                differences_sum[row_idx].append(np.sum(shap_subtract[row_idx][img_idx].flatten()))
    
    return np.sum(cos_similarity), np.array(differences_sum).diagonal()[np.argmax(np.abs(np.array(differences_sum).diagonal()))]
    


In [111]:
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    similarity_sum = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(spatial.distance.cosine(image.flatten(),s_server[row_idx][img_idx].flatten()))
                similarity_sum[row_idx].append(np.sum(shap_subtract[row_idx][img_idx].flatten()))
    argmax = np.argmax(np.array(cos_similarity).diagonal())
    print(np.array(cos_similarity).diagonal()[argmax] * np.array(similarity_sum).diagonal()[argmax])
    return np.sum(cos_similarity), np.array(cos_similarity).diagonal().dot(np.array(similarity_sum).diagonal())
    


# MNIST

In [27]:
from federated_learning.nets import MNISTCNN
from federated_learning.dataset import MNISTDataset
import os
config = Configuration()
config.POISONED_CLIENTS = 0
config.DATA_POISONING_PERCENTAGE = 1
config.DATASET = MNISTDataset
config.MODELNAME = config.MNIST_NAME
config.NETWORK = MNISTCNN
observer_config = ObserverConfiguration()
observer_config.experiment_type = "shap_fl_poisoned"
observer_config.experiment_id = 1
observer_config.test = False
observer_config.datasetObserverConfiguration = "MNIST"
neutral_label = 2

In [None]:
# Google Colab Settigns
config.TEMP = os.path.join('/content/drive/My Drive/Colab Notebooks/temp')
config.FMNIST_DATASET_PATH = os.path.join('/content/data/fmnist')
config.MNIST_DATASET_PATH = os.path.join('/content/data/mnist')
config.CIFAR10_DATASET_PATH = os.path.join('/content/data/cifar10')
config.VM_URL = "none"

In [28]:
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader) 
server = Server(config, observer_config,data.train_dataloader, data.test_dataloader, shap_util)
visualizer = Visualizer(shap_util)

MNIST training data loaded.
MNIST test data loaded.


## Experiment Setup 

In [129]:
import numpy as np
import copy
import torch
import os
for i in range(200):
    if (i+1) in [2, 5,10,75,100,200]:
        file = "./temp/models/ex6/MNIST_round_{}.model".format(i+1)
        if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
        torch.save(server.net.state_dict(), file)
    experiment_util.set_rounds(client_plane, server, i+1)
    experiment_util.run_round(client_plane, server, i+1)


KeyboardInterrupt



## Experiment

In [140]:
import torch
config.FROM_LABEL = 5
config.TO_LABEL = 4
shap_images = [config.FROM_LABEL ,config.TO_LABEL]
for j in [100]:
    data = config.DATASET(config)
    client_plane = ClientPlane(config, observer_config, data, shap_util)
    model_file = file = "./temp/models/ex5/MNIST_round_{}.model".format(j)
    server.net =  MNISTCNN()
    server.net.load_state_dict(torch.load(model_file))
    client_plane.reset_default_client_nets()
    client_plane.reset_poisoning_attack()

    server.test()
    recall, precision, accuracy = server.analize_test()
    print("Original", recall, precision, accuracy)
    server_shap = server.get_shap_values()

    config.POISONED_CLIENTS = 100
    experiment_util.update_configs(client_plane, server, config, observer_config)
    print(len(client_plane.clients[0].train_dataloader.dataset.dataset.targets[client_plane.clients[0].train_dataloader.dataset.dataset.targets == 5]))

    client_plane.poison_clients()
    clean_clients = experiment_util.select_random_clean(client_plane, config, 100)
    poisoned_clients = experiment_util.select_poisoned(client_plane, 100)
    clean_distance = []
    poisoned_distance = []
    clean_diff = []
    poisoned_diff = []
    print("Clean")
    for idx, i in enumerate(clean_clients[:100]):
        client_plane.update_clients(server.get_nn_parameters())
        client_plane.clients[i].train(j+1)
        clean_client_shap = client_plane.clients[i].get_shap_values()
        distance, diag_diff = cos_similarity_values(clean_client_shap, server_shap)
        clean_distance.append(distance)
        clean_diff.append(diag_diff)
        if idx+1%100 == 0:
            print(clean_distance[idx-50:idx])

    print("Poisoned")
    server.net =  MNISTCNN()
    server.net.load_state_dict(torch.load(model_file))
    for idx, i in enumerate(poisoned_clients[:100]):
        client_plane.update_clients(server.get_nn_parameters())    
        client_plane.clients[i].train(j+1)
        poisoned_client_shap = client_plane.clients[i].get_shap_values()
        distance, diag_diff = cos_similarity_values(poisoned_client_shap, server_shap)
        poisoned_distance.append(distance)
        poisoned_diff.append(diag_diff)
        if idx+1%50 == 0:
            print(poisoned_distance[idx-50:idx])
    print(len(client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets[client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets == 5]))
    client_plane.reset_default_client_nets()
    client_plane.reset_poisoning_attack()

MNIST training data loaded.
MNIST test data loaded.
Create 200 clients with dataset of size 300
Load default model successfully
20/200 clients cleaned
40/200 clients cleaned
60/200 clients cleaned
80/200 clients cleaned
100/200 clients cleaned
120/200 clients cleaned
140/200 clients cleaned
160/200 clients cleaned
180/200 clients cleaned
200/200 clients cleaned
Cleaning successfully

Test set: Average loss: 0.0001, Accuracy: 9645/10000 (96%)

Original tensor([0.9929, 0.9833, 0.9554, 0.9436, 0.9674, 0.9664, 0.9760, 0.9582, 0.9415,
        0.9594]) tensor([0.9567, 0.9876, 0.9527, 0.9744, 0.9754, 0.9610, 0.9709, 0.9535, 0.9704,
        0.9416]) 0.9645
5421
Poison 100/200 clients
Flip 100.0% of the 5 labels to 4
[ 33  81 137 195 147 128  82  89  75  30  58  99 104 132  19 153 166 182
  17  64 140 108 129 194 133  62  91  12 191  97 124  61  60  73 145  95
  32 109 181  11  86  67 173  35 139 148 127  21  41 149 193 163  34 184
 159  50 141   7 198 138 168 179  76  70  71  77   0  79 155  1

In [141]:
print(clean_distance)

[18.33311352494669, 23.201678343879784, 16.184654211062686, 15.753102699421635, 17.237769835446684, 16.41960548315498, 14.297770719013418, 18.082085301588556, 18.165814276865433, 19.333698369169255, 14.290950886508444, 16.616772819004872, 16.62331876906565, 18.488916534262152, 16.237408904996627, 16.31826953899634, 12.971845955789176, 18.657121515868454, 17.752374864399496, 16.157832538230505, 19.614810299816487, 14.056635626478528, 19.204473060330663, 17.84103346477388, 17.46101353847109, 19.125957146284346, 18.436619259397887, 13.925163162366982, 17.421578389406687, 14.910086208149771, 16.32204796829452, 26.46779754019328, 18.11025503962577, 22.1195285426384, 14.707992002033606, 17.03319440257377, 13.651158673716207, 16.453761864108053, 13.887323484648766, 17.33445269838168, 15.999753891134795, 17.554021593557373, 18.404467291406576, 15.930888234851805, 18.461599735481517, 15.464536751911217, 11.433566015365667, 14.138210196727838, 14.089079386558211, 17.619592747772888, 17.782902130

In [142]:
print(clean_diff)

[-0.1486991124824284, -0.09459700397036475, -0.25965698778079727, -0.12642251307241992, -0.09902670330619046, -0.13587759386584963, 0.09182233302874598, 0.1290072166706988, -0.2664530125514446, 0.12443237856291978, 0.10966485969050699, 0.12555956736753826, -0.07346741377246158, -0.10069260680532821, -0.1621928865354203, 0.09359394104269003, 0.11083005834775861, -0.1311725171908975, 0.12013834046898442, 0.12561838441053075, 0.07483562668162591, 0.13117683252161738, -0.26076982479037625, -0.14952023221607647, 0.07340427908507663, -0.2305527574564401, -0.22171318711740895, 0.05787796575308235, -0.07682500404840198, 0.11486983960705066, 0.09649084180521061, -0.16695019813241363, 0.12655395347809806, -0.24953359489396476, 0.11406822177285703, 0.12162475848954957, -0.03889586702668202, -0.25722693447841394, -0.12044171266454218, -0.2031417731858891, -0.1318321368650911, 0.08241019193192489, -0.19807406868077287, -0.14477308975543768, 0.07131651112911097, 0.10765262723114644, 0.11196431947141

In [143]:
print(poisoned_distance)

[28.04329177855942, 22.93405699120005, 26.2990541616805, 31.190730349238464, 26.1395801053027, 26.02826795305603, 26.610442662843923, 29.948690987415024, 26.57743309560643, 20.14462525056864, 27.08486952423729, 27.39696382431955, 29.72890185947621, 23.985954493735797, 31.521471338803664, 24.037655784426835, 25.043033724471538, 27.81954652029426, 24.373356532018732, 28.484436636716737, 26.460630421096056, 25.709237106342368, 25.562342074149626, 30.342497501863292, 25.26216911372805, 24.720436185632092, 31.14751259727102, 31.626622404521026, 23.222494050134436, 28.567032546362622, 25.830699558957537, 23.498115835029374, 27.916143060260914, 27.43902306878096, 27.762799546549353, 28.18545516207597, 30.564485656712534, 27.093576917495763, 28.03259257940968, 26.103799786100293, 27.28949037104967, 25.732275693576465, 27.686556415286557, 22.558289562065987, 28.452092466335383, 25.05285428305678, 25.688484486136694, 23.4268288797131, 24.045595845195916, 23.417017701155448, 25.03938557523374, 28

In [144]:
print(poisoned_diff)

[-0.8583076200646338, -0.569406840438363, -0.6887880949712275, -0.8248883188208039, -0.7909074141930741, -0.8487678296163182, -0.8257749522394977, -0.7392819164821521, -0.6675609808983773, -0.78294910934676, -0.8685446707718462, -0.812936969756171, -0.5400314670506599, -0.8249583526042539, -0.7890450542218508, -0.7734220995465548, -0.8987628415913675, -0.7896936976664086, -0.8097004714397675, -0.8879299197159284, -0.7321554427470038, -0.8696349906937266, -0.7240002458072565, -0.7786036951565891, -0.8098786359367081, -0.7989976901228182, -0.8387975333634969, -0.7879335318114127, -0.7126637285200674, -0.7716509459939065, -0.8262713273915361, -0.8152598042895183, -0.8615041474496457, -0.8598192093446346, -0.7563358493446029, -0.7543082148027889, -0.8562372364832118, -0.6934838069804892, -0.5192433323732462, -0.6585372632992632, -0.8059805205846072, -0.8542938496379078, -0.6438503504821602, -0.6331999060895657, -0.8474947175094347, -0.9047619906487256, -0.826799505745771, -0.74679008790264

In [None]:
poisoned_distance = [26.304913895554467, 28.768744721294805, 24.496953324151264, 28.929839974703384, 30.041622105106505, 28.88000855001611, 22.01042583325992, 24.93373246362499, 30.829588181152587, 26.573307706218213, 29.11585749038093, 30.518763749735392, 30.470510815776617, 24.65846549799402, 27.435934338719495, 23.53978293493098, 28.903363247072992, 25.151103611455238, 24.19006292306411, 29.854129167928154, 24.24058826902404, 29.686370685796913, 23.114331351251575, 27.16563310375412, 28.005433684227697, 28.341367657991025, 20.87087301188547, 24.944001244921502, 23.613396812544753, 25.84197931185814, 25.42082457463699, 27.064802409278073, 27.169415787641523, 26.095332051218293, 23.527566277750083, 26.925788665094586, 29.32743750213389, 28.499650131001374, 23.99977229019223, 29.652332464932528, 25.86096832558339, 30.10597346446337, 27.9014589668157, 30.272807680104407, 24.57113404545461, 24.738084647096215, 28.760206459305433, 30.677700740328405, 24.90229486896232, 33.75485653580947, 22.540291472205496, 29.77253778093523, 27.579933265488734, 25.320884654836256, 26.085117157528366, 25.850617909167696, 21.02256957099018, 26.806294758842782, 28.504632803916568, 30.833328173032047, 27.27665372717368, 27.110192746325577, 31.178069113699802, 23.708229359061615, 24.98079209861254, 34.166631664311424, 24.558344975098933, 28.96271542587465, 26.53594528203352, 22.04241262903232, 24.97811138148062, 31.191584406990856, 27.94326613376075, 27.58328636428841, 33.778879817737675, 27.37516151125356, 23.45526477047413, 27.19223959911988, 24.79286884779608, 27.142371164110298, 25.502275680758405, 24.608368838030934, 30.77356991015325, 24.139080907858062, 27.66715793410768, 31.880198185299843, 28.310216677922924, 32.52280544078636, 26.41893615572775, 22.312492240637773, 25.812189401328244, 30.20124998256374, 30.233796252016315, 24.24945537877363, 28.21021020831755, 22.957231846731986, 26.68451811504838, 28.508423343006406, 26.234117161682775, 30.15992370832739]

In [53]:
import numpy as np
print(np.sum(np.array(poisoned_max), axis=1))

[4.96583783 4.61425022 5.15367397 3.70152706 2.54571858 1.75864037
 2.59281241 2.81507922 5.23826165 4.34160674 5.03724042 2.70055827
 3.57846082 3.63228413 2.45018664 3.13208237 2.24805678 2.61931077
 5.74107246 3.36515621 2.16078415 5.64317652 3.01233409 2.34380852
 4.84338039 3.41656894 2.07963981 2.45371543 2.60570177 2.29001258
 2.02774796 3.99754964 3.30733384 2.75574305 2.521787   2.22768301
 2.76707482 1.72522182 3.52160862 2.58059689 3.00327736 3.19242359
 2.1239961  2.99383676 2.33431952 4.17840191 4.64131956 2.5299296
 4.21550609 4.09566487 5.39436675 3.40356152 3.43110258 4.50041964
 7.15375384 2.98014552 6.4832904  2.08811485 3.52475133 2.97818815
 4.84765422 3.52755497 3.83821563 3.77678892 2.10025973 1.89496203
 2.10579291 1.77406834 3.8480871  5.08742284 2.0213579  2.10324355
 3.20411096 2.20759407 3.63113669 2.87623086 3.96098894 2.83212078
 3.00828264 4.38292435 2.08481165 2.0033856  3.40813707 2.45342042
 5.42144955 2.43159906 3.94711871 2.24582474 3.84685947 3.11461

In [56]:
print(np.sum(np.array(clean_max), axis=1))

[4.61987687 2.87527493 4.34872613 2.3255856  4.34142862 4.62431781
 4.39405453 4.67186702 4.63992103 2.4605668  3.22147007 2.84640077
 3.34296018 3.09998645 2.61574937 2.9956768  2.06389271 0.6960543
 3.01784519 1.88149864 4.48640393 2.60496103 1.90513327 2.6050237
 3.10644562 1.37041067 4.97972843 2.08663756 2.09207863 4.99201792
 4.31917687 2.73667933 5.12519207 4.8635525  2.47924368 4.39613882
 4.13080474 4.30494006 3.53309687 3.50285118 5.64702299 2.50741676
 2.43210012 0.81117507 2.26143381 3.12050934 2.68081032 4.03967061
 6.38536298 3.81770291 4.72964121 2.31097171 2.46208599 5.05134857
 3.53949554 2.88105052 3.99258742 4.50805574 1.98121158 2.94409352
 2.06488533 2.53904899 3.83569395 1.18956072 3.97865149 2.49318799
 2.61769346 2.83892953 1.37637701 5.91223576 2.65464378 2.10875774
 3.42484292 3.4002331  5.15981937 3.98070208 2.08431699 5.07535474
 1.86347556 2.15129874 1.80239022 4.04653238 3.42562101 2.2507038
 3.55772528 1.76769577 5.42315649 3.00588771 2.69420352 3.0771183

In [42]:
min(clean_min)

[0.001277748437136239,
 0.37562921319451104,
 0.024410925266441397,
 0.017548063231149436,
 0.0013775012023081734,
 0.0009688085656576195,
 0.0012834730540204342,
 0.0046157339802048725,
 1.102438259910724,
 0.189220115997923]

In [40]:
max(poisoned_max)

[1.5471883774845747,
 1.143223064492457,
 0.13190825255036587,
 0.05471759821022815,
 0.0831212692487796,
 0.05179054200449862,
 0.057788786175582696,
 0.7154917580150293,
 0.030907432766011156,
 1.031517143332881]

In [36]:
print(poisoned_distance)


[30.755455253172272, 25.032426809456844, 30.139292061492117, 23.115211217099095, 24.155482546976998, 26.49177395325203, 24.04393426092334, 24.06141420773503, 23.42248388128856, 27.929969598658953, 25.29847049968215, 27.378213281060386, 25.650364409837888, 22.108527414659036, 24.88305289963074, 25.61555750654892, 23.514214071239884, 26.218776111140745, 18.689749127117214, 25.446594608229432, 22.934166023952542, 20.69154779561124, 21.117118541926956, 22.79689225667117, 28.25574048326921, 20.807160639199594, 24.39117699226789, 26.200097994789104, 25.98545534431143, 22.391670982395496, 24.2531107247435, 27.2166562311953, 16.188746168600527, 24.611823576150066, 24.772874681199337, 22.5107108227827, 26.758785422101038, 22.307644265229367, 20.52478105908785, 25.721204920067258, 25.174638904521274, 25.410079023881117, 28.584298505819646, 21.32697883080845, 27.85354628196658, 25.500484846525246, 24.22724135281804, 23.125231425804746, 21.79859987507048, 23.339960722091643, 25.704528498756584, 20

In [37]:
print(clean_distance)

[22.116639855287918, 17.774297429467715, 19.35761610162297, 18.330716405062322, 20.674600970747644, 17.47788222261395, 20.625261240838718, 17.846328787417406, 15.10062720731522, 13.736695388660236, 14.38184339517803, 16.117004489371546, 18.900725716989083, 14.424026077485538, 16.37803588326923, 17.145778382069693, 16.45641617889178, 13.53036707169336, 13.462994937275468, 15.63340457004331, 16.889812464787234, 20.599922771785085, 18.342640450421786, 12.097556153110016, 14.621062456211236, 16.290222234013616, 19.842198931578075, 17.805880036438303, 17.401179721982018, 16.03788557560193, 20.366592390888293, 17.608049777609747, 14.500922777031363, 19.215913578984267, 17.788930416082216, 15.940645952173657, 19.65821206999173, 16.378655643285583, 14.547973749189502, 16.121043469077204, 18.828194573186984, 17.817576917246985, 14.197214396406235, 16.266961717424337, 17.11472634160036, 12.60610394623252, 15.716909293700272, 16.828630370785074, 22.218944435719408, 15.480600329114425, 20.98729837

In [24]:
print(poisoned_distance)

[24.051489495607896, 28.831008810765415, 25.632336981977335, 24.229990744870594, 23.133240916840386, 22.22363146163848, 28.73751030669847, 28.95507892990102, 27.745664489933553, 24.903737713953863, 25.027305738934977, 23.563431077466078, 24.868272206672213, 25.68136861587653, 21.965448228250217, 28.1829270040029, 27.765861896010946, 25.532401264237375, 22.700614732774035, 26.107819808253137, 20.366026567071454, 24.62883696510425, 29.50028982703323, 28.021064801743215, 28.999130073342315, 28.649752876317088, 27.204765189545064, 30.646384915713956, 24.55192724405848, 23.042745919181844, 22.468438231293433, 25.455418391059258, 20.873184739433604, 19.203837120913466, 25.638018330352995, 26.935628402720848, 24.18904194962456, 18.88016544836806, 23.717556935998314, 29.3407845948178, 30.154655846739143, 21.47844368557781, 27.383255139936125, 23.39665299612276, 26.703127682016635, 25.608637485195928, 23.127303984439013, 27.920089754701095, 23.133168816137143, 23.81959287251483, 23.138536061813

# Fashion MNIST

In [15]:
from federated_learning.nets import FMNISTCNN
from federated_learning.dataset import FMNISTDataset
import os
config = Configuration()
config.POISONED_CLIENTS = 0
config.DATA_POISONING_PERCENTAGE = 1
config.DATASET = FMNISTDataset
config.MODELNAME = config.FMNIST_NAME
config.NETWORK = FMNISTCNN
observer_config = ObserverConfiguration()
observer_config.experiment_type = "shap_fl_poisoned"
observer_config.experiment_id = 1
observer_config.test = False
observer_config.datasetObserverConfiguration = "MNIST"
neutral_label = 2

In [None]:
# Google Colab Settigns
config.TEMP = os.path.join('/content/drive/My Drive/Colab Notebooks/temp')
config.FMNIST_DATASET_PATH = os.path.join('/content/data/fmnist')
config.MNIST_DATASET_PATH = os.path.join('/content/data/mnist')
config.CIFAR10_DATASET_PATH = os.path.join('/content/data/cifar10')
config.VM_URL = "none

In [14]:
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader) 
server = Server(config, observer_config,data.train_dataloader, data.test_dataloader, shap_util)
client_plane = ClientPlane(config, observer_config, data, shap_util)
visualizer = Visualizer(shap_util)

MNIST training data loaded.
MNIST test data loaded.
Create 200 clients with dataset of size 300


In [None]:
import numpy as np
import copy
import torch
import os
for i in range(200):
    if (i+1) in [2, 5,10,75,100,200]:
        file = "/content/drive/My Drive/Colab Notebooks/temp/models/ex6/FMNIST_round_{}.model".format(i+1)
        if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
        torch.save(server.net.state_dict(), file)
    experiment_util.set_rounds(client_plane, server, i+1)
    experiment_util.run_round(client_plane, server, i+1)

In [1]:
import torch
config.FROM_LABEL = 5
config.TO_LABEL = 4
shap_images = [config.FROM_LABEL ,config.TO_LABEL]

for j in [100]:
    data = config.DATASET(config)
    client_plane = ClientPlane(config, observer_config, data, shap_util)
    model_file = file = "./temp/models/ex5/FMNIST_round_{}.model".format(j)
    server.net =  FMNISTCNN()
    server.net.load_state_dict(torch.load(model_file))
    server.test()
    recall, precision, accuracy = server.analize_test()
    print("Original", recall, precision, accuracy)
    server_shap = server.get_shap_values()
    print(len(client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets[client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets == 5]))
    
    config.POISONED_CLIENTS = 100
    experiment_util.update_configs(client_plane, server, config, observer_config)
    client_plane.poison_clients()
    clean_clients = experiment_util.select_random_clean(client_plane, config, 100)
    poisoned_clients = experiment_util.select_poisoned(client_plane, 100)
    clean_distance = []
    poisoned_distance = []
    print("Clean")
    for i in clean_clients[:2]:
        client_plane.update_clients(server.get_nn_parameters())
        client_plane.clients[i].train(j+1)
        clean_client_shap = client_plane.clients[i].get_shap_values()
        distance = cos_similarity_values(clean_client_shap, server_shap)
        clean_distance.append(distance)
    print(len(client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets[client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets == 5]))

    print("Poisoned")
    server.net =  FMNISTCNN()
    server.net.load_state_dict(torch.load(model_file))
    for i in poisoned_clients[:2]:
        client_plane.update_clients(server.get_nn_parameters())    
        client_plane.clients[i].train(j+1)
        poisoned_client_shap = client_plane.clients[i].get_shap_values()
        poisoned_distance.distance = cos_similarity_values(poisoned_client_shap, server_shap)
    client_plane.reset_default_client_nets()
    client_plane.reset_poisoning_attack()
    print(client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets[client_plane.clients[poisoned_clients[0]].poisoning_indices])
    print(len(client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets[client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets == 5]))

NameError: name 'config' is not defined