## Initialization

In [1]:
import os
import pathlib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import polyscope as ps
import mcubes
import trimesh

from tqdm.notebook import tqdm
from types import SimpleNamespace

try:
    os.chdir(pathlib.Path(initial).parent)
except NameError:
    initial = os.getcwd()
    os.chdir(pathlib.Path(initial).parent)
    
from iarap.model.neural_rtf import NeuralRTF, NeuralRTFConfig
from iarap.model.neural_sdf import NeuralSDF, NeuralSDFConfig
from iarap.utils.meshing import *

np.random.seed(1234567)
torch.manual_seed(1234567)
ps.init()

In [2]:
def fixed_point_invert(g, y, iters=15, verbose=False):
    with torch.no_grad():
        x = y
        dim = x.size(-1)
        for i in range(iters):
            x = y - g(x)
            if verbose:
                err = (y - (x + g(x))).view(-1, dim).norm(dim=-1).mean()
                err = err.detach().cpu().item()
                print("iter:%d err:%s" % (i, err))
    return x

class LipBoundedPosEnc(nn.Module):

    def __init__(self, inp_features, n_freq, cat_inp=True):
        super().__init__()
        self.inp_feat = inp_features
        self.n_freq = n_freq
        self.cat_inp = cat_inp
        self.out_dim = 2 * self.n_freq * self.inp_feat
        if self.cat_inp:
            self.out_dim += self.inp_feat

    def forward(self, x):
        """
        :param x: (bs, npoints, inp_features)
        :return: (bs, npoints, 2 * out_features + inp_features)
        """
        assert len(x.size()) == 3
        bs, npts = x.size(0), x.size(1)
        const = (2 ** torch.arange(self.n_freq) * np.pi).view(1, 1, 1, -1)
        const = const.to(x)

        # Out shape : (bs, npoints, out_feat)
        cos_feat = torch.cos(const * x.unsqueeze(-1)).view(
            bs, npts, self.inp_feat, -1)
        sin_feat = torch.sin(const * x.unsqueeze(-1)).view(
            bs, npts, self.inp_feat, -1)
        out = torch.cat(
            [sin_feat, cos_feat], dim=-1).view(
            bs, npts, 2 * self.inp_feat * self.n_freq)
        const_norm = torch.cat(
            [const, const], dim=-1).view(
            1, 1, 1, self.n_freq * 2).expand(
            -1, -1, self.inp_feat, -1).reshape(
            1, 1, 2 * self.inp_feat * self.n_freq)

        if self.cat_inp:
            out = torch.cat([out, x], dim=-1)
            const_norm = torch.cat(
                [const_norm, torch.ones(1, 1, self.inp_feat).to(x)], dim=-1)

            return out / const_norm / np.sqrt(self.n_freq * 2 + 1)
        else:

            return out / const_norm / np.sqrt(self.n_freq * 2)


class InvertibleResBlockLinear(nn.Module):

    def __init__(self, inp_dim, hid_dim, nblocks=1,
                 nonlin='leaky_relu',
                 pos_enc_freq=None):
        super().__init__()
        self.dim = inp_dim
        self.nblocks = nblocks

        self.pos_enc_freq = pos_enc_freq
        if self.pos_enc_freq is not None:
            inp_dim_af_pe = self.dim * (self.pos_enc_freq * 2 + 1)
            self.pos_enc = LipBoundedPosEnc(self.dim, self.pos_enc_freq)
        else:
            self.pos_enc = lambda x: x
            inp_dim_af_pe = inp_dim

        self.blocks = nn.ModuleList()
        self.blocks.append(nn.utils.spectral_norm(
            nn.Linear(inp_dim_af_pe, hid_dim)))
        for _ in range(self.nblocks):
            self.blocks.append(
                nn.utils.spectral_norm(
                    nn.Linear(hid_dim, hid_dim),
                )
            )
        self.blocks.append(
            nn.utils.spectral_norm(
                nn.Linear(hid_dim, self.dim),
            )
        )

        self.nonlin = nonlin.lower()
        if self.nonlin == 'leaky_relu':
            self.act = nn.LeakyReLU()
        elif self.nonlin == 'relu':
            self.act = nn.ReLU()
        elif self.nonlin == 'elu':
            self.act = nn.ELU()
        elif self.nonlin == 'softplus':
            self.act = nn.Softplus()
        else:
            raise NotImplementedError

    def forward_g(self, x):
        orig_dim = len(x.size())
        if orig_dim == 2:
            x = x.unsqueeze(0)

        y = self.pos_enc(x)
        for block in self.blocks[:-1]:
            y = self.act(block(y))
        y = self.blocks[-1](y)

        if orig_dim == 2:
            y = y.squeeze(0)

        return y

    def forward(self, x):
        return x + self.forward_g(x)

    def invert(self, y, verbose=False, iters=15):
        return fixed_point_invert(
            lambda x: self.forward_g(x), y, iters=iters, verbose=verbose
        )


