# Pose Estimation
Compare to PyTorch3D `Camera position optimization` sample. 

In [None]:
import os
import sys

import torch

import numpy as np
import matplotlib.pyplot as plt


In [None]:
%matplotlib inline

In [None]:
# io utils
from pytorch3d.io import load_objs_as_meshes

# 3D transformations functions
from pytorch3d.transforms import so3_log_map, axis_angle_to_matrix, axis_angle_to_quaternion, quaternion_to_matrix

# rendering components
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, 
    RasterizationSettings, MeshRenderer, MeshRasterizer,
    HardPhongShader, PointLights, TexturesVertex
)

from tqdm.notebook import tqdm

# Load data and generate views with PyTorch3D
we're using the cow model from Keenan Crane, featured in the PyTorch3D tutorials

In [None]:
if torch.cuda.is_available():
    torch_device = torch.device('cuda')
else:
    torch_device = torch.device('cpu')
mesh = load_objs_as_meshes(['data/cow.obj'], device=torch_device)

# seems sane to fetch/estimate scale
shape_scale = float(mesh.verts_list()[0].std(0).mean())*3 
t_model_scale = np.ptp(np.array(mesh.verts_list()[0].cpu()),0).mean()
print('model is {:.2f}x the size of the cow'.format(shape_scale/1.18))

This is simply the dataset generation code, taken from the PyTorch3D tutorial

In [None]:
image_size = (64,64)
vfov_degrees = 60

lights = PointLights(device=torch_device, location=[[0.0, 0.0, -3.0*shape_scale]])
R, T = look_at_view_transform(dist=2.7*shape_scale, elev=60, azim=20,device=torch_device)
camera = FoVPerspectiveCameras(device=torch_device,znear=shape_scale, zfar=100*shape_scale, fov=vfov_degrees) 
raster_settings = RasterizationSettings(
    image_size=image_size, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(
        device=torch_device, 
        cameras=camera,
        lights=lights
    )
)
depth_raster =MeshRasterizer(
    cameras=camera, 
    raster_settings=raster_settings
)

# Render the cow mesh 
target_image = np.squeeze(renderer(mesh, camera=camera, lights=lights, R=R, T=T))
target_depth = np.squeeze(depth_raster(mesh, camera=camera, lights=lights, R=R, T=T).zbuf)
target_depth[target_depth == -1] = np.nan

# Setup Fuzzy Metaball renderer

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
#jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import fm_render

# expected performance is roughly
# saved (optimized shape!): 1.5, volume: 2.5, surface: 2.5

show_saved = True
show_volume = True
if show_saved:
    hyperparams = [2.00, -0.1,  6.4, -5.44]
else:
    if show_volume:
        hyperparams = [2.58, 0.70, 6.64, -5.61]
    else:
        hyperparams = [1.84, 0.30, 5.44, -4.33]

NUM_MIXTURE = 40
beta2 = jnp.float32(np.exp(hyperparams[0]))
beta3 = jnp.float32(np.exp(hyperparams[1]))
beta4 = jnp.float32(np.exp(hyperparams[2]))
beta5 = -jnp.float32(np.exp(hyperparams[3]))

render_jit = jax.jit(fm_render.render_func_mrp)

In [None]:
if show_saved:
    import pickle
    # load a model of the cow. feel free to use the one reconstructed from silhouettes!
    with open('fuzzy_cow.pkl','rb') as fp:
        mean,prec,weight_log = pickle.load(fp)
    weights = np.exp(weight_log)
    weights /= np.sum(weights)
else:
    import trimesh
    import sklearn.mixture
    tmesh = trimesh.Trimesh(mesh.cpu().verts_packed(),mesh.cpu().faces_packed())
    if show_volume:
        pts = trimesh.sample.volume_mesh(tmesh,10000)
    else:
        pts = trimesh.sample.sample_surface_even(tmesh,10000)[0]
    gmm = sklearn.mixture.GaussianMixture(NUM_MIXTURE)
    gmm.fit(pts)
    weights = gmm.weights_
    weight_log = np.log(weights)
    mean = gmm.means_
    prec = gmm.precisions_cholesky_

In [None]:
def axangle2mrp(axangle):
    scale = jnp.linalg.norm(axangle)
    vec = axangle/scale
    return jnp.tan(scale/4)*vec

def rot2mrp(axangle):
    scale = jnp.linalg.norm(axangle)
    vec = axangle/scale
    return jnp.tan(scale/4)*vec
    
