In [1]:
import os
import numpy as np
import torch
from torch import nn
import mcubes
import trimesh

from model import *
from load_llff import *

In [None]:
"""
Hyperparameters
"""

# Model
d_filter = 256          # Dimensions of linear layer filters
n_layers = 8            # Number of layers in network bottleneck
skip = [4]              # Layers at which to apply input residual
use_fine_model = True   # If set, creates a fine model
d_filter_fine = 256     # Dimensions of linear layer filters of fine network
n_layers_fine = 8       # Number of layers in fine network bottleneck

# Stratified sampling
n_samples = 64         # Number of spatial samples per ray
perturb = False        # If set, applies noise to sample positions
inverse_depth = False   # If set, samples points linearly in inverse depth

# Hierarchical sampling
N_importance = 64             # Number of samples per ray
perturb_hierarchical = False  # If set, applies noise to sample positions

# Encoder
d_input = 3           # Number of input dimensions
n_freqs = 10          # Number of encoding functions for samples
log_sampling = True   # If set, frequencies scale in log space
use_viewdirs = True   # If set, use view direction as input
n_freqs_views = 4     # Number of encoding functions for views

# Training
chunksize = 1024            # Modify as needed to fit in GPU memory

# LLFF & Dataloading
basedir = ".\\logs"         # Base directory for logs and ckpts
expname = "ParrotnPlate"    # Custom experiment name
data_dir = ".\\nerf_sample_parrotnPlate"         # Input data directory
dataset_type = "llff"       # Dataset Type
factor = 4                  # Load down scaled image, 
                            # NEEDS to be create manually inside the folder such as, image_4
spherify = True             # Set if it's for 360 inward scenes
llff_hold = 8               # If set, take image 1/N as test set
use_ndc = False             # Use ndc for forward facing scenes
render_test = False         # Render test set instead of custom poses

# Function kwargs
stratified_sampling_kwargs = {
    'n_samples': n_samples,
    'perturb': perturb,
    'inverse_depth': inverse_depth
}
hierarchical_sampling_kwargs = {
    'perturb': perturb
}

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load model

In [4]:
def batchify(fn, chunk):
    """
    Constructs a version of 'fn' that applies to smaller batches.
    """
    if chunk is None:
        return fn
    def ret(inputs, viewdirs):
        return torch.cat([fn(inputs[i:i+chunk], viewdirs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret

In [5]:
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """
    Prepares inputs and applies network 'fn'.
    """

    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)
    
    if viewdirs is not None:
        input_dirs = viewdirs[:,None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)

    outputs_flat = batchify(fn, netchunk)(embedded, viewdirs=embedded_dirs)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs

In [6]:
def init_models():
  """
  Initialize models, encoders, and optimizer for training.
  """

  # Encoders
  encoder = PositionalEncoder(d_input, n_freqs, log_sampling=log_sampling)
  encode = lambda x: encoder(x)

  # View direction encoders
  if use_viewdirs:
    encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views,
                                        log_sampling=log_sampling)
    encode_viewdirs = lambda x: encoder_viewdirs(x)
    d_viewdirs = encoder_viewdirs.d_output
    test_dir = d_viewdirs
  else:
    encode_viewdirs = None
    d_viewdirs = None

  # Models
  model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,
              d_viewdirs=d_viewdirs)
  model.to(device)
  model_params = list(model.parameters())
  if use_fine_model:
    fine_model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,
                      d_viewdirs=d_viewdirs)
    fine_model.to(device)
    model_params = model_params + list(fine_model.parameters())
  else:
    fine_model = None
  
  network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=encode,
                                                                embeddirs_fn=encode_viewdirs,
                                                                netchunk=chunksize)

  # Checkpoints loading
  os.makedirs(os.path.join(basedir, expname), exist_ok=True)
  ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
  print("Found ckpts", ckpts)

  if len(ckpts) > 0:
    ckpt_path = ckpts[-1]
    print('testing from', ckpt_path)
    ckpt = torch.load(ckpt_path)

    # Load model
    model.load_state_dict(ckpt['coarse_model_dict'])
    if fine_model is not None:
        fine_model.load_state_dict(ckpt['fine_model_dict'])

  render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : perturb,
        'N_importance' : N_importance,
        'network_fine' : fine_model,
        'N_samples' : n_samples,
        'network_fn' : model,
        'use_viewdirs' : use_viewdirs,
        'white_bkgd' : False,
        'raw_noise_std' : 0.,
    }
  
  # NDC only good for LLFF-style forward facing data
  if not use_ndc:
    print('Not ndc!')
    render_kwargs_train['use_ndc'] = False
    render_kwargs_train['lindisp'] = inverse_depth

  render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
  render_kwargs_test['perturb'] = False
  render_kwargs_test['raw_noise_std'] = 0.

  return render_kwargs_train, render_kwargs_test, model_params

In [7]:
K = None

