In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Pose Estimation
Compare to PyTorch3D `Camera position optimization` sample. 

In [None]:
import os
import sys
import pickle
import glob

import numpy as np
import matplotlib.pyplot as plt

In [None]:
from tqdm.notebook import tqdm

In [None]:
import trimesh
import pyrender
import transforms3d

from tqdm.notebook import tqdm
class QuasiRandom():
    def __init__(self,dim=1,seed=None):
        self.dim = dim
        self.x = np.random.rand(dim) if seed is None else seed
        root_sys = [1] +[0 for i in range(dim-1)] + [-1,-1]
        self.const = sorted(np.roots(root_sys))[-1].real
        self.phi = np.array([1/(self.const)**(i+1) for i in range(dim)])
    def generate(self,n_points=1):
        res = np.zeros((n_points,self.dim))
        for i in range(n_points):
            res[i] = self.x = (self.x+self.phi)
        return np.squeeze(res%1)
    
mesh_file = 'data/cow.obj'

mesh_tri = trimesh.load(mesh_file)

# seems sane to fetch/estimate scale
shape_scale = float(mesh_tri.vertices.std(0).mean())*3
center = np.array(mesh_tri.vertices.mean(0))
t_model_scale = np.ptp(mesh_tri.vertices,0).mean()

print('model is {:.2f}x the size of the cow'.format(shape_scale/1.18))

In [None]:
image_size = (64,64)
vfov_degrees = 45
focal_length = 0.5*image_size[0]/np.tan((np.pi/180.0)*vfov_degrees/2)
cx = (image_size[1]-1)/2
cy = (image_size[0]-1)/2
rand_quat = QuasiRandom(dim=4,seed=0).generate(1)
rand_quat = rand_quat/np.linalg.norm(rand_quat)

mesh = pyrender.Mesh.from_trimesh(mesh_tri)

scene = pyrender.Scene()
scene.add(mesh)


R = transforms3d.quaternions.quat2mat(rand_quat)
loc = np.array([0,0,3*shape_scale]) @ R + center
pose = np.vstack([np.vstack([R,loc]).T,np.array([0,0,0,1])])

light = pyrender.SpotLight(color=np.ones(3), intensity=50.0,
                            innerConeAngle=np.pi/16.0,
                            outerConeAngle=np.pi/6.0)
scene.add(light, pose=pose)

camera = pyrender.IntrinsicsCamera(focal_length,focal_length,cx,cy,znear=0.1*shape_scale,zfar=100*shape_scale)
scene.add(camera,pose=pose)

r = pyrender.OffscreenRenderer(image_size[1],image_size[0])
color, target_depth = r.render(scene)
target_depth[target_depth ==0] = np.nan

plt.subplot(1,2,1)
plt.imshow(color)
plt.subplot(1,2,2)
plt.imshow(target_depth)
plt.tight_layout()

# Setup Fuzzy Metaball renderer

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
#jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import fm_render

# volume usually False since color optimization implies surface samples
# And code defaults towards that sort of usage now
show_volume = False

NUM_MIXTURE = 40
beta2 = 21.4
beta3 = 2.66

gmm_init_scale = 80

render_jit = jax.jit(fm_render.render_func_quat)

In [None]:

import trimesh
import sklearn.mixture
if show_volume:
    pts = trimesh.sample.volume_mesh(mesh_tri,10000)
else:
    pts = trimesh.sample.sample_surface_even(mesh_tri,10000)[0]
gmm = sklearn.mixture.GaussianMixture(NUM_MIXTURE)
gmm.fit(pts)
weights_log = np.log( gmm.weights_) + np.log(gmm_init_scale)
mean = gmm.means_
prec = gmm.precisions_cholesky_


In [None]:

height, width = image_size
K = np.array([[focal_length, 0, cx],[0,focal_length,cy],[0,0,1]])
pixel_list = (np.array(np.meshgrid(np.arange(width),height-np.arange(height)-1,[0]))[:,:,:,0]).reshape((3,-1)).T
camera_rays = (pixel_list - K[:,2])/np.diag(K)
camera_rays[:,-1] = -1

trans_true = loc
quat_true = rand_quat

# Add noise to pose

