# Run CO3D Sequence (2 parameter)

In [None]:
import sys, os, glob
import numpy as np
import pandas as pd
from utils_opt import readFlow

import skimage.io as sio
import matplotlib.pyplot as plt

In [None]:
import os
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
#jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import fm_render

render_jit = jax.jit(fm_render.render_func_idx_quattrans)
render_jit_ray = jax.jit(fm_render.render_func_rays)
jax_flow_rend = jax.jit(fm_render.render_func_idx_quattrans_flow)

In [None]:
dataset_dir = 'rvid/teddybear_34_1479_4753//'
co3d_seq = os.path.split(dataset_dir.rstrip('/').lstrip('/'))[-1]
output_folder = os.path.join('tmp_out',co3d_seq)
NUM_MIXTURE = 40
shape_scale = 1.8
c_scale = 4.5
f_scale = 210
rand_sphere_size = 55
cov_scale = 1.2e-2
weight_scale = 1.1
LR_RATE = 0.08
beta2 = 21.4
beta3 = 2.66
Nepoch = 10
batch_size = 50000


## Load Data

In [None]:
import gzip
import pickle
with gzip.open(os.path.join(dataset_dir,'pose_depth.pkl.gz'),'rb') as fp:
    depth_and_pose = pickle.load(fp)

true_depths = depth_and_pose['depths']
fl = depth_and_pose['fl']
poses = np.array(depth_and_pose['poses'])


In [None]:
masks_folder = os.path.join(dataset_dir,'masks','video1')
in_files = sorted(glob.glob(masks_folder + '/*.png'))

masks = []
for img_loc in in_files:
    mask = sio.imread(img_loc)
    mask = (mask > 0).astype(np.float32)
    masks.append(mask)
masks = np.array(masks)
PY,PX = mask.shape
image_size = (PY,PX)

In [None]:
masks_folder = os.path.join(dataset_dir,'JPEGImages','video1')
in_files = sorted(glob.glob(masks_folder + '/*.jpg'))

images = []
for img_loc in in_files:
    img = sio.imread(img_loc).astype(np.float32)
    images.append(img)
images = np.array(images)

In [None]:
fwd_flows = []
bwd_flows = []

flow_fol = os.path.join(dataset_dir,'Flow','video1','*.flo')

flow_files = sorted(glob.glob(flow_fol))

for flfile in flow_files:
    new_flow = readFlow(flfile)
    if PY > PX:
        new_flow = np.stack([new_flow[:,:,1],new_flow[:,:,0]],axis=2)
    if 'bwd' in flfile:
        bwd_flows.append(new_flow)
    else:
        fwd_flows.append(new_flow)


# last flow has no fowards
fwd_flows = fwd_flows + [new_flow*0]
# first flow has no backwards
bwd_flows = [new_flow*0] + bwd_flows

In [None]:
if 'mesh' in depth_and_pose:
    pt_cld = depth_and_pose['mesh'].vertices
    import sklearn.mixture as mixture

    idx2 = np.arange(pt_cld.shape[0])
    np.random.shuffle(idx2)
    clf = mixture.GaussianMixture(40)
    clf.fit(pt_cld[idx2[:10000]])

    pt_cld_shape_scale = float(pt_cld.std(0).mean())*3
    center = pt_cld.mean(0)
else:         
    pt_cld_shape_scale = 3.0
    center = np.zeros(3,dtype=np.float32)

In [None]:
SCALE_MUL_FACTOR = shape_scale/pt_cld_shape_scale
SCALE_MUL_FACTOR

In [None]:
img_shape = (PY,PX)
min_size_idx = np.argmin(img_shape)
min_size = img_shape[min_size_idx]
max_size = img_shape[1-min_size_idx]
invF = 0.5*min_size/fl
min_dim = np.linspace(-1,1,min_size)
aspect = max_size/min_size
max_dim = np.linspace(-aspect,aspect,max_size)
grid = [-max_dim,-min_dim,1,0] if min_size_idx == 0 else [-min_dim,-max_dim,1,0]
pixel_list = np.transpose(np.squeeze(np.meshgrid(*grid,indexing='ij')),(2,1,0))

