In [12]:
from options.test_options import TestOptions
from data import DataLoader
from models import create_model
from models.layers.mesh_prepare import extract_features
from util.writer import Writer
from data.base_dataset import collate_fn
import numpy as np
import random
import torch.nn as nn
import torch
import torch.utils.data
import torch.nn.functional as F

In [13]:
def test_attacked_model(model, dataset, writer, print_results):
    writer.reset_counter()
    attacked_model_outputs = []
    
    for i, data in enumerate(dataset):
        model.set_input(data)
        output, ncorrect, nexamples = model.test() 
        attacked_model_outputs.append(output)
        writer.update_counter(ncorrect, nexamples)
        
        if(print_results and ncorrect == 0):
            print("Label:  " + str(data['label']))
            print("pred: " + str(torch.max(output, 1)[1]))
            print("files_name: " + str(data['mesh'][0].filename))
        
    writer.print_acc(-1, writer.acc)
    return attacked_model_outputs 
    
def find_new_vertex_index(vertices_edges, edge_index, old_vertex_index):
    for new_vertex_index, new_vertex_edges in enumerate(vertices_edges):
            for new_edge_index in new_vertex_edges:
                if(new_edge_index == edge_index and new_vertex_index != old_vertex_index):
                    return new_vertex_index
                
def get_random_walk(mesh, random_walk_size):
    walk_steps = 0
    random_walk_vertices = []
    random_walk_indices = []
    random.seed(walk_steps)
    vertex_index = random.randint(0, len(mesh.vs)-1)
    
    while walk_steps < random_walk_size: 
        random_walk_vertices.append(mesh.vs[vertex_index])
        random_walk_indices.append(vertex_index)        
        walk_steps += 1  
        
        vertex_edges = mesh.ve[vertex_index]
        random.seed(walk_steps+1)
        random_edge_index = random.randint(0, len(vertex_edges)-1)  
        new_edge_index = vertex_edges[random_edge_index]
        vertex_index = find_new_vertex_index(mesh.ve, new_edge_index, random_walk_indices[-1])
        
        #Prevents random walk from crossing over itself
        count_of_vertex_edge_attempts = 0
        walk_steps_backwards = 0
        while(vertex_index in random_walk_indices):
            if(count_of_vertex_edge_attempts >= len(vertex_edges)-1):
                walk_steps_backwards += 1
                go_back_to_index = walk_steps-walk_steps_backwards
                
                if(go_back_to_index == 0):
                    # Trying again, mesh seems to be broken?
                    return get_random_walk(mesh, random_walk_size)
                
                vertex_edges = mesh.ve[random_walk_indices[go_back_to_index]]
                count_of_vertex_edge_attempts = 0
                
            random_edge_index = (random_edge_index + 1) % len(vertex_edges)
            new_edge_index = vertex_edges[random_edge_index]
            vertex_index = find_new_vertex_index(mesh.ve, new_edge_index, random_walk_indices[-1])
            count_of_vertex_edge_attempts += 1        
    
    return torch.FloatTensor(random_walk_vertices), random_walk_indices

