In [None]:
'''
This cell loads the model from the config file and initializes the viewer
'''
# %matplotlib widget
import torch
import matplotlib.pyplot as plt
from nerfstudio.utils.eval_utils import eval_setup
from pathlib import Path
import numpy as np
from nerfstudio.viewer.viewer import Viewer
from nerfstudio.configs.base_config import ViewerConfig
import cv2
from torchvision.transforms import ToTensor
from PIL import Image
from nerfstudio.utils import writer
import time
from threading import Lock
from nerfstudio.cameras.cameras import Cameras
from copy import deepcopy
from torchvision.transforms.functional import resize
from lerf.zed import Zed
import warp as wp
from toad.optimization.rigid_group_optimizer import RigidGroupOptimizer
from toad.optimization.atap_loss import ATAPLoss
from toad.utils import *
wp.init()

# config = Path("outputs/nerfgun2/dig/2024-05-03_161203/config.yml")
# config = Path("outputs/nerfgun3/dig/2024-05-03_170424/config.yml")
# config = Path("outputs/nerfgun4/dig/2024-05-07_130351/config.yml")
# config = Path("outputs/painter_sculpture/dig/2024-05-10_132522/config.yml")
# config = Path("outputs/painter_sculpture/dig/2024-05-16_233028/config.yml")#with ruilongs v2
# config = Path("outputs/buddha_balls_poly/dig/2024-05-09_123412/config.yml")
# config = Path("outputs/buddha_balls_poly/dig/2024-05-16_231213/config.yml")#with ruilongs v2
# config = Path("outputs/cal_bear/dig/2024-05-15_155531/config.yml")#this one groups table with bear for some reason
# config = Path("outputs/bww_faucet/dig/2024-05-12_215440/config.yml")
# config = Path("outputs/cmk_tpose2/dig/2024-05-14_142439/config.yml")
# config = Path("outputs/cal_bear/dig/2024-05-17_142920/config.yml")#ruilong v2
# config = Path("outputs/mac_charger/dig/2024-05-17_145312/config.yml")
# config = Path("outputs/mac_charger2/dig/2024-05-17_152545/config.yml")
# config = Path("outputs/glue_gun/dig/2024-05-17_161408/config.yml")
# config = Path("outputs/buddha_balls_poly/dig/2024-05-19_122050/config.yml")# reuilong v2, 32-dim gauss
# config = Path("outputs/mac_charger/dig/2024-05-19_125443/config.yml")
# config = Path("outputs/mac_charger2/dig/2024-05-19_132100/config.yml")
# config = Path("outputs/mac_charger/dig/2024-05-20_191616/config.yml")#with antialias
# config = Path("outputs/buddha_balls_poly/dig/2024-05-20_192646/config.yml")
# config = Path("outputs/garfield_plushie/dig/2024-05-21_144709/config.yml")
config=Path("outputs/boops_poly/dig/2024-05-22_000924/config.yml")
OUTPUT_FOLDER = Path("renders/boops_poly")

assert OUTPUT_FOLDER.stem in str(config), "Output folder name does not match config name"
OUTPUT_FOLDER.mkdir(exist_ok=True,parents=True)
train_config,pipeline,_,_ = eval_setup(config)
dino_loader = pipeline.datamanager.dino_dataloader
train_config.logging.local_writer.enable = False
# We need to set up the writer to track number of rays, otherwise the viewer will not calculate the resolution correctly
writer.setup_local_writer(train_config.logging, max_iter=train_config.max_num_iterations)
v = Viewer(ViewerConfig(default_composite_depth=False,num_rays_per_chunk=-1),config.parent,pipeline.datamanager.get_datapath(),pipeline,train_lock=Lock())

In [2]:
"""
This cell defines a simple pose optimizer for learning a rigid transform offset given a gaussian model, star pose, and starting view
"""

