# Run CO3D Sequence for Shape from Silhouette

In [None]:
import sys, os
import numpy as np
import pandas as pd

In [None]:
data_dir = 'data'
type_f = 'teddybear'
co3d_seq = '379_44778_89217'
output_folder = type_f+'_'+co3d_seq
co3d_seq_folder = os.path.join(data_dir,type_f,co3d_seq)

In [None]:
df = pd.read_json(os.path.join(data_dir,type_f,'frame_annotations.jgz'),compression={'method':'gzip'})

## Load Data

In [None]:
NUM_MIXTURE = 40
SCALE = 4

In [None]:
import skimage.io as io
import matplotlib.pyplot as plt
import transforms3d 
import skimage.transform as sktrans
df2 = df[df.sequence_name == int(co3d_seq.replace('_',''))]

images = []
masks = []
rotations = []
rotation_mats = []
translations = []
fls = []
pps = []
sizes = []
ground_truths = []
crays_set = []
ground_images = []
for row in  df2.sort_values('frame_number').itertuples():
    fn, imgd, maskd, view = row[2],row[4],row[6],row[7]
    maskd = maskd['path'][maskd['path'].index(co3d_seq):]
    imgd = imgd['path'][imgd['path'].index(co3d_seq):]
    
    Rmat = np.array(view['R'])
    Tvec = np.array(view['T'])
    Tvec = -Rmat @ Tvec
    img = io.imread(os.path.join(data_dir,type_f,imgd))
    mask = io.imread(os.path.join(data_dir,type_f,maskd))
    images.append(img)
    masks.append(mask)
    v,s = transforms3d.axangles.mat2axangle(Rmat)
    rotations.append(v*s)
    rotation_mats.append(Rmat)
    translations.append(Tvec)
    
    fl = np.array(view['focal_length'])
    pp = np.array(view['principal_point'])
    
    half_image_size_wh_orig = np.array(list(reversed(mask.shape))) / 2.0
    # principal point and focal length in pixels
    principal_point_px = (
        -1.0 * (pp - 1.0) * half_image_size_wh_orig
    )
    focal_length_px = fl * half_image_size_wh_orig

    fls.append(focal_length_px)
    pps.append(principal_point_px)
    
    
    sizeA = np.array(mask.shape)
    PX,PY = reversed(sizeA)
    FLX, FLY = focal_length_px
    CX,CY = principal_point_px
    #print(PX,PY,FLX,FLY,CX,CY)
    
    PY = PY//SCALE
    PX = PX//SCALE
    sizes.append((PX,PY))

    cx = CX/SCALE
    cy = CY/SCALE
    fx = FLX/SCALE
    fy = FLY/SCALE
    K = np.array([[fx, 0, cx],[0,fy,cy],[0,0,1]])
    pixel_list = (np.array(np.meshgrid(PX-np.arange(PX)-1,PY-np.arange(PY)-1,[0]))[:,:,:,0]).reshape((3,-1)).T

    camera_rays = (pixel_list - K[:,2])/np.diag(K)
    camera_rays[:,-1] = 1
    crays_set.append(camera_rays)
    ground_truths.append(sktrans.resize(mask,(PY,PX)))
    ground_images.append(sktrans.resize(img,(PY,PX)))

In [None]:
SCALE_MUL_FACTOR = 2.7/np.linalg.norm(translations,axis=1).mean()
# gradients can be sensitive to scale

In [None]:
# get a rough init to make sure everything loaded right
import trimesh
pt_cld = trimesh.load(os.path.join(co3d_seq_folder,'pointcloud.ply')).vertices
import sklearn.mixture as mixture

idx2 = np.arange(pt_cld.shape[0])
np.random.shuffle(idx2)
clf = mixture.GaussianMixture(40)
clf.fit(pt_cld[idx2[:10000]])

obj_scale_true = (clf.weights_[:,None] * clf.means_).std(0).mean()

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
#jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import fm_render

hyperparams = fm_render.hyperparams_models

beta2 = jnp.float32(np.exp(hyperparams[0]))
beta3 = jnp.float32(np.exp(hyperparams[1]))
beta4 = jnp.float32(np.exp(hyperparams[2]))
beta5 = -jnp.float32(np.exp(hyperparams[3]))

render_jit = jax.jit(fm_render.render_func)
obj_scale = obj_scale_true # more reliable than using pose data for a prior

