In [None]:
%load_ext autoreload
%autoreload 2

# built-in libraries
import os 

# third party libraries
import torch

# local modules
import dataset
import models
import train

SHREC14 =  "../../Downloads/Mesh-Datasets/MyShrec14"
PARAMS_FILE = "../model_data/SHREC14.pt"

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Training

## Shape Classification using SHREC14
It is necessary to train the classifier:

In [None]:
traindata = dataset.Shrec14Dataset(SHREC14, device=DEVICE, train=True, test=False)
testdata = dataset.Shrec14Dataset(SHREC14, device=DEVICE, train=False, test=True)

MODEL = models.chebynet.ChebnetClassifier_SHREC14(
        nums_conv_units=[32,32,16,16],
        num_classes=traindata.num_classes,
        parameters_file=PARAMS_FILE)
MODEL.to(DEVICE)


#train network
train.train_SHREC14(
    train_data=traindata,
    classifier=MODEL,
    parameters_file=PARAMS_FILE,
    learning_rate=3e-4,
    epoch_number=10)

accuracy, confusion = train.evaluate_SHREC14(
    eval_data=testdata,
    classifier=MODEL,
    epoch_number=2)

import matplotlib.pyplot as plt

print(accuracy)
plt.matshow(confusion)
plt.show()

## Shape retrival using SHREC14
In the following cell, it's shown how to train a classifier for identification of a subject. The main difference is in the dataset ground-truths, and its subdivision in test/training sets.

In [None]:
PARAMS_FILE = "../model_data/SHREC14_retrival.pt"

traindata = dataset.Shrec14Dataset_retrivial(SHREC14, device=DEVICE, train=True, test=False)
testdata = dataset.Shrec14Dataset_retrivial(SHREC14, device=DEVICE, train=False, test=True)

MODEL = models.chebynet.ChebnetClassifier_SHREC14(
        nums_conv_units=[32,32,16,16],
        num_classes=traindata.num_classes,
        parameters_file=PARAMS_FILE,
        K=20)
MODEL.to(DEVICE)

#train network
train.train_SHREC14(
    train_data=traindata,
    classifier=MODEL,
    parameters_file=PARAMS_FILE,
    learning_rate=3e-4,
    epoch_number=0)

accuracy, confusion = train.evaluate_SHREC14(
    eval_data=traindata,
    classifier=MODEL,
    epoch_number=1)

import matplotlib.pyplot as plt

print(accuracy)
plt.matshow(confusion)
plt.show()

# Visualization procedures

In [None]:
import plotly
import plotly.graph_objects as go
import numpy as np

def visualize(pos, faces, intensity=None):
  cpu = torch.device("cpu")
  if type(pos) != np.ndarray:
    pos = pos.to(cpu).clone().detach().numpy()
  if pos.shape[-1] != 3:
    raise ValueError("Vertices positions must have shape [n,3]")
  if type(faces) != np.ndarray:
    faces = faces.to(cpu).clone().detach().numpy()
  if faces.shape[-1] != 3:
    raise ValueError("Face indices must have shape [m,3]") 
  if intensity is None:
    intensity = np.ones([pos.shape[0]])
  elif type(intensity) != np.ndarray:
    intensity = intensity.to(cpu).clone().detach().numpy()

  x, z, y = pos.T
  i, j, k = faces.T

  mesh = go.Mesh3d(x=x, y=y, z=z,
            color='lightpink',
            intensity=intensity,
            opacity=1,
            colorscale=[[0, 'gold'],[0.5, 'mediumturquoise'],[1, 'magenta']],
            i=i, j=j, k=k,
            showscale=True)
  layout = go.Layout(scene=go.layout.Scene(aspectmode="data")) 

  #pio.renderers.default="plotly_mimetype"
  fig = go.Figure(data=[mesh],
                  layout=layout)
  fig.update_layout(
      autosize=True,
      margin=dict(l=20, r=20, t=20, b=20),
      paper_bgcolor="LightSteelBlue")
  fig.show()
    
def compare(pos1, faces1, pos2, faces2):
    n,m = pos1.shape[0], pos2.shape[0]
    tmpx = torch.cat([pos1, pos2],dim=0)
    tmpf = torch.cat([faces1, faces2+n], dim=0)
    color = torch.zeros([n+m],dtype=pos1.dtype, device=pos1.device)
    color[n:] = (pos1-pos2).norm(p=2,dim=-1)
    visualize(tmpx, tmpf,color)

# adversarial examples
## Carlini & Wagner