def get_vid_frame(cap,timestamp):
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    # Calculate the frame number based on the timestamp and fps
    frame_number = min(int(timestamp * fps),int(cap.get(cv2.CAP_PROP_FRAME_COUNT)-1))
    
    # Set the video position to the calculated frame number
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
    
    # Read the frame
    success, frame = cap.read()
    # convert BGR to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return frame

MATCH_RESOLUTION = 500
camera_input = 'iphone_vertical' # ['train_cam', 'iphone','zed', 'iphone_vertical','zed_svo']
video_path = Path("motion_vids/boops_lift.MOV")
svo_path = Path("motion_vids/buddha_remove_good.svo2")
start_time = 0.3

cam_pose = pipeline.viewer_control.get_camera(200,None,0)
if cam_pose is None:
    cam_pose = torch.eye(4).float().cuda()
if camera_input == 'train_cam':
    init_cam,data = pipeline.datamanager.next_train(0)
    view_cam_pose = pipeline.viewer_control.get_camera(200,None,0)
    init_cam.camera_to_worlds = view_cam_pose.camera_to_worlds
    init_cam.rescale_output_resolution(MATCH_RESOLUTION/max(init_cam.width,init_cam.height))
elif camera_input == 'iphone':
    init_cam = Cameras(camera_to_worlds=pipeline.viewer_control.get_camera(200,None,0).camera_to_worlds,fx = 1137.0,fy = 1137.0,cx = 1280.0/2,cy = 720/2,width=1280,height=720)
    init_cam.rescale_output_resolution(MATCH_RESOLUTION/max(init_cam.width,init_cam.height))
elif camera_input == 'iphone_vertical':
    init_cam = Cameras(camera_to_worlds=pipeline.viewer_control.get_camera(200,None,0).camera_to_worlds,fy = 1137.0,fx = 1137.0,cy = 1280/2,cx = 720/2,height=1280,width=720)
    init_cam.rescale_output_resolution(MATCH_RESOLUTION/max(init_cam.width,init_cam.height))
elif camera_input in ['zed','zed_svo']:
    try:
        zed.cam.close()
        del zed
    except:
        pass
    finally:
        zed = Zed(recording_file=str(svo_path.absolute()) if camera_input == 'zed_svo' else None, start_time=start_time)
    fps = 30
    left_rgb,_,_ = zed.get_frame()
    K = zed.get_K()
    init_cam = Cameras(camera_to_worlds=pipeline.viewer_control.get_camera(200,None,0).camera_to_worlds,fx = K[0,0],fy = K[1,1],cx = K[0,2],cy = K[1,2],width=1920,height=1080)
    init_cam.rescale_output_resolution(MATCH_RESOLUTION/max(init_cam.width,init_cam.height))
outputs = pipeline.model.get_outputs_for_camera(init_cam)
if pipeline.cluster_labels is not None:
    labels = pipeline.cluster_labels.int().cuda()
    group_masks = [(cid == labels).cuda() for cid in range(labels.max() + 1)]
else:
    labels = torch.zeros(pipeline.model.num_points).int().cuda()
    group_masks = [torch.ones(pipeline.model.num_points).bool().cuda()]
optimizer = RigidGroupOptimizer(pipeline.model,dino_loader,init_cam,group_masks, group_labels = labels, render_lock = v.train_lock)
rgb_renders = [] 

AttributeError: 'NoneType' object has no attribute 'camera_to_worlds'

In [None]:
if camera_input in ['zed','zed_svo']:
    left_rgb, right_rgb,depth = zed.get_frame()
    target_frame_rgb = (left_rgb/255)
    right_frame_rgb = (right_rgb/255)
    optimizer.set_frame(target_frame_rgb,depth=depth)
else:
    assert video_path.exists()
    motion_clip = cv2.VideoCapture(str(video_path.absolute()))
    start=1
    end=9
    fps = 30
    frame = get_vid_frame(motion_clip,start)
    target_frame_rgb = ToTensor()(Image.fromarray(frame)).permute(1,2,0).cuda()
    optimizer.set_frame(target_frame_rgb)