In [None]:
while True:
    t_err_cap = 0.5
    rad_eps = (np.pi/180.0)*90 # range. so 90 is -45 to +45

    t_err_vec = np.random.randn(3)
    t_err_vec = t_err_vec/np.linalg.norm(t_err_vec)
    t_err_mag = np.random.rand()

    trans_offset = t_err_cap*t_err_mag*t_err_vec*t_model_scale
    trans_shift = trans_true - trans_offset

    angles = np.random.randn(3)
    angles = angles/np.linalg.norm(angles)
    angle_mag = (np.random.rand()-0.5)*rad_eps
    R_I = transforms3d.quaternions.quat2mat(quat_true).T
    R_R = transforms3d.axangles.axangle2mat(angles,angle_mag)
    R_C =  R_R @ R_I

    quat_init = transforms3d.quaternions.mat2quat(R_C.T)
    trans_init = R_R@trans_shift

    rand_rot = abs(angle_mag*(180.0/np.pi))
    rand_trans = 100*(t_err_mag*t_err_cap)
    init_pose_err = np.sqrt(rand_rot*rand_trans)
    if rand_trans >30 and rand_rot >30:
        print('pose error of {:.1f}, random rotation of {:.1f} degrees and translation of {:.1f}%'.format(init_pose_err,rand_rot,rand_trans))
        break
#axangl_init = axangl_true.copy()
#trans_init = trans_true.copy()

In [None]:
def compute_normals(camera_rays, depth_py_px,image_size):
    nan_depth = depth_py_px.ravel()
    PY,PX=image_size
    #nan_depth = jnp.nan_to_num(depth_py_px.ravel(),nan=1e-9)

    dpt = jnp.array( camera_rays.reshape((-1,3)) * nan_depth[:,None] )
    dpt = dpt.reshape((PY,PX,3))
    ydiff = dpt - jnp.roll(dpt,1,0)
    xdiff = dpt - jnp.roll(dpt,1,1)
    ydiff = jnp.nan_to_num(ydiff,nan=1e-9) # new
    xdiff = jnp.nan_to_num(xdiff,nan=1e-9) # new 

    ddiff = jnp.cross(xdiff.reshape((-1,3)),ydiff.reshape((-1,3)),)
    nan_ddiff = jnp.nan_to_num(ddiff,nan=0)
    norms = nan_ddiff/(1e-20+jnp.linalg.norm(nan_ddiff,axis=1,keepdims=True))

    return norms


# Solve for camera pose

In [None]:
def error_func(est_depth,est_alpha,true_depth):
    cond = jnp.isnan(est_depth) | jnp.isnan(true_depth)
    #err = (est_depth - true_depth)/jnp.nan_to_num(true_depth,nan=1)
    err = (est_depth - true_depth)/jnp.nanmean(true_depth)

    depth_loss =  abs(jnp.where(cond,0,err)).mean()

    true_alpha = ~jnp.isnan(true_depth)
    est_alpha = jnp.clip(est_alpha,1e-7,1-1e-7)
    mask_loss = -((true_alpha * jnp.log(est_alpha)) + (~true_alpha)*jnp.log(1-est_alpha))

    term1 = depth_loss.mean()
    term2 = mask_loss.mean()
    return 50*term1 + term2

def objective(params,means,prec,weights_log,camera_rays,beta2,beta3,depth):
    mrp,trans= params
    render_res = render_jit(means,prec,weights_log,camera_rays,mrp,trans,beta2,beta3)
    return error_func(render_res[0],render_res[1],depth)

def objective_simple(params,means,prec,weights_log,camera_rays,beta2,beta3,depth):
    mrp = jnp.array(params[:3])
    trans = jnp.array(params[3:])
    render_res = render_jit(means,prec,weights_log,camera_rays,mrp,trans,beta2,beta3)
    return error_func(render_res[0],render_res[1],depth)
grad_render3 = jax.jit(jax.value_and_grad(objective))

In [None]:
from jax.example_libraries import optimizers
from util import DegradeLR
# Number of optimization steps
# typically only needs a few hundred
# and early exits
Niter = 2000

loop = tqdm(range(Niter))

# babysit learning rates
adjust_lr = DegradeLR(3e-4,0.1,50,10,-1e-4)
opt_init, opt_update, opt_params = optimizers.momentum(adjust_lr.step_func,0.95)

# to test scale invariance
HUHSCALE = 1
# should get same result even if world scale changes

tmp = [quat_init,HUHSCALE*trans_init]
opt_state = opt_init(tmp)

