In [1]:
import os
from collections import OrderedDict
import torch
import torch.nn.functional as F
from src.util import (
    make_faces,    
)    

def vertex_tris(faces):
    res = [[] for _ in range(faces.max()+1)]
    for fid, face in enumerate(faces):        
        for vid in face:
            res[vid].append(fid)        
    return res

def vertex_tri_maps(faces):
    vts = vertex_tris(faces)
    r, c = len(vts), max([len(x) for  x in vts])
    vert_tri_indices = torch.zeros(r, c, dtype=torch.long)
    vert_tri_weights = torch.zeros(r, c)    
    for r, tris in enumerate(vts):        
        weight = 1. #/ len(tris)
        for c, tri_id in enumerate(tris):
            vert_tri_indices[r, c] = tri_id
            vert_tri_weights[r, c] = weight
    return vert_tri_indices, vert_tri_weights.unsqueeze(dim=-1)[None]

In [4]:
side = 2
faces = make_faces(side, side)
print(faces.shape)
faces

(2, 3)


array([[2, 0, 3],
       [1, 3, 0]])

In [6]:
vrt_no = side ** 2
angle_sel = torch.randint(0, vrt_no, (vrt_no, 6, 3))
print(angle_sel.shape)
#angle_sel

torch.Size([4, 6, 3])


In [10]:
angle_vrt_idx = torch.zeros(vrt_no, 6, 3, dtype=torch.long) -1
angle_vrt_idx.shape

torch.Size([4, 6, 3])

In [13]:
for m in angle_vrt_idx[0]:
    print(m[0].item())

tensor(-1)
tensor(-1)
tensor(-1)
tensor(-1)
tensor(-1)
tensor(-1)


In [14]:
angle_vrt_idx[0, 1, 2]

tensor(-1)

In [21]:
angle_vrt_idx = torch.zeros(vrt_no, 6, 3, dtype=torch.long) -1

for face in faces:
    v0, v1, v2 = face
    for i, m in enumerate(angle_vrt_idx[v0]):
        if m[0].item() == -1:
            angle_vrt_idx[v0, i, 0] = v1
            angle_vrt_idx[v0, i, 1] = v0
            angle_vrt_idx[v0, i, 2] = v2
            break
    for i, m in enumerate(angle_vrt_idx[v1]):
        if m[0].item() == -1:
            angle_vrt_idx[v1, i, 0] = v0
            angle_vrt_idx[v1, i, 1] = v1
            angle_vrt_idx[v1, i, 2] = v2
            break
    for i, m in enumerate(angle_vrt_idx[v2]):
        if m[0].item() == -1:
            angle_vrt_idx[v2, i, 0] = v0
            angle_vrt_idx[v2, i, 1] = v2
            angle_vrt_idx[v2, i, 2] = v1
            break
            
angle_vrt_idx = torch.where(angle_vrt_idx > 0, angle_vrt_idx, 0)
print(angle_vrt_idx.shape)
angle_vrt_idx

torch.Size([4, 6, 3])