In [None]:
# test render
means, prec, weights =  [jnp.array(_) for _ in [clf.means_, clf.precisions_cholesky_, clf.weights_]]
weights_log = jnp.log(weights+1e-6)

test_alphas = []
for i in range(len(rotations)):
    axangl = rotations[i]
    trans = translations[i]
    camera_rays = crays_set[i]
    PX,PY = sizes[i]
    res_img,res_p,est_alpha = render_jit(means,prec,weights_log,camera_rays,axangl,trans,beta2/obj_scale_true,beta3,beta4,beta5)

    res_imgA = np.array(res_img)
    res_imgA[est_alpha < 0.5] = np.nan
    test_alphas.append(est_alpha.reshape((PY,PX)))

In [None]:
import matplotlib.pyplot as plt
from util import image_grid
image_grid(test_alphas,8,9,rgb=False)

In [None]:
# random init settings

init_scale = 0.25

rand_sphere_size = 30.0*init_scale # inverse size
rand_sphere_var = 0.012/init_scale # actual distribution
clip_d = 2.0
scale_mul = 0.5
weight_eps = 0.3
prec_eps = 0.02

# there are all unscaled, 0 mean, standard size, etc.
rand_mean_base = np.random.multivariate_normal(mean=[0,0,0],cov=np.identity(3),size=NUM_MIXTURE)
clipped_mean_base = (np.minimum(clip_d,np.linalg.norm((rand_mean_base),axis=1))/np.linalg.norm((rand_mean_base),axis=1))[:,None] * rand_mean_base
rand_prec_base = np.array([np.identity(3) for _ in range(NUM_MIXTURE)])

# these get shifted to the problem at hand
rand_mean = np.mean(pt_cld.astype(float),0) + np.sqrt(rand_sphere_var)*clipped_mean_base
rand_weight_log = np.log(( np.ones(NUM_MIXTURE) + np.maximum(-0.99,weight_eps*np.random.randn(NUM_MIXTURE)) )/NUM_MIXTURE) 
rand_prec = rand_sphere_size*(rand_prec_base  + np.maximum(-0.99,prec_eps*np.random.randn(*rand_prec_base.shape)))


init_alphas = []
for i in range(len(rotations)):
    axangl = rotations[i]
    trans = translations[i]
    camera_rays = crays_set[i]
    PX,PY = sizes[i]
    res_img,res_p,est_alpha = render_jit(rand_mean,rand_prec,rand_weight_log,camera_rays,axangl,trans,beta2/obj_scale_true,beta3,beta4,beta5)

    res_imgA = np.array(res_img)
    res_imgA[est_alpha < 0.5] = np.nan
    init_alphas.append(est_alpha.reshape((PY,PX)))
image_grid(init_alphas,6,6,rgb=False)

In [None]:
total_ray_set = []
for i in range(len(rotations)):
    rmat = rotation_mats[i]
    trans = translations[i]
    camera_rays = crays_set[i]
    camera_rays = camera_rays @ np.array(rmat).T
    trans = np.tile(trans[None],(camera_rays.shape[0],1))
    
    rays_trans = np.stack([camera_rays,trans],1)
    total_ray_set.append(np.array(rays_trans))

In [None]:
all_rays = jnp.vstack(total_ray_set)

In [None]:
render_jit_ray = jax.jit(fm_render.render_func_rays)
last_img_size = sizes[-1][0]*sizes[-1][1]
res_img,res_p,est_alpha = render_jit_ray(means,prec,weights_log,all_rays[-last_img_size:],beta2/obj_scale_true,beta3,beta4,beta5)

In [None]:
plt.subplot(1,2,1)
plt.imshow(est_alpha.reshape((sizes[-1][1],sizes[-1][0])))
plt.subplot(1,2,2)
plt.imshow(ground_truths[-1])

In [None]:
def objective(params,camera_rays,beta2,beta3,beta4,beta5,true_alpha):
    CLIP_ALPHA = 1e-6
    means,prec,weights_log = params
    render_res = fm_render.render_func_rays(means,prec,weights_log,camera_rays,beta2,beta3,beta4,beta5)

    est_alpha = render_res[2]
    est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA)
    mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))
    return mask_loss.mean()
grad_render3 = jax.value_and_grad(objective)

In [None]:
from jax.example_libraries import optimizers
from tqdm.notebook import tqdm
from util import DegradeLR