class InvertibleMLP(nn.Module):

    def __init__(self, _, cfg):
        super().__init__()
        self.cfg = cfg
        self.dim = cfg.dim
        self.out_dim = cfg.out_dim
        self.hidden_size = cfg.hidden_size
        self.n_blocks = cfg.n_blocks
        self.n_g_blocks = getattr(cfg, "n_g_blocks", 1)

        # Network modules
        self.blocks = nn.ModuleList()
        for _ in range(self.n_blocks):
            self.blocks.append(
                InvertibleResBlockLinear(
                    self.dim, self.hidden_size,
                    nblocks=self.n_g_blocks, nonlin=cfg.nonlin,
                    pos_enc_freq=getattr(cfg, "pos_enc_freq", None),
                )
            )

    def forward(self, x):
        """
        :param x: (bs, npoints, self.dim) Input coordinate (xyz)
        :return: (bs, npoints, self.dim) Gradient (self.dim dimension)
        """
        out = x
        for block in self.blocks:
            out = block(out)
        return out
    
    def deform(self, x):
        return self(x)

    def inverse(self, y, verbose=False, iters=15):
        x = y
        for block in self.blocks[::-1]:
            x = block.invert(x, verbose=verbose, iters=iters)
        return x

In [18]:
SDF_CKPT = 'assets/weights/sdf/dragon.pt'
RTF_CKPT = 'C:\\Users\\pc\\Documents\\GLADIA\\nfgp-private-fork\\logs\\armadillo-arm_front-nfgp\\checkpoints\\epoch_499_iters_50000.pt'
DEFORM_CLASS = InvertibleMLP
NUM_PATCH_PTS = 30
PATCH_RADIUS = 0.2
SURFACE_SAMPLES = 12611
SPACE_SAMPLES = 0
CHUNK = 300000
MC_RESOLUTION = 512
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [19]:
sdf_model: NeuralSDF = NeuralSDFConfig().setup().to(DEVICE)
if len(SDF_CKPT) > 0:
    sdf_model.load_state_dict(torch.load(SDF_CKPT))

if DEFORM_CLASS is NeuralRTF:
    rtf_model: NeuralRTF = NeuralRTFConfig().setup().to(DEVICE)
    if len(RTF_CKPT) > 0:
        rtf_model.load_state_dict(torch.load(RTF_CKPT))
elif DEFORM_CLASS is InvertibleMLP:
    rtf_model: InvertibleMLP = InvertibleMLP(None, SimpleNamespace(**{
        'dim': 3, 'hidden_size': 256, 'n_blocks': 6,
        'nonlin': 'elu', 'out_dim': 3, 'pos_enc_freq': 5
    })).to(DEVICE)
    if len(RTF_CKPT) > 0:
        state_dict = {k: v for k, v in torch.load(RTF_CKPT)['next_dec'].items() if k.split('.')[0] == 'deform'}
        state_dict = {'.'.join(k.split('.')[1:]): v for k, v in state_dict.items()}
        rtf_model.load_state_dict(state_dict)

In [72]:
steps = torch.linspace(-1.0, 1.0, MC_RESOLUTION, device=DEVICE)
xx, yy, zz = torch.meshgrid(steps, steps, steps, indexing="ij")
volume = torch.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T.float()
f_eval = []
with torch.no_grad():
    for sample in tqdm(torch.split(volume, CHUNK, dim=0)):
        f_eval.append(sdf_model(sample.contiguous())['dist'].cpu().numpy())
f_volume = np.concatenate(f_eval, axis=0).reshape(*([MC_RESOLUTION] * 3))
shape_verts, shape_faces = mcubes.marching_cubes(f_volume, 0.0)
shape_verts /= MC_RESOLUTION // 2
shape_verts -= 1.0

  0%|          | 0/448 [00:00<?, ?it/s]

## Minor examples