tensor([[[2, 0, 3],
         [1, 0, 3],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[3, 1, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 2, 3],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[2, 3, 0],
         [1, 3, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]])

In [29]:
#def get_angles(vrt):
vrt_no = 4
vrt = torch.rand(1, vrt_no, 3)

bs = vrt.size(0)
angle_pts = vrt.index_select(1, angle_vrt_idx.view(-1)).reshape(bs, vrt_no, 6, 3, 3)
print(angle_pts.shape)

torch.Size([1, 4, 6, 3, 3])


In [30]:
a = angle_pts[:, :, :, 0]
b = angle_pts[:, :, :, 1]
c = angle_pts[:, :, :, 2]

ba = a - b
bc = c - b
ba, bc

(tensor([[[[-0.0074,  0.1447,  0.4200],
           [-0.1492,  0.0498, -0.4213],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000]],
 
          [[-0.4587,  0.1465,  0.0729],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000]],
 
          [[ 0.0074, -0.1447, -0.4200],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000]],
 
          [[ 0.6005, -0.0516,  0.7683],
           [ 0.4587, -0.1465, -0.0729],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000]]]]),
 tensor([[[[-0.6079,  0.196

In [39]:
eps = 1e-08
baNrm = ba / (torch.norm(ba, dim=-1).unsqueeze(-1) + eps)
bcNrm = bc / (torch.norm(bc, dim=-1).unsqueeze(-1) + eps)
print(baNrm.shape, bcNrm.shape)
baNrm

torch.Size([1, 4, 6, 3]) torch.Size([1, 4, 6, 3])


tensor([[[[-0.0166,  0.3258,  0.9453],
          [-0.3317,  0.1108, -0.9368],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[-0.9419,  0.3008,  0.1497],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0166, -0.3258, -0.9453],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.6149, -0.0528,  0.7868],
          [ 0.9419, -0.3008, -0.1497],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]])

In [42]:
dot_bac = (baNrm * bcNrm).sum(dim=-1).unsqueeze(-1)
print(dot_bac.shape)
dot_bac

torch.Size([1, 4, 6, 1])


tensor([[[[-0.3508],
          [ 0.7556],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000]],

         [[-0.2055],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000]],

         [[ 0.7163],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000]],

         [[ 0.9047],
          [ 0.7963],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000],
          [ 0.0000]]]])

In [43]:
torch.arccos(dot_bac)

tensor([[[[1.9293],
          [0.7142],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708]],

         [[1.7778],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708]],

         [[0.7722],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708]],

         [[0.4401],
          [0.6496],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708]]]])

In [47]:
def get_angles(vrt):
    bs = vrt.size(0)
    angle_pts = vrt.index_select(1, angle_vrt_idx.view(-1)).reshape(bs, vrt_no, 6, 3, 3)
    a = angle_pts[:, :, :, 0]
    b = angle_pts[:, :, :, 1]
    c = angle_pts[:, :, :, 2]

    ba = a - b
    bc = c - b
    
    ba_nrm = torch.norm(ba, dim=-1).unsqueeze(-1)
    bc_nrm = torch.norm(bc, dim=-1).unsqueeze(-1)
    ba_nrm = torch.where(ba_nrm > 0, ba_nrm, torch.tensor(1.))
    bc_nrm = torch.where(bc_nrm > 0, bc_nrm, torch.tensor(1.))
    
    ba_normed = ba / ba_nrm
    bc_normed = bc / bc_nrm
    dot_bac = (ba_normed * bc_normed).sum(dim=-1).unsqueeze(-1)
    angles = torch.arccos(dot_bac)
    return angles

angles = get_angles(vrt)
print(angles.shape)
angles

torch.Size([1, 4, 6, 1])


tensor([[[[1.9293],
          [0.7142],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708]],

         [[1.7778],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708]],

         [[0.7722],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708]],

         [[0.4401],
          [0.6496],
          [1.5708],
          [1.5708],
          [1.5708],
          [1.5708]]]])

In [50]:
angle_vrt_idx = torch.zeros(vrt_no, 6, 3, dtype=torch.long) -1
#angle_vrt_idx = 
for face in faces:
    v0, v1, v2 = face
    for i, m in enumerate(angle_vrt_idx[v0]):
        if m[0].item() == -1:
            angle_vrt_idx[v0, i, 0] = v1
            angle_vrt_idx[v0, i, 1] = v0
            angle_vrt_idx[v0, i, 2] = v2
            break
    for i, m in enumerate(angle_vrt_idx[v1]):
        if m[0].item() == -1:
            angle_vrt_idx[v1, i, 0] = v0
            angle_vrt_idx[v1, i, 1] = v1
            angle_vrt_idx[v1, i, 2] = v2
            break
    for i, m in enumerate(angle_vrt_idx[v2]):
        if m[0].item() == -1:
            angle_vrt_idx[v2, i, 0] = v0
            angle_vrt_idx[v2, i, 1] = v2
            angle_vrt_idx[v2, i, 2] = v1
            break
            
#angle_vrt_idx = torch.where(angle_vrt_idx > 0, angle_vrt_idx, 0)
print(angle_vrt_idx.shape)
#angle_vrt_idx

