In [10]:
import open3d as o3d
import torch
import tqdm
import numpy as np

from models import PixelNeRFNet
from train import PixelNeRFTrainer


In [11]:

device = torch.device("cuda:0")

net = PixelNeRFNet().to(device=device)
net.load_state_dict(torch.load('/home/liqi/code/AutoRF-pytorch/output/200.ckpt'))
net = net.eval()



In [26]:
chunk = 1024
obj = 0

### Tune these parameters until the whole object lies tightly in range with little noise ###
N = 100 # controls the resolution, set this number small here because we're only finding
        # good ranges here, not yet for mesh reconstruction; we can set this number high
        # when it comes to final reconstruction.
xmin, xmax = -1, 1 # left/right range
ymin, ymax = -1, 1 # forward/backward range
zmin, zmax = -1, 1 # up/down range
## Attention! the ranges MUST have the same length!
sigma_threshold = 10 # controls the noise (lower=maybe more noise; higher=some mesh might be missing)
############################################################################################

x = np.linspace(xmin, xmax, N)
y = np.linspace(ymin, ymax, N)
z = np.linspace(zmin, zmax, N)

xyz_ = torch.FloatTensor(np.stack(np.meshgrid(x, y, z), -1).reshape(-1, 3)).cuda()
dir_ = torch.zeros_like(xyz_).cuda()

In [27]:
from kitti import KITTI

dataset = KITTI('train.txt')

image, all_rays, rois, intersect, objs = dataset.__getscene__(10, distinct_intersect=True)

In [28]:
src_images = rois.to(device)

with torch.no_grad():
    latents = net.encode(src_images)

In [29]:
with torch.no_grad():
    B = xyz_.shape[0]
    sigma_chunks = []
    for i in range(0, B, chunk):
        xyz = xyz_[None, i:i+chunk]
        d = dir_[None, i:i+chunk]
        out = net(xyz, d, latents[obj][None])
        sigma = out[..., -1]
        sigma_chunks += [sigma.squeeze()]
        
    sigma = torch.cat(sigma_chunks, 0)
    
sigma = sigma.cpu().numpy()
sigma = sigma.reshape(N, N, N)

In [30]:
x, y, z = np.where(sigma > 7)
x = x * objs[obj, 4].numpy()
y = y * objs[obj, 3].numpy()
z = z * objs[obj, 5].numpy()

pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(np.array([x, y, z]).T))

In [31]:
o3d.io.write_point_cloud("pc.pcd", pcd)

True

In [32]:
import mcubes
import trimesh

vertices, triangles = mcubes.marching_cubes(sigma, sigma_threshold)

vertices[:, 0] *= objs[obj, 4].numpy()
vertices[:, 1] *= objs[obj, 3].numpy()
vertices[:, 2] *= objs[obj, 5].numpy()


mesh = trimesh.Trimesh(vertices, triangles)
mesh.show()