# Shape from Silhouette With Rays
Compare to PyTorch3D `Fit a mesh with texture` sample. We only use a silhouette loss (PyTorch sample uses color, silhouette, edge, normal and laplacian loss terms). In our testing, this runs about 1,000x faster than the PyTorch  example on our CPU hardware, although it doesn't reconstruct color. 

This example differs from the other in that camera poses are implicitly stored in the camera rays data structure and there is no traditional pose optimization possible. However, here we can run simple (no batch) gradient descent on the entire dataset at once, which should be faster for large parallel computers.

In [None]:
import os
import sys

import torch

import numpy as np
import matplotlib.pyplot as plt

from util import image_grid

In [None]:
%matplotlib inline

In [None]:
# io utils
from pytorch3d.io import load_objs_as_meshes

# 3D transformations functions
from pytorch3d.transforms import so3_log_map

# rendering components
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, 
    RasterizationSettings, MeshRenderer, MeshRasterizer,
    HardPhongShader, PointLights, TexturesVertex
)

from tqdm.notebook import tqdm

# Load data and generate views with PyTorch3D
we're using the cow model from Keenan Crane, featured in the PyTorch3D tutorials

In [None]:
if torch.cuda.is_available():
    torch_device = torch.device('cuda')
else:
    torch_device = torch.device('cpu')
#torch_device = torch.device("cpu")
mesh = load_objs_as_meshes(['data/cow.obj'], device=torch_device)

# seems sane to fetch/estimate scale
shape_scale = float(mesh.verts_list()[0].std(0).mean())*3 
t_model_scale = np.ptp(np.array(mesh.verts_list()[0]),0).mean()
print('model is {:.2f}x the size of the cow'.format(shape_scale/1.18))

This is simply the dataset generation code, taken from the PyTorch3D tutorial

In [None]:
num_views = 20
image_size = (64,64)
vfov_degrees = 60

if False:
    # Get a batch of viewing angles like PyT3D
    elev = torch.linspace(0, 360, num_views)
    azim = torch.linspace(-180, 180, num_views)
else:
    # Get a batch of views that cover the scene well and are all unique
    import scipy.stats.qmc
    rand_angles_sample = scipy.stats.qmc.Sobol(2,scramble=False).random(num_views+1)[1:]
    rand_angles = (rand_angles_sample*np.array([360.0,360.0]) + np.array([0,-180.0])).astype(np.float32)
    elev = rand_angles[:,0]
    azim = rand_angles[:,1]

lights = PointLights(device=torch_device, location=[[0.0, 0.0, -3.0*shape_scale]])
R, T = look_at_view_transform(dist=2.7*shape_scale, elev=elev, azim=azim)
cameras = FoVPerspectiveCameras(device=torch_device, R=R, T=T, 
                                znear=shape_scale, zfar=100*shape_scale, fov=vfov_degrees)
camera = FoVPerspectiveCameras(device=torch_device, R=R[None, 1, ...], 
                                  T=T[None, 1, ...],
                                  znear=shape_scale, zfar=100*shape_scale, fov=vfov_degrees) 
raster_settings = RasterizationSettings(
    image_size=image_size, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(
        device=torch_device, 
        cameras=camera,
        lights=lights
    )
)

# Create a batch of meshes by repeating the cow mesh and associated textures. 
# Meshes has a useful `extend` method which allows us do this very easily. 
# This also extends the textures. 
meshes = mesh.extend(num_views)

# Render the cow mesh from each viewing angle
target_images = renderer(meshes, cameras=cameras, lights=lights)

# Our multi-view cow dataset will be represented by these 2 lists of tensors,
# each of length num_views.
target_rgb = [target_images[i, ..., :3] for i in range(num_views)]
target_cameras = [FoVPerspectiveCameras(device=torch_device, R=R[None, i, ...], 
                                           T=T[None, i, ...], znear=shape_scale, zfar=100*shape_scale, fov=vfov_degrees) for i in range(num_views)]


In [None]:
np_images = target_images.cpu().numpy()
target_sil = np_images[:,:,:,3]
image_grid(np_images, rows=4, cols=5, rgb=True)