torch.Size([4, 6, 3])


In [53]:
angle_vrt_idx.sum(dim=-1).shape

torch.Size([4, 6])

In [58]:
angle_vrt_wt = torch.where(angle_vrt_idx.sum(dim=-1) != -3, 1., 0.)[None].unsqueeze(-1)
print(angle_vrt_wt.shape)
angle_vrt_wt

torch.Size([1, 4, 6, 1])


tensor([[[[1.],
          [1.],
          [0.],
          [0.],
          [0.],
          [0.]],

         [[1.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]],

         [[1.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]],

         [[1.],
          [1.],
          [0.],
          [0.],
          [0.],
          [0.]]]])

In [61]:
def vertex_angle_maps(faces):
    vrt_no =  faces.max() + 1 
    angle_vrt_idx = torch.zeros(vrt_no, 6, 3, dtype=torch.long) -1
    #angle_vrt_idx = 
    for face in faces:
        v0, v1, v2 = face
        for i, m in enumerate(angle_vrt_idx[v0]):
            if m[0].item() == -1:
                angle_vrt_idx[v0, i, 0] = v1
                angle_vrt_idx[v0, i, 1] = v0
                angle_vrt_idx[v0, i, 2] = v2
                break
        for i, m in enumerate(angle_vrt_idx[v1]):
            if m[0].item() == -1:
                angle_vrt_idx[v1, i, 0] = v0
                angle_vrt_idx[v1, i, 1] = v1
                angle_vrt_idx[v1, i, 2] = v2
                break
        for i, m in enumerate(angle_vrt_idx[v2]):
            if m[0].item() == -1:
                angle_vrt_idx[v2, i, 0] = v0
                angle_vrt_idx[v2, i, 1] = v2
                angle_vrt_idx[v2, i, 2] = v1
                break
    angle_vrt_wt = torch.where(angle_vrt_idx.sum(dim=-1) != -3, 1., 0.)
    angle_vrt_wt = angle_vrt_wt[None].unsqueeze(-1)
    angle_vrt_idx = torch.where(angle_vrt_idx > 0, angle_vrt_idx, 0)
    return angle_vrt_idx, angle_vrt_wt

angle_vrt_idx, angle_vrt_wt = vertex_angle_maps(faces)
angle_vrt_idx.shape, angle_vrt_wt.shape

(torch.Size([4, 6, 3]), torch.Size([1, 4, 6, 1]))

In [63]:
def get_vertex_angles(vrt, angle_vrt_idx, angle_vrt_wt):
    bs = vrt.size(0)
    angle_pts = vrt.index_select(1, angle_vrt_idx.view(-1)).reshape(bs, vrt_no, 6, 3, 3)
    a = angle_pts[:, :, :, 0]
    b = angle_pts[:, :, :, 1]
    c = angle_pts[:, :, :, 2]

    ba = a - b
    bc = c - b
    
    ba_nrm = torch.norm(ba, dim=-1).unsqueeze(-1)
    bc_nrm = torch.norm(bc, dim=-1).unsqueeze(-1)
    ba_nrm = torch.where(ba_nrm > 0, ba_nrm, torch.tensor(1.))
    bc_nrm = torch.where(bc_nrm > 0, bc_nrm, torch.tensor(1.))
    
    ba_normed = ba / ba_nrm
    bc_normed = bc / bc_nrm
    dot_bac = (ba_normed * bc_normed).sum(dim=-1).unsqueeze(-1)
    angles = torch.arccos(dot_bac) * angle_vrt_wt
    return angles

side = 2
vrt = torch.rand(1, side ** 2, 3)
angles = get_vertex_angles(vrt, angle_vrt_idx, angle_vrt_wt)
angles.shape

torch.Size([1, 4, 6, 1])

In [64]:
side = 2
vrt = torch.rand(3, side ** 2, 3)
angles = get_angles(vrt, angle_vrt_idx, angle_vrt_wt)
angles.shape

torch.Size([3, 4, 6, 1])

In [83]:
import os
from collections import OrderedDict
import torch
import torch.nn.functional as F
from src.util import (
    make_faces,    
)    

