# Pose Estimation
Compare to PyTorch3D `Camera position optimization` sample. 

In [None]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt


In [None]:
%matplotlib inline

In [None]:
import torch

# io utils
from pytorch3d.io import load_objs_as_meshes

# 3D transformations functions
from pytorch3d.transforms import so3_log_map, axis_angle_to_matrix, axis_angle_to_quaternion, quaternion_to_matrix

# rendering components
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, 
    RasterizationSettings, MeshRenderer, MeshRasterizer,
    HardPhongShader, PointLights, TexturesVertex
)

from tqdm.notebook import tqdm

# Load data and generate views with PyTorch3D
we're using the cow model from Keenan Crane, featured in the PyTorch3D tutorials

In [None]:
if not torch.cuda.is_available():
    torch_device = torch.device('cuda')
else:
    torch_device = torch.device('cpu')
mesh = load_objs_as_meshes(['data/cow.obj'], device=torch_device)

# seems sane to fetch/estimate scale
shape_scale = float(mesh.verts_list()[0].std(0).mean())*3 
t_model_scale = np.ptp(np.array(mesh.verts_list()[0]),0).mean()
print('model is {:.2f}x the size of the cow'.format(shape_scale/1.18))

This is simply the dataset generation code, taken from the PyTorch3D tutorial

In [None]:
image_size = (64,64)
vfov_degrees = 60

lights = PointLights(device=torch_device, location=[[0.0, 0.0, -3.0*shape_scale]])
R, T = look_at_view_transform(dist=2.7*shape_scale, elev=60, azim=20)
camera = FoVPerspectiveCameras(device=torch_device,znear=shape_scale, zfar=100*shape_scale, fov=vfov_degrees) 
raster_settings = RasterizationSettings(
    image_size=image_size, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(
        device=torch_device, 
        cameras=camera,
        lights=lights
    )
)
depth_raster =MeshRasterizer(
    cameras=camera, 
    raster_settings=raster_settings
)

# Render the cow mesh 
target_image = np.squeeze(renderer(mesh, camera=camera, lights=lights, R=R, T=T))
target_depth = np.squeeze(depth_raster(mesh, camera=camera, lights=lights, R=R, T=T).zbuf)
target_depth[target_depth == -1] = np.nan

# Setup Fuzzy Metaball renderer

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
#jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import fm_render
hyperparams = np.array([-13.69159107,  -2.67968404,   0.71023575,   6.36908448,  -5.42242999])
NUM_MIXTURE = 40
beta1 = jnp.float32(np.exp(hyperparams[0]))
beta2 = jnp.float32(np.exp(hyperparams[1]))
beta3 = jnp.float32(np.exp(hyperparams[2]))
beta4 = jnp.float32(np.exp(hyperparams[3]))
beta5 = jnp.float32(-np.exp(hyperparams[4]))

render_jit = jax.jit(fm_render.render_func)

In [None]:
import pickle
# load a model of the cow. feel free to use the one reconstructed from silhouettes!
with open('fuzzy_cow.pkl','rb') as fp:
    mean,prec,weight_log = pickle.load(fp)
weights = np.exp(weight_log)
weights /= np.sum(weights)
obj_scale = (weights[:,None] * mean).std(0).mean()

In [None]:
def convert_pyt3dcamera(cam, image_size = image_size):
    height, width = image_size
    cx = width/2
    cy = height/2
    f = (height/np.tan((np.pi/180)*float(cam.fov[0])/2))*0.5
    K = np.array([[f, 0, cx],[0,f,cy],[0,0,1]])
    pixel_list = (np.array(np.meshgrid(width-np.arange(width),height-np.arange(height),[0]))[:,:,:,0]).reshape((3,-1)).T

    camera_rays = (pixel_list - K[:,2])/np.diag(K)
    camera_rays[:,-1] = 1
    return jnp.array(camera_rays), jnp.array(so3_log_map(cam.R)[0]), jnp.array(-cam.R[0]@cam.T[0])
# get the real camera
camera_rays, axangl_true, trans_true = convert_pyt3dcamera(camera)

# Add noise to pose

In [None]:
trans_eps = 0.5
rad_eps = (np.pi/180.0)*90 # range. so 90 is -45 to +45

trans_err = np.random.randn(3)
trans_err = trans_err/np.linalg.norm(trans_err)
trans_err_r = np.random.rand()