pixel_list = pixel_list.reshape((-1,4))

In [None]:
poses = jnp.array(poses)

In [None]:
from util import image_grid

In [None]:
# random init settings
rand_mean = center+pt_cld_shape_scale*np.random.multivariate_normal(mean=[0,0,0],cov=cov_scale*np.identity(3),size=NUM_MIXTURE)
rand_weight_log = jnp.log(weight_scale*np.ones(NUM_MIXTURE)/NUM_MIXTURE)
rand_prec = jnp.array([np.identity(3)*rand_sphere_size/pt_cld_shape_scale for _ in range(NUM_MIXTURE)])
rand_color = jnp.array(np.random.randn(NUM_MIXTURE,3))

init_alphas = []
for i in range(min(36,len(poses))):
    pixel_list[:,3] = i
    res_img,est_alpha,_,_ = render_jit(rand_mean,rand_prec,rand_weight_log,pixel_list,invF,poses,beta2/pt_cld_shape_scale,beta3)

    res_imgA = np.array(res_img)
    res_imgA[est_alpha < 0.5] = np.nan
    init_alphas.append(est_alpha.reshape((PY,PX)))
image_grid(init_alphas,6,6,rgb=False)

In [None]:
total_ray_set = []
for i in range(len(poses)):
    pixel_list[:,3] = i

    total_ray_set.append(pixel_list.copy())
all_rays = jnp.vstack(total_ray_set)

In [None]:
# scaled into ray coord space, vectorized flows
fwv_flow = jnp.array(np.array(fwd_flows).reshape((-1,2)))/(min_size/2)
bwv_flow = jnp.array(np.array(bwd_flows).reshape((-1,2)))/(min_size/2)

In [None]:
last_img_size = np.prod(img_shape)
v_idx = 40

In [None]:
def objective(params,camera_rays,invF,poses,beta2,beta3,true_alpha,true_fwd,true_bwd,true_color):
    CLIP_ALPHA = 1e-6
    means,prec,weights_log,colors = params
    est_depth, est_alpha, est_norm, est_w,flowp,flowm = fm_render.render_func_idx_quattrans_flow(means,prec,weights_log,camera_rays,invF,poses,beta2,beta3)
    est_color = est_w.T @ (jnp.tanh(colors)*0.5+0.5)
    est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA)
    mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))
    pad_alpha = true_alpha[:,None]
    flow1 = jnp.abs(pad_alpha*true_fwd-pad_alpha*flowp)
    flow2 = jnp.abs(pad_alpha*true_bwd-pad_alpha*flowm)
    cdiff = jnp.abs( (true_color-est_color)*true_alpha[:,None] )
    return mask_loss.mean() + c_scale*cdiff.mean() + f_scale*(flow1.mean() + flow2.mean()) 
grad_render3 = jax.value_and_grad(objective)

In [None]:
import optax
from tqdm.notebook import tqdm
from util import DegradeLR

vecM = jnp.array([1,1,1,1,SCALE_MUL_FACTOR,SCALE_MUL_FACTOR,SCALE_MUL_FACTOR])[None]

train_size = all_rays.shape[0]
Niter_epoch = int(round(train_size/batch_size))

def irc(x): return int(round(x))

# babysit learning rates
adjust_lr = DegradeLR(LR_RATE,0.5,irc(Niter_epoch*0.25),irc(Niter_epoch*0.1),-1e-4)

optimizer = optax.adam(adjust_lr.step_func)

tmp = [rand_mean,rand_prec,rand_weight_log,rand_color]
#tmp = [means,prec,weights_log]

opt_state = optimizer.init(tmp)

all_sils = jnp.hstack([_.ravel() for _ in masks]).astype(jnp.float32)
all_colors = jnp.hstack([_.ravel()/255.0 for _ in images]).astype(jnp.float32).reshape((-1,3))
all_colors = all_colors**(1/2.2)

losses = []
opt_configs = []
outer_loop = tqdm(range(Nepoch), desc=" epoch", position=0)

