# Run CO3D Sequence for Shape from Silhouette

In [None]:
import sys, os, glob
import numpy as np
import pandas as pd

import skimage.io as sio
import matplotlib.pyplot as plt
import pathlib

In [None]:
import os
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
#jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import fm_render

render_jit_ray = jax.jit(fm_render.render_func_rays)



In [None]:
dataset_dir = 'teddybear/34_1479_4753/'
input_folder = os.path.join(dataset_dir,'images')
co3d_seq = os.path.split(dataset_dir.rstrip('/').lstrip('/'))[-1]
output_folder = os.path.join('tmp_out',co3d_seq)
NUM_MIXTURE = 40
shape_scale = 1.8
c_scale = 4.5
rand_sphere_size = 55
cov_scale = 1.2e-2
weight_scale = 1.1
LR_RATE = 0.08
#beta2 = 21.4
#beta3 = 2.66
beta2, beta3 = jnp.array(fm_render.hyperparams)
Nepoch = 10
batch_size = 50000
target_size = 125000//4

## Load Data

In [None]:
# do it at some canonical size
in_files = sorted(glob.glob(os.path.join(input_folder,'*.jpg')) + glob.glob(os.path.join(input_folder,'*.png')))
PYo,PXo = sio.imread(in_files[0]).shape[:2]
init_scale = np.prod([PYo,PXo])
scales = {}
for i in range(10):
    scale = 2**i
    scales[scale] = init_scale/(scale**2)
scale_to_use = sorted([(abs(np.log(v/target_size)),k) for k,v in scales.items() ])[0][1]
PY,PX = int(round(PYo/scale_to_use)),int(round(PXo/scale_to_use))
scale_to_use,PY,PX

In [None]:
import skimage
import skimage.io as sio
import skimage.transform as strans
# co3d sequences miss some data
valid_inputs = []
color_images = []
file_map = {}

for idx,file in enumerate(in_files):
    name = pathlib.Path(file).parts[-1]
    img = sio.imread(file)
    valid_inputs.append(img.sum() != 0)
    new_name = 'frame{:06d}.jpg'.format(sum(valid_inputs))
    if valid_inputs[-1] == False:
        continue
    #print(new_name)
    file_map[idx] = sum(valid_inputs)
    simg = strans.resize(img,(PY,PX))
    color_images.append(simg)

In [None]:
df = pd.read_json(os.path.join(*(dataset_dir.split('/')[:-2] + ['frame_annotations.jgz'])),compression={'method':'gzip'})
df2 = df[df.sequence_name == int(co3d_seq.replace('_',''))]
fls = []
pps = []
sizes = []
assert(len(df2) == len(valid_inputs))
for i,row in enumerate(df2.sort_values('frame_number').itertuples()):
    fn, imgd, maskd, view = row[2],row[4],row[6],row[7]
    if not valid_inputs[i]:
        continue
    fl = np.array(view['focal_length'])
    pp = np.array(view['principal_point'])
    sizeA = list(row[4]['size'])

    if 'intrinsics_format' in view and view['intrinsics_format'] == 'ndc_isotropic':
        half_image_size_wh_orig = np.array(list(reversed(sizeA))) / 2.0
        rescale = half_image_size_wh_orig.min()
        # principal point and focal length in pixels
        principal_point_px = half_image_size_wh_orig - pp * rescale
        focal_length_px = fl * rescale
    else:
        half_image_size_wh_orig = np.array(list(reversed(sizeA))) / 2.0
        # principal point and focal length in pixels
        principal_point_px = (
            -1.0 * (pp - 1.0) * half_image_size_wh_orig
        )
        focal_length_px = fl * half_image_size_wh_orig

    fls.append(focal_length_px)
    pps.append(principal_point_px)

    sizes.append(sizeA)
assert(np.array(sizes).std(0).sum() == 0) # same sizes
pp = np.array(pps).mean(0)
fl = np.array(fls).mean(0).mean()
meanpp = (np.array([pp[1],pp[0]])/np.array(sizes).mean(0)).mean() 
assert(abs(meanpp - 0.5) < 1e-3) # basically center of frame
fl = fl/scale_to_use

