In [2]:
import torch
from meshplot import plot
import sys 
sys.path.append('../')
import radfoam
device='cuda'

In [4]:
num_init_points = 30000
primal_points = (torch.rand(num_init_points, 3, device='cuda')-.5)*2.

primal_values = torch.norm(primal_points,dim=1)-.5
primal_features = primal_points[:, 0]

triangulation = radfoam.Triangulation(primal_points)
perm = triangulation.permutation().to(torch.long)
primal_points = primal_points[perm]
primal_values = primal_values[perm]
primal_features = primal_features[perm]

In [108]:
def triangle_case1(tets, values, points, features, alpha_f):
    '''one vertex is marked inside (resp. outside), three are outside (resp. inside).'''
    ins = tets[values<0].repeat_interleave(3)
    out = tets[values>=0]
    
    # interpolate point position
    v_ins = values[values<0].repeat_interleave(3)
    v_out = values[values>=0]
    alpha_value = (v_out/(v_out-v_ins))[:, None]
    new_points = alpha_value*points[ins] + (1-alpha_value)*points[out]
    
    # interpolate features
    new_features = alpha_f*features[ins] + (1-alpha_f)*features[out]
    
    # create triangles
    new_tri = torch.arange(len(new_points), device=points.device).reshape(len(new_points)//3,3)
    
    return new_points, new_tri, new_features

def triangle_case2(tets, values, points, features, alpha_f):
    '''two vertices are marked inside, two are marked outside'''
    ins = tets[values<0]
    out = tets[values>=0]
    
    # interpolate point position
    v_ins = values[values<0]
    v_out = values[values>=0]
    
    a1 = (v_out[::2, None]/(v_out[::2, None]-v_ins[::2, None]))
    p1 = a1*points[ins][::2] + (1-a1)*points[out][::2]
    
    a2 = (v_out[1::2, None]/(v_out[1::2, None]-v_ins[1::2, None]))
    p2 = a2*points[ins][1::2] + (1-a2)*points[out][1::2]
    
    a3 = (v_out[1::2, None]/(v_out[1::2, None]-v_ins[::2, None]))
    p3 = a3*points[ins][::2] + (1-a3)*points[out][1::2]
    
    a4 = (v_out[::2, None]/(v_out[::2, None]-v_ins[1::2, None]))
    p4 = a4*points[ins][1::2] + (1-a4)*points[out][::2]
    
    new_points = torch.cat((p1,p2,p3,p4))

    # interpolate features
    f1 = alpha_f*features[ins][::2] + (1-alpha_f)*features[out][::2]
    f2 = alpha_f*features[ins][1::2] + (1-alpha_f)*features[out][1::2]
    f3 = alpha_f*features[ins][::2] + (1-alpha_f)*features[out][1::2]
    f4 = alpha_f*features[ins][1::2] + (1-alpha_f)*features[out][::2]
    
    new_features = torch.cat((f1,f2,f3,f4))
    
    # create triangles
    ls = len(p1)
    new_tri = torch.tensor([[0,2*ls,3*ls], [1*ls,3*ls,2*ls]], device=points.device).repeat(ls,1)
    new_tri += torch.arange(ls, device=points.device).repeat_interleave(2)[:, None]
    
    return new_points, new_tri, new_features

def reverse_triangles(tri, reverse):
    tri[reverse, 0], tri[reverse, 1] = tri[reverse, 1].clone(), tri[reverse, 0].clone()

def marching_tetrahedra(tets, sdf_values, points, features, alpha_f=.5):
    """
        marching tetrahedra of a given tet grid (in our case extracted from delaunay)

        Parameters:
            tets: (N,4) tetrahedra indices
            is_inside: (M,) boolean tensor, True if the point is inside the mesh
            points: (M,3) tensor of vertices
            features: (M,D) tensor of features at the vertices
            alpha_f: float between 0 and 1, interpolation factor for features. .5 is default, 1. is inside vertices only
            
    """
    values = sdf_values[tets]
    
    pos = (values>0).sum(1)
    new_v, new_f, new_cf = [], [], []
    cur_ind = 0
    for i in [1, 2, 3]:
        if (pos==i).sum()>0:
            if i==1:
                reverse = torch.logical_or(values[:, 1]>0, values[:, 3]>0)[pos==1]
                new_points, new_tri, new_features = triangle_case1(tets[pos==1], -values[pos==1], points, features, 1-alpha_f)
                reverse_triangles(new_tri, reverse)
            if i==2:
                f13 = torch.logical_and(values[:, 1]<0, values[:, 3]<0)
                f02 = torch.logical_and(values[:, 0]<0, values[:, 2]<0)
                reverse = torch.logical_not(f13+f02)[pos==2]
                new_points, new_tri, new_features = triangle_case2(tets[pos==2], values[pos==2], points, features, alpha_f)
                reverse_triangles(new_tri, reverse.repeat_interleave(2))
            if i==3:
                reverse = torch.logical_or(values[:, 0]<0, values[:, 2]<0)[pos==3]
                new_points, new_tri, new_features = triangle_case1(tets[pos==3], (values[pos==3]), points, features, alpha_f)
                reverse_triangles(new_tri, reverse)
            new_v.append(new_points)
            new_cf.append(new_features)
            new_f.append(cur_ind+new_tri)
            cur_ind += len(new_points)
            
            

           
    return torch.cat(new_v), torch.cat(new_f), torch.cat(new_cf)

v, f, feat = marching_tetrahedra(triangulation.tets().long(), primal_values, primal_points, primal_features)


In [109]:


v, f, feat = marching_tetrahedra(triangulation.tets().long(), primal_values, primal_points, primal_features)

import igl
import numpy as np
# Convert tensors to numpy arrays
v_np = v.cpu().detach().numpy()
f_np = f.cpu().detach().numpy()

face_normals = igl.per_face_normals(v_np, f_np, np.array([1.0, 1.0, 1.0], dtype=np.float32))

# plot(v.cpu().detach().numpy(), f.cpu().detach().numpy(), feat.cpu().detach().numpy())
plot(v_np, f_np, (face_normals+1)/2.)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.000262…

<meshplot.Viewer.Viewer at 0x7cb35d773a00>

In [105]:
import time
start = time.time()
marching_tetrahedra(triangulation.tets().long(), primal_values, primal_points, primal_features)
print(time.time()-start)

0.014841556549072266


In [110]:
def export_obj(nv, nf, name: str, nvn=None):
    if name[-4:] != ".obj":
        name += ".obj"
    try:
        file = open(name, "x")
    except:
        file = open(name, "w")
    # file.write("o {} \n".format(name))

    for v in nv:
        file.write("v {} {} {}\n".format(*v))
    file.write("\n")

    if nvn is not None:
        for vn in nvn:
            file.write("vn {} {} {}\n".format(*vn))
    file.write("\n")

    for face in nf:
        file.write("f " + " ".join([str(fi + 1) for fi in face]) + "\n")
    file.write("\n")

In [111]:
export_obj(v.cpu().detach().numpy(), f.cpu().detach().numpy(), 'test.obj')