In [None]:
import torch
import scipy 
import trimesh 
import numpy as np
from potpourri3d import read_mesh
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import pymeshlab

import torchvision

from rna_config import Config
from diffusion_utils import compute_operators, normalize_positions

from model import DiffusionNet
from plot_utils.plot import plot, double_plot
from mesh_utils.mesh import TriMesh


device = 'cuda'
dtype = torch.float32

In [10]:
model_cfg = {
    "inp_feat": Config.inp_feat,
    "p_in": Config.p_in,
    "p_out": 147,
    "N_block": Config.n_block,
    "n_channels": Config.n_channels,
    "outputs_at": Config.outputs_at,
}
model = DiffusionNet(
    C_in=model_cfg['p_in'],
    C_out=model_cfg['p_out'],
    C_width=model_cfg["n_channels"],
    N_block=model_cfg['N_block'],
    outputs_at=model_cfg['outputs_at'],
    with_gradient_features=True,
    diffusion_method="spectral"
)
model.load_state_dict(torch.load("/Data/rna/models/model_final.pth", weights_only=True))

<All keys matched successfully>

In [12]:
off_path = "/Data/rna/RNADataset/off/4V83_BB.off" # mesh from the test set
verts, faces = read_mesh(off_path)
labels = np.loadtxt("/Data/rna/RNADataset/labels/4V83_BB.txt").astype(int) + 1
label_map = np.loadtxt("/Data/rna/RNADataset/label_map", dtype=int)
label_map = {k: v for k, v in label_map}
labels = np.array([label_map[label] for label in labels])
verts = torch.tensor(verts).float()
verts = normalize_positions(verts)
verts.shape, labels.shape

(torch.Size([14910, 3]), (14910,))

In [None]:
mesh = TriMesh(verts, faces)
double_plot(mesh,mesh,labels,labels)
# plot(mesh, cmap=labels)

HBox(children=(Output(), Output()))

HBox(children=(Output(), Output()))

In [9]:
off_path = "/Data/rna/RNADataset/off/4V83_BB.off" # mesh from the test set
verts, faces = read_mesh(off_path)
labels = np.loadtxt("/Data/rna/RNADataset/labels/4V83_BB.txt").astype(int) + 1
label_map = np.loadtxt("/Data/rna/RNADataset/label_map", dtype=int)
label_map = {k: v for k, v in label_map}
labels = np.array([label_map[label] for label in labels])
verts = torch.tensor(verts).float()
verts = normalize_positions(verts)
faces = torch.tensor(faces)

frames, vertex_area, L, evals, evecs, gradX, gradY = compute_operators(
    torch.tensor(verts), torch.tensor(faces), k_eig=Config.num_eig
)

model.eval()
with torch.no_grad():
    preds = model(verts, vertex_area, evals=evals, evecs=evecs, gradX=gradX, gradY=gradY, L=L)
    pred_labels = torch.argmax(preds, dim=1)

  torch.tensor(verts), torch.tensor(faces), k_eig=Config.num_eig
  vertex_normals = vertex_normals / np.linalg.norm(vertex_normals,axis=-1,keepdims=True)


In [10]:
(labels == pred_labels.cpu().numpy()).mean()

0.8254862508383635

In [9]:
plot(mesh, pred_labels.cpu().numpy())

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-104.2162…

<meshplot.Viewer.Viewer at 0x7fa6783e1e80>

In [30]:
off_path = "/Data/rna/RNADataset/off/4V83_BB.off" # mesh from the test set
verts, faces = read_mesh(off_path)
labels = np.loadtxt("/Data/rna/RNADataset/labels/4V83_BB.txt").astype(int) + 1
label_map = np.loadtxt("/Data/rna/RNADataset/label_map", dtype=int)
label_map = {k: v for k, v in label_map}
labels = np.array([label_map[label] for label in labels])
verts = torch.tensor(verts).float()
verts = normalize_positions(verts)
faces = torch.tensor(faces)

ms = pymeshlab.MeshSet()
verts = np.asarray(verts, dtype=np.float64)
faces = np.asarray(faces, dtype=np.int32)
labels = np.asarray(labels, dtype=np.float64).reshape(-1, 1) 

m = pymeshlab.Mesh(vertex_matrix=verts, face_matrix=faces, v_scalar_array=labels)
ms.add_mesh(m)
ms.meshing_decimation_quadric_edge_collapse(
    targetfacenum=5000,
    preservenormal=True,
    preserveboundary=True,
    preservetopology=False,
    optimalplacement=True,
    planarquadric=True,
    qualitythr=0.3,
    autoclean=True
)                
mesh = ms.current_mesh()
verts = mesh.vertex_matrix()
faces = mesh.face_matrix()
labels = mesh.vertex_scalar_array().astype(int)

verts = torch.tensor(verts).float()
faces = torch.tensor(faces)

frames, vertex_area, L, evals, evecs, gradX, gradY = compute_operators(
    torch.tensor(verts), torch.tensor(faces), k_eig=Config.num_eig
)

model.eval()
with torch.no_grad():
    preds_downsampled = model(verts, vertex_area, evals=evals, evecs=evecs, gradX=gradX, gradY=gradY, L=L)
    pred_labels_downsampled = torch.argmax(preds_downsampled, dim=1)

  torch.tensor(verts), torch.tensor(faces), k_eig=Config.num_eig


In [31]:
(labels == pred_labels_downsampled.cpu().numpy()).mean()

0.7847695390781563

In [None]:
mesh = TriMesh(verts, faces)
plot(mesh, pred_labels_downsampled.cpu().numpy())

AttributeError: 'pymeshlab.pmeshlab.Mesh' object has no attribute 'vertices'

In [35]:
off_path = "/Data/rna/RNADataset/off/4V83_BB.off" # mesh from the test set
verts, faces = read_mesh(off_path)
labels = np.loadtxt("/Data/rna/RNADataset/labels/4V83_BB.txt").astype(int) + 1
label_map = np.loadtxt("/Data/rna/RNADataset/label_map", dtype=int)
label_map = {k: v for k, v in label_map}
labels = np.array([label_map[label] for label in labels])
verts = torch.tensor(verts).float()
verts = normalize_positions(verts)
faces = torch.zeros((0, 3), dtype=torch.int32)

frames, vertex_area, L, evals, evecs, gradX, gradY = compute_operators(
    torch.tensor(verts), torch.tensor(faces), k_eig=Config.num_eig
)

model.eval()
with torch.no_grad():
    preds_pc = model(verts, vertex_area, evals=evals, evecs=evecs, gradX=gradX, gradY=gradY, L=L)
    pred_labels_pc = torch.argmax(preds_pc, dim=1)

  torch.tensor(verts), torch.tensor(faces), k_eig=Config.num_eig


In [36]:
(labels == pred_labels_pc.cpu().numpy()).mean()

0.4039570757880617

In [37]:
mesh = TriMesh(verts)
plot(mesh, pred_labels_pc.cpu().numpy())

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.1241606…

<meshplot.Viewer.Viewer at 0x7f2dbc295b80>