# Feature Engineering with SHAP values Experiment 2

## Google Colab

In [None]:
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive', force_remount=True)

import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks')
sys.path.append('/content/drive/My Drive/Colab Notebooks/federated_learning')

!pip install shap==0.40.0

In [None]:
import sklearn

## Experimental Setup

In [1]:
from federated_learning.utils import SHAPUtil, experiment_util, Visualizer
from federated_learning import ClientPlane, Configuration, ObserverConfiguration
from federated_learning.server import Server
from datetime import datetime

In [None]:
def cos_similarity_values(s_client, s_server):
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(shap_subtract):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(round(np.sum(image.flatten()), 3))

    print(np.matrix(cos_similarity))

In [None]:
from scipy import spatial
import numpy

In [None]:
scipy.__version__


In [None]:
numpy.__version__

In [None]:
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity_server = [[] for i in range(10)]
    cos_similarity_client = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_server):
        for img_idx, image in enumerate(row):
                cos_similarity_server[row_idx].append(np.sum(image.flatten()))
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity_client[row_idx].append(np.sum(image.flatten()))
    spatial.distance.cosine(np.array(cos_similarity_server).flatten(), np.array(cos_similarity_client).flatten())
    return spatial.distance.cosine(np.array(cos_similarity_server).flatten(), np.array(cos_similarity_client).flatten())

In [None]:
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    similarity_sum = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(spatial.distance.cosine(image.flatten(),s_server[row_idx][img_idx].flatten()))
                
    return np.sum(cos_similarity)
    


In [None]:
# Works for MNIST
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    differences_sum = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(spatial.distance.cosine(image.flatten(),s_server[row_idx][img_idx].flatten()))
                differences_sum[row_idx].append(np.sum(shap_subtract[row_idx][img_idx].flatten()))
    
    return np.sum(cos_similarity), np.array(differences_sum).diagonal()[np.argmax(np.abs(np.array(differences_sum).diagonal()))]
    


In [None]:
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    similarity_sum = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(spatial.distance.cosine(image.flatten(),s_server[row_idx][img_idx].flatten()))
                similarity_sum[row_idx].append(np.sum(shap_subtract[row_idx][img_idx].flatten()))
    argmax = np.argmax(np.array(cos_similarity).diagonal())
    print(np.array(cos_similarity).diagonal()[argmax] * np.array(similarity_sum).diagonal()[argmax])
    return np.sum(cos_similarity), np.array(cos_similarity).diagonal().dot(np.array(similarity_sum).diagonal())
    


In [13]:
def cos_similarity_values(s_client, s_server):
    from scipy import spatial
    import numpy as np
    cos_similarity = [[] for i in range(10)]
    differences_sum = [[] for i in range(10)]
    shap_subtract = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(s_client):
        for img_idx, image in enumerate(row):
                cos_similarity[row_idx].append(spatial.distance.cosine(image.flatten(),s_server[row_idx][img_idx].flatten()))
                differences_sum[row_idx].append(np.sum(shap_subtract[row_idx][img_idx].flatten()))
    #print(cos_similarity[5][5], cos_similarity[5][4],cos_similarity[4][5])
    return np.sum(cos_similarity), np.max(cos_similarity), np.array(differences_sum).flatten()[np.argmax(np.abs(np.array(differences_sum).flatten()))]
    


# MNIST

In [3]:
from federated_learning.nets import MNISTCNN
from federated_learning.dataset import MNISTDataset
import os
config = Configuration()
config.POISONED_CLIENTS = 0
config.DATA_POISONING_PERCENTAGE = 1
config.DATASET = MNISTDataset
config.MODELNAME = config.MNIST_NAME
config.NETWORK = MNISTCNN
observer_config = ObserverConfiguration()
observer_config.experiment_type = "shap_fl_poisoned"
observer_config.experiment_id = 1
observer_config.test = False
observer_config.datasetObserverConfiguration = "MNIST"
neutral_label = 2

In [None]:
# Google Colab Settigns
config.TEMP = os.path.join('/content/drive/My Drive/Colab Notebooks/temp')
config.FMNIST_DATASET_PATH = os.path.join('/content/data/fmnist')
config.MNIST_DATASET_PATH = os.path.join('/content/data/mnist')
config.CIFAR10_DATASET_PATH = os.path.join('/content/data/cifar10')
config.VM_URL = "none"

