In [67]:
import torch.optim as optim
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import prettytable as pt
%run helpers_preproc.ipynb
%matplotlib widget
%run getExamples.ipynb

# Hyper Parameters

In [68]:
origDataDir = 'SHREC11/'
newDataDir = 'SHREC11_demo/'
K = 5
# radiuss = [0.0002,0.002,0.01]
train_size = 400 # * len(radiuss)
test_size = 200 # * len(radiuss)

# obj1 = 15
# obj2 = 17

# 1. expand dataset and scramble

In [69]:
# expand(origDataDir, newDataDir, radiuss)
# pick_copy(origDataDir, newDataDir, obj1, obj2)
scramble(newDataDir, 400, 600)

# 2. read in the dataset

## 2.1 read in v (252 x 3),f (...x...) and cluster

In [70]:
ver_list = []
adj_list = []
Nadj_list = []
gMat_list = []
gLbl_list = []

for i in range(600):
    v, f = get_nodes(newDataDir + 'T' + str(i) + '.obj')
    group_mat, group_labels = K_mean_cluster(v, K)
    
    adj = get_adj_from_f(f)            
    normed_adj = adj / np.reshape(np.sum(adj,axis = 0),[252,1])
    ver_list.append(v)
    adj_list.append(adj)
    Nadj_list.append(normed_adj)
    gMat_list.append(group_mat)
    gLbl_list.append(group_labels)
    
    
vers = np.array(ver_list)
adjs = np.array(adj_list)
Nadjs = np.array(Nadj_list)
gMats = np.array(gMat_list)
gLbls = np.array(gLbl_list)

  normed_adj = adj / np.reshape(np.sum(adj,axis = 0),[252,1])


## 2.2 Read in the labels

In [71]:
label_np = np.array(readLbl(600 ,newDataDir+'labels.txt'))
label_np_train = label_np[:train_size]
label_np_test = label_np[train_size:]

## 2.3 Split into training set and testing set

In [72]:
vers_train = torch.tensor(vers[:train_size]).float()
adjs_train = torch.tensor(adjs[:train_size]).float()
nadjs_train = torch.nan_to_num(torch.tensor(Nadjs[:train_size]).float(),0,0,0)
gMats_train = torch.tensor(gMats[:train_size]).float()
label_train = torch.tensor(label_np_train).float()

vers_test = torch.tensor(vers[train_size:]).float()
adjs_test = torch.tensor(adjs[train_size:]).float()
nadjs_test = torch.nan_to_num(torch.tensor(Nadjs[train_size:]).float(),0,0,0)
gMats_test = torch.tensor(gMats[train_size:]).float()
label_test = torch.tensor(label_np_test).float()

label_mat_train = torch.tensor(np.where(igl.all_pairs_distances(label_np_train,label_np_train,False) > 0.5,0,1)).float()
label_mat_test = torch.tensor(np.where(igl.all_pairs_distances(label_np_test,label_np_test,False) > 0.5,0,1)).float()

# 3. Training

## 3.1 Hyperparameters for training

In [73]:
contraGWs = [10,10,10,10]
contraMWs = [10,10,10,10]
atkGWs = [10,10,10,10]
atkMWs = [10,10,10,K * 3 + 1]

## 3.2 Initialize neural networks

In [74]:
%run NNs.ipynb
%run helpers_preproc.ipynb
loadWeightDir = 'weights/2-21-1/'

In [75]:
load_model = True

In [76]:
contraG = GCN(3,contraGWs)
contraM = MLP(252* contraGWs[len(contraGWs)-1],contraMWs)
atkG = GCN(3,atkGWs)
atkM = MLP_atk(252* atkGWs[len(contraGWs)-1], 0.2, K, atkMWs)

if load_model:
    contraG.weights = torch.load(loadWeightDir + 'contraGw.pt')
    contraM.weights = torch.load(loadWeightDir + 'contraMw.pt')
    atkG.weights = torch.load(loadWeightDir + 'atkGw.pt')
    atkM.weights = torch.load(loadWeightDir + 'atkMw.pt')

## 3.4 Extract node-level features from clean data

In [77]:
feas_clean = extract_node_feature(vers_train,adjs_train)

In [78]:
%run L.ipynb

## 3.5 Train atkNN😈 with contraNN🤠

In [79]:
lr = 0.0001
torch.set_printoptions(precision=10)

Loss Defined to be: 

$\text{-diffTypeMean + sameTypeMean} + 0.02 \times \sqrt{\text{sameTypeStd}}$