# Setup Fuzzy Metaball renderer

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
#jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import fm_render
# when using probs, old settings
#hyperparams = np.array([-13.69159107,  -2.67968404,   0.71023575,   6.36908448,  -5.42242999])
# when using stds, new settings,. beta1 does nothing
hyperparams = np.array([-8.96549948, -2.78121285, -0.08753679,  6.41910729, -5.4442808 ])
NUM_MIXTURE = 40
beta1 = jnp.float32(np.exp(hyperparams[0]))
beta2 = jnp.float32(np.exp(hyperparams[1]))
beta3 = jnp.float32(np.exp(hyperparams[2]))
beta4 = jnp.float32(np.exp(hyperparams[3]))
beta5 = jnp.float32(-np.exp(hyperparams[4]))

render_jit = jax.jit(fm_render.render_func_rays)
obj_scale = 1/(120*shape_scale)

In [None]:
# get a reasonable initialization to check
show_vgmm_init = False
if show_vgmm_init:
    from sklearn.mixture import GaussianMixture
    vgmm_model = GaussianMixture(NUM_MIXTURE)

    import trimesh
    trimesh_mesh = trimesh.Trimesh(mesh.verts_list()[0],mesh.faces_list()[0])
    vol_samples = trimesh.sample.volume_mesh(trimesh_mesh,10000)

    vgmm_model.fit(vol_samples)

    means = jnp.array(vgmm_model.means_)
    prec = jnp.array(vgmm_model.precisions_cholesky_)
    weights_log = jnp.log(jnp.array(vgmm_model.weights_) + 1e-6)
    weights = np.exp(weights_log)
    weights /= np.sum(weights)
    obj_scale = (weights[:,None] * means).std(0).mean()

Initialize a Fuzzy Metaballs model from random blobs

In [None]:
if show_vgmm_init:
    rand_mean = np.random.multivariate_normal(mean=means.mean(0),cov=0.8*jnp.cov(means,rowvar=False),size=NUM_MIXTURE)
    rand_weight_log = jnp.log(np.ones(NUM_MIXTURE)/NUM_MIXTURE)
    rand_sphere_size = jnp.diag(prec.mean(0)).mean()
    rand_prec = jnp.array([np.identity(3)*rand_sphere_size for _ in prec])
else:
    rand_mean = np.random.multivariate_normal(mean=[0,0,0],cov=1e-1*np.identity(3)*shape_scale,size=NUM_MIXTURE)
    rand_weight_log = jnp.log(np.ones(NUM_MIXTURE)/NUM_MIXTURE)
    rand_sphere_size = 13
    rand_prec = jnp.array([np.identity(3)*rand_sphere_size/shape_scale for _ in range(NUM_MIXTURE)])

In [None]:
def convert_pyt3dcamera_rays(cam, image_size = image_size):
    height, width = image_size
    cx = (width-1)/2
    cy = (height-1)/2
    f = (height/np.tan((np.pi/180)*float(cam.fov[0])/2))*0.5
    K = np.array([[f, 0, cx],[0,f,cy],[0,0,1]])
    pixel_list = (np.array(np.meshgrid(width-np.arange(width)-1,height-np.arange(height)-1,[0]))[:,:,:,0]).reshape((3,-1)).T

    camera_rays = (pixel_list - K[:,2])/np.diag(K)
    camera_rays[:,-1] = 1
    
    translation = np.array(-cam.R[0]@cam.T[0])

    camera_rays = camera_rays @ np.array(cam.R[0]).T 
    trans = np.tile(translation[None],(camera_rays.shape[0],1))
    
    rays_trans = np.stack([camera_rays,trans],1)
    return jnp.array(rays_trans)

cameras_list = [convert_pyt3dcamera_rays(cam) for cam in target_cameras]

In [None]:
if show_vgmm_init:
    alpha_results = []
    for camera_rays in cameras_list:
        est_depth, est_probs = render_jit(means,prec,weights_log,camera_rays,beta1/obj_scale,beta2/obj_scale,beta3)
        est_alpha = jnp.tanh(beta4*(jnp.exp(est_probs).sum(0)+beta5) )*0.5 + 0.5
        alpha_results.append(est_alpha.reshape(image_size))

alpha_results_rand = []
alpha_results_rand_depth = []
for camera_rays in cameras_list:
    est_depth, est_probs = render_jit(rand_mean,rand_prec,rand_weight_log,camera_rays,beta1/obj_scale,beta2/obj_scale,beta3)
    est_alpha = jnp.tanh(beta4*(jnp.exp(est_probs).sum(0)+beta5) )*0.5 + 0.5
    alpha_results_rand.append(est_alpha.reshape(image_size))
    est_depth = np.array(est_depth)
    est_depth[est_alpha < 0.5] = np.nan
    alpha_results_rand_depth.append(est_depth.reshape(image_size))