losses = []
jax_tdepth = jnp.array(target_depth.ravel())

for i in loop:
    p = opt_params(opt_state)

    val,g = grad_render3(p,HUHSCALE*mean,prec/HUHSCALE,weights_log,camera_rays,beta2/(HUHSCALE*shape_scale),beta3,HUHSCALE*jax_tdepth)
    
    S = jnp.linalg.norm(p[1])
    S2 = S*S

    g1 = g[0]
    g2 = g[1]*S2

    opt_state = opt_update(i, [g1,g2], opt_state)

    val = float(val)
    losses.append(val)
    if adjust_lr.add(val):
        break
    # Print the losses
    loop.set_description("total_loss = %.3f" % val)

In [None]:
mrp_final, trans_final = opt_params(opt_state)
trans_final = trans_final/HUHSCALE

In [None]:
# 2nd order is also possible
if False:
    from jax.scipy.optimize import minimize
    res = minimize(objective_simple,jnp.hstack([mrp_init,trans_init]),method='BFGS',args=(mean,prec,weight_log,camera_rays,beta2,beta3,beta4,beta5,jax_tdepth,))
    mrp_final = res.x[:3]
    trans_final = res.x[3:]

In [None]:
plt.title('convergence plot')
plt.plot(losses,marker='.',lw=0,ms=5,alpha=0.5)
plt.xlabel('iteration')
plt.ylabel('log loss')

In [None]:
quat_final, trans_final = opt_params(opt_state)
trans_final = trans_final/HUHSCALE

# Visualize Results

In [None]:
vmin,vmax = np.nanmin(target_depth),np.nanmax(target_depth)

In [None]:
plt.subplot(2,3,1)
plt.imshow(color)
plt.title('image')
plt.axis('off')
plt.subplot(2,3,4)
est_depth_true, est_alpha_true, _, _ = render_jit(mean,prec,weights_log,camera_rays,quat_true,trans_true,beta2/shape_scale,beta3)
est_depth_true = np.array(est_depth_true)
est_depth_true[est_alpha_true < 0.5] = np.nan
plt.imshow(est_depth_true.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('true pose')
plt.axis('off')
est_depth_init, est_alpha, _, _ = render_jit(mean,prec,weights_log,camera_rays,quat_init,trans_init,beta2/shape_scale,beta3)
est_depth_init = np.array(est_depth_init)
est_depth_init[est_alpha < 0.5] = np.nan
plt.subplot(2,3,2)
plt.imshow(est_alpha.reshape(image_size),cmap='Greys')
plt.title('init FM alpha')
plt.axis('off')
plt.subplot(2,3,5)
plt.imshow(est_depth_init.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('init FM depth')
plt.axis('off')
est_depth,  est_alpha, _, _ = render_jit(mean,prec,weights_log,camera_rays,quat_final,trans_final,beta2/shape_scale,beta3)
est_depth = np.array(est_depth)
est_depth[est_alpha < 0.5] = np.nan
plt.subplot(2,3,3)
plt.imshow(est_alpha.reshape(image_size),cmap='Greys')
plt.title('final FM alpha')
plt.axis('off')
plt.subplot(2,3,6)
plt.imshow(est_depth.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('final FM depth')
plt.axis('off')
plt.tight_layout()

In [None]:
q1 = quat_true/np.linalg.norm(quat_true)
q2 = quat_final/np.linalg.norm(quat_final)
e1 = np.arccos(np.clip((q1 * q2).sum(),-1,1))
e2 = np.arccos(np.clip((-q1 * q2).sum(),-1,1))
rot_err = float((180.0/np.pi)*2*min(e1,e2))

R1 = np.array(transforms3d.quaternions.quat2mat(q1))
R2 = np.array(transforms3d.quaternions.quat2mat(q2))
t_norm = np.linalg.norm(R1.T@np.array(trans_true)-R2.T@np.array(trans_final))
trans_err = 100*t_norm/t_model_scale

pose_err = np.sqrt(rot_err*trans_err)
print('init. pose error of {:04.1f} with rot. of {:04.1f} deg and trans. of {:04.1f}%'.format(init_pose_err,rand_rot,rand_trans))
print('final pose error of {:04.1f} with rot. of {:04.1f} deg and trans. of {:04.1f}%'.format(pose_err,rot_err,trans_err))