In [23]:
import os
from collections import OrderedDict
import torch
import torch.nn.functional as F
from src.util import (
    make_faces,    
)    

def vertex_tris(faces):
    res = [[] for _ in range(faces.max()+1)]
    for fid, face in enumerate(faces):        
        for vid in face:
            res[vid].append(fid)        
    return res

def vertex_tri_maps(faces):
    vts = vertex_tris(faces)
    r, c = len(vts), max([len(x) for  x in vts])
    vert_tri_indices = torch.zeros(r, c, dtype=torch.long)
    vert_tri_weights = torch.zeros(r, c)    
    for r, tris in enumerate(vts):        
        weight = 1. #/ len(tris)
        for c, tri_id in enumerate(tris):
            vert_tri_indices[r, c] = tri_id
            vert_tri_weights[r, c] = weight
    return vert_tri_indices, vert_tri_weights.unsqueeze(dim=-1)[None]

class VertexNormals(torch.nn.Module):
    
    def __init__(self, opt, load=True):
        super().__init__()
        self.size = opt.data_patch_size
        self.path = os.path.join(opt.data_dir, 
            'trimap_{}.pth'.format(opt.data_patch_size))
        if load and os.path.exists(self.path):
            trimap = torch.load(self.path)
        else:
            trimap = self.make_trimap(opt.data_patch_size)
            torch.save(trimap, self.path)
        self.assign_trimap(trimap)
    
    def assign_trimap(self,  trimap):
        self.register_buffer('faces',  trimap['faces'])
        self.register_buffer('vert_tri_indices', trimap['vert_tri_indices'])
        self.register_buffer('vert_tri_weights', trimap['vert_tri_weights'])

    def vertex_normals_mean(self, vrt):
        face_normals = self.get_face_normals(vrt)
        bs = face_normals.size(0)
        r, c = self.vert_tri_indices.shape
        fn_group = face_normals.index_select(1, 
            self.vert_tri_indices.flatten()).reshape(bs, r, c, 3)
        weighted_fn_group = fn_group * self.vert_tri_weights    
        vertex_normals = weighted_fn_group.sum(dim=-2)
        return F.normalize(vertex_normals, p=2, dim=-1)
    
    def vertex_normals_weighted_area(self, vrt):
        face_normals = self.get_face_normals(vrt)
        face_areas = self.get_face_areas(vrt)
        bs = face_normals.size(0)
        r, c = self.vert_tri_indices.shape
        fn_group = face_normals.index_select(1, 
            self.vert_tri_indices.flatten()).reshape(bs, r, c, 3)
        
        fa_group = face_areas.index_select(1, 
            self.vert_tri_indices.flatten()).reshape(bs, r, c, 1)
        weighted_fa_group = fa_group * self.vert_tri_weights        
        
        weighted_fn_group = fn_group * fa_group   
        vertex_normals = weighted_fn_group.sum(dim=-2)
        return F.normalize(vertex_normals, p=2, dim=-1)
    
    def get_face_normals(self, vrt):
        faces = self.faces
        v1 = vrt.index_select(1,faces[:, 1]) - vrt.index_select(1, faces[:, 0])
        v2 = vrt.index_select(1,faces[:, 2]) - vrt.index_select(1, faces[:, 0])
        face_normals = F.normalize(v1.cross(v2), p=2, dim=-1)  # [F, 3]
        return face_normals
 
    
    def get_face_areas(self, vrt):
        faces = self.faces

        v0 = vrt.index_select(1, faces[:, 0])
        v1 = vrt.index_select(1, faces[:, 1])
        v2 = vrt.index_select(1, faces[:, 2])

        a = torch.norm(v1 - v0, dim=-1)
        b = torch.norm(v2 - v0, dim=-1)
        c = torch.norm(v2 - v1, dim=-1)

        s = (a + b + c) / 2
        return torch.sqrt(s*(s-a)*(s-b)*(s-c)).unsqueeze(dim=-1)
        
    def __repr__(self):
        return f'VertexNormals: size: {self.size} path: {self.path}'
    
    def make_trimap(self, size):
        faces = torch.tensor(make_faces(size, size))
        vert_tri_indices, vert_tri_weights = vertex_tri_maps(faces)
        return OrderedDict(OrderedDict([
          ('vert_tri_indices', vert_tri_indices),
          ('vert_tri_weights', vert_tri_weights),
          ('faces', faces),
        ]))
    
from argparse import Namespace

opt = Namespace()
opt.data_patch_size = 4
opt.data_dir = './data'
    
vn = VertexNormals(opt, load=False)    
vn

VertexNormals: size: 4 path: ./data/trimap_4.pth

In [20]:
vrt = torch.rand(1, 4**2, 3)
face_normals = vn.get_face_normals(vrt)
face_normals.shape

torch.Size([1, 18, 3])

In [14]:
self = vn


faces = self.faces

v0 = vrt.index_select(1, faces[:, 0])
v1 = vrt.index_select(1, faces[:, 1])
v2 = vrt.index_select(1, faces[:, 2])