rand_idx = np.arange(train_size)
params = tmp
def inner_iter(j_idx,rand_idx_local,opt_state,p):
    idx = jax.lax.dynamic_slice(rand_idx_local,[j_idx*batch_size],[batch_size])

    val,g = grad_render3([p[0]*SCALE_MUL_FACTOR,p[1]/SCALE_MUL_FACTOR,p[2],p[3]],all_rays[idx],invF,vecM*poses,
                         beta2/(shape_scale),beta3,all_sils[idx],fwv_flow[idx],bwv_flow[idx],all_colors[idx])   
    updates, opt_state = optimizer.update(g, opt_state,p)
    p = optax.apply_updates(p, updates)
    return val, opt_state, p 
jax_iter = jax.jit(inner_iter)
done = False
for i in outer_loop:
    np.random.shuffle(rand_idx)
    rand_idx_jnp = jnp.array(rand_idx)

    for j in tqdm(range(Niter_epoch), desc=" iteration", position=1, leave=False):
        opt_configs.append(list(params))
        val,opt_state,params = jax_iter(j,rand_idx_jnp,opt_state,params)
        val = float(val)
        losses.append(val)

        if adjust_lr.add(val):
            done = True
            break
        outer_loop.set_description(" loss {:.3f}".format(val))
    if done:
        break

In [None]:
plt.plot(losses)

In [None]:
final_mean, final_prec, final_weight_log,final_color = params

In [None]:
dump_out = {
    'mean': np.array(final_mean),
    'prec': np.array(final_prec),
    'wlog': np.array(final_weight_log),
    'color': np.array(final_color)
}
import pickle
with open('output.pkl','wb') as fp:
    pickle.dump(dump_out,fp)

In [None]:
result_depths = []
result_alphas = []
results_colors = []

for i in range(len(poses)):
    pixel_list[:,3] = i
    res_img,est_alpha,_,w = render_jit(final_mean, final_prec, final_weight_log,pixel_list,invF,poses,beta2/pt_cld_shape_scale,beta3)
    est_color = np.array(w.T @ (jnp.tanh(final_color)*0.5+0.5))**(2.2)

    res_imgA = np.array(res_img)
    est_alpha = np.array(est_alpha)
    res_imgA[est_alpha < 0.5] = np.nan
    est_color[est_alpha < 0.5] = np.nan

    result_depths.append(res_imgA.reshape((PY,PX)))
    result_alphas.append(est_alpha.reshape((PY,PX)))
    results_colors.append(est_color.reshape((PY,PX,3)))

In [None]:
plt.subplot(1,3,1)
plt.imshow(result_alphas[-1])
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow(result_depths[-1])
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(est_color.reshape((PY,PX,3)),interpolation='nearest')
plt.axis('off')


In [None]:
def per_gaussian_error(params,camera_rays,invF,poses,beta2,beta3,true_alpha,true_color,lower_std,upper_std):
    CLIP_ALPHA = 1e-7
    CLIP_ALPHA = 1e-6
    means,prec,weights_log,colors = params
    est_depth, est_alpha, est_norm, est_w =render_jit(means,prec,weights_log,camera_rays,invF,poses,beta2,beta3)
    est_color = est_w.T @ (jnp.tanh(colors)*0.5+0.5)
    est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA)
    mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))
    cdiff = jnp.abs( (true_color-est_color)*true_alpha[:,None] )
    
    per_err = ((mask_loss*est_w).mean(axis=1) + c_scale*(cdiff.mean(axis=1) * est_w).mean(axis=1) )
    avg_w = est_w.mean(axis=1)
    keep_idx = (avg_w > (avg_w.mean() - lower_std*avg_w.std()))
    split_idx = (per_err >= (per_err.mean() + upper_std*per_err.std()))
    c_var =     (true_color[:,None,:] *est_w.T[:,:,None]).std(axis=0)
    return split_idx, keep_idx, c_var