trans_err = trans_eps*trans_err*trans_err_r
trans_shift = np.array(trans_true) - trans_err

angles = np.random.randn(3)
angles = angles/np.linalg.norm(angles)
angle_mag = (np.random.rand()-0.5)*rad_eps
R_I = np.array(axis_angle_to_matrix(torch.tensor(np.array(axangl_true))))
R_R = np.array(axis_angle_to_matrix(torch.tensor(angles*angle_mag))).T
R_C =  R_R @ R_I

axangl_init = np.array(so3_log_map(torch.tensor(R_C[None])))[0]
trans_init = np.array(R_R@np.array(trans_shift))

rand_rot = abs(angle_mag*(180.0/np.pi))
rand_trans = 100*(trans_err_r*trans_eps/t_model_scale)
init_pose_err = np.sqrt(rand_rot*rand_trans)
print('pose error of {:.1f}, random rotation of {:.1f} degrees and translation of {:.1f}%'.format(init_pose_err,rand_rot,rand_trans))

In [None]:
plt.subplot(2,2,1)
plt.imshow(target_image)
plt.title('image')
plt.axis('off')
plt.subplot(2,2,2)
vmin,vmax = np.nanmin(target_depth),np.nanmax(target_depth)
plt.imshow(target_depth,vmin=vmin,vmax=vmax)
plt.title('depth')
plt.axis('off')
est_depth, est_probs = render_jit(mean,prec,weight_log,camera_rays,axangl_init,trans_init,beta1/obj_scale,beta2/obj_scale,beta3)
est_alpha = jnp.tanh(beta4*(jnp.exp(est_probs).sum(0)+beta5) )*0.5 + 0.5
est_depth = np.array(est_depth)
est_depth[est_alpha < 0.5] = np.nan
plt.subplot(2,2,3)
plt.imshow(est_alpha.reshape(image_size),cmap='Greys')
plt.title('FM alpha')
plt.axis('off')
plt.subplot(2,2,4)
plt.imshow(est_depth.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('FM depth')
plt.axis('off')
plt.tight_layout()

# Solve for camera pose

In [None]:
def error_func(est_depth,est_alpha,true_depth,b4,b5):
    cond = jnp.isnan(est_depth) | jnp.isnan(true_depth)
    valid_depth_frac =   (~jnp.isnan(cond)).sum()/cond.shape[0]
    avg_depth = jnp.where(cond,0,true_depth).mean()/valid_depth_frac
    err = (est_depth - true_depth)/avg_depth
    depth_loss =  (jnp.where(cond,0,err)**2).mean()

    true_alpha = ~jnp.isnan(true_depth)
    est_alpha = jnp.tanh(b4*(jnp.exp(est_alpha).sum(0)+b5) )*0.5 + 0.5
    est_alpha = jnp.clip(est_alpha,1e-6,1-1e-6)
    mask_loss = -((true_alpha * jnp.log(est_alpha)) + (~true_alpha)*jnp.log(1-est_alpha))
    
    loss_mul = true_alpha.sum()
    term1 = depth_loss.mean()
    term2 = mask_loss.mean()
    return term1 + term2

def objective(params,depth):
    means,prec,weights_log,camera_rays,axangl,trans, beta1,beta2,beta3,beta4,beta5 = params
    render_res = render_jit(means,prec,weights_log,camera_rays,axangl,trans,beta1,beta2,beta3)
    return error_func(render_res[0],render_res[1],depth,beta4,beta5)
grad_render3 = jax.jit(jax.value_and_grad(objective))

In [None]:
from jax.experimental import optimizers
from util import DegradeLR
# Number of optimization steps
# typically only needs a few hundred
# and early exits
Niter = 2000

loop = tqdm(range(Niter))

# babysit learning rates
adjust_lr = DegradeLR(1e-3,0.5,50,10,-1e-3)
opt_init, opt_update, opt_params = optimizers.momentum(adjust_lr.step_func,0.9)


tmp = [axangl_init,trans_init]
opt_state = opt_init(tmp)

losses = []
jax_tdepth = jnp.array(target_depth.ravel())

for i in loop:
    p = opt_params(opt_state)

    val,g = grad_render3([mean,prec,weight_log,camera_rays,p[0],p[1],beta1/obj_scale,beta2/obj_scale,beta3,beta4,beta5],jax_tdepth)

    S = jnp.linalg.norm(p[1])
    S2 = S*S
    RS = jnp.linalg.norm(g[4])
    TS = jnp.linalg.norm(g[5])*S
    g1 = g[4]
    g2 = g[5]*S2

    opt_state = opt_update(i, [g1,g2], opt_state)

    val = float(val)
    losses.append(val)
    if adjust_lr.add(val):
        break
    # Print the losses
    loop.set_description("total_loss = %.3f" % val)

In [None]:
axangl_final, trans_final = opt_params(opt_state)

In [None]:
plt.title('convergence plot')
plt.plot(losses,marker='.',lw=0,ms=5,alpha=0.5)
plt.xlabel('iteration')
plt.ylabel('log loss')

# Visualize Results

In [None]:
plt.subplot(2,3,1)
plt.imshow(target_image)
plt.title('image')
plt.axis('off')
plt.subplot(2,3,4)
vmin,vmax = np.nanmin(target_depth),np.nanmax(target_depth)
est_depth_true, est_prob_true = render_jit(mean,prec,weight_log,camera_rays,axangl_true,trans_true,beta1/obj_scale,beta2/obj_scale,beta3)
est_alpha_true = jnp.tanh(beta4*(jnp.exp(est_prob_true).sum(0)+beta5) )*0.5 + 0.5
est_depth_true = np.array(est_depth_true)
est_depth_true[est_alpha_true < 0.5] = np.nan
plt.imshow(est_depth_true.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('true pose')
plt.axis('off')
est_depth, est_probs = render_jit(mean,prec,weight_log,camera_rays,axangl_init,trans_init,beta1/obj_scale,beta2/obj_scale,beta3)
est_alpha = jnp.tanh(beta4*(jnp.exp(est_probs).sum(0)+beta5) )*0.5 + 0.5
est_depth = np.array(est_depth)
est_depth[est_alpha < 0.5] = np.nan
plt.subplot(2,3,2)
plt.imshow(est_alpha.reshape(image_size),cmap='Greys')
plt.title('init FM alpha')
plt.axis('off')
plt.subplot(2,3,5)
plt.imshow(est_depth.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('init FM depth')
plt.axis('off')
est_depth, est_probs = render_jit(mean,prec,weight_log,camera_rays,axangl_final,trans_final,beta1/obj_scale,beta2/obj_scale,beta3)
est_alpha = jnp.tanh(beta4*(jnp.exp(est_probs).sum(0)+beta5) )*0.5 + 0.5
est_depth = np.array(est_depth)
est_depth[est_alpha < 0.5] = np.nan
plt.subplot(2,3,3)
plt.imshow(est_alpha.reshape(image_size),cmap='Greys')
plt.title('final FM alpha')
plt.axis('off')
plt.subplot(2,3,6)
plt.imshow(est_depth.reshape(image_size),vmin=vmin,vmax=vmax)
plt.title('final FM depth')
plt.axis('off')
plt.tight_layout()
#plt.savefig('pose_est.pdf',facecolor=plt.gcf().get_facecolor(), edgecolor='none',bbox_inches='tight')

In [None]:
q1 = axis_angle_to_quaternion(torch.tensor(np.array(axangl_true))[None])[0]
q2 = axis_angle_to_quaternion(torch.tensor(np.array(axangl_final))[None])[0]
e1 = torch.acos(torch.clamp((q1 * q2).sum(),-1,1))
e2 = torch.acos(torch.clamp((-q1 * q2).sum(),-1,1))
rot_err = float((180.0/np.pi)*2*min(e1,e2))

R1 = np.array(quaternion_to_matrix(q1))
R2 = np.array(quaternion_to_matrix(q2))
t_norm = np.linalg.norm(R1.T@np.array(trans_true)-R2.T@np.array(trans_final))
trans_err = 100*t_norm/t_model_scale

pose_err = np.sqrt(rot_err*trans_err)
print('init. pose error of {:04.1f} with rot. of {:04.1f} deg and trans. of {:04.1f}%'.format(init_pose_err,rand_rot,rand_trans))
print('final pose error of {:04.1f} with rot. of {:04.1f} deg and trans. of {:04.1f}%'.format(pose_err,rot_err,trans_err))