### Surface projection

In [6]:
initial = torch.tensor([[-0.7, 0.0, 0.0], 
                        [-0.6, 0.1, 0.0],
                        [-0.5, 0.0, 0.3]], device=DEVICE)

trajectory = [initial]
offsets = []

point = initial
for it in range(5):
    point_new = sdf_model.project_nearest(point)
    trajectory.append(point_new)
    offsets.append(point_new - point)
    point = point_new

offsets.append(torch.zeros_like(initial))
trajectory = torch.cat(trajectory, dim=0).view(-1, 3).cpu().detach().numpy()
offsets = torch.cat(offsets, dim=0).view(-1, 3).cpu().detach().numpy()

In [7]:
ps.register_surface_mesh("Input Shape", shape_verts, shape_faces, enabled=True)
traj_pc = ps.register_point_cloud("Trajectory", trajectory, enabled=True)
traj_pc.add_vector_quantity("Offsets", offsets, vectortype='ambient', enabled=True)
ps.show()
ps.remove_all_structures()

## Local Patch Meshing

### Patch Sampling and Triangulation

In [73]:
patches = {'vert': {}, 'triv': {}}
sampling_methods = {'runif': sphere_random_uniform,
                    'rnorm': sphere_gaussian_radius,
                    'linear': sphere_sunflower,
                    'normal': gaussian_max_norm}

In [74]:
for name, sampler in sampling_methods.items():
    patch = get_patch_mesh(sampler, delaunay, NUM_PATCH_PTS, PATCH_RADIUS, DEVICE)
    patches['vert'][name] = patch[0].cpu().numpy()
    patches['triv'][name] = patch[1].cpu().numpy()

In [10]:
i = 0
for name in sampling_methods.keys():
    offset = np.zeros_like(patches['vert'][name])
    offset[:, 0] = PATCH_RADIUS * (2.2 * i)
    vert = patches['vert'][name] + offset
    ps.register_surface_mesh(f"{name}_patch", vert, patches['triv'][name], enabled=True, edge_width=1)
    i += 1

ps.show()
ps.remove_all_structures()

### Patch Projection

In [75]:
SELECTED_PATCH = 'linear'

plane_coords, triangles = patches['vert'][SELECTED_PATCH], patches['triv'][SELECTED_PATCH]

In [76]:
# surf_sample = torch.tensor([[-0.54,  0.365,  0.08]], device=DEVICE)
# for it in range(5):
#     surf_sample = sdf_model.project_nearest(surf_sample).detach()

surf_sample = sdf_model.sample_zero_level_set(SURFACE_SAMPLES, 0.05, 10000, (-1, 1), 15).detach()
space_sample = torch.rand(SPACE_SAMPLES, 3, device=DEVICE) * 2 - 1
samples = torch.cat([surf_sample, space_sample], dim=0).detach()

sdf_outs = sdf_model(samples, with_grad=True)
sample_dist, patch_normals = sdf_outs['dist'], F.normalize(sdf_outs['grad'], dim=-1)
tangent_planes = sdf_model.tangent_plane(samples).cpu().numpy()

tangent_coords = (np.expand_dims(tangent_planes, 1) @ plane_coords.reshape(1, -1, 3, 1)).squeeze() 
tangent_pts = tangent_coords + samples.unsqueeze(1).detach().cpu().numpy()
triangles_all = np.expand_dims(triangles, 0) + (tangent_pts.shape[1] * np.arange(0, tangent_pts.shape[0]).reshape(-1, 1, 1))

tangent_pts = tangent_pts.reshape(-1, 3)
triangles_all = triangles_all.reshape(-1, 3)

In [13]:
ps.register_surface_mesh("Input Shape", shape_verts, shape_faces, enabled=True, transparency=0.5)
ps.register_surface_mesh("Projected Patches", tangent_pts, triangles_all, enabled=True, edge_width=1.0)
normals_pc = ps.register_point_cloud("Normal Origins", samples.cpu().detach().numpy(), radius=0.0, enabled=True)
normals_pc.add_vector_quantity("Patch Normals", patch_normals.cpu().detach().numpy(), enabled=True)
ps.show()
ps.remove_all_structures()

### Patch Fitting/Deformation

In [14]:
level_set_verts = torch.from_numpy(tangent_pts).to(DEVICE, torch.float).reshape(
    SURFACE_SAMPLES + SPACE_SAMPLES, NUM_PATCH_PTS, 3)