def get_split_gaussian(params,camera_rays,invF,poses,beta2,beta3,true_alpha,true_color,lower_std,upper_std):
    split_idx,keep_idx,c_var = per_gaussian_error(params,camera_rays,invF,poses,beta2,beta3,true_alpha,true_color,lower_std,upper_std)
    t_keep_idx = keep_idx & (~split_idx)

    means,prec,weights_log,colors = params

    new_means, new_prec, new_weights, new_colors = [],[],[], []
    for i in np.where(np.array(split_idx))[0]:
        mu, preco, wlog, col = means[i], prec[i], weights_log[i], colors[i]
        covar = np.linalg.pinv(preco.T @ preco)
        u,s,vt = np.linalg.svd(covar)
        s2 = s.copy()
        s2[0] = s2[0] * np.sqrt(1-2/np.pi)
        covar2 = u@np.diag(s2)@vt
        m1 = mu + (u[0] * np.sqrt(s[0]) * np.sqrt(2/np.pi))
        m2 = mu - (u[0] * np.sqrt(s[0]) * np.sqrt(2/np.pi))
        precn = np.linalg.cholesky(np.linalg.pinv(covar2)).T

        new_means.append(m1)
        new_means.append(m2)
        new_prec.append(precn)
        new_prec.append(precn)
        new_weights.append(wlog+ 0.1*np.random.randn())
        new_weights.append(wlog+ 0.1*np.random.randn())
        new_colors.append(col + 0.1*np.random.randn(3))
        new_colors.append(col + 0.1*np.random.randn(3))
        oldp = [np.array(_)[t_keep_idx] for _ in params]
        m2 = np.vstack([oldp[0],new_means])
        p2 = np.vstack([oldp[1],new_prec])
        w2 = np.hstack([oldp[2],new_weights])
        c2 = np.vstack([oldp[3],new_colors])
    return [jnp.array(_).astype(jnp.float32) for _ in [m2,p2,w2,c2]]
idx =rand_idx_jnp[:10*batch_size]               
params2 = get_split_gaussian(params,all_rays[idx],invF,vecM*poses,beta2/(shape_scale),beta3,all_sils[idx],all_colors[idx],2,1)
print(params2[0].shape)

In [None]:
from scipy.stats import trim_mean
errs = []
d1f = np.hstack([_.ravel() for _ in  true_depths]).ravel()
d2f = np.hstack([_.ravel() for _ in result_depths]).ravel()

mask = (all_sils !=0 ) & (~np.isnan(d1f)) & (~np.isnan(d2f)) & (d1f !=0) 

trim_mean(abs(d1f[mask]-d2f[mask]),0.1)

In [None]:
image_grid(masks,rows=3,cols=5,rgb=False)

In [None]:
max_frame = len(poses)
FWD_BCK_TIMES = 4
THRESH_IDX = np.where(np.array(losses)/min(losses) < 1.02)[0][0]
USE_FIRST_N_FRAC = THRESH_IDX/len(losses)
N_FRAMES = max_frame*FWD_BCK_TIMES
opt_to_use = np.round(np.linspace(0,int(np.floor(len(opt_configs)*USE_FIRST_N_FRAC-1)),N_FRAMES)).astype(int)
loss_v = np.log(losses)
loss_v -= loss_v.min()
loss_v /= loss_v.max()
loss_v = np.cumsum(loss_v)
loss_v -= loss_v.min()
loss_v /= loss_v.max()
tv = np.stack([N_FRAMES*loss_v,(len(opt_configs)-1)*np.linspace(0,1,len(losses))]).T
#plt.plot(tv[:,0],tv[:,1])
#opt_to_use = np.round(np.interp(np.arange(N_FRAMES),tv[:,0],tv[:,1])).astype(int)

In [None]:
THRESH_IDX/len(losses)

In [None]:
plt.plot(losses[:THRESH_IDX])

In [None]:
len(opt_configs)

In [None]:
if os.path.exists(output_folder):
    import shutil
    shutil.rmtree(output_folder)
os.makedirs(output_folder, exist_ok=True)