In [14]:
def train_imitating_network(imitating_nn, criterion, optimizer, random_walk_vertices, labels, attacked_model_outputs):

    output = imitating_nn(random_walk_vertices)
    loss = criterion(output, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 
    
    _, predictions = torch.max(output.data, 1)
    num_correct = (predictions == labels).sum().item()
    
    return output, loss.item(), num_correct

def use_imitating_network_for_attack(imitating_nn, criterion, optimizer, random_walk_vertices, labels, attacked_model_outputs):

    output = imitating_nn(random_walk_vertices)
    loss = criterion(F.softmax(output, dim=1), F.softmax(attacked_model_outputs, dim=1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 
    
    _, predictions = torch.max(output.data, 1)
    num_correct = (predictions == labels).sum().item()
    
    return output, loss.item(), num_correct, predictions

In [15]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)

    def forward(self, input, hidden):
        
        combined = torch.cat((input, hidden), -1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

class Imitating_NN(nn.Module):
    def __init__(self, input_size, output_size):
        super(Imitating_NN, self).__init__()
        
        self.scaling_factor = 10
        self.input_size = input_size
        
        self.first_linear = nn.Linear(input_size, 2*self.scaling_factor*input_size)
        self.second_linear = nn.Linear(2*self.scaling_factor*input_size, self.scaling_factor*input_size)
        self.relu = nn.ReLU()
        self.rnn = RNN(self.scaling_factor*input_size, self.scaling_factor*input_size, self.scaling_factor*input_size)
        self.third_linear = nn.Linear(self.scaling_factor*input_size, 2*self.scaling_factor*input_size)
        self.fourth_linear = nn.Linear(2*self.scaling_factor*input_size, output_size)
        
    def forward(self, random_walk_vertices):
        
        output = self.first_linear(random_walk_vertices)
        output = self.second_linear(output)
        output = self.relu(output)
        
        hidden = self.rnn.initHidden()
        for step in output:
            output, hidden = self.rnn(torch.reshape(step, (1, self.input_size*self.scaling_factor)), hidden)
            
        output = self.third_linear(output)
        output = self.fourth_linear(output)
            
        return output

# Orchestration

In [16]:
testing_opt = TestOptions().parse()

testing_opt.serial_batches = True
dataloader = DataLoader(testing_opt)
mesh_cnn = create_model(testing_opt)
opt_writer = Writer(testing_opt)

shift_weight = 0.3
num_vertices_to_move = 5
random_walk_size = 50
num_categories = 30
num_vertice_coordinates = 3
imitating_nn = Imitating_NN(num_vertice_coordinates, num_categories)
kld_criterion = nn.KLDivLoss(reduction="batchmean")
ce_criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(imitating_nn.parameters(), lr = 0.0005)

loaded mean / std from cache
loading the model from ./checkpoints/shrec16/latest_net.pth


In [17]:
# Record testing accuracy before attack
attacked_model_outputs = test_attacked_model(mesh_cnn, dataloader, opt_writer, False)

epoch: -1, TEST ACC: [99.167 %]



In [18]:
# Beginning Random Walk Attack
dataloader = DataLoader(testing_opt)

# Training Imitation Network
for epoch in range(100):
    epoch_loss = 0
    epoch_num_correct = 0
    epoch_num_samples = 0;
    for i, data in enumerate(dataloader):
        mesh = data["mesh"][0]
        label = torch.from_numpy(data["label"])
        
        random_walk_vertices, random_walk_indices = get_random_walk(mesh, random_walk_size)
        
        imitating_nn_output, loss, num_correct = train_imitating_network(imitating_nn, ce_criterion, optimizer, random_walk_vertices, label, attacked_model_outputs[i])
        
        epoch_num_samples += label.size(dim=-1)
        epoch_loss += loss
        epoch_num_correct += num_correct
    
    accuracy = epoch_num_correct / epoch_num_samples
    print("epoch: " + str(epoch) + ", loss: " + str(epoch_loss), ", accuracy: " + str(accuracy))
    

loaded mean / std from cache


  return torch.FloatTensor(random_walk_vertices), random_walk_indices


epoch: 0, loss: 418.0003640651703 , accuracy: 0.0
epoch: 1, loss: 412.8466935157776 , accuracy: 0.0
epoch: 2, loss: 412.38605999946594 , accuracy: 0.0
epoch: 3, loss: 410.9580202102661 , accuracy: 0.0
epoch: 4, loss: 409.9135329723358 , accuracy: 0.016666666666666666
epoch: 5, loss: 406.6867673397064 , accuracy: 0.016666666666666666
epoch: 6, loss: 406.7157185077667 , accuracy: 0.0
epoch: 7, loss: 410.45795035362244 , accuracy: 0.016666666666666666
epoch: 8, loss: 410.97475814819336 , accuracy: 0.016666666666666666
epoch: 9, loss: 405.3257279396057 , accuracy: 0.05
epoch: 10, loss: 401.6131776571274 , accuracy: 0.06666666666666667
epoch: 11, loss: 398.68185555934906 , accuracy: 0.041666666666666664
epoch: 12, loss: 393.52218449115753 , accuracy: 0.05
epoch: 13, loss: 388.47274899482727 , accuracy: 0.058333333333333334
epoch: 14, loss: 384.17047357559204 , accuracy: 0.075
epoch: 15, loss: 380.5140552520752 , accuracy: 0.1
epoch: 16, loss: 377.54541516304016 , accuracy: 0.091666666666666

In [19]:
# Moving Vertices
dataloader = DataLoader(testing_opt)

def gradients_magnitude(vertices):
    return vertices[1][0] ** 2 + vertices[1][1] ** 2 + vertices[1][2] ** 2
    
overridden_meshes = []
for i, data in enumerate(dataloader):
        mesh = data["mesh"][0]
        random_walk_vertices, random_walk_indices = get_random_walk(mesh, random_walk_size)
        label = torch.from_numpy(data["label"])
    
        random_walk_vertices.requires_grad = True
        imitating_nn_output, loss, num_correct, pred = use_imitating_network_for_attack(imitating_nn, kld_criterion, optimizer, random_walk_vertices, label, attacked_model_outputs[i])
        
        gradients_dict = dict(zip(random_walk_indices, random_walk_vertices.grad))
        gradients_dict = dict(sorted(gradients_dict.items(), key = gradients_magnitude))
        
        max_grad = torch.max(random_walk_vertices.grad.flatten(), 0)[0].item()
        min_grad = torch.min(random_walk_vertices.grad.flatten(), 0)[0].item()
        
        num_vertices_changed = 0
        while(num_vertices_changed < num_vertices_to_move):
            
            gradient_entry = gradients_dict.popitem()
            index = gradient_entry[0]
            gradients = gradient_entry[1]
            
            mesh.vs[index][0] += ((2 * shift_weight * (gradients[0]-min_grad) / (max_grad-min_grad)) - shift_weight)
            mesh.vs[index][1] += ((2 * shift_weight * (gradients[1]-min_grad) / (max_grad-min_grad)) - shift_weight)
            mesh.vs[index][2] += ((2 * shift_weight * (gradients[2]-min_grad) / (max_grad-min_grad)) - shift_weight)
            
            #mesh.vs[index][0] += gradients[0]
            #mesh.vs[index][1] += gradients[1]
            #mesh.vs[index][2] += gradients[2]
    
            num_vertices_changed += 1
                
        mesh.features = extract_features(mesh)
        overridden_meshes.append(mesh)
        new_file_name = "datasets/random_walks/" + mesh.filename
        mesh.export(file=new_file_name)

dataloader.dataloader.dataset.override_meshes(overridden_meshes) 
        

loaded mean / std from cache


In [20]:
# Record testing accuracy after random walk attack
attacked_model_outputs = test_attacked_model(mesh_cnn, dataloader, opt_writer, True)

Label:  [7]
pred: tensor([5])
files_name: T156.obj
Label:  [7]
pred: tensor([5])
files_name: T576.obj
Label:  [8]
pred: tensor([2])
files_name: T434.obj
Label:  [8]
pred: tensor([2])
files_name: T598.obj
Label:  [10]
pred: tensor([23])
files_name: T476.obj
Label:  [11]
pred: tensor([5])
files_name: T504.obj
Label:  [11]
pred: tensor([6])
files_name: T582.obj
Label:  [13]
pred: tensor([17])
files_name: T102.obj
Label:  [13]
pred: tensor([17])
files_name: T530.obj
Label:  [14]
pred: tensor([9])
files_name: T105.obj
Label:  [14]
pred: tensor([9])
files_name: T4.obj
Label:  [14]
pred: tensor([9])
files_name: T471.obj
Label:  [14]
pred: tensor([15])
files_name: T478.obj
Label:  [16]
pred: tensor([5])
files_name: T343.obj
Label:  [19]
pred: tensor([9])
files_name: T21.obj
Label:  [19]
pred: tensor([6])
files_name: T404.obj
Label:  [19]
pred: tensor([15])
files_name: T519.obj
Label:  [19]
pred: tensor([9])
files_name: T542.obj
Label:  [22]
pred: tensor([9])
files_name: T505.obj
Label:  [24]
p

In [21]:
# Random Pertubation Attack
dataloader = DataLoader(testing_opt)
overridden_meshes = []
for i, data in enumerate(dataloader):
    
    mesh = data["mesh"][0]
    num_random_changes = 0
    
    while(num_random_changes < num_vertices_to_move):
        
        random_vertex_index = random.randint(0, len(mesh.vs)-1)    
        mesh.vs[random_vertex_index][0] += random.uniform(-shift_weight, shift_weight)
        mesh.vs[random_vertex_index][1] += random.uniform(-shift_weight, shift_weight)
        mesh.vs[random_vertex_index][2] += random.uniform(-shift_weight, shift_weight)
        num_random_changes += 1
    
    mesh.features = extract_features(mesh)
    overridden_meshes.append(mesh)
    new_file_name = "datasets/random_pertubations/" + mesh.filename
    mesh.export(file=new_file_name)

dataloader.dataloader.dataset.override_meshes(overridden_meshes) 

loaded mean / std from cache


In [22]:
# Record testing accuracy after random pertubation attack
attacked_model_outputs = test_attacked_model(mesh_cnn, dataloader, opt_writer, True)

Label:  [7]
pred: tensor([5])
files_name: T576.obj
Label:  [8]
pred: tensor([2])
files_name: T598.obj
Label:  [14]
pred: tensor([9])
files_name: T105.obj
Label:  [19]
pred: tensor([22])
files_name: T404.obj
Label:  [24]
pred: tensor([9])
files_name: T435.obj
Label:  [25]
pred: tensor([4])
files_name: T492.obj
Label:  [28]
pred: tensor([5])
files_name: T540.obj
Label:  [29]
pred: tensor([19])
files_name: T461.obj
Label:  [29]
pred: tensor([19])
files_name: T581.obj
Label:  [29]
pred: tensor([9])
files_name: T586.obj
epoch: -1, TEST ACC: [91.667 %]