if dataset_type == "llff":
    images, poses, bds, render_poses, i_test = load_llff_data(basedir=data_dir, factor=factor, spherify=spherify)
    focal = poses[0, 2, 4]
    height, width = poses[0, :2, -1]
    height, width = int(height), int(width)
    hwf = [height, width, focal]
    poses = poses[:, :3, :4]

    print(f'Images shape: {images.shape}')
    print(f'Poses shape: {poses.shape}')
    print(f'Focal: {focal}')

    if not isinstance(i_test, list):
        i_test = [i_test]

    if llff_hold > 0:
        print(f'Set hold out frequency: {llff_hold}')
        i_test = np.arange(images.shape[0])[::llff_hold]
    i_val = i_test
    n_training = np.array([i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val)])

    if use_ndc:
        near, far = 0., 1.
    else:
        near = np.ndarray.min(bds) *.9
        far = np.ndarray.max(bds) * 1.

print(f'Define near far: {near}, {far}')

print(f'Train views are: {n_training}')
print(f'Val views are: {i_val}')
print(f'Test views are: {i_test}')
    
if K is None:
    K = np.array([
        [focal, 0, 0.5*width],
        [0, focal, 0.5*height],
        [0, 0, 1]
    ])

if render_test:
    render_poses = np.array(poses[i_test])

os.makedirs(os.path.join(basedir, expname), exist_ok=True)

# Create nerf model
render_kwargs_train, render_kwargs_test, grad_vars = init_models()

bds_dict = {
    'near' : near,
    'far' : far,
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)

# Move testing data to GPU
render_poses = torch.Tensor(render_poses).to(device)

Loaded image data (29, 1008, 756, 3) [1008.          756.          792.18388554]
Loaded .\nerf_sample_parrotnPlate 2.487472993908299 37.963128464167646
Data:
(29, 3, 5) (29, 1008, 756, 3) (29, 2)
HOLDOUT view is 14
Images shape: (29, 1008, 756, 3)
Poses shape: (29, 3, 4)
Focal: 792.1838989257812
Set hold out frequency: 8
Define near far: 0.569020128250122, 9.649142265319824
Train views are: [ 1  2  3  4  5  6  7  9 10 11 12 13 14 15 17 18 19 20 21 22 23 25 26 27
 28]
Val views are: [ 0  8 16 24]
Test views are: [ 0  8 16 24]
Found ckpts ['.\\logs\\ParrotnPlate\\001000.tar', '.\\logs\\ParrotnPlate\\002000.tar', '.\\logs\\ParrotnPlate\\003000.tar', '.\\logs\\ParrotnPlate\\004000.tar', '.\\logs\\ParrotnPlate\\005000.tar', '.\\logs\\ParrotnPlate\\006000.tar', '.\\logs\\ParrotnPlate\\007000.tar', '.\\logs\\ParrotnPlate\\008000.tar', '.\\logs\\ParrotnPlate\\009000.tar', '.\\logs\\ParrotnPlate\\010000.tar', '.\\logs\\ParrotnPlate\\011000.tar', '.\\logs\\ParrotnPlate\\012000.tar', '.\\logs\\Pa

# Try to find the specific tight bounds for Marching Cubes Algorithm

In [51]:
network_query_fn, network_fine = render_kwargs_test['network_query_fn'], render_kwargs_test['network_fine']
device = next(network_fine.parameters()).device

# Tune these hyperparameters until your main object is tightly in range with little noise
N = 128 # Meshgrid size, controls the resolution. (Set small when finding ranges; Set high when final reconstruction)

#!!! Caution, all three axis MUST range in the same length.
xmin, xmax = -0.4, 0.4 # left/right range
ymin, ymax = -0.4, 0.4 # forward/backward range
zmin, zmax = -0.35, 0.45 # up/down range

sigma_threshold =20. # Controls the noise (lower: maybe more noise; higher: some mesh might be missing)
############################################################################################

# Query points
x = np.linspace(xmin, xmax, N)
y = np.linspace(ymin, ymax, N)
z = np.linspace(zmin, zmax, N)

with torch.no_grad():
    xyz_pts = torch.FloatTensor(np.stack(np.meshgrid(x, y, z), -1).reshape(-1, 1, 3)).to(device)
    viewdirs = torch.zeros(xyz_pts.shape[0], 3).to(device)

    rgbsigma = network_query_fn(xyz_pts, viewdirs, network_fine)

    sigma = rgbsigma[..., -1].cpu().numpy()
    sigma = np.maximum(sigma, 0)
    sigma = sigma.reshape(N, N, N)

# Visualization
vertices, triangles = mcubes.marching_cubes(sigma, sigma_threshold)
mesh = trimesh.Trimesh(vertices/N, triangles)
mesh.show()

In [None]:
# Finally, export your work if needed
scene_name = expname + "_mesh"
#mcubes.export_mesh(vertices, triangles, f"{scene_name}.dae")