In [97]:
from options.test_options import TestOptions
from data import DataLoader
from models import create_model
from util.writer import Writer
import numpy as np
import random
import torch.nn as nn
import torch

In [98]:
def test_attacked_model(model, dataset, writer):
    writer.reset_counter()
    attacked_model_outputs = []
    
    for i, data in enumerate(dataset):
        model.set_input(data)
        output, ncorrect, nexamples = model.test() 
        attacked_model_outputs.append(output)
        writer.update_counter(ncorrect, nexamples)
        
    writer.print_acc(-1, writer.acc)
    return attacked_model_outputs 
    
def find_new_vertex_index(vertices_edges, edge_index, old_vertex_index):
    for new_vertex_index, new_vertex_edges in enumerate(vertices_edges):
            for new_edge_index in new_vertex_edges:
                if(new_edge_index == edge_index and new_vertex_index != old_vertex_index):
                    return new_vertex_index
                
def get_random_walk(mesh, random_walk_size):
    walk_steps = 0
    random_walk_vertices = []
    random_walk_indices = []
    vertex_index = random.randint(0, len(mesh.vs)-1)
    
    while walk_steps < random_walk_size: 
        random_walk_vertices.append(mesh.vs[vertex_index])
        random_walk_indices.append(vertex_index)        
        walk_steps += 1  
        
        vertex_edges = mesh.ve[vertex_index]
        random_edge_index = random.randint(0, len(vertex_edges)-1)  
        new_edge_index = vertex_edges[random_edge_index]
        vertex_index = find_new_vertex_index(mesh.ve, new_edge_index, random_walk_indices[-1])
        
        #Prevents random walk from crossing over itself
        count_of_vertex_edge_attempts = 0
        walk_steps_backwards = 0
        while(vertex_index in random_walk_indices):
            if(count_of_vertex_edge_attempts >= len(vertex_edges)-1):
                walk_steps_backwards += 1
                go_back_to_index = walk_steps-walk_steps_backwards
                
                if(go_back_to_index == 0):
                    # Trying again, mesh seems to be broken?
                    return get_random_walk(mesh, random_walk_size)
                
                vertex_edges = mesh.ve[random_walk_indices[go_back_to_index]]
                count_of_vertex_edge_attempts = 0
                
            random_edge_index = (random_edge_index + 1) % len(vertex_edges)
            new_edge_index = vertex_edges[random_edge_index]
            vertex_index = find_new_vertex_index(mesh.ve, new_edge_index, random_walk_indices[-1])
            count_of_vertex_edge_attempts += 1        
    
    return torch.FloatTensor(random_walk_vertices), random_walk_indices

In [99]:
def train_imitating_network(imitating_nn, criterion, lr, random_walk_vertices, attacked_nn_output):
    
    hidden = imitating_nn.initHidden()
    imitating_nn.zero_grad()

    for vertice in random_walk_vertices:
        output, hidden = imitating_nn(torch.reshape(vertice, (1, 3)), hidden)
    
    loss = criterion(output, attacked_nn_output)
    loss.backward()

    # Add parameters' gradients to their values, multiplied by learning rate
    for p in imitating_nn.parameters():
        p.data.add_(p.grad.data, alpha=-lr)

    return output, loss.item()