for it in range(15):
    level_set_verts = sdf_model.project_level_sets(level_set_verts, sample_dist).detach()
level_set_verts = level_set_verts.cpu().detach().view(-1, 3).numpy()

In [15]:
ps.register_surface_mesh("Input Shape", shape_verts, shape_faces, enabled=True, transparency=0.5)
ps.register_surface_mesh("Deformed Patches", level_set_verts, triangles_all, enabled=True, edge_width=1.0)
normals_pc = ps.register_point_cloud("Normal Origins", samples.cpu().detach().numpy(), radius=0.0, enabled=True)
normals_pc.add_vector_quantity("Patch Normals", patch_normals.cpu().detach().numpy(), enabled=True)
ps.show()
ps.remove_all_structures()

## LPM Evaluation

In [33]:
ps.register_surface_mesh("marching cubes mesh", shape_verts, shape_faces, enabled=True, transparency=0.5)
ps.show()
ps.remove_all_structures()

In [77]:
mc_mesh_obj = trimesh.Trimesh(shape_verts, shape_faces)
print(shape_verts.shape, shape_faces.shape)
NUM_TESTS = 300000
with torch.no_grad():
    tests = torch.from_numpy(mc_mesh_obj.sample(NUM_TESTS)).to(DEVICE, torch.float)
    dists = sdf_model.distance(tests)
    error_max = (dists.abs()).max()
    error_mean = (dists.abs()).mean()
error_max, error_mean

(378330, 3) (756660, 3)


(tensor(0.0050, device='cuda:0'), tensor(0.0017, device='cuda:0'))

In [78]:
mc_mean_edge = mc_mesh_obj.edges_unique_length.mean()

patch_mesh_obj = trimesh.Trimesh(plane_coords, triangles)
patch_mean_edge = patch_mesh_obj.edges_unique_length.mean()
patch_mesh_obj.vertices *= (mc_mean_edge / patch_mean_edge)
norm_plane_coords = patch_mesh_obj.vertices
print(mc_mean_edge, patch_mean_edge)

tangent_coords = (np.expand_dims(tangent_planes, 1) @ norm_plane_coords.reshape(1, -1, 3, 1)).squeeze() 
tangent_pts = (tangent_coords + samples.unsqueeze(1).detach().cpu().numpy()).reshape(-1, 3)
level_set_verts = torch.from_numpy(tangent_pts).to(DEVICE, torch.float).reshape(
    SURFACE_SAMPLES + SPACE_SAMPLES, NUM_PATCH_PTS, 3)
for it in range(15):
    level_set_verts = sdf_model.project_level_sets(level_set_verts, sample_dist).detach()
level_set_verts = level_set_verts.cpu().detach().view(-1, 3).numpy()
level_set_verts.shape, triangles_all.shape

0.0037595845510896343 0.08020722948562231


((378330, 3), (630550, 3))

In [79]:
ps.register_surface_mesh("local patch mesh", level_set_verts, triangles_all, enabled=True, transparency=0.5)
ps.show()
ps.remove_all_structures()

In [80]:
patch_mesh_obj = trimesh.Trimesh(level_set_verts, triangles_all)
NUM_TESTS = 300000
with torch.no_grad():
    tests = torch.from_numpy(patch_mesh_obj.sample(NUM_TESTS)).to(DEVICE, torch.float)
    dists = sdf_model.distance(tests)
    error_max = (dists.abs()).max()
    error_mean = (dists.abs()).mean()
error_max, error_mean

(tensor(0.0008, device='cuda:0'), tensor(1.7795e-05, device='cuda:0'))

In [None]:
dragon_camera = {"farClipRatio":20.0,"fov":45.0,"nearClipRatio":0.005,"projectionMode":"Perspective","viewMat":[0.825720965862274,6.63567334413528e-09,0.564077973365784,0.168529987335205,0.20917721092701,0.928702771663666,-0.30620214343071,-0.0139824077486992,-0.523861527442932,0.370830506086349,0.766852080821991,-1.84108781814575,0.0,0.0,0.0,1.0],"windowHeight":1200,"windowWidth":1600}

## Metrics Evaluation

### Dense Mesh Deformation

We deform the original input mesh to evaluate all properties.

In [27]:
LOAD_INPUT_MESH = 'assets/mesh/armadillo.ply'
LOAD_DEFORMED_MESH = 'assets\\mesh\\arap_results\\spokes_and_rims\\dino_arap_snout_experiment.off'