In [None]:
frame_idxs = []
frame_list = list(range(max_frame))
for i in range(FWD_BCK_TIMES):
    if (i % 2) == 0:
        frame_idxs += frame_list
    else:
        frame_idxs += frame_list[::-1]

In [None]:
full_res_alpha = []
full_res_depth = []
full_res_color = []

for r_idx,c_idx in zip(frame_idxs,opt_to_use):
    p = opt_configs[c_idx]

    pixel_list[:,3] = r_idx
    est_depth,est_alpha,_,w = render_jit(p[0],p[1],p[2],pixel_list,invF,poses,beta2/pt_cld_shape_scale,beta3)
    est_color = (1-est_alpha[:,None])*0 + est_alpha[:,None] * np.array(w.T @ (jnp.tanh(p[3])*0.5+0.5))**(2.2)

    est_alpha = np.array(est_alpha)
    est_depth = np.array(est_depth)
    est_depth[est_alpha < max(0.5,np.percentile(est_alpha,0.99))] = np.nan
    #est_color[est_alpha < 0.5] = np.nan

    full_res_alpha.append(est_alpha.reshape((PY,PX)))
    full_res_depth.append(est_depth.reshape((PY,PX)))
    full_res_color.append(est_color.reshape((PY,PX,3)))

    print('.',end='')

In [None]:
vecr = np.hstack([_.ravel() for _ in full_res_depth])
vecr = vecr[~np.isnan(vecr)]
vmin = np.percentile(vecr,5)
vmax = np.percentile(vecr,95)
vscale = vmax-vmin

In [None]:
import matplotlib
from PIL import Image, ImageDraw, ImageFont
start_f = 0
avg_size = np.array([PX,PY])
fsize = irc(96/4)

font = ImageFont.truetype('Roboto-Regular.ttf', size=irc(avg_size[0]/16))
cmap = matplotlib.cm.get_cmap('viridis')
cmap2 = matplotlib.cm.get_cmap('magma')

for i,mask_res in enumerate(full_res_alpha):
    r_idx = frame_idxs[i]
    #img1 = ground_images[r_idx]/255.0*np.clip(full_masks[r_idx] > .1,0.3,1)[:,:,None]
    #img2 = ground_images[r_idx]*np.clip((mask_res)**0.4,0.05,1)[:,:,None]
    img2 = full_res_color[i]#np.tile(mask_res[:,:,None],(1,1,3))
    img_gt_mask = np.tile(masks[r_idx][:,:,None],(1,1,3))

    true_alpha = masks[r_idx]

    est_alpha = jnp.clip(mask_res,1e-6,1-1e-6)
    mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))
    loss_viz = cmap2(0.25*mask_loss)[:,:,:3]

    depth = cmap((full_res_depth[i]-vmin)/vscale)[:,:,:3]
    img2 = np.concatenate((images[r_idx]/255.0,img_gt_mask,img2, depth), axis=1)
    int_img = np.round(img2*255).astype(np.uint8)
    pil_img = Image.fromarray(int_img)
    d1 = ImageDraw.Draw(pil_img)
    d1.text((avg_size[0]*1.1, irc(fsize*0.1)), "Iteration: {:3d}\nEpoch: {:.1f}".format(opt_to_use[i],opt_to_use[i]/Niter_epoch), ha='center',font=font,fill=(180, 180, 180))
    d1.text((avg_size[0]*1.3, irc(avg_size[1]-fsize*2.5)), "Target Mask", font=font,fill=(255, 255, 255),ha='center')
    #d1.text((avg_size[0]*2.4, irc(avg_size[1]-fsize*1.5)), "Loss", font=font,fill=(255, 255, 255),ha='center',align='center')
    d1.text((avg_size[0]*2.3, irc(avg_size[1]-fsize*3.5)), "Estimated\nColor", font=font,fill=(255, 255, 255),ha='center',align='center')
    d1.text((avg_size[0]*3.3, irc(avg_size[1]-fsize*3.5)), "Estimated\nDepth", font=font,fill=(255, 255, 255),ha='center',align='center')

    img3 = np.array(pil_img)
    
    
    sio.imsave('{}/{:03d}.jpg'.format(output_folder,i),img3,quality=95)