In [None]:
import adversarial.carlini_wagner as cw
from torch_geometric.data import Data

def get_classifier(data, mesh_index): #IMPORTANT THIS IS A TRICK, NOT A GOOD PRACTICE
    I = [(i.to(DEVICE),v.to(DEVICE), s) for (i,v,s) in data.downscale_matrices[mesh_index]]
    E = [ e .to(DEVICE) for e in data.downscaled_edges[mesh_index]]
    return  lambda x:  MODEL(Data(pos=x, edge_index=data[mesh_index].edge_index, y=data[mesh_index].y), I, E)

def CW_adversarial_example(
    mesh_index=0,
    target_class=None,
    data=testdata,
    perturbation="lowband",
    distortion="local_euclidean",
    eigenvecs_number=36,
    adversarial_coefficient:float="default",
    regularization_coefficient:float="default",
    learning_rate:float=5e-5,
    minimization_iterations=1000,
    tuning_iterations=3):
    """
    Create an adversarial example using the C&W method.
    
    Arguments:
     - data: the dataset containing the mesh used during the adversarial attack.
     - mesh_index: index (in data) of the mesh to perturb.
     - target_class: class goal for the targeted adversarial attack.
     - perturbation: type of perturbation, assumes values 'lowband' or 'vertex'
     - distortion: type of distortion, assumes values 'L2' or 'local_euclidean'.
     - eigenvecs_number: number of eigenvalues used for the lowband perturbation.
     - adversarial_coefficient: coefficient used by the adversarial term in C&W.
     - regularization_coefficient: coefficient used for the centroid regularization term for 'local_euclidean'
     - learning_rate: learning rate for the gradient descent iterations.
     - miimization_iterations: number of gradient descent iterations.
     - tuning_iterations: number of iterations used to tune the adversarial coefficient.
     
    Returns:
     - adex: aversarial example. 
    """
    
    #check input consistency:
    if perturbation not in  ["lowband","vertex"]:
        raise ValueError("Invalid input for argument 'perturbation'. Must either be 'lowband' or 'vertex'!")
        
    if distortion not in  ["L2","local_euclidean"]:
        raise ValueError("Invalid input for argument 'distortion'. Must either be 'L2' or 'local_euclidean'!")
        
    if adversarial_coefficient != "default" and not isinstance(adversarial_coefficient, float):
        raise ValueError("Invalid input for argument 'adversarial_coefficient'. Must either be the string 'default' or any floating point number!")
    
    if regularization_coefficient != "default" and not isinstance(regularization_coefficient, float):
        raise ValueError("Invalid input for argument 'regularization_coefficient'. Must either be the string 'default' or any floating point number!")

    if adversarial_coefficient == "default":
        if distortion == "L2":
            adversarial_coefficient = 5e-3
        elif distortion == "local_euclidean":
            adversarial_coefficient = 5e-7
    
    if regularization_coefficient == "default":
        if distortion == "L2":
            regularization_coefficient = 0
        elif distortion == "local_euclidean":
            regularization_coefficient = 1e3
    
    i = mesh_index
    x = data[i].pos
    e = data[i].edge_index.t().to(DEVICE) # needs to be transposed
    f = data[i].face.t().to(DEVICE) # needs to be transposed
    y = data[i].y
    t = target_class

    # configure adversarial example components
    builder = cw.CWBuilder(search_iterations=tuning_iterations)
    builder.set_classifier(get_classifier(data, i))
    builder.set_perturbation_type(perturbation, eigs_num=eigenvecs_number)
    print(adversarial_coefficient)
    builder.set_mesh(x, e, f).set_adversarial_coeff(adversarial_coefficient)
    if t is not None: builder.set_target(t)
    
    if distortion=="L2":
        builder.set_distortion_function(cw.L2_distortion)
    elif distortion=="local_euclidean":
        builder.set_distortion_function(cw.LocallyEuclideanDistortion(K=40))
        builder.set_regularization_function(cw.centroid_regularizer)

    builder.set_minimization_iterations(minimization_iterations).set_learning_rate(learning_rate)
    adex = builder.build(usetqdm="standard")
    return adex

In [None]:
adex = CW_adversarial_example(
data=traindata,
mesh_index=0,
target_class=1,
perturbation="lowband",
learning_rate=5e-4, 
minimization_iterations=10,
tuning_iterations=1)

print("adversarial attack: "+("successful" if adex.is_successful else "unsuccessful"))
compare(adex.pos, f, adex.perturbed_pos, f)