_,axs = plt.subplots(1,2,figsize=(10,4))
axs[0].imshow(outputs["rgb"].detach().cpu().numpy())
axs[1].imshow(target_frame_rgb.cpu().numpy())

In [None]:
import moviepy.editor as mpy
xs,ys,outputs,best_pose,renders,best_pix = optimizer.initialize_obj_pose()
_,axs = plt.subplots(1,2,figsize=(10,4))
axs[0].imshow(outputs["rgb"].detach().cpu().numpy())
rescale = max(target_frame_rgb.shape[0],target_frame_rgb.shape[1])/MATCH_RESOLUTION
best_pix = best_pix*rescale
axs[1].scatter(xs.cpu().numpy()*rescale,ys.cpu().numpy()*rescale)
axs[1].imshow(target_frame_rgb.cpu().numpy())
axs[0].imshow(optimizer.rgb_frame.cpu().numpy(),alpha=.3)
renders = [r.detach().cpu().numpy()*255 for r in renders]
#save video as test_camopt.mp4
out_clip = mpy.ImageSequenceClip(renders, fps=30)  
out_clip.write_videofile("test_camopt.mp4")

In [None]:
from nerfstudio.utils.colormaps import apply_depth_colormap
import tqdm
import moviepy.editor as mpy
import plotly.express as px
def plotly_render(frame):
    fig = px.imshow(frame)
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),showlegend=False,yaxis_visible=False, yaxis_showticklabels=False,xaxis_visible=False, xaxis_showticklabels=False
    )
    return fig
fig = plotly_render(outputs['rgb'].detach().cpu().numpy())
try:
    frame_vis.remove()
except:
    pass
frame_vis = pipeline.viewer_control.viser_server.add_gui_plotly(fig, 9/16)
try:
    animate_button.remove()
    frame_slider.remove()
    reset_button.remove()
except:
    pass
def composite_vis_frame(target_frame_rgb,outputs):
    target_vis_frame = resize(target_frame_rgb.permute(2,0,1),(outputs["rgb"].shape[0],outputs["rgb"].shape[1])).permute(1,2,0)
    # composite the outputs['rgb'] on top of target_vis frame
    target_vis_frame = target_vis_frame*0.5 + outputs["rgb"]*0.5
    return target_vis_frame

try:
    render_button.remove()
    filename_input.remove()
    status_mkdown.remove()
except:
    pass
import viser
filename_input = v.viser_server.add_gui_text("File Name","render")
status_mkdown = v.viser_server.add_gui_markdown(" ")
render_button = v.viser_server.add_gui_button("Render Animation",color='green',icon=viser.Icon.MOVIE)
@render_button.on_click
def render(_):
    render_button.disabled = True
    render_frames = []
    camera = pipeline.viewer_control.get_camera(1080,1920,0)
    for i in tqdm.tqdm(range(len(optimizer.keyframes))):
        status_mkdown.content = f"Rendering...{i/len(optimizer.keyframes):.01f}"
        pipeline.model.eval()
        optimizer.apply_keyframe(i)
        with torch.no_grad():
            outputs = pipeline.model.get_outputs_for_camera(camera)
        render_frames.append(outputs["rgb"].detach().cpu().numpy()*255)
    status_mkdown.content = "Saving..."
    out_clip = mpy.ImageSequenceClip(render_frames, fps=fps)
    fname = filename_input.value
    (OUTPUT_FOLDER / 'posed_renders').mkdir(exist_ok=True)
    render_folder = OUTPUT_FOLDER / 'posed_renders'
    out_clip.write_videofile(f"{render_folder}/{fname}.mp4", fps=fps,codec='libx264')
    out_clip.write_videofile(f"{render_folder}/{fname}_mac.mp4", fps=fps,codec='mpeg4',bitrate='5000k')
    v.viser_server.send_file_download(f"{fname}_mac.mp4",open(f"{render_folder}/{fname}_mac.mp4",'rb').read())
    status_mkdown.content = "Done!"
    render_button.disabled = False