def convert_pyt3dcamera(cam, image_size = image_size):
    height, width = image_size
    cx = (width-1)/2
    cy = (height-1)/2
    f = (height/np.tan((np.pi/180)*float(cam.fov[0])/2))*0.5
    K = np.array([[f, 0, cx],[0,f,cy],[0,0,1]])
    pixel_list = (np.array(np.meshgrid(width-np.arange(width)-1,height-np.arange(height)-1,[0]))[:,:,:,0]).reshape((3,-1)).T

    camera_rays = (pixel_list - K[:,2])/np.diag(K)
    camera_rays[:,-1] = 1
    return jnp.array(camera_rays), jnp.array(so3_log_map(cam.R)[0].cpu()), jnp.array(-cam.R[0].cpu()@cam.T[0].cpu())
# get the real camera
camera_rays, axangl_true, trans_true = convert_pyt3dcamera(camera)
mrp_true = axangle2mrp(axangl_true)

# Add noise to pose

In [None]:
while True:
    trans_eps = 0.5
    rad_eps = (np.pi/180.0)*90 # range. so 90 is -45 to +45

    trans_err = np.random.randn(3)
    trans_err = trans_err/np.linalg.norm(trans_err)
    trans_err_r = np.random.rand()

    trans_err = trans_eps*trans_err*trans_err_r
    trans_shift = np.array(trans_true) - trans_err

    angles = np.random.randn(3)
    angles = angles/np.linalg.norm(angles)
    angle_mag = (np.random.rand()-0.5)*rad_eps
    R_I = np.array(fm_render.mrp_to_rot(mrp_true).T)
    R_R = np.array(axis_angle_to_matrix(torch.tensor(angles*angle_mag))).T
    R_C =  R_R @ R_I

    axangl_init = np.array(so3_log_map(torch.tensor(R_C[None])))[0]
    trans_init = np.array(R_R@np.array(trans_shift))
    mrp_init = axangle2mrp(axangl_init)

    rand_rot = abs(angle_mag*(180.0/np.pi))
    rand_trans = 100*(trans_err_r*trans_eps/t_model_scale)
    init_pose_err = np.sqrt(rand_rot*rand_trans)
    print('pose error of {:.1f}, random rotation of {:.1f} degrees and translation of {:.1f}%'.format(init_pose_err,rand_rot,rand_trans))
    if rand_trans > 30 and rand_rot > 30:
        break