In [None]:
poses = []
depths = []
masks = []
import skimage.io as sio
import skimage.transform as sktrans
import transforms3d
for i,row in enumerate(df2.sort_values('frame_number').itertuples()):
    fn, imgd, maskd, view = row[2],row[4],row[6],row[7]
    depthd = row[5]
    if not valid_inputs[i]:
        continue
    #maskd = maskd['path']#[maskd['path'].index(co3d_seq):]
    #imgd = imgd['path']#[imgd['path'].index(co3d_seq):]
    mask = np.clip(sio.imread(maskd['path'])/253,0,1) #> 0
    masks.append(sktrans.resize(mask,(PY,PX),anti_aliasing=True,order=0))
    
    Rmat = np.array(view['R'])
    Tvec = np.array(view['T'])
    Tvec = -Rmat @ Tvec
    q = transforms3d.quaternions.mat2quat(Rmat.T)
    poses.append((q,Tvec))
    
    depth_r = sio.imread(depthd['path'])#.astype(float)
    depth_m = sio.imread(depthd['mask_path']).astype(float)
    
    depth_r_s = depth_r.shape
    depth_r = depthd['scale_adjustment']*np.frombuffer(depth_r,dtype=np.float16).astype(np.float32).reshape(depth_r_s)

    valid_d = (depth_r > 0)

    depth_r[~valid_d] = np.nan
    depth_r[~(depth_m >0)] = np.nan

    depth_r = sktrans.resize(depth_r,(PY,PX),anti_aliasing=False,order=0)
    depths.append(depth_r)

In [None]:
if os.path.exists(os.path.join(dataset_dir,'pointcloud.ply')):
    import trimesh
    mesh_tri = trimesh.load(os.path.join(dataset_dir,'pointcloud.ply'))
    pt_cld = mesh_tri.vertices
    import sklearn.mixture as mixture

    idx2 = np.arange(pt_cld.shape[0])
    np.random.shuffle(idx2)
    clf = mixture.GaussianMixture(40)
    clf.fit(pt_cld[idx2[:10000]])

    pt_cld_shape_scale = float(pt_cld.std(0).mean())*3
    center = pt_cld.mean(0)
else:         
    pt_cld_shape_scale = 3.0
    center = np.zeros(3,dtype=np.float32)

In [None]:
SCALE_MUL_FACTOR = shape_scale/pt_cld_shape_scale
SCALE_MUL_FACTOR

In [None]:
height, width = PY,PX
cx = (PX-1)/2
cy = (PY-1)/2
K = np.array([[fl, 0, cx],[0,fl,cy],[0,0,1]])
pixel_list = (np.array(np.meshgrid(width-np.arange(width)-1,height-np.arange(height)-1,[0]))[:,:,:,0]).reshape((3,-1)).T
camera_rays = (pixel_list - K[:,2])/np.diag(K)
camera_rays[:,-1] = 1
cameras_list = []
for quat,trans in poses:
    R = transforms3d.quaternions.quat2mat(quat)
    camera_rays2 = camera_rays @ R
    t = np.tile(trans[None],(camera_rays2.shape[0],1))
    
    rays_trans = np.stack([camera_rays2,t],1)
    cameras_list.append(rays_trans)

In [None]:
from util import image_grid

In [None]:
# random init settings
rand_mean = center+pt_cld_shape_scale*np.random.multivariate_normal(mean=[0,0,0],cov=cov_scale*np.identity(3),size=NUM_MIXTURE)
rand_weight_log = jnp.log(weight_scale*np.ones(NUM_MIXTURE)/NUM_MIXTURE)
rand_prec = jnp.array([np.identity(3)*rand_sphere_size/pt_cld_shape_scale for _ in range(NUM_MIXTURE)])
rand_color = jnp.array(np.random.randn(NUM_MIXTURE,3))

init_alphas = []
init_depths = []
render_jit = jax.jit(fm_render.render_func_rays)

for ray_trans in cameras_list[:36]:
    est_depth, est_alpha, est_norm, est_w = render_jit(rand_mean,rand_prec,rand_weight_log,ray_trans,beta2/shape_scale,beta3)

    est_depth = np.array(est_depth)
    est_depth[est_alpha < 0.5] = np.nan
    init_alphas.append(est_alpha.reshape((PY,PX)))
    init_depths.append(est_depth.reshape((PY,PX)))

image_grid(init_alphas,6,6,rgb=False)

In [None]:
def objective(params,camera_rays,beta2,beta3,true_alpha,true_color):
    CLIP_ALPHA = 1e-7
    means,prec,weights_log,colors = params
    est_depth, est_alpha, est_norm, est_w = fm_render.render_func_rays(means,prec,weights_log,camera_rays,beta2,beta3)
    est_color = est_w.T @ (jnp.tanh(colors)*0.5+0.5)
    est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA)
    mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))
    cdiff = jnp.abs( (true_color-est_color)*true_alpha[:,None] )
    return mask_loss.mean() + c_scale*cdiff.mean()
grad_render3 = jax.value_and_grad(objective)

In [None]:
import optax
from tqdm.notebook import tqdm
from util import DegradeLR

vecM = jnp.array([[1,SCALE_MUL_FACTOR]])[:,:,None]

