In [1]:
import numpy as np
import torch
import torch.nn.functional as  F
import trimesh
import pytorch3d
from pytorch3d.utils import ico_sphere

from src.config import get_parser
from src.util import make_faces


torch.set_printoptions(sci_mode=False, precision=4)

In [35]:
def vertex_tris(faces):
    res = []
    for vid in range(faces.max()+1):
        vertex_faces = []
        for fid, face in enumerate(faces):
            if vid in face:
                vertex_faces.append(fid)
        res.append(vertex_faces)
    return res

def vertex_tri_maps(faces):
    vts = vertex_tris(faces)
    r, c = len(vts), max([len(x) for  x in vts])
    vert_tri_indices = torch.zeros(r, c, dtype=torch.long)
    vert_tri_weights = torch.zeros(r, c)    
    for r, tris in enumerate(vts):
        weight = 1. / len(tris)
        for c, tri_id in enumerate(tris):
            vert_tri_indices[r, c] = tri_id
            vert_tri_weights[r, c] = weight
    return vert_tri_indices, vert_tri_weights.unsqueeze(dim=-1)[None]

def get_face_normals(vertices, faces):
    vec1 = vertices[faces[:, 1]] - vertices[faces[:, 0]]
    vec2 = vertices[faces[:, 2]] - vertices[faces[:, 0]]
    face_normals = F.normalize(vec1.cross(vec2), p=2, dim=-1)  # [F, 3]
    return face_normals

def get_vertex_normals(face_normals, vert_tri_indices, vert_tri_weights):    
    weighted_face_normals = face_normals[vert_tri_indices] * vert_tri_weights
    vertex_normals = weighted_face_normals.sum(dim=-2)
    return  F.normalize(vertex_normals, p=2, dim=-1)

In [3]:
device = torch.device("cuda:0")

src_mesh = ico_sphere(0, device)
ico_vertices = src_mesh.verts_list()[0]
ico_vertex_normals = src_mesh.verts_normals_list()[0]
ico_faces = src_mesh.faces_list()[0]
ico_face_normals = src_mesh.faces_normals_list()[0]
ico_vertices.shape, ico_vertex_normals.shape, ico_faces.shape, ico_face_normals.shape

(torch.Size([12, 3]),
 torch.Size([12, 3]),
 torch.Size([20, 3]),
 torch.Size([20, 3]))

In [4]:
tris = ico_vertices[ico_faces]
tris.shape

torch.Size([20, 3, 3])

In [5]:
faces_batch = ico_faces[None]
faces_batch.shape

torch.Size([1, 20, 3])

In [6]:
vertices_batch = ico_vertices[None].clone()
vertices_batch.shape

torch.Size([1, 12, 3])

In [7]:
vertices_batch

tensor([[[-0.5257,  0.8507,  0.0000],
         [ 0.5257,  0.8507,  0.0000],
         [-0.5257, -0.8507,  0.0000],
         [ 0.5257, -0.8507,  0.0000],
         [ 0.0000, -0.5257,  0.8507],
         [ 0.0000,  0.5257,  0.8507],
         [ 0.0000, -0.5257, -0.8507],
         [ 0.0000,  0.5257, -0.8507],
         [ 0.8507,  0.0000, -0.5257],
         [ 0.8507,  0.0000,  0.5257],
         [-0.8507,  0.0000, -0.5257],
         [-0.8507,  0.0000,  0.5257]]], device='cuda:0')

In [8]:
faces_batch.shape, faces_batch[:, :, 1].shape, faces_batch[0][:, 1].shape

(torch.Size([1, 20, 3]), torch.Size([1, 20]), torch.Size([20]))

In [9]:
vertices_batch.shape

torch.Size([1, 12, 3])

In [10]:
vt_bch = torch.randn((4, 12, 3), device=device)
vt_bch.shape              

torch.Size([4, 12, 3])

In [15]:
fc_bch= torch.randint(0, 12, (2, 20, 3), device=device)
fc_bch.shape

torch.Size([2, 20, 3])

In [16]:
fc_bch[:, :, 1].shape, fc_bch[:, 1].shape

(torch.Size([2, 20]), torch.Size([2, 3]))

In [17]:
torch.index_select(vt_bch, 1, ico_faces[:, 1]).shape

torch.Size([4, 20, 3])

In [19]:
vt_bch.index_select(1, ico_faces[:, 1]).shape

torch.Size([4, 20, 3])

In [20]:
ico_faces[:, 1].shape

torch.Size([20])

In [21]:
def get_face_normals(vrt, faces):    
    v1 = vrt.index_select(1,faces[:, 1]) - vrt.index_select(1, faces[:, 0])
    v2 = vrt.index_select(1,faces[:, 2]) - vrt.index_select(1, faces[:, 0])
    face_normals = F.normalize(v1.cross(v2), p=2, dim=-1)  # [F, 3]
    return face_normals

