In [1]:
from options.test_options import TestOptions
from data import DataLoader
from models import create_model
from models.layers.mesh_prepare import extract_features
from util.writer import Writer
from data.base_dataset import collate_fn
import numpy as np
import random
import torch.nn as nn
import torch
import torch.utils.data

In [2]:
def test_attacked_model(model, dataset, writer):
    writer.reset_counter()
    attacked_model_outputs = []
    
    for i, data in enumerate(dataset):
        model.set_input(data)
        output, ncorrect, nexamples = model.test() 
        attacked_model_outputs.append(output)
        writer.update_counter(ncorrect, nexamples)
        
    writer.print_acc(-1, writer.acc)
    return attacked_model_outputs 
    
def find_new_vertex_index(vertices_edges, edge_index, old_vertex_index):
    for new_vertex_index, new_vertex_edges in enumerate(vertices_edges):
            for new_edge_index in new_vertex_edges:
                if(new_edge_index == edge_index and new_vertex_index != old_vertex_index):
                    return new_vertex_index
                
def get_random_walk(mesh, random_walk_size):
    walk_steps = 0
    random_walk_vertices = []
    random_walk_indices = []
    vertex_index = random.randint(0, len(mesh.vs)-1)
    
    while walk_steps < random_walk_size: 
        random_walk_vertices.append(mesh.vs[vertex_index])
        random_walk_indices.append(vertex_index)        
        walk_steps += 1  
        
        vertex_edges = mesh.ve[vertex_index]
        random_edge_index = random.randint(0, len(vertex_edges)-1)  
        new_edge_index = vertex_edges[random_edge_index]
        vertex_index = find_new_vertex_index(mesh.ve, new_edge_index, random_walk_indices[-1])
        
        #Prevents random walk from crossing over itself
        count_of_vertex_edge_attempts = 0
        walk_steps_backwards = 0
        while(vertex_index in random_walk_indices):
            if(count_of_vertex_edge_attempts >= len(vertex_edges)-1):
                walk_steps_backwards += 1
                go_back_to_index = walk_steps-walk_steps_backwards
                
                if(go_back_to_index == 0):
                    # Trying again, mesh seems to be broken?
                    return get_random_walk(mesh, random_walk_size)
                
                vertex_edges = mesh.ve[random_walk_indices[go_back_to_index]]
                count_of_vertex_edge_attempts = 0
                
            random_edge_index = (random_edge_index + 1) % len(vertex_edges)
            new_edge_index = vertex_edges[random_edge_index]
            vertex_index = find_new_vertex_index(mesh.ve, new_edge_index, random_walk_indices[-1])
            count_of_vertex_edge_attempts += 1        
    
    return torch.FloatTensor(random_walk_vertices), random_walk_indices