all_rays = jnp.vstack(cameras_list)
train_size = all_rays.shape[0]
Niter_epoch = int(round(train_size/batch_size))

def irc(x): return int(round(x))

# babysit learning rates
adjust_lr = DegradeLR(LR_RATE,0.5,irc(Niter_epoch*0.4),irc(Niter_epoch*0.1),-1e-4)

optimizer = optax.adam(adjust_lr.step_func)

tmp = [rand_mean,rand_prec,rand_weight_log,rand_color]

opt_state = optimizer.init(tmp)

all_sils = jnp.hstack([_.ravel() for _ in masks]).astype(jnp.float32)
all_colors = jnp.hstack([_.ravel() for _ in color_images]).astype(jnp.float32).reshape((-1,3))
all_colors = all_colors**(1/2.2)

losses = []
opt_configs = []
outer_loop = tqdm(range(Nepoch), desc=" epoch", position=0)

rand_idx = np.arange(train_size)
params = tmp
def inner_iter(j_idx,rand_idx_local,opt_state,p):
    idx = jax.lax.dynamic_slice(rand_idx_local,[j_idx*batch_size],[batch_size])

    val,g = grad_render3([p[0]*SCALE_MUL_FACTOR,p[1]/SCALE_MUL_FACTOR,p[2],p[3]],vecM*all_rays[idx],
                         beta2/(shape_scale),beta3,all_sils[idx],all_colors[idx])   
    updates, opt_state = optimizer.update(g, opt_state,p)
    p = optax.apply_updates(p, updates)
    return val, opt_state, p 
jax_iter = jax.jit(inner_iter)
done = False
for i in outer_loop:
    np.random.shuffle(rand_idx)
    rand_idx_jnp = jnp.array(rand_idx)

    for j in tqdm(range(Niter_epoch), desc=" iteration", position=1, leave=False):
        opt_configs.append(list(params))
        val,opt_state,params = jax_iter(j,rand_idx_jnp,opt_state,params)
        val = float(val)
        losses.append(val)

        if adjust_lr.add(val):
            done = True
            break
        outer_loop.set_description(" loss {:.3f}".format(val))
    if done:
        break

In [None]:
plt.plot(losses)

In [None]:
final_mean, final_prec, final_weight_log,final_color = params

In [None]:
result_depths = []
result_alphas = []
results_colors = []

for ray_trans in cameras_list:
    est_depth, est_alpha, est_norm, est_w = render_jit(final_mean,final_prec,final_weight_log,ray_trans,beta2/shape_scale,beta3)
    est_color = np.array(est_w.T @ (jnp.tanh(final_color)*0.5+0.5))**(2.2)

    est_depth = np.array(est_depth)
    est_alpha = np.array(est_alpha)
    est_depth[est_alpha < 0.5] = np.nan
    est_color[est_alpha < 0.5] = np.nan

    result_depths.append(est_depth.reshape((PY,PX)))
    result_alphas.append(est_alpha.reshape((PY,PX)))
    results_colors.append(est_color.reshape((PY,PX,3)))

In [None]:
plt.subplot(1,3,1)
plt.imshow(result_alphas[-1])
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow(result_depths[-1])
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(est_color.reshape((PY,PX,3)),interpolation='nearest')
plt.axis('off')

In [None]:
from scipy.stats import trim_mean
errs = []
d1f = np.hstack([_.ravel() for _ in  depths]).ravel()
d2f = np.hstack([_.ravel() for _ in result_depths]).ravel()

mask = (all_sils !=0 ) & (~np.isnan(d1f)) & (~np.isnan(d2f)) & (d1f !=0) 

trim_mean(abs(d1f[mask]-d2f[mask]),0.1)

In [None]:
image_grid(masks,rows=3,cols=5,rgb=False)

In [None]:
max_frame = len(poses)
FWD_BCK_TIMES = 4
THRESH_IDX = np.where(np.array(losses)/min(losses) < 1.02)[0][0]
USE_FIRST_N_FRAC = THRESH_IDX/len(losses)
N_FRAMES = max_frame*FWD_BCK_TIMES
opt_to_use = np.round(np.linspace(0,int(np.floor(len(opt_configs)*USE_FIRST_N_FRAC-1)),N_FRAMES)).astype(int)

In [None]:
THRESH_IDX/len(losses)

In [None]:
plt.plot(losses[:THRESH_IDX])

In [None]:
frame_idxs = []
frame_list = list(range(max_frame))
for i in range(FWD_BCK_TIMES):
    if (i % 2) == 0:
        frame_idxs += frame_list
    else:
        frame_idxs += frame_list[::-1]