In [4]:
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader) 
server = Server(config, observer_config,data.train_dataloader, data.test_dataloader, shap_util)
visualizer = Visualizer(shap_util)

MNIST training data loaded.
MNIST test data loaded.


## Experiment Setup 

In [None]:
import numpy as np
import copy
import torch
import os
for i in range(200):
    if (i+1) in [2, 5,10,75,100,200]:
        file = "./temp/models/ex6/MNIST_round_{}.model".format(i+1)
        if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
        torch.save(server.net.state_dict(), file)
    experiment_util.set_rounds(client_plane, server, i+1)
    experiment_util.run_round(client_plane, server, i+1)

## Experiment

In [14]:
import torch
config.FROM_LABEL = 3
config.TO_LABEL = 8
shap_images = [config.FROM_LABEL ,config.TO_LABEL]
for j in [100]:
    data = config.DATASET(config)
    client_plane = ClientPlane(config, observer_config, data, shap_util)
    model_file = file = "./temp/models/ex6/MNIST_round_{}.model".format(j)
    server.net =  MNISTCNN()
    server.net.load_state_dict(torch.load(model_file))

    server.test()
    recall, precision, accuracy = server.analize_test()
    print("Original", recall, precision, accuracy)
    server_shap = server.get_shap_values()

    config.POISONED_CLIENTS = 100
    experiment_util.update_configs(client_plane, server, config, observer_config)
    print(len(client_plane.clients[0].train_dataloader.dataset.dataset.targets[client_plane.clients[0].train_dataloader.dataset.dataset.targets == 5]))

    client_plane.poison_clients()
    clean_clients = experiment_util.select_random_clean(client_plane, config, 100)
    poisoned_clients = experiment_util.select_poisoned(client_plane, 100)
    clean_distance = []
    poisoned_distance = []
    clean_diff = []
    poisoned_diff = []
    clean_max = []
    poisoned_max = []
    print("Clean")
    print(len(client_plane.clients[0].train_dataloader.dataset.dataset.targets[client_plane.clients[0].train_dataloader.dataset.dataset.targets == 5]))
    for idx, i in enumerate(clean_clients[:100]):
        client_plane.update_clients(server.get_nn_parameters())
        client_plane.clients[i].train(j+1)
        clean_client_shap = client_plane.clients[i].get_shap_values()
        distance, distance_max, diag_diff = cos_similarity_values(clean_client_shap, server_shap)
        clean_distance.append(distance)
        clean_max.append(distance_max)
        clean_diff.append(diag_diff)
        if (idx+1)%25 == 0:
            print(clean_distance[idx-25:idx])

    print("Poisoned")
    server.net =  MNISTCNN()
    server.net.load_state_dict(torch.load(model_file))
    for idx, i in enumerate(poisoned_clients[:100]):
        client_plane.update_clients(server.get_nn_parameters())    
        client_plane.clients[i].train(j+1)
        poisoned_client_shap = client_plane.clients[i].get_shap_values()
        distance, distance_max, diag_diff  = cos_similarity_values(poisoned_client_shap, server_shap)
        poisoned_distance.append(distance)
        poisoned_max.append(distance_max)
        poisoned_diff.append(diag_diff)
        if (idx+1)%25 == 0:
            print(poisoned_distance[idx-25:idx])
    print(len(client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets[client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets == 5]))
    client_plane.reset_default_client_nets()
    client_plane.reset_poisoning_attack()

MNIST training data loaded.
MNIST test data loaded.
Create 200 clients with dataset of size 300

Test set: Average loss: 0.0002, Accuracy: 9625/10000 (96%)

Original tensor([0.9929, 0.9833, 0.9612, 0.9525, 0.9562, 0.9720, 0.9729, 0.9523, 0.9446,
        0.9366]) tensor([0.9615, 0.9867, 0.9350, 0.9649, 0.9812, 0.9527, 0.9739, 0.9459, 0.9664,
        0.9565]) 0.9625
