In [1]:
import numpy as np
import torch
import torch.nn.functional as  F
import trimesh
import pytorch3d
from pytorch3d.utils import ico_sphere

from src.config import get_parser
from src.util import make_faces


torch.set_printoptions(sci_mode=False, precision=4)

In [2]:
def vertex_tris(faces):
    res = []
    for vid in range(faces.max()+1):
        vertex_faces = []
        for fid, face in enumerate(faces):
            if vid in face:
                vertex_faces.append(fid)
        res.append(vertex_faces)
    return res

def vertex_tri_maps(faces):
    vts = vertex_tris(faces)
    r, c = len(vts), max([len(x) for  x in vts])
    vert_tri_indices = torch.zeros(r, c, dtype=torch.long)
    vert_tri_weights = torch.zeros(r, c)    
    for r, tris in enumerate(vts):
        weight = 1. / len(tris)
        for c, tri_id in enumerate(tris):
            vert_tri_indices[r, c] = tri_id
            vert_tri_weights[r, c] = weight
    return vert_tri_indices, vert_tri_weights.unsqueeze(dim=-1)[None]

def get_face_normals(vrt, faces):    
    v1 = vrt.index_select(1,faces[:, 1]) - vrt.index_select(1, faces[:, 0])
    v2 = vrt.index_select(1,faces[:, 2]) - vrt.index_select(1, faces[:, 0])
    face_normals = F.normalize(v1.cross(v2), p=2, dim=-1)  # [F, 3]
    return face_normals

def get_vertex_normals(face_normals, vert_tri_indices, vert_tri_weights):
    bs = face_normals.size(0)
    r, c = vert_tri_indices.shape
    fn_group = face_normals.index_select(1, 
        vert_tri_indices.flatten()).reshape(bs, r, c, 3)
    weighted_fn_group = fn_group * vert_tri_weights    
    vertex_normals = weighted_fn_group.sum(dim=-2)
    return F.normalize(vertex_normals, p=2, dim=-1)

In [3]:
device = torch.device("cuda:0")

src_mesh = ico_sphere(0, device)
ico_vertices = src_mesh.verts_list()[0]
ico_vertex_normals = src_mesh.verts_normals_list()[0]
ico_faces = src_mesh.faces_list()[0]
ico_face_normals = src_mesh.faces_normals_list()[0]
ico_vertices.shape, ico_vertex_normals.shape, ico_faces.shape, ico_face_normals.shape

(torch.Size([12, 3]),
 torch.Size([12, 3]),
 torch.Size([20, 3]),
 torch.Size([20, 3]))

In [5]:
face_normals =  get_face_normals(ico_vertices[None], ico_faces)
torch.allclose(face_normals[0], ico_face_normals)

True

In [6]:
vert_tri_indices, vert_tri_weights = vertex_tri_maps(ico_faces)
vert_tri_indices = vert_tri_indices.to(device)
vert_tri_weights = vert_tri_weights.to(device)
vert_tri_indices.shape, vert_tri_weights.shape

(torch.Size([12, 5]), torch.Size([1, 12, 5, 1]))

In [7]:
vert_normals = get_vertex_normals(face_normals, vert_tri_indices, vert_tri_weights)
vert_normals.shape

torch.Size([1, 12, 3])

In [13]:
torch.allclose(vert_normals[0], ico_vertex_normals, atol=1e-05)

True

In [11]:
md = torch.max(vert_normals[0] - ico_vertex_normals).item()
"{:+.6f}".format(md)

'+0.000006'