In [1]:
import torch
import torch.nn.functional as F
import pytorch3d
from pytorch3d.utils import ico_sphere

from src.util import make_faces

torch.set_printoptions(sci_mode=False, precision=4)

In [2]:
def vertex_tris(faces):
    res = []
    for vid in range(faces.max()+1):
        vertex_faces = []
        for fid, face in enumerate(faces):
            if vid in face:
                vertex_faces.append(fid)
        res.append(vertex_faces)
    return res

def get_face_normals(vertices, faces):
    vec1 = vertices[faces[:, 1]] - vertices[faces[:, 0]]
    vec2 = vertices[faces[:, 2]] - vertices[faces[:, 0]]
    face_normals = F.normalize(vec1.cross(vec2), p=2, dim=-1)  # [F, 3]
    return face_normals



def get_vertex_normals(face_normals, vert_tri_indices, vert_tri_weights):    
    weighted_face_normals = face_normals[vert_tri_indices] * vert_tri_weights
    vertex_normals = weighted_face_normals.sum(dim=-2)
    return vertex_normals


In [3]:
device = torch.device("cuda:0")

src_mesh = ico_sphere(0, device)
ico_vertices = src_mesh.verts_list()[0]
ico_vertex_normals = src_mesh.verts_normals_list()[0]
ico_faces = src_mesh.faces_list()[0]
ico_face_normals = src_mesh.faces_normals_list()[0]
ico_vertices.shape, ico_vertex_normals.shape, ico_faces.shape, ico_face_normals.shape

(torch.Size([12, 3]),
 torch.Size([12, 3]),
 torch.Size([20, 3]),
 torch.Size([20, 3]))

In [4]:
tris = ico_vertices[ico_faces]
tris.shape

torch.Size([20, 3, 3])

In [7]:
face_normals = get_face_normals(ico_vertices, ico_faces)
print(face_normals.shape)
torch.allclose(face_normals, ico_face_normals)

torch.Size([20, 3])


True

In [19]:
def vertex_tri_maps(faces):
    vts = vertex_tris(faces)
    r, c = len(vts), max([len(x) for  x in vts])
    vert_tri_indices = torch.zeros(r, c, dtype=torch.long)
    vert_tri_weights = torch.zeros(r, c)    
    for r, tris in enumerate(vts):
        weight = 1. / len(tris)
        for c, tri_id in enumerate(tris):
            vert_tri_indices[r, c] = tri_id
            vert_tri_weights[r, c] = weight
    return vert_tri_indices, vert_tri_weights.unsqueeze(dim=-1)

vert_tri_indices, vert_tri_weights = vertex_tri_maps(ico_faces)
vert_tri_indices = vert_tri_indices.to(device)
vert_tri_weights = vert_tri_weights.to(device)

vert_tri_indices.shape, vert_tri_weights.shape

(torch.Size([12, 5]), torch.Size([12, 5, 1]))

In [22]:
def get_vertex_normals(face_normals, vert_tri_indices, vert_tri_weights):    
    weighted_face_normals = face_normals[vert_tri_indices] * vert_tri_weights
    vertex_normals = weighted_face_normals.sum(dim=-2)
    return  F.normalize(vertex_normals, p=2, dim=-1)

get_vertex_normals(face_normals, vert_tri_indices, vert_tri_weights)

tensor([[    -0.5257,      0.8507,     -0.0000],
        [     0.5257,      0.8507,      0.0000],
        [    -0.5257,     -0.8507,      0.0000],
        [     0.5257,     -0.8507,     -0.0000],
        [     0.0000,     -0.5257,      0.8507],
        [    -0.0000,      0.5257,      0.8507],
        [     0.0000,     -0.5257,     -0.8507],
        [     0.0000,      0.5257,     -0.8507],
        [     0.8507,      0.0000,     -0.5257],
        [     0.8507,      0.0000,      0.5257],
        [    -0.8507,     -0.0000,     -0.5257],
        [    -0.8507,      0.0000,      0.5257]], device='cuda:0')

In [21]:
ico_vertex_normals

tensor([[    -0.5257,      0.8507,     -0.0000],
        [     0.5257,      0.8507,      0.0000],
        [    -0.5257,     -0.8507,      0.0000],
        [     0.5257,     -0.8507,     -0.0000],
        [    -0.0000,     -0.5257,      0.8507],
        [     0.0000,      0.5257,      0.8507],
        [     0.0000,     -0.5257,     -0.8507],
        [    -0.0000,      0.5257,     -0.8507],
        [     0.8507,     -0.0000,     -0.5257],
        [     0.8507,      0.0000,      0.5257],
        [    -0.8507,      0.0000,     -0.5257],
        [    -0.8507,     -0.0000,      0.5257]], device='cuda:0')