In [None]:
plt.imshow(img3)

In [None]:
(avg_size[0]*1.3, irc(avg_size[1]-fsize*1.5)),avg_size

In [None]:
plt.figure(figsize=(18,8))
plt.imshow(img3)
plt.axis('off')

In [None]:
import subprocess
if os.path.exists('{}.mp4'.format(output_folder)):
    os.remove('{}.mp4'.format(output_folder))
subprocess.call(' '.join(['/usr/bin/ffmpeg',
                 '-framerate','60',
                 '-i','{}/%03d.jpg'.format(output_folder),
                 '-vf','\"pad=ceil(iw/2)*2:ceil(ih/2)*2\"',
                 '-c:v','h264',
                 '-pix_fmt','yuv420p',
                 '{}.mp4'.format(output_folder)]),shell=True)

In [None]:
#raise

In [None]:
output_folder

In [None]:
p = opt_configs[-1]

In [None]:
base_idx = min(len(masks)-1,28)
pixel_list[:,3] = base_idx
est_depth,est_alpha,_,_,flowp,flowm = jax_flow_rend(p[0],p[1],p[2],pixel_list,invF,poses,beta2/pt_cld_shape_scale,beta3)

flowp = (min_size/2)*np.array(flowp)
flowm = (min_size/2)*np.array(flowm)
flowp[est_alpha < 0.5] = np.nan
flowm[est_alpha < 0.5] = np.nan

In [None]:
tmp_f = np.copy(fwd_flows[base_idx])
tmp_f[est_alpha.reshape((PY,PX)) < 0.5] = np.nan

plt.subplot(1,2,1)
plt.imshow(tmp_f[:,:,0],vmin=-6,vmax=6,cmap='RdBu' )
plt.axis('off')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(tmp_f[:,:,1],vmin=-6,vmax=6,cmap='RdBu' )
plt.axis('off')

plt.colorbar()

plt.figure()
plt.subplot(1,2,1)
plt.imshow(flowp[:,0].reshape((PY,PX)),vmin=-6,vmax=6,cmap='RdBu' )
plt.axis('off')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(flowp[:,1].reshape((PY,PX)),vmin=-6,vmax=6,cmap='RdBu' )
plt.axis('off')
plt.colorbar()

In [None]:
tmp_b = np.copy(bwd_flows[base_idx])
tmp_b[est_alpha.reshape((PY,PX)) < 0.5] = np.nan

plt.subplot(1,2,1)
plt.imshow(tmp_b[:,:,0],vmin=-6,vmax=6,cmap='RdBu' )
plt.axis('off')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(tmp_b[:,:,1],vmin=-6,vmax=6,cmap='RdBu' )
plt.axis('off')

plt.colorbar()

plt.figure()
plt.subplot(1,2,1)
plt.imshow(flowm[:,0].reshape((PY,PX)),vmin=-6,vmax=6,cmap='RdBu' )
plt.axis('off')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(flowm[:,1].reshape((PY,PX)),vmin=-6,vmax=6,cmap='RdBu' )
plt.axis('off')
plt.colorbar()

In [None]:
import zpfm_render
render_jit2 = jax.jit(zpfm_render.render_func_idx_quattrans_flow)


In [None]:
points_export = []
colors_export = []
colors_export_plain = []

normals_export = []

scaleE = 2

thesh_min = 0.9