def vertex_tris(faces):
    res = [[] for _ in range(faces.max()+1)]
    for fid, face in enumerate(faces):        
        for vid in face:
            res[vid].append(fid)        
    return res

def vertex_tri_maps(faces):
    vts = vertex_tris(faces)
    r, c = len(vts), max([len(x) for  x in vts])
    vert_tri_indices = torch.zeros(r, c, dtype=torch.long)
    vert_tri_weights = torch.zeros(r, c)    
    for r, tris in enumerate(vts):        
        weight = 1. #/ len(tris)
        for c, tri_id in enumerate(tris):
            vert_tri_indices[r, c] = tri_id
            vert_tri_weights[r, c] = weight
    return vert_tri_indices, vert_tri_weights.unsqueeze(dim=-1)[None]

def vertex_angle_maps(faces):
    vrt_no =  faces.max() + 1 
    angle_vrt_idx = torch.zeros(vrt_no, 6, 3, dtype=torch.long) -1
    #angle_vrt_idx = 
    for face in faces:
        v0, v1, v2 = face
        for i, m in enumerate(angle_vrt_idx[v0]):
            if m[0].item() == -1:
                angle_vrt_idx[v0, i, 0] = v1
                angle_vrt_idx[v0, i, 1] = v0
                angle_vrt_idx[v0, i, 2] = v2
                break
        for i, m in enumerate(angle_vrt_idx[v1]):
            if m[0].item() == -1:
                angle_vrt_idx[v1, i, 0] = v0
                angle_vrt_idx[v1, i, 1] = v1
                angle_vrt_idx[v1, i, 2] = v2
                break
        for i, m in enumerate(angle_vrt_idx[v2]):
            if m[0].item() == -1:
                angle_vrt_idx[v2, i, 0] = v0
                angle_vrt_idx[v2, i, 1] = v2
                angle_vrt_idx[v2, i, 2] = v1
                break
    angle_vrt_wt = torch.where(angle_vrt_idx.sum(dim=-1) != -3, 1., 0.)
    angle_vrt_wt = angle_vrt_wt[None].unsqueeze(-1)
    angle_vrt_idx = torch.where(angle_vrt_idx > 0, angle_vrt_idx, 0)
    return angle_vrt_idx, angle_vrt_wt