In [3]:
def train_imitating_network(imitating_nn, criterion, optimizer, random_walk_vertices, labels):

    output = imitating_nn(random_walk_vertices)
    loss = criterion(output, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 
    
    _, predictions = torch.max(output.data, 1)
    num_correct = (predictions == labels).sum().item()
    
    return output, loss.item(), num_correct

In [4]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)

    def forward(self, input, hidden):
        
        combined = torch.cat((input, hidden), -1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

class Imitating_NN(nn.Module):
    def __init__(self, input_size, output_size):
        super(Imitating_NN, self).__init__()
        
        self.scaling_factor = 10
        self.input_size = input_size
        
        self.first_linear = nn.Linear(input_size, 2*self.scaling_factor*input_size)
        self.second_linear = nn.Linear(2*self.scaling_factor*input_size, self.scaling_factor*input_size)
        self.relu = nn.ReLU()
        self.rnn = RNN(self.scaling_factor*input_size, self.scaling_factor*input_size, self.scaling_factor*input_size)
        self.third_linear = nn.Linear(self.scaling_factor*input_size, 2*self.scaling_factor*input_size)
        self.fourth_linear = nn.Linear(2*self.scaling_factor*input_size, output_size)
        
    def forward(self, random_walk_vertices):
        
        output = self.first_linear(random_walk_vertices)
        output = self.second_linear(output)
        output = self.relu(output)
        
        hidden = self.rnn.initHidden()
        for step in output:
            output, hidden = self.rnn(torch.reshape(step, (1, self.input_size*self.scaling_factor)), hidden)
            
        output = self.third_linear(output)
        output = self.fourth_linear(output)
            
        return output
        

# Orchestration

In [5]:
testing_opt = TestOptions().parse()

testing_opt.serial_batches = True
dataloader = DataLoader(testing_opt)
mesh_cnn = create_model(testing_opt)
opt_writer = Writer(testing_opt)

shift_weight = 0.2
num_vertices_to_move = 5
random_walk_size = 50
num_categories = 30
num_vertice_coordinates = 3
imitating_nn = Imitating_NN(num_vertice_coordinates, num_categories)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(imitating_nn.parameters(), lr = 0.0005)

loaded mean / std from cache
loading the model from ./checkpoints/shrec16/latest_net.pth


In [6]:
# Record testing accuracy before attack
attacked_model_outputs = test_attacked_model(mesh_cnn, dataloader, opt_writer)

epoch: -1, TEST ACC: [99.167 %]



In [7]:
# Beginning Random Walk Attack
dataloader = DataLoader(testing_opt)

# Training Imitation Network
for epoch in range(100):
    epoch_loss = 0
    epoch_num_correct = 0
    epoch_num_samples = 0;
    for i, data in enumerate(dataloader):
        mesh = data["mesh"][0]
        label = torch.from_numpy(data["label"])
        
        random_walk_vertices, random_walk_indices = get_random_walk(mesh, random_walk_size)
        
        #TODO: combe back and use random walk here
        imitating_nn_output, loss, num_correct = train_imitating_network(imitating_nn, criterion, optimizer, torch.FloatTensor(mesh.vs), label)
        
        epoch_num_samples += label.size(dim=-1)
        epoch_loss += loss
        epoch_num_correct += num_correct
    
    accuracy = epoch_num_correct / epoch_num_samples
    print("epoch: " + str(epoch) + ", loss: " + str(epoch_loss), ", accuracy: " + str(accuracy))
    

loaded mean / std from cache


  return torch.FloatTensor(random_walk_vertices), random_walk_indices


epoch: 0, loss: 417.96099519729614 , accuracy: 0.016666666666666666
epoch: 1, loss: 412.9377884864807 , accuracy: 0.03333333333333333
epoch: 2, loss: 412.79872155189514 , accuracy: 0.03333333333333333
epoch: 3, loss: 411.5038158893585 , accuracy: 0.0
epoch: 4, loss: 409.6977620124817 , accuracy: 0.008333333333333333
epoch: 5, loss: 402.54399609565735 , accuracy: 0.041666666666666664
epoch: 6, loss: 404.5584237575531 , accuracy: 0.025
epoch: 7, loss: 401.5942211151123 , accuracy: 0.041666666666666664
epoch: 8, loss: 393.63180553913116 , accuracy: 0.05
epoch: 9, loss: 385.1892976760864 , accuracy: 0.05
epoch: 10, loss: 374.5214728116989 , accuracy: 0.05
epoch: 11, loss: 365.6240222454071 , accuracy: 0.041666666666666664
epoch: 12, loss: 355.9369351863861 , accuracy: 0.075
epoch: 13, loss: 351.3061738014221 , accuracy: 0.09166666666666666
epoch: 14, loss: 345.26844650506973 , accuracy: 0.1
epoch: 15, loss: 338.11372262239456 , accuracy: 0.1
epoch: 16, loss: 335.9178787469864 , accuracy: 0

In [8]:
# Moving Vertices
dataloader = DataLoader(testing_opt)

def gradients_magnitude(vertices):
    return vertices[1][0] ** 2 + vertices[1][1] ** 2 + vertices[1][2] ** 2
    
overridden_meshes = []
for i, data in enumerate(dataloader):
        mesh = data["mesh"][0]
        random_walk_vertices, random_walk_indices = get_random_walk(mesh, random_walk_size)
        label = torch.from_numpy(data["label"])
    
        #random_walk_vertices.requires_grad = True
        mesh_vs = torch.FloatTensor(mesh.vs)
        mesh_vs.requires_grad = True
        
        #TODO: combe back and use random walk here
        imitating_nn_output, loss, num_correct = train_imitating_network(imitating_nn, criterion, optimizer, mesh_vs, label)
        
        gradients_dict = {index: gradients for index, gradients in enumerate(mesh_vs.grad)}
        gradients_dict = dict(sorted(gradients_dict.items(), key = gradients_magnitude))
        
        max_grad = torch.max(mesh_vs.grad.flatten(), 0)[0].item()
        min_grad = torch.min(mesh_vs.grad.flatten(), 0)[0].item()
        
        num_vertices_changed = 0
        while(num_vertices_changed < num_vertices_to_move):
            
            gradient_entry = gradients_dict.popitem()
            index = gradient_entry[0]
            gradients = gradient_entry[1]
            
            #mesh.vs[index][0] += ((2 * shift_weight * (gradients[0]-min_grad) / (max_grad-min_grad)) - shift_weight)
            #mesh.vs[index][1] += ((2 * shift_weight * (gradients[1]-min_grad) / (max_grad-min_grad)) - shift_weight)
            #mesh.vs[index][2] += ((2 * shift_weight * (gradients[2]-min_grad) / (max_grad-min_grad)) - shift_weight)
            
            mesh.vs[index][0] += gradients[0]
            mesh.vs[index][1] += gradients[1]
            mesh.vs[index][2] += gradients[2]
            
            num_vertices_changed += 1
            
        mesh.features = extract_features(mesh)
        overridden_meshes.append(mesh)
        new_file_name = "datasets/random_walks/" + mesh.filename
        mesh.export(file=new_file_name)

dataloader.dataloader.dataset.override_meshes(overridden_meshes) 
        

loaded mean / std from cache


In [9]:
# Record testing accuracy after random walk attack
attacked_model_outputs = test_attacked_model(mesh_cnn, dataloader, opt_writer)

epoch: -1, TEST ACC: [38.333 %]



In [10]:
# Random Pertubation Attack
dataloader = DataLoader(testing_opt)
overridden_meshes = []
for i, data in enumerate(dataloader):
    
    mesh = data["mesh"][0]
    num_random_changes = 0
    
    while(num_random_changes < num_vertices_to_move):
        
        random_vertex_index = random.randint(0, len(mesh.vs)-1)    
        mesh.vs[random_vertex_index][0] += random.uniform(-shift_weight, shift_weight)
        mesh.vs[random_vertex_index][1] += random.uniform(-shift_weight, shift_weight)
        mesh.vs[random_vertex_index][2] += random.uniform(-shift_weight, shift_weight)
        num_random_changes += 1
    
    mesh.features = extract_features(mesh)
    overridden_meshes.append(mesh)
    new_file_name = "datasets/random_pertubations/" + mesh.filename
    mesh.export(file=new_file_name)

dataloader.dataloader.dataset.override_meshes(overridden_meshes) 

loaded mean / std from cache


In [11]:
# Record testing accuracy after random pertubation attack
attacked_model_outputs = test_attacked_model(mesh_cnn, dataloader, opt_writer)

epoch: -1, TEST ACC: [96.667 %]