5421
Poison 100/200 clients
Flip 100.0% of the 3 labels to 8
[112  85  73 138 118 134 182 123 162   7  31  57 178  51 121 192 193  56
  49  59   6 141 154  35 107 190  10  41  83 180  50 110  40 194 173 113
 140  17  27  52 163 177  70  24  92 165  62  36 169  84 195  30 120  47
  67 171 157  82 139  63  11  81  58  69  99  77 179   9 147 151 124 127
  76 117  12   3 145  21  53   5  13 148  78 115 129   2   4 161 143 181
  68  55  97  90 186  29 184  66  79  74]
20/100 clients poisoned
40/100 clients poisoned
60/100 clients poisoned
80/100 clients poisoned
100/100 clients poisoned
Clean
5421
[]
[21.53802743052501, 14.7176652

In [21]:
print(clean_distance)

[22.489031099374284, 25.389126715709708, 21.156146163699535, 17.2354870405107, 20.249273288241845, 19.02144938901933, 21.16966604980584, 20.18016334036071, 17.905544322780884, 22.91513698386014, 21.67940917028987, 25.955259338133487, 14.909112519451298, 20.62448134173489, 17.416342316982856, 20.132771917260015, 15.323234149343099, 22.01019951838938, 17.93250889301105, 16.76015926625974, 16.85429805129324, 17.83423937392598, 20.266815299072682, 17.331412308294105, 21.53802743052501, 14.717665282654568, 24.849916971755047, 21.296506387231116, 20.390169737836327, 18.948457979007664, 18.543202275854014, 23.747372249938746, 18.56792489073116, 21.872255216746776, 19.379412021748816, 19.985785039684977, 23.697747653282143, 17.382529968830077, 17.805282352518063, 19.611251980369, 21.592436629411335, 17.81769483692241, 18.28023643400877, 23.07714262299427, 24.304775348021845, 16.933359275361887, 22.139604338804915, 24.530775372174283, 24.12473298919851, 16.774431491986785, 18.95285159027638, 21

In [22]:
print(clean_max)

[1.7037140444704053, 1.8212052853028249, 1.722543421136047, 1.6703132708239545, 1.8244046769882987, 1.7772961270929855, 1.8215329789594272, 1.846137706455229, 1.8436783203747433, 1.8741908383106822, 1.7702468131343236, 1.7329100188855133, 1.6294371566825392, 1.6940465620002005, 1.6718372614872274, 1.7290118717033005, 1.5604934214605677, 1.849269747081061, 1.831255268584905, 1.8335170850104578, 1.6075701313111914, 1.7906538696320071, 1.7860509254431922, 1.7786319192156865, 1.864957162119978, 1.6585525822423088, 1.8632254811000986, 1.7750302433716958, 1.8145748541595104, 1.7660998502446736, 1.8223475661823976, 1.808749286290547, 1.6254270498661536, 1.7455196745372614, 1.5931944589987075, 1.8605244289883371, 1.7642219008659568, 1.7271607008220977, 1.722034170168249, 1.8515276547630943, 1.8238100060097096, 1.7302536349943045, 1.7542582973481458, 1.6833894259533402, 1.773211775555943, 1.6917004129828461, 1.8951202924952208, 1.8132232605546483, 1.6681464259635992, 1.4533184692407943, 1.82505

In [23]:
print(clean_diff)

[-0.2174609984935092, -0.21996760097779688, 0.2126738049142869, 0.10397319479350609, 0.12412082690897108, 0.08246350674035607, 0.1640793020690925, -0.16341254063287636, -0.04978197288676256, -0.20869143725699146, 0.1186135626745024, -0.2778386669482451, 0.10795437987851919, -0.1258420522469894, 0.10176003279366563, -0.10518846367407875, 0.06998720077426857, -0.1467135997068958, -0.13344819569012145, 0.15148675179977822, -0.1645790237838618, -0.22670852441157407, 0.11586790221347615, 0.08875842090764507, 0.1191578836057989, 0.11593880948385038, -0.13372938839793136, 0.06563855657488116, 0.09536874405256701, -0.11640088289032136, -0.037954999759055585, -0.16200940048580215, -0.17584638632495242, -0.14116492699268335, 0.04268532073999154, 0.17506026557051335, -0.0970046814547807, -0.0720857823529415, -0.2615241406742892, 0.12823684750005704, -0.14701387811938682, 0.18805509682969124, 0.1196385418681789, -0.22925938077871777, 0.23490235430428008, 0.12429617251466674, -0.29798060080895894, 

In [24]:
print(poisoned_distance)

[26.602832886891836, 25.064650515857466, 26.568569636145345, 25.679844187111136, 22.91643237067186, 22.468062881321504, 25.642062536804406, 21.248108609973116, 20.18608766111264, 22.759775013860832, 25.731566481932187, 23.521748472987262, 23.521893634215523, 24.586051265852817, 22.687439167154196, 24.560314572571908, 20.287332213278845, 29.442123127300615, 18.67328303476978, 20.12849561176915, 24.448497107748548, 21.144182006146607, 19.225379047406477, 25.976056458911597, 23.083101693307505, 21.82774555101824, 27.330097156354626, 25.87259208482654, 25.139267029663976, 26.4014933203969, 23.43837911048436, 25.70393730388087, 22.6358558860246, 22.605726061445804, 24.803048381944983, 23.951535675679335, 25.344558231599372, 25.029682720081663, 27.473386648511624, 24.323786519357952, 22.10297175893764, 22.22706703815827, 24.264385454799317, 23.482146231217207, 27.781233964347706, 27.12747596633376, 21.30138265076058, 28.073422887454306, 18.760802528510684, 21.70254137175454, 18.1732186192342

In [25]:
print(poisoned_max)

[1.5790410054289084, 1.6909148075291718, 1.801635350167539, 1.726261664632104, 1.7473397351512878, 1.7072757807141756, 1.7582304898752135, 1.8393777045804631, 1.6621946972122401, 1.5145610176503517, 1.671755396639302, 1.6627959035875475, 1.7224261125719291, 1.799177824707015, 1.695652173359938, 1.7028394154982398, 1.5966445180411268, 1.8002510544179464, 1.766099610296949, 1.7880002890541489, 1.8771853836004362, 1.8087211197147686, 1.7752918242327387, 1.6834448700850195, 1.6194553510309277, 1.6408066815613458, 1.8371179601373264, 1.6692071388167324, 1.6372854847910043, 1.8402449726947903, 1.8337995265416742, 1.8139074542474678, 1.7621321660262113, 1.9136778126457537, 1.7074115690057123, 1.9451462955457992, 1.6709482000294682, 1.7729744988605258, 1.6904985109338386, 1.8169866389098324, 1.823519777803484, 1.6899448495077112, 1.670847457434856, 1.7170162011598753, 1.6302284584726545, 1.9770377205777208, 1.7588119688991823, 1.8782836093039748, 1.8438394705333763, 1.7655442091302347, 1.80064

In [26]:
print(poisoned_diff)

[-0.7563950436988235, -0.8048668864204741, -0.8077063644849067, -0.7800255312476436, -0.8035965592474724, -0.8075231086937655, -0.7861927711763764, 0.8130452063735374, -0.7894938928461284, -0.8043363051973857, -0.8266077902385406, -0.8039124141336071, -0.8007651375288453, -0.8172352505984338, -0.7993900370719538, -0.7848594483811269, -0.8243533155194019, -0.7602903487291468, -0.8065425397725959, -0.7730344354428611, -0.8258196306736092, -0.8032267966409186, -0.7925244275507001, 0.8283074971884405, -0.800611811034365, -0.8076797592152283, -0.800073904514862, -0.8045526169528417, -0.8218668827811351, -0.7877063743339053, -0.8028870309604104, -0.7948654926099865, 0.8289515245770251, -0.8109219561049412, -0.7995666184683567, -0.8085030180951216, -0.8316123851477427, -0.8160876977650187, -0.8244079708158717, -0.8057572653045603, -0.8186728669899999, 0.8132849125896204, -0.8200696298043069, -0.7558944658687872, -0.8013238231449524, -0.8241910213183122, -0.803166509170354, -0.787161513181263,

# Fashion MNIST

In [None]:
from federated_learning.nets import FMNISTCNN
from federated_learning.dataset import FMNISTDataset
import os
config = Configuration()
config.POISONED_CLIENTS = 0
config.DATA_POISONING_PERCENTAGE = 1
config.DATASET = FMNISTDataset
config.MODELNAME = config.FMNIST_NAME
config.NETWORK = FMNISTCNN
observer_config = ObserverConfiguration()
observer_config.experiment_type = "shap_fl_poisoned"
observer_config.experiment_id = 1
observer_config.test = False
observer_config.datasetObserverConfiguration = "MNIST"
neutral_label = 2

In [None]:
# Google Colab Settigns
config.TEMP = os.path.join('/content/drive/My Drive/Colab Notebooks/temp')
config.FMNIST_DATASET_PATH = os.path.join('/content/data/fmnist')
config.MNIST_DATASET_PATH = os.path.join('/content/data/mnist')
config.CIFAR10_DATASET_PATH = os.path.join('/content/data/cifar10')
config.VM_URL = "none"

In [None]:
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader) 
server = Server(config, observer_config,data.train_dataloader, data.test_dataloader, shap_util)
client_plane = ClientPlane(config, observer_config, data, shap_util)
visualizer = Visualizer(shap_util)

In [None]:
import numpy as np
import copy
import torch
import os
for i in range(200):
    if (i+1) in [2, 5,10,75,100,200]:
        file = "/content/drive/My Drive/Colab Notebooks/temp/models/ex6/FMNIST_round_{}.model".format(i+1)
        if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
        torch.save(server.net.state_dict(), file)
    experiment_util.set_rounds(client_plane, server, i+1)
    experiment_util.run_round(client_plane, server, i+1)

In [None]:
import torch
config.FROM_LABEL = 5
config.TO_LABEL = 4
shap_images = [config.FROM_LABEL ,config.TO_LABEL]
for j in [100]:
    data = config.DATASET(config)
    client_plane = ClientPlane(config, observer_config, data, shap_util)
    model_file = file = "/content/drive/My Drive/Colab Notebooks/temp/models/ex6/FMNIST_round_{}.model".format(j)
    server.net =  FMNISTCNN()
    server.net.load_state_dict(torch.load(model_file))

    server.test()
    recall, precision, accuracy = server.analize_test()
    print("Original", recall, precision, accuracy)
    server_shap = server.get_shap_values()

    config.POISONED_CLIENTS = 100
    experiment_util.update_configs(client_plane, server, config, observer_config)
    print(len(client_plane.clients[0].train_dataloader.dataset.dataset.targets[client_plane.clients[0].train_dataloader.dataset.dataset.targets == 5]))

    client_plane.poison_clients()
    clean_clients = experiment_util.select_random_clean(client_plane, config, 100)
    poisoned_clients = experiment_util.select_poisoned(client_plane, 100)
    clean_distance = []
    poisoned_distance = []
    clean_diff = []
    poisoned_diff = []
    print("Clean")
    print(len(client_plane.clients[0].train_dataloader.dataset.dataset.targets[client_plane.clients[0].train_dataloader.dataset.dataset.targets == 5]))
    for idx, i in enumerate(clean_clients[:100]):
        client_plane.update_clients(server.get_nn_parameters())
        client_plane.clients[i].train(j+1)
        clean_client_shap = client_plane.clients[i].get_shap_values()
        distance, diag = cos_similarity_values(clean_client_shap, server_shap)
        clean_distance.append(distance)
        clean_diff.append(diag)
        if (idx+1)%25 == 0:
            print(clean_distance[idx-25:idx])

    print("Poisoned")
    server.net =  FMNISTCNN()
    server.net.load_state_dict(torch.load(model_file))
    for idx, i in enumerate(poisoned_clients[:100]):
        client_plane.update_clients(server.get_nn_parameters())    
        client_plane.clients[i].train(j+1)
        poisoned_client_shap = client_plane.clients[i].get_shap_values()
        distance, diag = cos_similarity_values(poisoned_client_shap, server_shap)
        poisoned_distance.append(distance)
        poisoned_diff.append(diag)
        if (idx+1)%25 == 0:
            print(poisoned_distance[idx-25:idx])
    print(len(client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets[client_plane.clients[poisoned_clients[0]].train_dataloader.dataset.dataset.targets == 5]))
    client_plane.reset_default_client_nets()
    client_plane.reset_poisoning_attack()

In [None]:
print(clean_distance)

In [None]:
print(poisoned_distance)

In [None]:
print(clean_diff)

In [None]:
print(poisoned_diff)