In [100]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        
        combined = torch.cat((input, hidden), -1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

# Orchestration

In [101]:
#TODO: test accuracy impact from random movement of vertices

testing_opt = TestOptions().parse()

testing_opt.serial_batches = True
dataloader = DataLoader(testing_opt)
mesh_cnn = create_model(testing_opt)
opt_writer = Writer(testing_opt)

lr = 0.005
random_walk_size = 20
num_categories = 30
num_vertice_coordinates = 3
imitating_nn = RNN(num_vertice_coordinates, num_vertice_coordinates, num_categories)
criterion = nn.KLDivLoss(reduction="batchmean")

loaded mean / std from cache
loading the model from ./checkpoints/shrec16/latest_net.pth


In [102]:
# Record testing accuracy before attack
attacked_model_outputs = test_attacked_model(mesh_cnn, dataloader, opt_writer)

epoch: -1, TEST ACC: [99.167 %]



In [103]:
# Stage attack
for i, data in enumerate(dataloader):
    print("i: " + str(i))
    print("Mesh labels: " + str(data["label"]))
    
    mesh = data["mesh"][0]
    random_walk_vertices, random_walk_indices = get_random_walk(mesh, random_walk_size)
    
    imitating_nn_output, loss = train_imitating_network(imitating_nn, criterion, lr, random_walk_vertices, attacked_model_outputs[i])
    
    print(attacked_model_outputs[i])
    print(imitating_nn_output)
        

i: 0
Mesh labels: [0]
tensor([[  5.4885, -19.8533, -21.4682,  -9.6839,  -8.4958,  -3.6247,  -4.4009,
          -6.7783, -23.9704, -19.1216, -15.8959,  -6.0566,  -4.8584, -11.2263,
         -16.3660, -13.7719,  -8.6655, -10.7375,  -5.2502, -18.5240,  -0.8722,
         -19.7822, -22.9073,  -3.3643, -25.6530,  -7.8275, -25.9369, -18.4495,
          -3.2837, -13.6820]])
tensor([[-3.4542, -3.0707, -3.2943, -3.5866, -3.5335, -3.2485, -3.5848, -3.4698,
         -3.5041, -3.1847, -3.7291, -3.1679, -3.3466, -3.8176, -3.3832, -3.6203,
         -3.3114, -3.1424, -3.0506, -3.7467, -3.0802, -3.0646, -3.4375, -3.6509,
         -3.3954, -3.8525, -3.6231, -3.5640, -3.4031, -3.4906]],
       grad_fn=<LogSoftmaxBackward0>)
i: 1
Mesh labels: [0]
tensor([[  6.3456, -20.1829, -23.7263, -11.1970, -10.4479,  -3.5765,  -4.1582,
          -5.3284, -26.1639, -20.5226, -15.8104,  -4.7861,  -4.6728, -11.7994,
         -17.7693, -14.5876,  -7.8231, -13.3061,  -7.3532, -18.4909,  -1.7529,
         -21.3954, -25.230

i: 16
Mesh labels: [4]
tensor([[-14.4666, -25.2881, -23.2693,  -4.2469,   4.1617,  -7.2387,  -6.6712,
         -23.6178, -14.9484,  -9.2918, -26.8329, -19.5727,  -7.9810,  -5.1441,
         -13.2365, -17.1941, -26.8171,  -7.3547,  -8.5970, -17.0876,  -2.1206,
         -11.8050,  -8.0532, -19.1102, -22.9791,  -4.3415, -12.6852, -25.8596,
         -14.1197, -16.0344]])
tensor([[-3.1199, -2.6806, -2.6678, -3.4307, -3.2096, -3.4178, -3.5871, -3.2733,
         -3.6096, -3.4306, -3.8453, -3.2101, -3.5710, -3.8926, -3.1465, -4.0628,
         -3.7846, -3.4892, -3.2241, -3.7052, -3.2363, -3.1513, -3.8802, -3.5096,
         -3.3454, -3.2630, -3.7501, -3.9730, -3.7384, -3.6119]],
       grad_fn=<LogSoftmaxBackward0>)
i: 17
Mesh labels: [4]
tensor([[-12.8144, -24.9859, -23.9821,  -5.1918,   4.0042,  -5.5493,  -6.6703,
         -22.1579, -16.3636, -10.5309, -26.6521, -18.2376,  -5.8620,  -4.8374,
         -14.1308, -17.7268, -24.5397,  -5.5886,  -8.3652, -17.1923,  -1.0984,
         -11.4467,  -8.6

tensor([[ -4.8895,  -2.3995, -24.5613, -17.8141, -22.8905,  -1.8075,  -2.1679,
           7.4466, -26.3032, -12.3963,  -8.0546,  -1.4604, -11.1244,  -9.0692,
         -17.2034,  -9.4489,  -1.9529, -25.8494, -23.8965,  -4.3958, -21.3874,
         -15.2626, -26.1622,  -8.3694, -19.8813, -17.8226, -23.9910,  -2.0411,
         -12.3425, -10.9291]])
tensor([[-3.0040, -2.6296, -2.4970, -3.3712, -2.6430, -3.3922, -3.7482, -2.9433,
         -3.6359, -3.8034, -3.9551, -3.1671, -3.8878, -3.9969, -3.1058, -4.2960,
         -4.2628, -3.7646, -3.4171, -3.7485, -3.4780, -3.3524, -4.3352, -3.3982,
         -3.2497, -2.9879, -3.7176, -4.2650, -4.0417, -3.8046]],
       grad_fn=<LogSoftmaxBackward0>)
i: 30
Mesh labels: [7]
tensor([[ -5.9805,  -3.0573, -24.6134, -14.8289, -19.8778,  -0.1675,  -3.1060,
           5.3697, -27.4725, -11.5344, -13.2346,  -6.3060, -11.9175,  -6.4258,
         -19.0744,  -8.6457,  -6.2993, -21.3891, -21.0928,  -7.9301, -20.3652,
         -10.0025, -26.6814,  -8.8228, -23.4032

i: 47
Mesh labels: [11]
tensor([[ -4.4743, -16.6764, -28.8617, -20.4170, -17.7206,  -8.2828,  -1.4638,
          -3.3905, -24.2498, -17.8774,  -6.6131,   5.0031,  -3.0654, -15.8612,
         -11.2393, -16.5554,  -4.1831, -28.1277, -22.4059,  -4.0396, -14.3031,
         -30.4099, -19.6010, -11.8734, -16.3253,  -8.2380, -24.1908, -24.4493,
          -5.0211,  -0.9363]])
tensor([[-3.3323, -2.7978, -3.2470, -3.4802, -3.2983, -3.2669, -3.8805, -3.0984,
         -2.7975, -3.4326, -3.5198, -3.2057, -3.4742, -4.0416, -3.2377, -3.5869,
         -3.6291, -3.4862, -3.0489, -3.8859, -3.4553, -3.2736, -3.7567, -3.6111,
         -3.0764, -3.8218, -3.4689, -3.7241, -3.5653, -4.0504]],
       grad_fn=<LogSoftmaxBackward0>)
i: 48
Mesh labels: [12]
tensor([[ -2.5661, -21.2394, -26.3129, -14.7674,  -3.7861,  -3.7560,  -5.4643,
          -9.4294, -18.3040, -16.3390, -16.4038,  -1.9755,   8.2599,  -6.0303,
         -12.2086, -19.6346,  -8.9211,  -9.9889, -12.6845,  -8.9282,   0.3199,
         -19.3350,  -8

i: 59
Mesh labels: [14]
tensor([[-20.5121, -10.8599,  -7.5975, -11.9631, -13.1542, -18.6488,  -6.4280,
         -20.0937,  -4.2401,  -6.0960,  -2.7362, -11.5782, -19.2985, -22.1444,
           4.0644,  -8.5918, -17.2663, -21.8911, -12.9144,  -3.9298, -19.6721,
         -27.1996,  -7.6258, -14.0386,   0.8437,  -8.9811, -12.5892, -27.7664,
          -3.5568,  -2.5695]])
tensor([[-3.2471, -2.9717, -2.9834, -3.3302, -3.1088, -3.3783, -3.7743, -3.1433,
         -3.1786, -3.4968, -3.5607, -2.9044, -3.4623, -3.7337, -3.3299, -3.7239,
         -3.7101, -3.3119, -3.2720, -3.7460, -3.3775, -3.2841, -3.8481, -3.5629,
         -3.2406, -3.6394, -3.6291, -3.8240, -3.7149, -3.6340]],
       grad_fn=<LogSoftmaxBackward0>)
i: 60
Mesh labels: [15]
tensor([[-13.5967,  -2.3520,  -6.0668,  -3.6582, -14.0064,  -8.8274,  -3.0108,
          -9.4713, -15.9220,  -2.6821,  -7.0240, -15.6165, -25.0987, -13.4810,
          -5.9365,   4.0143, -15.2033, -18.3239,  -9.3359, -11.3389, -23.5738,
          -8.4257, -23

tensor([[-3.1615, -3.3601, -2.8217, -3.2577, -3.0940, -3.2593, -3.5307, -3.2150,
         -3.6027, -3.4434, -3.6474, -2.7488, -3.4002, -3.4899, -3.5977, -3.9856,
         -3.4001, -3.0973, -3.4851, -3.8008, -3.2939, -3.4197, -3.8396, -3.5370,
         -3.7327, -3.5305, -3.8574, -4.0536, -3.6075, -3.1738]],
       grad_fn=<LogSoftmaxBackward0>)
i: 71
Mesh labels: [17]
tensor([[-11.8748, -12.4198, -10.1863,  -7.0423,  -2.3060,   0.9893,  -9.2045,
         -18.1749, -14.9915,  -8.3809, -21.1256, -19.9540, -11.8130,  -9.8971,
         -14.5896, -10.9697, -13.8232,   9.1062,  -0.7848, -19.7840,  -9.8210,
          -2.3611, -12.4612,  -7.9398, -18.0150, -17.6356, -17.1692, -11.3991,
          -9.7322, -20.6211]])
tensor([[-3.2316, -3.2438, -3.0196, -3.3510, -3.2313, -3.1925, -3.8154, -3.4423,
         -3.3812, -3.1925, -3.5597, -2.7936, -3.4423, -3.6413, -3.4072, -3.7166,
         -3.0654, -2.9146, -3.4329, -3.7068, -3.4965, -3.3949, -3.6837, -3.7857,
         -3.5352, -3.6000, -3.9006, -3.9

tensor([[ -2.9784,  -4.4604,  -1.8961,  -9.3151, -17.1384,  -6.3612,  -4.1701,
          -7.2187, -20.3217, -11.6831,  -3.7824, -12.5407, -23.2622, -21.8246,
         -10.2428,  -0.1535,  -3.3852, -14.7988,  -3.6948, -18.7315, -21.2843,
         -15.8201, -33.0464,   7.7941, -18.8109, -19.5735, -29.5446, -10.3507,
          -2.9813, -21.2939]])
tensor([[-3.4818, -3.1847, -3.3001, -3.5524, -3.6373, -3.2478, -3.6074, -3.5228,
         -3.5663, -3.2667, -3.6730, -3.1814, -3.2184, -3.7131, -3.5249, -3.6071,
         -3.1891, -3.0384, -2.9797, -3.8459, -2.9849, -3.0175, -3.2689, -3.6275,
         -3.5478, -4.0234, -3.7146, -3.6347, -3.3964, -3.5118]],
       grad_fn=<LogSoftmaxBackward0>)
i: 95
Mesh labels: [23]
tensor([[ -2.1636,  -5.2587,  -0.9725,  -8.3864, -15.1310,  -5.6785,  -4.3806,
          -7.5165, -18.6564, -11.3702,  -4.0861, -11.9357, -21.2485, -20.6205,
          -9.2940,  -0.4356,  -3.2191, -11.8320,  -1.3545, -18.3466, -18.7689,
         -14.9522, -30.3434,   7.8995, -17.540

i: 107
Mesh labels: [26]
tensor([[-20.4858, -14.5195, -17.9894,  -9.6578,  -5.7690, -12.8269,  -5.4880,
         -15.2435,  -5.2535,  -0.3668, -18.8117, -14.1828, -12.4093,   0.9912,
          -3.8652,  -9.0219, -24.8032, -17.0661, -15.1116,  -0.9258, -12.9558,
         -10.5589,  -1.9612, -19.6294, -11.6979,  -1.1449,   8.4558, -15.2649,
         -14.4261,  -5.6530]])
tensor([[-3.4432, -3.1094, -3.1487, -3.4541, -3.4204, -3.4276, -3.6030, -3.3298,
         -3.4899, -3.5449, -3.6234, -3.1287, -3.4076, -3.7529, -3.4608, -3.6442,
         -3.5595, -3.1388, -3.0603, -3.8013, -2.9983, -3.0165, -3.5205, -3.4059,
         -3.2240, -3.8188, -3.5352, -3.7109, -3.5715, -3.4857]],
       grad_fn=<LogSoftmaxBackward0>)
i: 108
Mesh labels: [27]
tensor([[-1.3613e+01,  4.2414e+00, -1.5571e+01, -8.6816e+00, -1.8457e+01,
         -2.1858e+00, -4.5799e+00, -1.5762e-02, -2.0048e+01, -3.5250e+00,
         -1.3950e+01, -1.5606e+01, -2.1535e+01, -5.3531e+00, -1.5352e+01,
         -1.9129e+00, -1.2371e+01, 

In [104]:
# Record testing accuracy after attack
attacked_model_outputs = test_attacked_model(mesh_cnn, dataloader, opt_writer)

epoch: -1, TEST ACC: [99.167 %]

