In [376]:
import torch.optim as optim
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import prettytable as pt
%run helpers_preproc.ipynb
%matplotlib widget
%run getExamples.ipynb

# Hyper Parameters

In [397]:
origDataDir = 'SHREC11/'
newDataDir = 'SHREC11_mini/'
K = 5
# radiuss = [0.0002,0.002,0.01]
train_size = 10 # * len(radiuss)
test_size = 30 # * len(radiuss)

obj1 = 15
obj2 = 17

# 1. expand dataset and scramble

In [398]:
# expand(origDataDir, newDataDir, radiuss)
pick_copy(origDataDir, newDataDir, obj1, obj2)
scramble(newDataDir, 20, 40)

# 2. read in the dataset

## 2.1 read in v (252 x 3),f (...x...) and cluster

In [399]:
ver_list = []
adj_list = []
Nadj_list = []
gMat_list = []
gLbl_list = []

for i in range(40):
    v, f = get_nodes(newDataDir + 'T' + str(i) + '.obj')
    group_mat, group_labels = K_mean_cluster(v, K)
    
    adj = get_adj_from_f(f)            
    normed_adj = adj / np.reshape(np.sum(adj,axis = 0),[252,1])
    ver_list.append(v)
    adj_list.append(adj)
    Nadj_list.append(normed_adj)
    gMat_list.append(group_mat)
    gLbl_list.append(group_labels)
    
    
vers = np.array(ver_list)
adjs = np.array(adj_list)
Nadjs = np.array(Nadj_list)
gMats = np.array(gMat_list)
gLbls = np.array(gLbl_list)

## 2.2 Read in the labels

In [400]:
label_np = np.array(readLbl(40 ,newDataDir+'labels.txt'))
label_np_train = label_np[:train_size]
label_np_test = label_np[train_size:]

## 2.3 Split into training set and testing set

In [401]:
vers_train = torch.tensor(vers[:train_size]).float()
adjs_train = torch.tensor(adjs[:train_size]).float()
nadjs_train = torch.nan_to_num(torch.tensor(Nadjs[:train_size]).float(),0,0,0)
gMats_train = torch.tensor(gMats[:train_size]).float()
label_train = torch.tensor(label_np_train).float()

vers_test = torch.tensor(vers[train_size:]).float()
adjs_test = torch.tensor(adjs[train_size:]).float()
nadjs_test = torch.nan_to_num(torch.tensor(Nadjs[train_size:]).float(),0,0,0)
gMats_test = torch.tensor(gMats[train_size:]).float()
label_test = torch.tensor(label_np_test).float()

label_mat_train = torch.tensor(np.where(igl.all_pairs_distances(label_np_train,label_np_train,False) > 0.5,0,1)).float()
label_mat_test = torch.tensor(np.where(igl.all_pairs_distances(label_np_test,label_np_test,False) > 0.5,0,1)).float()

# 3. Training

## 3.1 Hyperparameters for training

In [402]:
contraGWs = [10,10,10,10]
contraMWs = [10,10,10,10]
atkGWs = [10,10,10,10]
atkMWs = [10,10,10,K * 3 + 1]

## 3.2 Initialize neural networks

In [403]:
%run NNs.ipynb
%run helpers_preproc.ipynb
loadWeightDir = 'weights/2-21-1/'

In [404]:
load_model = True

In [405]:
contraG = GCN(3,contraGWs)
contraM = MLP(252* contraGWs[len(contraGWs)-1],contraMWs)
atkG = GCN(3,atkGWs)
atkM = MLP_atk(252* atkGWs[len(contraGWs)-1], 0.2, K, atkMWs)

if load_model:
    contraG.weights = torch.load(loadWeightDir + 'contraGw.pt')
    contraM.weights = torch.load(loadWeightDir + 'contraMw.pt')
    atkG.weights = torch.load(loadWeightDir + 'atkGw.pt')
    atkM.weights = torch.load(loadWeightDir + 'atkMw.pt')

## 3.4 Extract node-level features from clean data

In [406]:
feas_clean = extract_node_feature(vers_train,adjs_train)

In [407]:
%run L.ipynb

## 3.5 Train atkNNðŸ˜ˆ with contraNNðŸ¤ 

In [408]:
lr = 0.0001
torch.set_printoptions(precision=10)

Loss Defined to be: 

$\text{-diffTypeMean + sameTypeMean} + 0.02 \times \sqrt{\text{sameTypeStd}}$

In [409]:
#get clean features
feas_clean = extract_node_feature(vers_train,adjs_train) #grad on nothing
opt = optim.Adam(atkG.weights + atkM.weights + contraG.weights + contraM.weights, lr = lr)