In [80]:
#get clean features
feas_clean = extract_node_feature(vers_train,adjs_train) #grad on nothing
opt = optim.Adam(atkG.weights + atkM.weights + contraG.weights + contraM.weights, lr = lr)

for contraI in range(1):
    trainLog = pt.PrettyTable()
    trainLog.field_names = [" ","Loss", "Same Mean", "Diff Mean", "Same STD"]
    
    
    opt.zero_grad()

    #get poisoned vertex locations
    poisonsByGroups = atkM.forward(atkG.forward(nadjs_train,feas_clean))
    transformation = translate_by_group(gMats_train,
                                        poisonsByGroups)
    poisonedVers = vers_train + transformation
        
        
    #extract poisoned features
    feas_poisoned = extract_node_feature(poisonedVers,adjs_train) #grad on 😈

    
    #get contraNN's performance on poisoned feature with grads on 😈    
    contraM.eval()
    contraG.eval()
    aPerfPoisoned, APP = loss(contraM.forward(contraG.forward(nadjs_train,feas_poisoned)),
                         label_mat_train) #grad on 😈
    
    #get contraNN's performance on clean feature and poisoned feature with grads on 🤠
    contraM.train()
    contraG.train()
    feas_poisoned_nograd = feas_poisoned.detach() #grad on nothing
    
    objFeasFromClean = contraM.forward(contraG.forward(nadjs_train,feas_clean))
    objFeasFromPoisoned = contraM.forward(contraG.forward(nadjs_train,feas_poisoned_nograd))
    
    diffInFeas = torch.sum((objFeasFromClean - objFeasFromPoisoned)**2)
    
    cPerfClean, CPC = loss(objFeasFromClean,
                      label_mat_train) #grad on 🤠
    
    cPerfPoisoned, CPP = loss(objFeasFromPoisoned,
                         label_mat_train) #grad on 🤠
    
    
    trainLog.add_row(['On Clean Data']+CPC)
    trainLog.add_row(['On Poisoned Data']+CPP)

    
    overallLoss = -aPerfPoisoned + cPerfClean + 1 * cPerfPoisoned#  + diffInFeas
    #overallLoss = -aPerfPoisoned
    overallLoss.backward()
    opt.step()
    
    print(trainLog)

+------------------+----------+-----------+-----------+----------+
|                  |   Loss   | Same Mean | Diff Mean | Same STD |
+------------------+----------+-----------+-----------+----------+
|  On Clean Data   | -0.2624  |  0.03267  |  0.85403  | 0.22835  |
| On Poisoned Data | -0.20927 |  0.08198  |  0.84703  | 0.47377  |
+------------------+----------+-----------+-----------+----------+


## 3.7 Seeing how it does on Test Data

In [81]:
feas_clean_test = extract_node_feature(vers_test,adjs_test)

atkTest = atkM.forward(atkG.forward(nadjs_test,feas_clean_test))

poisonedVersTest = vers_test + translate_by_group(gMats_test,atkTest)
        
feas_poisoned_test = extract_node_feature(poisonedVersTest,adjs_test)
        
contraOutPoisonedTest = contraM.forward(contraG.forward(nadjs_test,feas_poisoned_test))
contraLossPoisonedTest,_ = loss(contraOutPoisonedTest,label_mat_test)

contraOutCleanTest = contraM.forward(contraG.forward(nadjs_test,feas_clean_test))
contraLossCleanTest,_ = loss(contraOutCleanTest,label_mat_test)

print(contraLossPoisonedTest)
print(contraOutCleanTest.shape)

tensor(-0.1831341088, grad_fn=<AddBackward0>)
torch.Size([200, 10])


In [82]:
feas_clean_test = extract_node_feature(vers_test,adjs_test)

atkTest = atkM.forward(atkG.forward(nadjs_test,feas_clean_test))

poisonedVersTest = vers_test + translate_by_group(gMats_test,atkTest)

feas_poisoned_test = extract_node_feature(poisonedVersTest,adjs_test)

contraOutPoisonedTest = contraM.forward(contraG.forward(nadjs_test,feas_poisoned_test))

contraOutCleanTest = contraM.forward(contraG.forward(nadjs_test,feas_clean_test))

# #display histograms
# display_hists(contraOutPoisonedTest, label_mat_test, test_size)
# display_hists(contraOutCleanTest, label_mat_test, test_size)

In [88]:
# obj1 = torch.where(label_test == obj1)
# obj2 = torch.where(label_test == obj2)