if camera_input in ['zed','zed_svo']:
    if len(rgb_renders)==0:
        for i in tqdm.tqdm(range(10)):
            target_vis_frame = composite_vis_frame(target_frame_rgb,outputs)
            vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1).detach().cpu().numpy()
            fig = plotly_render(target_vis_frame.detach().cpu().numpy())
            frame_vis.figure = fig
            rgb_renders.append(vis_frame*255)
            outputs = optimizer.step(50, use_depth=i>7, metric_depth=True)
    while True:
        # If input camera is the zed, just loop it indefinitely until no more frames
        left_rgb, _, depth = zed.get_frame()
        if left_rgb is None:
            break
        target_frame_rgb = left_rgb/255
        optimizer.set_frame(target_frame_rgb,depth=depth)
        outputs = optimizer.step(50, metric_depth=True)
        v._trigger_rerender()
        optimizer.register_keyframe()
        target_vis_frame = composite_vis_frame(target_frame_rgb,outputs)
        vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1).detach().cpu().numpy()
        rgb_renders.append(vis_frame*255)
        fig = plotly_render(target_vis_frame.detach().cpu().numpy())
        frame_vis.figure = fig
elif camera_input in ['iphone','iphone_vertical','train_cam']:
    # Otherwise procces the video
    if len(rgb_renders)==0:
        for i in tqdm.tqdm(range(10)):
            target_vis_frame = composite_vis_frame(target_frame_rgb,outputs)
            vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1).detach().cpu().numpy()
            fig = plotly_render(target_vis_frame.detach().cpu().numpy())
            frame_vis.figure = fig
            rgb_renders.append(vis_frame*255)
            outputs = optimizer.step(30, use_depth=i>7, metric_depth=False)

    for t in tqdm.tqdm(np.linspace(start,end,int((end-start)*fps))):
        frame = get_vid_frame(motion_clip,t)
        target_frame_rgb = ToTensor()(Image.fromarray(frame)).permute(1,2,0).cuda()
        optimizer.set_frame(target_frame_rgb)
        outputs = optimizer.step(50, metric_depth=False)
        optimizer.register_keyframe()
        v._trigger_rerender()
        target_vis_frame = composite_vis_frame(target_frame_rgb,outputs)
        vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1).detach().cpu().numpy()
        fig = plotly_render(target_vis_frame.detach().cpu().numpy())
        frame_vis.figure = fig
        rgb_renders.append(vis_frame*255)
#save as an mp4
out_clip = mpy.ImageSequenceClip(rgb_renders, fps=fps)  

fname = str(OUTPUT_FOLDER / "optimizer_out_antialias_withdepth_highatap.mp4")

out_clip.write_videofile(fname, fps=fps,codec='libx264')
out_clip.write_videofile(fname.replace('.mp4','_mac.mp4'),fps=fps,codec='mpeg4',bitrate='5000k')

# Populate some viewer elements to visualize the animation
animate_button = v.viser_server.add_gui_button("Play Animation")
frame_slider = v.viser_server.add_gui_slider("Frame",0,len(optimizer.keyframes)-1,1,0)
reset_button = v.viser_server.add_gui_button("Reset Transforms")

@animate_button.on_click
def play_animation(_):
    for i in range(len(optimizer.keyframes)):
        optimizer.apply_keyframe(i)
        v._trigger_rerender()
        time.sleep(1/fps)
@frame_slider.on_update
def apply_keyframe(_):
    optimizer.apply_keyframe(frame_slider.value)
    v._trigger_rerender()
@reset_button.on_click
def reset_transforms(_):
    optimizer.reset_transforms()
    v._trigger_rerender()