for contraI in range(1):
    trainLog = pt.PrettyTable()
    trainLog.field_names = [" ","Loss", "Same Mean", "Diff Mean", "Same STD"]
    
    
    opt.zero_grad()

    #get poisoned vertex locations
    poisonsByGroups = atkM.forward(atkG.forward(nadjs_train,feas_clean))
    transformation = translate_by_group(gMats_train,
                                        poisonsByGroups)
    poisonedVers = vers_train + transformation
        
        
    #extract poisoned features
    feas_poisoned = extract_node_feature(poisonedVers,adjs_train) #grad on ðŸ˜ˆ

    
    #get contraNN's performance on poisoned feature with grads on ðŸ˜ˆ    
    contraM.eval()
    contraG.eval()
    aPerfPoisoned, APP = loss(contraM.forward(contraG.forward(nadjs_train,feas_poisoned)),
                         label_mat_train) #grad on ðŸ˜ˆ
    
    #get contraNN's performance on clean feature and poisoned feature with grads on ðŸ¤ 
    contraM.train()
    contraG.train()
    feas_poisoned_nograd = feas_poisoned.detach() #grad on nothing
    
    objFeasFromClean = contraM.forward(contraG.forward(nadjs_train,feas_clean))
    objFeasFromPoisoned = contraM.forward(contraG.forward(nadjs_train,feas_poisoned_nograd))
    
    diffInFeas = torch.sum((objFeasFromClean - objFeasFromPoisoned)**2)
    
    cPerfClean, CPC = loss(objFeasFromClean,
                      label_mat_train) #grad on ðŸ¤ 
    
    cPerfPoisoned, CPP = loss(objFeasFromPoisoned,
                         label_mat_train) #grad on ðŸ¤ 
    
    
    trainLog.add_row(['On Clean Data']+CPC)
    trainLog.add_row(['On Poisoned Data']+CPP)

    
    overallLoss = -aPerfPoisoned + cPerfClean + 1 * cPerfPoisoned#  + diffInFeas
    #overallLoss = -aPerfPoisoned
    overallLoss.backward()
    opt.step()
    
    print(trainLog)

+------------------+----------+-----------+-----------+----------+
|                  |   Loss   | Same Mean | Diff Mean | Same STD |
+------------------+----------+-----------+-----------+----------+
|  On Clean Data   | -0.27957 |  0.00955  |  0.45692  | 0.01354  |
| On Poisoned Data | -0.24454 |  0.06707  |  0.54889  | 0.10432  |
+------------------+----------+-----------+-----------+----------+


## 3.7 Seeing how it does on Test Data

In [410]:
feas_clean_test = extract_node_feature(vers_test,adjs_test)

atkTest = atkM.forward(atkG.forward(nadjs_test,feas_clean_test))

poisonedVersTest = vers_test + translate_by_group(gMats_test,atkTest)
        
feas_poisoned_test = extract_node_feature(poisonedVersTest,adjs_test)
        
contraOutPoisonedTest = contraM.forward(contraG.forward(nadjs_test,feas_poisoned_test))
contraLossPoisonedTest,_ = loss(contraOutPoisonedTest,label_mat_test)

contraOutCleanTest = contraM.forward(contraG.forward(nadjs_test,feas_clean_test))
contraLossCleanTest,_ = loss(contraOutCleanTest,label_mat_test)

print(contraLossPoisonedTest)
print(contraOutCleanTest.shape)

tensor(-0.0362987444, grad_fn=<AddBackward0>)
torch.Size([30, 10])


In [411]:
feas_clean_test = extract_node_feature(vers_test,adjs_test)

atkTest = atkM.forward(atkG.forward(nadjs_test,feas_clean_test))

poisonedVersTest = vers_test + translate_by_group(gMats_test,atkTest)

feas_poisoned_test = extract_node_feature(poisonedVersTest,adjs_test)

contraOutPoisonedTest = contraM.forward(contraG.forward(nadjs_test,feas_poisoned_test))

contraOutCleanTest = contraM.forward(contraG.forward(nadjs_test,feas_clean_test))

# #display histograms
# display_hists(contraOutPoisonedTest, label_mat_test, test_size)
# display_hists(contraOutCleanTest, label_mat_test, test_size)

In [412]:
obj1 = torch.where(label_test == obj1)
obj2 = torch.where(label_test == obj2)

obj = torch.cat((obj1[0], obj2[0]), dim = 0)

data = torch.empty((len(obj),10))

len(obj)

30

In [413]:
for i in range(len(obj)):
    for j in range(10):
        data[i][j] = contraOutCleanTest[obj[i]][j].item()
        
data.shape

torch.Size([30, 10])

In [414]:
# output would be in size(40,5)
centroid, label = scipy.cluster.vq.kmeans2(data, 2, minit='points')
group1 = data[label == 0]
group2 = data[label == 1]

label

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0], dtype=int32)

In [415]:
group1.type

<function Tensor.type>

In [416]:
# get accuracy

c1 = 0
c2 = 0
for i in range(len(label)):
    if i < len(obj)/2 and label[i] == 0:
        c1 += 1
    
    elif i >= len(obj)/2 and label[i] == 1:
        c1 += 1
    
    
acc1 = c1/len(label)

for i in range(len(label)):
    if i < len(obj)/2 and label[i] == 1:
        c2 += 1
    
    elif i >= len(obj)/2 and label[i] == 0:
        c2 += 1
acc2 = c2/len(label)

max(acc1, acc2)

0.9

In [332]:
# ax.axes.xaxis.set_ticklabels([])
# ax.axes.yaxis.set_ticklabels([])
# ax.axes.zaxis.set_ticklabels([])
# ax.set_xlabel('X')
# ax.set_ylabel('Y')
# ax.set_zlabel('Z')