Nepoch = 10
batch_size = 24000
OPT_SCALE = SCALE_MUL_FACTOR

vecM = jnp.array([[1,1,1],[OPT_SCALE,OPT_SCALE,OPT_SCALE]])[None]

train_size = all_rays.shape[0]
Niter_epoch = int(round(train_size/batch_size))

def irc(x): return int(round(x))

# babysit learning rates
adjust_lr = DegradeLR(1e-1,0.5,irc(Niter_epoch*0.25),irc(Niter_epoch*0.1),-1e-4)

opt_init, opt_update, opt_params = optimizers.adam(adjust_lr.step_func)
tmp = [rand_mean,rand_prec,rand_weight_log]
opt_state = opt_init(tmp)

all_sils = jnp.hstack([_.ravel() for _ in ground_truths]).astype(jnp.float32)

losses = []
opt_configs = []
outer_loop = tqdm(range(Nepoch), desc=" epoch", position=0)

rand_idx = np.arange(train_size)

def inner_iter(j_idx,rand_idx_local,opt_state):
    idx = jax.lax.dynamic_slice(rand_idx_local,[j_idx*batch_size],[batch_size])

    p = opt_params(opt_state)
    val,g = grad_render3([p[0]*OPT_SCALE,p[1]/OPT_SCALE,p[2]],vecM*all_rays[idx],beta2/(OPT_SCALE*obj_scale),beta3,beta4,beta5,all_sils[idx])   
    opt_state = opt_update(i, g, opt_state)
    return val, opt_state
jax_iter = jax.jit(inner_iter)
done = False
for i in outer_loop:
    np.random.shuffle(rand_idx)
    rand_idx_jnp = jnp.array(rand_idx)

    for j in tqdm(range(Niter_epoch), desc=" iteration", position=1, leave=False):
        opt_configs.append(list(opt_params(opt_state)))
        val,opt_state = jax_iter(j,rand_idx_jnp,opt_state)
        val = float(val)
        losses.append(val)

        if adjust_lr.add(val):
            done = True
            break
        outer_loop.set_description(" loss {:.3f}".format(val))
    if done:
        break

In [None]:
plt.plot(losses)

In [None]:
final_mean, final_prec, final_weight_log = opt_params(opt_state)

In [None]:
result_alphas = []
for i in range(len(rotations)):
    axangl = rotations[i]
    trans = translations[i]
    camera_rays = crays_set[i]
    PX,PY = sizes[i]
    res_img,res_p,est_alpha = render_jit(final_mean*OPT_SCALE,final_prec/OPT_SCALE,final_weight_log,camera_rays,axangl,trans*OPT_SCALE,beta2/(OPT_SCALE*obj_scale),beta3,beta4,beta5)

    res_imgA = np.array(res_img)
    res_imgA[est_alpha < 0.5] = np.nan
    result_alphas.append(est_alpha.reshape((PY,PX)))

In [None]:
plt.subplot(1,2,1)
plt.imshow(result_alphas[-1])
plt.subplot(1,2,2)
plt.imshow(res_imgA.reshape((PY,PX)))
plt.colorbar()

In [None]:
image_grid(result_alphas,rows=3,cols=5,rgb=False)

In [None]:
image_grid(ground_truths,rows=3,cols=5,rgb=False)

In [None]:
max_frame = len(rotations)
FWD_BCK_TIMES = 4
THRESH_IDX = np.where(np.array(losses)/min(losses) < 1.1)[0][0]
USE_FIRST_N_FRAC = THRESH_IDX/len(losses)
N_FRAMES = max_frame*FWD_BCK_TIMES
opt_to_use = np.round(np.linspace(0,int(np.floor(len(opt_configs)*USE_FIRST_N_FRAC-1)),N_FRAMES)).astype(int)

In [None]:
THRESH_IDX/len(losses)

In [None]:
plt.plot(losses[:THRESH_IDX])

In [None]:
frame_idxs = []
frame_list = list(range(max_frame))
for i in range(FWD_BCK_TIMES):
    if (i % 2) == 0:
        frame_idxs += frame_list
    else:
        frame_idxs += frame_list[::-1]