mesh_input = trimesh.load(LOAD_INPUT_MESH, force='mesh')

mesh_input.vertices -= np.expand_dims(mesh_input.centroid, axis=0)
mesh_input.vertices /= np.abs(mesh_input.vertices).max()
mesh_input.vertices *= 0.8

if len(LOAD_DEFORMED_MESH) > 0:
    mesh_deform = trimesh.load(LOAD_DEFORMED_MESH, force='mesh')
else:
    in_verts = torch.from_numpy(mesh_input.vertices).float()
    def_verts = []
    with torch.no_grad():
        for sample in torch.split(in_verts, CHUNK, dim=0):
            def_verts.append(rtf_model.inverse(sample.to(DEVICE)).cpu().numpy())
    def_verts = np.concatenate(def_verts, axis=0)
    mesh_deform = trimesh.Trimesh(def_verts, mesh_input.faces)

In [28]:
MOVING_HANDLES = ["assets/constraints/dino/transforms/snout_point_down.txt"]
STATIC_HANDLES = ["assets/constraints/dino/parts/left_foot.txt",
                  "assets/constraints/dino/parts/right_foot.txt"]

moving_pts = np.concatenate([np.loadtxt(f) for f in MOVING_HANDLES], axis=0)
if len(moving_pts.shape) < 2:
    moving_pts = moving_pts.reshape(1, 3)
static_pts = np.concatenate([np.loadtxt(f) for f in STATIC_HANDLES], axis=0)

ps.register_surface_mesh("Deformed Mesh", mesh_deform.vertices, mesh_deform.faces)
ps.register_point_cloud("Moving Points", moving_pts)
ps.register_point_cloud("Static Points", static_pts)
ps.show()
ps.remove_all_structures()

#### Global properties

In [43]:
source_global_metrics = {
    'area': mesh_input.area,
    'volume': mesh_input.volume
}

In [44]:
deform_global_metrics = {
    'area': mesh_deform.area,
    'volume': mesh_deform.volume
}

In [45]:
for k in source_global_metrics.keys():
    perc = (source_global_metrics[k] / deform_global_metrics[k]) * 100.
    print(f"{k} preservation: {perc:.2f}%")

area preservation: 101.69%
volume preservation: 103.75%


#### Topology properties

In [46]:
edges = mesh_input.edges
source_local_metrics = {
    'edge_lengths': np.linalg.norm(
        mesh_input.vertices[edges[:, 0]] - mesh_input.vertices[edges[:, 1]], axis=-1)
}

In [47]:
deform_local_metrics = {
    'edge_lengths': np.linalg.norm(
        mesh_deform.vertices[edges[:, 0]] - mesh_deform.vertices[edges[:, 1]], axis=-1)
}

In [48]:
for k in source_local_metrics.keys():
    error = source_local_metrics[k] - deform_local_metrics[k]
    print(f"{k} preservation: {error.mean()}%")

edge_lengths preservation: 7.596860775419095e-06%


### Neural Fields Deformation

Properties are evaluated on the marching cubes mesh obtained from the inverted deformation field.

In [49]:
mesh_source = trimesh.Trimesh(shape_verts, shape_faces)
mesh_source.fix_normals()

In [None]:
f_eval = []
with torch.no_grad():
    for sample in tqdm(torch.split(volume, CHUNK, dim=0)):
        f_eval.append(sdf_model(rtf_model.inverse(sample.contiguous()))['dist'].cpu().numpy())
f_volume = np.concatenate(f_eval, axis=0).reshape(*([MC_RESOLUTION] * 3))
def_verts, def_faces = mcubes.marching_cubes(f_volume, 0.0)
def_verts /= MC_RESOLUTION // 2
def_verts -= 1.0

mesh_deform = trimesh.Trimesh(def_verts, def_faces)

#### Global properties

In [32]:
source_global_metrics = {
    'area': mesh_source.area,
    'volume': mesh_source.volume
}

In [33]:
deform_global_metrics = {
    'area': mesh_deform.area,
    'volume': mesh_deform.volume
}

In [34]:
for k in source_global_metrics.keys():
    perc = (source_global_metrics[k] / deform_global_metrics[k]) * 100.
    print(f"{k} preservation: {perc:.2f}%")

area preservation: 101.15%
volume preservation: -100.31%