a = torch.norm(v1 - v0, dim=-1)
b = torch.norm(v2 - v0, dim=-1)
c = torch.norm(v2 - v1, dim=-1)

a.shape, b.shape, c.shape
# face_normals = F.normalize(v1.cross(v2), p=2, dim=-1)  # [F, 3]
# face_normals.shape

(torch.Size([1, 18]), torch.Size([1, 18]), torch.Size([1, 18]))

In [10]:
torch.norm(d0, dim=-1).shape

torch.Size([1, 18])

In [15]:
s = (a + b + c) / 2
s.shape

torch.Size([1, 18])

In [18]:
area = torch.sqrt(s*(s-a)*(s-b)*(s-c)).unsqueeze(dim=-1)
area.shape

torch.Size([1, 18, 1])

In [19]:
area

tensor([[[0.0457],
         [0.0563],
         [0.0753],
         [0.3535],
         [0.3572],
         [0.3478],
         [0.0457],
         [0.0694],
         [0.1011],
         [0.1260],
         [0.0270],
         [0.0473],
         [0.1203],
         [0.1110],
         [0.2059],
         [0.1934],
         [0.1958],
         [0.1200]]])

In [24]:
self = vn

face_normals = self.get_face_normals(vrt)
face_areas = self.get_face_areas(vrt)

bs = face_normals.size(0)
r, c = self.vert_tri_indices.shape
fn_group = face_normals.index_select(1, 
    self.vert_tri_indices.flatten()).reshape(bs, r, c, 3)
weighted_fn_group = fn_group * self.vert_tri_weights    
vertex_normals = weighted_fn_group.sum(dim=-2)
F.normalize(vertex_normals, p=2, dim=-1)

tensor([[[-0.8071,  0.3850, -0.4477],
         [-0.7340,  0.2182,  0.6431],
         [-0.3662,  0.5688,  0.7365],
         [-0.3709,  0.5662,  0.7361],
         [-0.9351,  0.3503,  0.0530],
         [-0.5659, -0.7684, -0.2989],
         [ 0.7011,  0.6292, -0.3356],
         [ 0.8337,  0.5522, -0.0052],
         [-0.4385,  0.8386,  0.3232],
         [-0.2773,  0.7977,  0.5355],
         [ 0.3311, -0.8590, -0.3904],
         [ 0.8275,  0.2958, -0.4772],
         [-0.5680,  0.7756,  0.2753],
         [ 0.3528,  0.9039, -0.2419],
         [-0.6222,  0.3325,  0.7087],
         [ 0.1378, -0.9344, -0.3284]]])

In [25]:
face_areas

tensor([[[0.1571],
         [0.0609],
         [0.1026],
         [0.2294],
         [0.2132],
         [0.1625],
         [0.2804],
         [0.1887],
         [0.0686],
         [0.0893],
         [0.1664],
         [0.2833],
         [0.0880],
         [0.3470],
         [0.0678],
         [0.0467],
         [0.1112],
         [0.0753]]])

In [30]:
fa_group = face_areas.index_select(1, 
    self.vert_tri_indices.flatten()).reshape(bs, r, c, 1)
weighted_fa_group = fa_group * self.vert_tri_weights    
weighted_fa_group

tensor([[[[0.1571],
          [0.0609],
          [0.0000],
          [0.0000],
          [0.0000],
          [0.0000]],

         [[0.0609],
          [0.1026],
          [0.2294],
          [0.0000],
          [0.0000],
          [0.0000]],

         [[0.2294],
          [0.2132],
          [0.1625],
          [0.0000],
          [0.0000],
          [0.0000]],

         [[0.1625],
          [0.0000],
          [0.0000],
          [0.0000],
          [0.0000],
          [0.0000]],

         [[0.1571],
          [0.2804],
          [0.1887],
          [0.0000],
          [0.0000],
          [0.0000]],

         [[0.1571],
          [0.0609],
          [0.1026],
          [0.1887],
          [0.0686],
          [0.0893]],

         [[0.1026],
          [0.2294],
          [0.2132],
          [0.0893],
          [0.1664],
          [0.2833]],

         [[0.2132],
          [0.1625],
          [0.2833],
          [0.0000],
          [0.0000],
          [0.0000]],

         [[0.2804],
    

In [31]:
weighted_fa_group.shape

torch.Size([1, 16, 6, 1])

In [32]:
weighted_fn_group * weighted_fa_group

tensor([[[[ 0.0329, -0.0296, -0.1507],
          [-0.0312,  0.0203,  0.0482],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000]],

         [[-0.0312,  0.0203,  0.0482],
          [-0.0950,  0.0270, -0.0281],
          [-0.0578, -0.0214,  0.2209],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000]],

         [[-0.0578, -0.0214,  0.2209],
          [ 0.0546,  0.0206, -0.2050],
          [-0.0603,  0.0920,  0.1196],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000]],

         [[-0.0603,  0.0920,  0.1196],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000]],

         [[ 0.0329, -0.0296, -0.1507],
          [-0.086