In [None]:
full_res_alpha = []
full_res_depth = []
for r_idx,c_idx in zip(frame_idxs,opt_to_use):
    p = opt_configs[c_idx]

    axangl = rotations[r_idx]
    trans = translations[r_idx]
    camera_rays_F = crays_set[r_idx]
    PX_F,PY_F = sizes[r_idx]

    est_depth,res_p,est_alpha = render_jit(OPT_SCALE*p[0],p[1]/OPT_SCALE,p[2],camera_rays_F,axangl,OPT_SCALE*trans,beta2/(OPT_SCALE*obj_scale),beta3,beta4,beta5)
    est_alpha = np.array(est_alpha)
    est_depth = np.array(est_depth)

    est_depth[est_alpha < 0.5] = np.nan
    full_res_alpha.append(est_alpha.reshape((PY_F,PX_F)))
    full_res_depth.append(est_depth.reshape((PY_F,PX_F)))
    print('.',end='')

In [None]:
if os.path.exists(output_folder):
    import shutil
    shutil.rmtree(output_folder)
os.mkdir(output_folder)

In [None]:
vecr = np.hstack([_.ravel() for _ in full_res_depth])
vecr = vecr[~np.isnan(vecr)]
vmin = np.percentile(vecr,5)
vmax = np.percentile(vecr,95)
vscale = vmax-vmin

In [None]:
import matplotlib
from PIL import Image, ImageDraw, ImageFont
start_f = 0
avg_size = np.mean(sizes,axis=0)
fsize = irc(96/SCALE)

font = ImageFont.truetype('Roboto-Regular.ttf', size=irc(96/SCALE))
cmap = matplotlib.cm.get_cmap('viridis')
cmap2 = matplotlib.cm.get_cmap('magma')

for i,mask_res in enumerate(full_res_alpha):
    r_idx = frame_idxs[i]
    #img1 = ground_images[r_idx]/255.0*np.clip(full_masks[r_idx] > .1,0.3,1)[:,:,None]
    #img2 = ground_images[r_idx]*np.clip((mask_res)**0.4,0.05,1)[:,:,None]
    img2 = np.tile(mask_res[:,:,None],(1,1,3))
    img_gt_mask = np.tile(ground_truths[r_idx][:,:,None],(1,1,3))

    true_alpha = ground_truths[r_idx]

    est_alpha = jnp.clip(mask_res,1e-6,1-1e-6)
    mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))
    loss_viz = cmap2(0.25*mask_loss)[:,:,:3]

    depth = cmap((full_res_depth[i]-vmin)/vscale)[:,:,:3]
    img2 = np.concatenate((ground_images[r_idx],img_gt_mask,loss_viz,img2, depth), axis=1)
    int_img = np.round(img2*255).astype(np.uint8)
    pil_img = Image.fromarray(int_img)
    d1 = ImageDraw.Draw(pil_img)
    d1.text((avg_size[0]*1.1, irc(fsize*0.1)), "Iteration: {:3d}".format(opt_to_use[i]), ha='center',font=font,fill=(255, 255, 255))
    d1.text((avg_size[0]*1.3, irc(avg_size[1]-fsize*1.5)), "Target Mask", font=font,fill=(255, 255, 255),ha='center')
    d1.text((avg_size[0]*2.4, irc(avg_size[1]-fsize*1.5)), "Loss", font=font,fill=(255, 255, 255),ha='center',align='center')
    d1.text((avg_size[0]*3.3, irc(avg_size[1]-fsize*2.5)), "Estimated\nMask", font=font,fill=(255, 255, 255),ha='center',align='center')
    d1.text((avg_size[0]*4.3, irc(avg_size[1]-fsize*2.5)), "Estimated\nDepth", font=font,fill=(255, 255, 255),ha='center',align='center')

    img3 = np.array(pil_img)
    
    io.imsave('{}/{:03d}.jpg'.format(output_folder,i),img3,quality=95)

In [None]:
plt.figure(figsize=(18,8))
plt.imshow(img3)
plt.axis('off')

In [None]:
import subprocess
if os.path.exists('{}.mp4'.format(output_folder)):
    os.remove('{}.mp4'.format(output_folder))
subprocess.call(' '.join(['/usr/local/bin/ffmpeg',
                 '-framerate','24',
                 '-i','{}/%03d.jpg'.format(output_folder),
                 '-vf','\"pad=ceil(iw/2)*2:ceil(ih/2)*2\"',
                 '-c:v','h264',
                 '-pix_fmt','yuv420p',
                 '{}.mp4'.format(output_folder)]),shell=True)