face_normals_b =  get_face_normals(vertices_batch, ico_faces)
face_normals_b

tensor([[[    -0.5774,      0.5774,      0.5774],
         [    -0.0000,      0.9341,      0.3569],
         [     0.0000,      0.9341,     -0.3569],
         [    -0.5774,      0.5774,     -0.5774],
         [    -0.9341,      0.3569,      0.0000],
         [     0.5774,      0.5774,      0.5774],
         [    -0.3569,      0.0000,      0.9341],
         [    -0.9341,     -0.3569,     -0.0000],
         [    -0.3569,     -0.0000,     -0.9341],
         [     0.5774,      0.5774,     -0.5774],
         [     0.5774,     -0.5774,      0.5774],
         [     0.0000,     -0.9341,      0.3569],
         [    -0.0000,     -0.9341,     -0.3569],
         [     0.5774,     -0.5774,     -0.5774],
         [     0.9341,     -0.3569,      0.0000],
         [     0.3569,     -0.0000,      0.9341],
         [    -0.5774,     -0.5774,      0.5774],
         [    -0.5774,     -0.5774,     -0.5774],
         [     0.3569,      0.0000,     -0.9341],
         [     0.9341,      0.3569,      0.0000]]]

In [22]:
face_normals_b.shape

torch.Size([1, 20, 3])

In [23]:
torch.allclose(ico_face_normals, face_normals_b[0])

True

In [36]:
vert_tri_indices, vert_tri_weights = vertex_tri_maps(ico_faces)
vert_tri_indices = vert_tri_indices.to(device)
vert_tri_weights = vert_tri_weights.to(device)
vert_tri_indices.shape, vert_tri_weights.shape

(torch.Size([12, 5]), torch.Size([1, 12, 5, 1]))

In [37]:
def get_vertex_normals(face_normals, vert_tri_indices, vert_tri_weights):    
    weighted_face_normals = face_normals[vert_tri_indices] * vert_tri_weights
    vertex_normals = weighted_face_normals.sum(dim=-2)
    return  F.normalize(vertex_normals, p=2, dim=-1)

vert_normals_b = get_vertex_normals(face_normals_b, vert_tri_indices, vert_tri_weights)
vert_normals_b.shape

RuntimeError: The size of tensor a (20) must match the size of tensor b (5) at non-singleton dimension 2

In [30]:
face_normals_b.shape, vert_tri_indices.shape

(torch.Size([1, 20, 3]), torch.Size([12, 5]))

In [32]:
face_normals_b.index_select(1, vert_tri_indices.flatten()).shape

torch.Size([1, 60, 3])

In [27]:
face_normals_b.device, vert_tri_indices.device

(device(type='cuda', index=0), device(type='cuda', index=0))

In [37]:
def get_vertex_normals(face_normals, vert_tri_indices, vert_tri_weights):
    bs = face_normals.size(0)
    r, c = vert_tri_indices.shape
    fn_group = face_normals.index_select(1, 
        vert_tri_indices.flatten()).reshape(bs, r, c, 3)
    weighted_fn_group = fn_group * vert_tri_weights    
    vertex_normals = weighted_fn_group.sum(dim=-2)
    return F.normalize(vertex_normals, p=2, dim=-1)

vert_normals_b = get_vertex_normals(face_normals_b, vert_tri_indices, vert_tri_weights)
vert_normals_b.shape

torch.Size([1, 12, 3])

In [38]:
vert_normals_b

tensor([[[    -0.5257,      0.8507,     -0.0000],
         [     0.5257,      0.8507,      0.0000],
         [    -0.5257,     -0.8507,      0.0000],
         [     0.5257,     -0.8507,     -0.0000],
         [     0.0000,     -0.5257,      0.8507],
         [    -0.0000,      0.5257,      0.8507],
         [     0.0000,     -0.5257,     -0.8507],
         [     0.0000,      0.5257,     -0.8507],
         [     0.8507,      0.0000,     -0.5257],
         [     0.8507,      0.0000,      0.5257],
         [    -0.8507,     -0.0000,     -0.5257],
         [    -0.8507,      0.0000,      0.5257]]], device='cuda:0')

In [39]:
ico_vertex_normals

tensor([[    -0.5257,      0.8507,     -0.0000],
        [     0.5257,      0.8507,      0.0000],
        [    -0.5257,     -0.8507,      0.0000],
        [     0.5257,     -0.8507,     -0.0000],
        [    -0.0000,     -0.5257,      0.8507],
        [     0.0000,      0.5257,      0.8507],
        [     0.0000,     -0.5257,     -0.8507],
        [    -0.0000,      0.5257,     -0.8507],
        [     0.8507,     -0.0000,     -0.5257],
        [     0.8507,      0.0000,      0.5257],
        [    -0.8507,      0.0000,     -0.5257],
        [    -0.8507,     -0.0000,      0.5257]], device='cuda:0')