In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import tqdm
import torch 
import torch.nn.functional as func

import dataset
import models

FAUST = "../../Downloads/Mesh-Datasets/MyFaustDataset"
MODEL_PATH = "../model_data/data.pt"

data = dataset.FaustDataset(FAUST)

model = models.ChebClassifier(
    param_conv_layers=[64,64,32,32],
    E_t=data.downscaled_edges, 
    D_t=data.downscale_matrices,
    num_classes = 10)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
model.load_state_dict(torch.load(MODEL_PATH))


In [None]:
import plotly
import plotly.io as pio
import plotly.graph_objects as go
'''
def visualize(pos, face, intensity=None):
    """
    visualize input mesh
    """
    
    pos = pos.numpy()
    f = face.numpy()
    
    if intensity is None:
        intensity = np.ones(pos.shape[0])
    
    x,y,z = pos[:,0], pos[:,1], pos[:,2]
    i,j,k = f[0,:], f[1,:], f[2,:]
    mesh = go.Mesh3d(x=x, y=z, z=y,
              color='lightpink',
              intensity=intensity,
              opacity=1,
              i=i, j=j, k=k,
              showscale=True)
    layout = go.Layout(scene=go.layout.Scene(aspectmode="data")) 

    #pio.renderers.default="plotly_mimetype"
    fig = go.Figure(data=[mesh],
                   layout=layout)
    fig.update_layout(
        autosize=True,
        margin=dict(l=20, r=20, t=20, b=20),
        paper_bgcolor="LightSteelBlue")
    fig.show()'''
    
import plotly
import plotly.graph_objects as go
import numpy as np

def visualize(pos, faces, intensity=None):
  cpu = torch.device("cpu")
  if type(pos) != np.ndarray:
    pos = pos.to(cpu).clone().detach().numpy()
  if pos.shape[-1] != 3:
    raise ValueError("Vertices positions must have shape [n,3]")
  if type(faces) != np.ndarray:
    faces = faces.to(cpu).clone().detach().numpy()
  if faces.shape[-1] != 3:
    raise ValueError("Face indices must have shape [m,3]") 
  if intensity is None:
    intensity = np.ones([pos.shape[0]])
  elif type(intensity) != np.ndarray:
    intensity = intensity.to(cpu).clone().detach().numpy()

  x, z, y = pos.T
  i, j, k = faces.T

  mesh = go.Mesh3d(x=x, y=y, z=z,
            color='lightpink',
            intensity=intensity,
            opacity=1,
            colorscale=[[0, 'gold'],[0.5, 'mediumturquoise'],[1, 'magenta']],
            i=i, j=j, k=k,
            showscale=True)
  layout = go.Layout(scene=go.layout.Scene(aspectmode="data")) 

  #pio.renderers.default="plotly_mimetype"
  fig = go.Figure(data=[mesh],
                  layout=layout)
  fig.update_layout(
      autosize=True,
      margin=dict(l=20, r=20, t=20, b=20),
      paper_bgcolor="LightSteelBlue")
  fig.show()


now that we have defined the visualization procedure, we can show some meshes:

In [None]:
import mesh.transforms as mo

x = data[0].pos.clone()
f = data[0].face.clone()
mo.transform_rotation_(x, dims=[1,2,0])
visualize(x,f)

## Eigenvectors and Eigenvalues computation

In [None]:
from mesh.laplacian import laplacebeltrami_FEM
from mesh.laplacian import LB_v2
import scipy
import scipy.sparse.linalg  as slinalg


def eigenpairs(pos:torch.Tensor, face:torch.Tensor, K:int):

    if pos.shape[-1] != 3:
        raise ValueError("Vertices positions must have shape [n,3]")
    if faces.shape[-1] != 3:
        raise ValueError("Face indices must have shape [m,3]") 
  
    stiff, area, lump = laplacebeltrami_FEM(pos, face)
    #stiff, area = LB_v2(pos, face)
    n = pos.shape[0]

    stiff.coalesce()
    area.coalesce()

    si, sv = stiff.indices(), stiff.values()
    ai, av = area.indices(), area.values()

    ri,ci = si
    S = scipy.sparse.csr_matrix( (sv, (ri,ci)), shape=(n,n))

    ri,ci = ai
    A = scipy.sparse.csr_matrix( (av, (ri,ci)), shape=(n,n))

    #A_lumped = scipy.sparse.csr_matrix( (lump, (range(n),range(n))), shape=(n,n))

    eigvals, eigvecs = slinalg.eigsh(S, M=A, k=K, sigma=-1e-6)
    eigvals = torch.tensor(eigvals)
    eigvecs = torch.tensor(eigvecs)
    return eigvals, eigvecs

In [None]:
r = torch.rand(K,3)*0.01
x = pos + eigvecs.matmul(r)

visualize(x, face.t(), intensity=np.linalg.norm(x-pos,axis=-1))

# Heat Kernel