In [None]:
full_res_alpha = []
full_res_depth = []
full_res_color = []

for r_idx,c_idx in zip(frame_idxs,opt_to_use):
    p = opt_configs[c_idx]
    ray_trans = cameras_list[r_idx]
    est_depth, est_alpha, est_norm, est_w = render_jit(p[0],p[1],p[2],ray_trans,beta2/shape_scale,beta3)

    est_color = est_alpha[:,None] * np.array(est_w.T @ (jnp.tanh(p[3])*0.5+0.5))**(2.2)

    est_alpha = np.array(est_alpha)
    est_depth = np.array(est_depth)
    est_depth[est_alpha < 0.5] = np.nan
    #est_color[est_alpha < 0.5] = np.nan

    full_res_alpha.append(est_alpha.reshape((PY,PX)))
    full_res_depth.append(est_depth.reshape((PY,PX)))
    full_res_color.append(est_color.reshape((PY,PX,3)))

    print('.',end='')

In [None]:
if os.path.exists(output_folder):
    import shutil
    shutil.rmtree(output_folder)
os.makedirs(output_folder, exist_ok=True)

In [None]:
vecr = np.hstack([_.ravel() for _ in full_res_depth])
vecr = vecr[~np.isnan(vecr)]
vmin = np.percentile(vecr,5)
vmax = np.percentile(vecr,95)
vscale = vmax-vmin

In [None]:
import matplotlib
from PIL import Image, ImageDraw, ImageFont
start_f = 0
avg_size = np.array([PX,PY])
fsize = irc(96/4)

font = ImageFont.truetype('Roboto-Regular.ttf', size=irc(avg_size[0]/8))
cmap = matplotlib.cm.get_cmap('viridis')
cmap2 = matplotlib.cm.get_cmap('magma')

for i,mask_res in enumerate(full_res_alpha):
    r_idx = frame_idxs[i]
    #img1 = ground_images[r_idx]/255.0*np.clip(full_masks[r_idx] > .1,0.3,1)[:,:,None]
    #img2 = ground_images[r_idx]*np.clip((mask_res)**0.4,0.05,1)[:,:,None]
    img2 = full_res_color[i]#np.tile(mask_res[:,:,None],(1,1,3))
    img_gt_mask = np.tile(masks[r_idx][:,:,None],(1,1,3))

    true_alpha = masks[r_idx]

    est_alpha = jnp.clip(mask_res,1e-6,1-1e-6)
    mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))
    loss_viz = cmap2(0.25*mask_loss)[:,:,:3]

    depth = cmap((full_res_depth[i]-vmin)/vscale)[:,:,:3]
    img2 = np.concatenate((color_images[r_idx],img_gt_mask,loss_viz,img2, depth), axis=1)
    int_img = np.round(img2*255).astype(np.uint8)
    pil_img = Image.fromarray(int_img)
    d1 = ImageDraw.Draw(pil_img)
    d1.text((avg_size[0]*1.1, irc(fsize*0.1)), "Iteration: {:3d}\nEpoch: {:.1f}".format(opt_to_use[i],opt_to_use[i]/Niter_epoch), ha='center',font=font,fill=(180, 180, 180))
    d1.text((avg_size[0]*1.3, irc(avg_size[1]-fsize*1.5)), "Target Mask", font=font,fill=(255, 255, 255),ha='center')
    d1.text((avg_size[0]*2.4, irc(avg_size[1]-fsize*1.5)), "Loss", font=font,fill=(255, 255, 255),ha='center',align='center')
    d1.text((avg_size[0]*3.3, irc(avg_size[1]-fsize*2.5)), "Estimated\nColor", font=font,fill=(255, 255, 255),ha='center',align='center')
    d1.text((avg_size[0]*4.3, irc(avg_size[1]-fsize*2.5)), "Estimated\nDepth", font=font,fill=(255, 255, 255),ha='center',align='center')

    img3 = np.array(pil_img)
    
    
    sio.imsave('{}/{:03d}.jpg'.format(output_folder,i),img3,quality=95)

In [None]:
plt.figure(figsize=(18,8))
plt.imshow(img3)
plt.axis('off')

In [None]:
import subprocess
if os.path.exists('{}.mp4'.format(output_folder)):
    os.remove('{}.mp4'.format(output_folder))
subprocess.call(' '.join(['/usr/local/bin/ffmpeg',
                 '-framerate','60',
                 '-i','{}/%03d.jpg'.format(output_folder),
                 '-vf','\"pad=ceil(iw/2)*2:ceil(ih/2)*2\"',
                 '-c:v','h264',
                 '-pix_fmt','yuv420p',
                 '{}.mp4'.format(output_folder)]),shell=True)