In [None]:
plt.subplot(2,2,1)
plt.imshow(target_image.cpu())
plt.title('image')
plt.axis('off')
plt.subplot(2,2,2)
vmin,vmax = np.nanmin(target_depth.cpu()),np.nanmax(target_depth.cpu())
plt.imshow(target_depth.cpu(),vmin=vmin,vmax=vmax)
plt.title('depth')
plt.axis('off')
est_depth, est_probs, est_alpha = render_jit(mean,prec,weight_log,camera_rays,mrp_init,trans_init,beta2/shape_scale,beta3,beta4,beta5)
est_depth = np.array(est_depth)
est_depth[est_alpha < 0.5] = np.nan
plt.subplot(2,2,3)
plt.imshow(est_alpha.reshape(image_size),cmap='Greys')
plt.title('FM alpha')
plt.axis('off')
plt.subplot(2,2,4)
plt.imshow(est_depth.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('FM depth')
plt.axis('off')
plt.tight_layout()

# Solve for camera pose

In [None]:
def error_func(est_depth,est_alpha,true_depth):
    cond = jnp.isnan(est_depth) | jnp.isnan(true_depth)
    #err = (est_depth - true_depth)/jnp.nan_to_num(true_depth,nan=1)
    err = (est_depth - true_depth)/jnp.nanmean(true_depth)

    depth_loss =  (jnp.where(cond,0,err)**2).mean()

    true_alpha = ~jnp.isnan(true_depth)
    est_alpha = jnp.clip(est_alpha,1e-6,1-1e-6)
    mask_loss = -((true_alpha * jnp.log(est_alpha)) + (~true_alpha)*jnp.log(1-est_alpha))

    term1 = depth_loss.mean()
    term2 = mask_loss.mean()
    return 20*term1 + term2

def objective(params,means,prec,weights_log,camera_rays,beta2,beta3,beta4,beta5,depth):
    mrp,trans= params
    render_res = render_jit(means,prec,weights_log,camera_rays,mrp,trans,beta2,beta3,beta4,beta5)
    return error_func(render_res[0],render_res[2],depth)

def objective_simple(params,means,prec,weights_log,camera_rays,beta2,beta3,beta4,beta5,depth):
    mrp = jnp.array(params[:3])
    trans = jnp.array(params[3:])
    render_res = render_jit(means,prec,weights_log,camera_rays,mrp,trans,beta2,beta3,beta4,beta5)
    return error_func(render_res[0],render_res[2],depth)
grad_render3 = jax.jit(jax.value_and_grad(objective))

In [None]:
from jax.example_libraries import optimizers
from util import DegradeLR
# Number of optimization steps
# typically only needs a few hundred
# and early exits
Niter = 2000

loop = tqdm(range(Niter))

# babysit learning rates
adjust_lr = DegradeLR(1e-3,0.1,50,10,-1e-4)
opt_init, opt_update, opt_params = optimizers.momentum(adjust_lr.step_func,0.95)

# to test scale invariance
HUHSCALE = 1
# should get same result even if world scale changes

tmp = [mrp_init,HUHSCALE*trans_init]
opt_state = opt_init(tmp)

losses = []
jax_tdepth = jnp.array(target_depth.cpu().ravel())

for i in loop:
    p = opt_params(opt_state)

    val,g = grad_render3(p,HUHSCALE*mean,prec/HUHSCALE,weight_log,camera_rays,beta2/(HUHSCALE*shape_scale),beta3,beta4,beta5,HUHSCALE*jax_tdepth)
    
    S = jnp.linalg.norm(p[1])
    S2 = S*S

    g1 = g[0]
    g2 = g[1]*S2

    opt_state = opt_update(i, [g1,g2], opt_state)

    val = float(val)
    losses.append(val)
    if adjust_lr.add(val):
        break
    # Print the losses
    loop.set_description("total_loss = %.3f" % val)

In [None]:
mrp_final, trans_final = opt_params(opt_state)
trans_final = trans_final/HUHSCALE

In [None]:
# 2nd order is also possible
if False:
    from jax.scipy.optimize import minimize
    res = minimize(objective_simple,jnp.hstack([mrp_init,trans_init]),method='BFGS',args=(mean,prec,weight_log,camera_rays,beta2,beta3,beta4,beta5,jax_tdepth,))
    mrp_final = res.x[:3]
    trans_final = res.x[3:]

In [None]:
plt.title('convergence plot')
plt.plot(losses,marker='.',lw=0,ms=5,alpha=0.5)
plt.xlabel('iteration')
plt.ylabel('log loss')

# Visualize Results

In [None]:
plt.subplot(2,3,1)
plt.imshow(target_image.cpu())
plt.title('image')
plt.axis('off')
plt.subplot(2,3,4)
est_depth_true, est_prob_true, est_alpha_true = render_jit(mean,prec,weight_log,camera_rays,mrp_true,trans_true,beta2/shape_scale,beta3,beta4,beta5)
est_depth_true = np.array(est_depth_true)
est_depth_true[est_alpha_true < 0.5] = np.nan
plt.imshow(est_depth_true.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('true pose')
plt.axis('off')
est_depth_init, est_probs,est_alpha = render_jit(mean,prec,weight_log,camera_rays,mrp_init,trans_init,beta2/shape_scale,beta3,beta4,beta5)
est_depth_init = np.array(est_depth_init)
est_depth_init[est_alpha < 0.5] = np.nan
plt.subplot(2,3,2)
plt.imshow(est_alpha.reshape(image_size),cmap='Greys')
plt.title('init FM alpha')
plt.axis('off')
plt.subplot(2,3,5)
plt.imshow(est_depth_init.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('init FM depth')
plt.axis('off')
est_depth, est_probs, est_alpha = render_jit(mean,prec,weight_log,camera_rays,mrp_final,trans_final,beta2/shape_scale,beta3,beta4,beta5)
est_depth = np.array(est_depth)
est_depth[est_alpha < 0.5] = np.nan
plt.subplot(2,3,3)
plt.imshow(est_alpha.reshape(image_size),cmap='Greys')
plt.title('final FM alpha')
plt.axis('off')
plt.subplot(2,3,6)
plt.imshow(est_depth.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('final FM depth')
plt.axis('off')
plt.tight_layout()
#plt.savefig('pose_est.pdf',facecolor=plt.gcf().get_facecolor(), edgecolor='none',bbox_inches='tight')

In [None]:
def mrp2quat(mrp):
    mag = mrp @ mrp
    q = np.array([1-mag] + list(2*mrp))/(1+mag)
    return torch.tensor(q)
q1 = mrp2quat(np.array(mrp_true))
q2 = mrp2quat(np.array(mrp_final))
e1 = torch.acos(torch.clamp((q1 * q2).sum(),-1,1))
e2 = torch.acos(torch.clamp((-q1 * q2).sum(),-1,1))
rot_err = float((180.0/np.pi)*2*min(e1,e2))

R1 = np.array(quaternion_to_matrix(q1))
R2 = np.array(quaternion_to_matrix(q2))
t_norm = np.linalg.norm(R1.T@np.array(trans_true)-R2.T@np.array(trans_final))
trans_err = 100*t_norm/t_model_scale

pose_err = np.sqrt(rot_err*trans_err)
print('init. pose error of {:04.1f} with rot. of {:04.1f} deg and trans. of {:04.1f}%'.format(init_pose_err,rand_rot,rand_trans))
print('final pose error of {:04.1f} with rot. of {:04.1f} deg and trans. of {:04.1f}%'.format(pose_err,rot_err,trans_err))