In [None]:
image_grid(target_sil, rows=4, cols=5, rgb=False)
plt.gcf().subplots_adjust(top=0.92)
plt.suptitle('Reference Masks')
if show_vgmm_init:
    image_grid(alpha_results, rows=4, cols=5, rgb=False)
    plt.gcf().subplots_adjust(top=0.92)
    plt.suptitle('vGMM Masks')
image_grid(alpha_results_rand, rows=4, cols=5, rgb=False,cmap='Greys')
plt.gcf().subplots_adjust(top=0.92)
plt.suptitle('random init masks')
image_grid(alpha_results_rand_depth, rows=4, cols=5, rgb=False)
plt.gcf().subplots_adjust(top=0.92)
plt.suptitle('SFS Fuzzy Metaball Initialization')
#plt.savefig('sfs_init.pdf',facecolor=plt.gcf().get_facecolor(), edgecolor='none',bbox_inches='tight')

# Optimize from a random cloud to a shape

In [None]:
def objective(params,true_alpha):
    CLIP_ALPHA = 1e-6
    means,prec,weights_log,camera_rays,beta1,beta2,beta3,beta4,beta5 = params
    render_res = render_jit(means,prec,weights_log,camera_rays,beta1,beta2,beta3)

    est_alpha = jnp.tanh(beta4*(jnp.exp(render_res[1]).sum(0)+beta5) )*0.5 + 0.5
    est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA)
    mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))
    return mask_loss.mean()
grad_render3 = jax.jit(jax.value_and_grad(objective))

In [None]:
from jax.example_libraries import optimizers

# Number of optimization steps
Niter = int(round(2000/(len(target_sil))))
# number of images to batch gradients over

loop = tqdm(range(Niter))

# babysit learning rates
# adjust_lr = DegradeLR(1e-3,0.5,train_size//2,train_size//10,-1e-3)

opt_init, opt_update, opt_params = optimizers.adam(3e-2)
tmp = [rand_mean,rand_prec,rand_weight_log]
opt_state = opt_init(tmp)

all_cameras = jnp.array(cameras_list).reshape((-1,2,3))
all_sils = jnp.array(target_sil.ravel()).astype(jnp.float32)

losses = []
accum_grad = None
grad_counter = 0

for i in loop:
    p = opt_params(opt_state)
    val,g = grad_render3([p[0],p[1],p[2],all_cameras,beta1/obj_scale,beta2/obj_scale,beta3,beta4,beta5],all_sils)   
    opt_state = opt_update(i, g[:3], opt_state)
   
    val = float(val)
    losses.append(val)
    loop.set_description("total_loss = %.3f" % val)

In [None]:
final_mean, final_prec, final_weight_log = opt_params(opt_state)

In [None]:
plt.title('convergence plot')
plt.plot(losses,marker='.',lw=0,ms=5,alpha=0.5)
plt.xlabel('iteration')
plt.ylabel('log loss')

# Visualize Results

In [None]:
alpha_results_final = []
alpha_results_depth = []
for camera_rays in cameras_list:
    est_depth, est_probs = render_jit(final_mean,final_prec,final_weight_log,camera_rays,beta1/obj_scale,beta2/obj_scale,beta3)
    est_alpha = jnp.tanh(beta4*(jnp.exp(est_probs).sum(0)+beta5) )*0.5 + 0.5
    alpha_results_final.append(est_alpha.reshape(image_size))
    
    est_depth = np.array(est_depth)
    
    est_depth[est_alpha < 0.5] = np.nan
    alpha_results_depth.append(est_depth.reshape(image_size))
image_grid(target_sil, rows=4, cols=5, rgb=False)
plt.gcf().subplots_adjust(top=0.92)
plt.suptitle('Reference Masks')

image_grid(alpha_results_final, rows=4, cols=5, rgb=False)
plt.gcf().subplots_adjust(top=0.92)
plt.suptitle('Final masks')

In [None]:
image_grid(alpha_results_depth, rows=4, cols=5, rgb=False,vmin=2.*shape_scale,vmax=3.*shape_scale)
plt.gcf().subplots_adjust(top=0.92)
plt.suptitle('SFS results')
plt.tight_layout()
#plt.savefig('sfs_res.pdf',facecolor=plt.gcf().get_facecolor(), edgecolor='none',bbox_inches='tight')

In [None]:
import pickle
with open('fuzzy_cow_shape_rays.pkl','wb') as fp:
    pickle.dump([final_mean,final_prec,final_weight_log],fp)