for i in range(len(poses)):
    pixel_list[:,3] = i
    rot_mats = jax.vmap(fm_render.quat_to_rot)(poses[:,:4])
    def rot_ray_t(rayi):
        ray = rayi[:3] * jnp.array([invF,invF,1])
        pose_idx = rayi[3].astype(int)
        return jnp.array([ray@rot_mats[pose_idx],poses[pose_idx][4:]])
    camera_rays_start = jax.vmap(rot_ray_t)(pixel_list)
    est_depth,est_alpha,est_norm,est_w,flowp,flowm = render_jit2(final_mean, final_prec*scaleE,(scaleE**2)*final_weight_log,pixel_list,invF,poses)

    est_color = np.array(est_w.T @ (jnp.tanh(final_color)*0.5+0.5))**(2.2)
    
    # nneed RGBA
    # or images[i] # #np.round(images[i])
    export_c = np.round(np.clip(est_color,0,1)*255).astype(np.uint8)
    alpha_c = (np.ones(export_c.shape[:-1])*255).astype(np.uint8)
    export_c = np.hstack([export_c.reshape((-1,3)),alpha_c.reshape((-1,1))]).reshape((-1,4))
    
    export_c2 = np.round(images[i]).astype(np.uint8)
    export_c2 = np.hstack([export_c2.reshape((-1,3)),alpha_c.reshape((-1,1))]).reshape((-1,4))
    
    est_3d = est_depth[:,None]*camera_rays_start[:,0]+camera_rays_start[:,1] 
    
    est_3d = np.array(est_3d)
    est_alpha = np.array(est_alpha)
    
    export_cond = (est_alpha > thesh_min) & (est_w.max(axis=0) > thesh_min)

    points_export.append(est_3d[export_cond])
    colors_export.append(export_c2[export_cond])
    normals_export.append(est_norm[export_cond])
    colors_export_plain.append(export_c[export_cond])
    
    
points_export = np.concatenate(points_export)
colors_export = np.concatenate(colors_export)
colors_export_plain = np.concatenate(colors_export_plain)
normals_export = np.concatenate(normals_export)

In [None]:
est_color.max()

In [None]:
import open3d as o3d
o3d_cld = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points_export))
o3d_cld.colors = o3d.utility.Vector3dVector(colors_export[:,:3].astype(float)/255.0)
o3d_cld.normals = o3d.utility.Vector3dVector(normals_export)
o3d.io.write_point_cloud("{}.ply".format(output_folder), o3d_cld)

o3d_cld.colors = o3d.utility.Vector3dVector(colors_export_plain[:,:3].astype(float)/255.0)
o3d.io.write_point_cloud("{}_plain.ply".format(output_folder), o3d_cld)

In [None]:
output_folder

In [None]:
result_depths2 = []
result_alphas2 = []
results_colors2 = []
scaleE=1

for i in range(len(poses)):
    pixel_list[:,3] = i
    est_depth,est_alpha,est_norm,est_w,flowp,flowm = render_jit2(final_mean, final_prec*scaleE,(scaleE**2)*final_weight_log,pixel_list,invF,poses)
    est_color = np.array(w.T @ (jnp.tanh(final_color)*0.5+0.5))**(2.2)

    est_depth = np.array(est_depth)
    est_alpha = np.array(est_alpha)
    est_depth[est_alpha < thesh_min] = np.nan
    est_color[est_alpha < thesh_min] = np.nan

    result_depths2.append(est_depth.reshape((PY,PX)))
    result_alphas2.append(est_alpha.reshape((PY,PX)))
    results_colors2.append(est_color.reshape((PY,PX,3)))
    break

In [None]:
plt.imshow(est_w.T[:,6].reshape((PY,PX)))

In [None]:
plt.plot(est_w[:,est_w.shape[1]//2+100])
plt.plot(est_w[:,0])


In [None]:
plt.imshow(result_alphas2[-1])

In [None]:
plt.imshow(result_depths2[-1])
plt.colorbar()
plt.figure()
plt.imshow(result_depths[-1])
plt.colorbar()

In [None]:
plt.imshow(results_colors2[-1])

In [None]:
import transforms3d
Rr = transforms3d.quaternions.quat2mat(poses[0][:4])
est_norm2 = -np.array(est_norm) @ Rr
est_norm2[est_alpha < 0.25] = np.nan
plt.imshow(est_norm2.reshape((image_size[0],image_size[1],3))*0.5+0.5)

In [None]:
from util import compute_normals
est_norms3 = compute_normals(camera_rays_start[:,0,:],est_depth.reshape((PY,PX)))
plt.imshow(est_norms3.reshape((image_size[0],image_size[1],3))*0.5+0.5)