class VertexNormals(torch.nn.Module):
    
    def __init__(self, opt, load=True):
        super().__init__()
        self.size = opt.data_patch_size
        self.path = os.path.join(opt.data_dir, 
            'trimap_{}.pth'.format(opt.data_patch_size))
        if load and os.path.exists(self.path):
            trimap = torch.load(self.path)
        else:
            trimap = self.make_trimap(opt.data_patch_size)
            torch.save(trimap, self.path)
        self.assign_trimap(trimap)
    
    def assign_trimap(self,  trimap):
        self.register_buffer('faces',  trimap['faces'])
        self.register_buffer('vert_tri_indices', trimap['vert_tri_indices'])
        self.register_buffer('vert_tri_weights', trimap['vert_tri_weights'])        
        self.register_buffer('angle_vrt_idx', trimap['angle_vrt_idx'])
        self.register_buffer('angle_vrt_wt', trimap['angle_vrt_wt'])

    def vertex_normals_mean(self, vrt):
        face_normals = self.get_face_normals(vrt)
        bs = face_normals.size(0)
        r, c = self.vert_tri_indices.shape
        fn_group = face_normals.index_select(1, 
            self.vert_tri_indices.flatten()).reshape(bs, r, c, 3)
        weighted_fn_group = fn_group * self.vert_tri_weights    
        vertex_normals = weighted_fn_group.sum(dim=-2)
        return F.normalize(vertex_normals, p=2, dim=-1)
    
    def vertex_normals_weighted_area(self, vrt):
        face_normals = self.get_face_normals(vrt)
        face_areas = self.get_face_areas(vrt)
        bs = face_normals.size(0)
        r, c = self.vert_tri_indices.shape
        fn_group = face_normals.index_select(1, 
            self.vert_tri_indices.flatten()).reshape(bs, r, c, 3)
        
        fa_group = face_areas.index_select(1, 
            self.vert_tri_indices.flatten()).reshape(bs, r, c, 1)
        weighted_fa_group = fa_group * self.vert_tri_weights        
        
        weighted_fn_group = fn_group * fa_group   
        vertex_normals = weighted_fn_group.sum(dim=-2)
        return F.normalize(vertex_normals, p=2, dim=-1)
    
    def vertex_normals_weighted_angles(self, vrt):
        face_normals = self.get_face_normals(vrt)
        vertex_angles = self.get_vertex_angles(vrt)
        bs = face_normals.size(0)
        r, c = self.vert_tri_indices.shape
        fn_group = face_normals.index_select(1, 
            self.vert_tri_indices.flatten()).reshape(bs, r, c, 3)

        weighted_fn_group = fn_group * vertex_angles   
        vertex_normals = weighted_fn_group.sum(dim=-2)
        return F.normalize(vertex_normals, p=2, dim=-1)
    
    def get_face_normals(self, vrt):
        faces = self.faces
        v1 = vrt.index_select(1,faces[:, 1]) - vrt.index_select(1, faces[:, 0])
        v2 = vrt.index_select(1,faces[:, 2]) - vrt.index_select(1, faces[:, 0])
        face_normals = F.normalize(v1.cross(v2), p=2, dim=-1)  # [F, 3]
        return face_normals
 
    
    def get_face_areas(self, vrt):
        faces = self.faces

        v0 = vrt.index_select(1, faces[:, 0])
        v1 = vrt.index_select(1, faces[:, 1])
        v2 = vrt.index_select(1, faces[:, 2])

        a = torch.norm(v1 - v0, dim=-1)
        b = torch.norm(v2 - v0, dim=-1)
        c = torch.norm(v2 - v1, dim=-1)

        s = (a + b + c) / 2
        return torch.sqrt(s*(s-a)*(s-b)*(s-c)).unsqueeze(dim=-1)
    
    def get_vertex_angles(self, vrt):
        angle_vrt_idx, angle_vrt_wt
        bs = vrt.size(0)
        angle_pts = vrt.index_select(1, self.angle_vrt_idx.view(-1))
        angle_pts = angle_pts.reshape(bs, -1, 6, 3, 3)
        a = angle_pts[:, :, :, 0]
        b = angle_pts[:, :, :, 1]
        c = angle_pts[:, :, :, 2]

        ba = a - b
        bc = c - b

        ba_nrm = torch.norm(ba, dim=-1).unsqueeze(-1)
        bc_nrm = torch.norm(bc, dim=-1).unsqueeze(-1)
        ba_nrm = torch.where(ba_nrm > 0, ba_nrm, torch.tensor(1.))
        bc_nrm = torch.where(bc_nrm > 0, bc_nrm, torch.tensor(1.))

        ba_normed = ba / ba_nrm
        bc_normed = bc / bc_nrm
        dot_bac = (ba_normed * bc_normed).sum(dim=-1).unsqueeze(-1)
        angles = torch.arccos(dot_bac) * self.angle_vrt_wt
        return angles
        
    def __repr__(self):
        return f'VertexNormals: size: {self.size} path: {self.path}'
    
    def make_trimap(self, size):
        faces = torch.tensor(make_faces(size, size))
        vert_tri_indices, vert_tri_weights = vertex_tri_maps(faces)
        angle_vrt_idx, angle_vrt_wt = vertex_angle_maps(faces)
        return OrderedDict(OrderedDict([
          ('vert_tri_indices', vert_tri_indices),
          ('vert_tri_weights', vert_tri_weights),
          ('faces', faces),
          ('angle_vrt_idx', angle_vrt_idx),
          ('angle_vrt_wt', angle_vrt_wt),
        ]))
    

from argparse import Namespace

opt = Namespace()
opt.data_patch_size = 4
opt.data_dir = './data'
    
vn = VertexNormals(opt, load=False)    
vn

VertexNormals: size: 4 path: ./data/trimap_4.pth

In [87]:
vn.vertex_normals_weighted_angles(torch.rand(3, 4**2, 3)).shape