In [100]:
def heat_kernel(eigvals:torch.Tensor, eigvecs:torch.Tensor, t:float) -> torch.Tensor:
    #hk = eigvecs.matmul(torch.diag(torch.exp(-t*eigvals)).matmul(eigvecs.t()))
    tmp = torch.exp(-t*eigvals).view(1,-1)
    hk = (tmp*eigvecs).matmul(eigvecs.t())
    return hk

def diffusion_distance(eigvals:torch.Tensor, eigvecs:torch.Tensor, t:float):
    n, k = eigvecs.shape
    print(eigvecs.shape)
    D = torch.zeros([n,n])
    for i in tqdm.trange(k):
        eigvec = eigvecs[:,i].view(-1,1)
        eigval = eigvals[i]
        tmp = eigvec.repeat(1, n)
        tmp = tmp - tmp.t()
        D = D + torch.exp(-2*t*eigval)*(tmp*tmp)
    return D

hk = heat_kernel(eigvals, eigvecs, 0.01)
d = diffusion_distance(eigvals, eigvecs, 0.01)
visualize(pos, faces, intensity=d[:,3000])

  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

torch.Size([6890, 10])


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:13<00:00,  1.38s/it]


## Metrics


In [None]:
import os
import pickle
import matplotlib.pyplot as plt

def compute_distance_mse(pos, perturbed_pos, faces, K, t):
    eigvals1, eigvecs1 = eigenpairs(pos, faces, K)
    eigvals2, eigvecs2 = eigenpairs(perturbed_pos, faces, K)
    d1 = diffusion_distance(eigvals1,eigvecs1,t)
    d2 = diffusion_distance(eigvals2,eigvecs2,t)
    return d1,d2

def compute_distance_distance(distance, perbed_pos, faces, K, t):
    eigvals, eigvecs = eigenpairs(perbed_pos, faces, K)
    d = diffusion_distance(eigvals,eigvecs,t)
    return torch.nn.functional.mse_loss(distance, d)

def get_generator_data(adv_data, faces, K=10, t=0.01):
    out_dictionary = {"MSE_diffusion":[], "LB_loss":[], "MCF_loss":[], "Euclidean_loss":[]}
    diff_distances, mesh_eigenpairs = {}, {}
    for (idx, target), data in adv_data.items():
        print("processing ", idx,":",target)

        # get useful properties
        metrics = data["tracking-data"]
        pos = torch.tensor(data["positions"],dtype=torch.double)
        ppos = torch.tensor(data["perturbed-positions"],dtype=torch.double)
        c = data["c-value"]
        
        if idx not in mesh_eigenpairs:
            mesh_eigenpairs[idx] = eigenpairs(pos, faces, K)
            
        if idx not in diff_distances:
            diff_distances[idx] = diffusion_distance(*mesh_eigenpairs[idx], t)

        out_dictionary["MSE_diffusion"] = compute_distance_distance(
            diff_distances[idx], ppos, faces, K, t)
        
        # insert metrics into output dictionary  
        for k in out_dictionary.keys():
            if k in metrics: out_dictionary[k].append(metrics[k][-1])

    # convert into numpy arrays
    for k in out_dictionary.keys():out_dictionary[k] = np.array(out_dictionary[k])
    return out_dictionary


root = "../model_data"
filenames = ["Spectral", "MCF", "Dist"]
faces = data[0].face.t()

generators_data = {}
for fname in filenames:
  absname = os.path.join(root, fname+"AdversarialGenerator.pt")
  with open(absname, "rb") as f:
    generators_data[fname] = pickle.load(f)
    
processed_data = {}
for gname, adv_data in generators_data.items():
    print(gname)
    processed_data[gname] = get_generator_data(adv_data, faces, K=10, t=0.01)
    

In [None]:
data2plot = {}
for k in tmp.keys():
  data2plot[k] = np.array(tmp[k]).T

fig = plt.figure(1, figsize=(9, 6))
ax = fig.add_subplot(111)

mi = 0
tmp = [data2plot["Spectral"][mi], data2plot["MCF"][mi],data2plot["Dist"][mi]]
bp = ax.boxplot(tmp)

In [105]:


pos = torch.Tensor(generators_data["MCF"][88,6]["positions"])
ppos = torch.Tensor(generators_data["MCF"][88,6]["perturbed-positions"])
d1,d2 = compute_distance_mse(pos, ppos, faces, K=30, t=0.01)
torch.nn.functional.mse_loss(d1, d2)

  0%|                                                                                            | 0/3 [00:00<?, ?it/s]

torch.Size([6890, 3])


100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.26it/s]
  0%|                                                                                            | 0/3 [00:00<?, ?it/s]

torch.Size([6890, 3])


100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.10s/it]


tensor(0.0633)

In [106]:
vi=0
visualize(pos,faces, d1[vi]-d2[vi])
visualize(pos,faces, (pos-ppos).norm(dim=-1))