# obj = torch.cat((obj1[0], obj2[0]), dim = 0)

ld = {}

data = torch.empty((len(label_test),10))

int(label_test[0].item())

11

In [103]:
for i in range(len(label_test)):
    for j in range(10):
        data[i][j] = contraOutCleanTest[int(label_test[i].item())][j].item()
    
    if ld[int(label_test[i].item())] is not None:
        temp = []
        temp.append(ld[int(label_test[i])])
        temp.append(ld[])
        
    else:    
        ld[int(label_test[i].item())] = data[i]
    
len(ld)

2030

In [85]:

centroid, label = scipy.cluster.vq.kmeans2(data, 30, minit='points')

for i in range(30):
    print("group num:" + str(i))
    print(data[label == i])
    for stuff in data[label == i]:
        
        
        
label

group num:0
tensor([[-0.0098834392, -0.8993330002, -0.9964920878, -0.6840175390,
         -0.8713663816,  0.4307914078,  0.6872954965, -0.9013398886,
         -0.9969114661, -0.4576650858],
        [-0.0385961644, -0.9002763033, -0.9966883659, -0.6810397506,
         -0.8726342320,  0.4255200028,  0.6840685010, -0.9020990133,
         -0.9970453382, -0.4533342421],
        [-0.0098834392, -0.8993330002, -0.9964920878, -0.6840175390,
         -0.8713663816,  0.4307914078,  0.6872954965, -0.9013398886,
         -0.9969114661, -0.4576650858],
        [-0.0098834392, -0.8993330002, -0.9964920878, -0.6840175390,
         -0.8713663816,  0.4307914078,  0.6872954965, -0.9013398886,
         -0.9969114661, -0.4576650858],
        [-0.0098834392, -0.8993330002, -0.9964920878, -0.6840175390,
         -0.8713663816,  0.4307914078,  0.6872954965, -0.9013398886,
         -0.9969114661, -0.4576650858],
        [-0.0385961644, -0.9002763033, -0.9966883659, -0.6810397506,
         -0.8726342320,  0.42

array([ 3, 24, 16, 12, 17, 18,  8,  6, 10, 16, 12,  6,  3, 10, 14,  8,  2,
       12,  7,  2, 14, 15, 28,  2,  2,  2,  8, 12, 11,  6,  3, 10,  4,  8,
       10,  3, 10,  5, 16,  3,  6, 17, 10,  5,  1,  8,  7,  3, 16,  2,  0,
        2, 17, 13,  2, 16,  0, 10,  0, 10, 10, 17,  6,  8, 12, 17,  1, 10,
       10, 19, 19,  4, 13,  0, 12, 10, 10, 22, 10, 17, 12, 14,  6,  8,  0,
       10, 22, 15,  4, 12, 10,  2,  4, 24, 11, 14,  4, 22,  6,  0, 24, 13,
       19, 10,  1,  0,  6,  0, 10, 17, 16,  2,  7, 19, 22, 10, 10,  7,  6,
       12, 17,  8, 19, 12, 19, 16, 11,  4, 18,  1, 19,  0, 19, 14, 28,  2,
        4,  5,  8, 22, 28,  8,  6,  8,  4, 14,  4, 28, 22, 15, 18, 24,  4,
       19,  1, 14,  0, 22, 15,  5, 16,  0,  7, 18,  1, 28,  7, 24, 10, 12,
       19,  8, 12, 11,  7, 16,  3, 24, 14,  3,  6,  8,  8, 24,  1,  7, 28,
       10,  4, 28,  1, 12, 18,  2, 11,  0, 16, 28,  2,  6], dtype=int32)

NameError: name 'group1' is not defined

In [None]:
# get accuracy

c1 = 0
c2 = 0
for i in range(len(label)):
    if i < len(obj)/2 and label[i] == 0:
        c1 += 1
    
    elif i >= len(obj)/2 and label[i] == 1:
        c1 += 1
    
    
acc1 = c1/len(label)

for i in range(len(label)):
    if i < len(obj)/2 and label[i] == 1:
        c2 += 1
    
    elif i >= len(obj)/2 and label[i] == 0:
        c2 += 1
acc2 = c2/len(label)

max(acc1, acc2)

In [332]:
# ax.axes.xaxis.set_ticklabels([])
# ax.axes.yaxis.set_ticklabels([])
# ax.axes.zaxis.set_ticklabels([])
# ax.set_xlabel('X')
# ax.set_ylabel('Y')
# ax.set_zlabel('Z')