In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import tqdm
import torch 
import torch.nn.functional as func

import faust
import models

FAUST = "../../Downloads/Mesh-Datasets/MyFaustDataset"
MODEL_PATH = "../model_data/data.pt"

dataset = faust.FaustDataset(FAUST)

model = models.ChebClassifier(
    param_conv_layers=[64,64,32,32],
    E_t=dataset.downscaled_edges, 
    D_t=dataset.downscale_matrices,
    num_classes = 10)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
model.load_state_dict(torch.load(MODEL_PATH))

In [None]:
from torch_geometric.data import Data
from torch_geometric.data import Dataset
from random import shuffle

def avg_pos(dataset:Dataset):
    x = torch.zeros(dataset[0].pos.shape)
    for data in dataset: x += data.pos
    return x/len(dataset)

def compute_impact(x:torch.Tensor, f:torch.nn.Module, perturbed_pos:torch.Tensor, num_samples:int):
    Z:torch.Tensor = f(x)
    impact = -torch.ones([x.shape[0]])
    vertex_count = x.shape[0]
    vertex_bag = list(range(vertex_count))
    shuffle(vertex_bag)
    for i in tqdm.trange(num_samples):
        vi = vertex_bag[i]
        tmp = x[vi,:].clone()
        x[vi,:] = perturbed_pos[vi,:]
        Z_per = f(x)
        x[vi,:] = tmp
        impact[vi] = torch.sum((Z-Z_per)**2)
    impact = impact/impact.max()
    return impact

In [None]:
#visualize(avg_mesh)
mesh = dataset[99]
per = torch.rand_like(mesh.pos)*30
impact = compute_impact(mesh.pos, f=lambda x: model(x), perturbed_pos=per, num_samples=100)
intensity = impact.detach().numpy()
visualize(mesh, intensity=intensity)