torch.Size([3, 16, 3])

In [69]:
vn.vertex_normals_mean(torch.rand(1, 4**2, 3))

tensor([[[-0.7848,  0.0928,  0.6127],
         [-0.4388,  0.7942, -0.4203],
         [ 0.1824, -0.9828,  0.0285],
         [-0.2433,  0.1500, -0.9583],
         [ 0.9133, -0.2668,  0.3077],
         [-0.7852,  0.6184, -0.0310],
         [-0.2432, -0.8824, -0.4028],
         [-0.2963,  0.0643, -0.9529],
         [-0.4481, -0.3672,  0.8151],
         [-0.5630, -0.6324,  0.5320],
         [-0.7874, -0.2311, -0.5716],
         [-0.7202,  0.4793, -0.5016],
         [-0.9542, -0.0856,  0.2868],
         [-0.6985, -0.3931,  0.5980],
         [ 0.4125, -0.2790, -0.8672],
         [ 0.0870,  0.9650,  0.2474]]])

In [70]:
vn.vertex_normals_weighted_area(torch.rand(1, 4**2, 3))

tensor([[[ 0.7990, -0.5796,  0.1599],
         [ 0.7537, -0.6504,  0.0945],
         [ 0.5777, -0.7052,  0.4112],
         [ 0.7763, -0.6028,  0.1842],
         [ 0.5251, -0.8342,  0.1684],
         [ 0.8467, -0.3847,  0.3676],
         [ 0.0080,  0.2465,  0.9691],
         [ 0.6881, -0.5961,  0.4137],
         [ 0.8520, -0.5104,  0.1165],
         [-0.9481,  0.3160,  0.0364],
         [ 0.8459,  0.1381, -0.5152],
         [ 0.7885, -0.5873,  0.1823],
         [ 0.8295, -0.5222,  0.1981],
         [ 0.8374, -0.4616,  0.2929],
         [ 0.6515, -0.7496,  0.1165],
         [ 0.7376, -0.6750,  0.0184]]])

In [71]:
self = vn
vrt = torch.rand(1, 4**2, 3)

face_normals = self.get_face_normals(vrt)
vertex_angles = self.get_vertex_angles(vrt)
bs = face_normals.size(0)
r, c = self.vert_tri_indices.shape
fn_group = face_normals.index_select(1, 
    self.vert_tri_indices.flatten()).reshape(bs, r, c, 3)

weighted_fn_group = fn_group * vrt_angles   
vertex_normals = weighted_fn_group.sum(dim=-2)
F.normalize(vertex_normals, p=2, dim=-1)

RuntimeError: shape '[1, 4, 6, 3, 3]' is invalid for input of size 864

In [80]:
self = vn
vrt = torch.rand(1, 4**2, 3)

angle_vrt_idx, angle_vrt_wt
bs = vrt.size(0)
angle_pts = vrt.index_select(1, self.angle_vrt_idx.view(-1))
angle_pts = angle_pts.reshape(bs, -1, 6, 3, 3)
a = angle_pts[:, :, :, 0]
b = angle_pts[:, :, :, 1]
c = angle_pts[:, :, :, 2]

ba = a - b
bc = c - b

ba_nrm = torch.norm(ba, dim=-1).unsqueeze(-1)
bc_nrm = torch.norm(bc, dim=-1).unsqueeze(-1)
ba_nrm = torch.where(ba_nrm > 0, ba_nrm, torch.tensor(1.))
bc_nrm = torch.where(bc_nrm > 0, bc_nrm, torch.tensor(1.))

ba_normed = ba / ba_nrm
bc_normed = bc / bc_nrm
dot_bac = (ba_normed * bc_normed).sum(dim=-1).unsqueeze(-1)
angles = torch.arccos(dot_bac) * self.angle_vrt_wt
angles.shape

torch.Size([1, 16, 6, 1])

In [75]:
angle_pts.shape

torch.Size([1, 288, 3])

In [77]:
angle_pts.reshape(bs, -1, 6, 3, 3).shape

torch.Size([1, 16, 6, 3, 3])