In [None]:
'''
This cell loads the model from the config file and initializes the viewer
'''
# %matplotlib widget
import torch
import matplotlib.pyplot as plt
from nerfstudio.utils.eval_utils import eval_setup
from pathlib import Path
import numpy as np
from nerfstudio.viewer.viewer import Viewer
from nerfstudio.configs.base_config import ViewerConfig
import cv2
from torchvision.transforms import ToTensor
from PIL import Image
from typing import List,Optional,Literal
from nerfstudio.utils import writer
import time
from threading import Lock
import kornia
from lerf.dig import DiGModel
from lerf.data.utils.dino_dataloader import DinoDataloader
from nerfstudio.cameras.cameras import Cameras
from copy import deepcopy
from torchvision.transforms.functional import resize
from contextlib import nullcontext
from lerf.zed import Zed
from nerfstudio.engine.schedulers import ExponentialDecayScheduler,ExponentialDecaySchedulerConfig
import warp as wp
wp.init()

# config = Path("outputs/nerfgun2/dig/2024-05-03_161203/config.yml")
# config = Path("outputs/nerfgun3/dig/2024-05-03_170424/config.yml")
# config = Path("outputs/nerfgun4/dig/2024-05-07_130351/config.yml")
# config = Path("outputs/painter_sculpture/dig/2024-05-10_132522/config.yml")
# config = Path("outputs/painter_sculpture/dig/2024-05-16_233028/config.yml")#with ruilongs v2
# config = Path("outputs/buddha_balls_poly/dig/2024-05-09_123412/config.yml")
# config = Path("outputs/buddha_balls_poly/dig/2024-05-16_231213/config.yml")#with ruilongs v2
# config = Path("outputs/cal_bear/dig/2024-05-15_155531/config.yml")#this one groups table with bear for some reason
# config = Path("outputs/boops_mug/dig/2024-05-10_223745/config.yml")
# config = Path("outputs/bww_faucet/dig/2024-05-12_215440/config.yml")
# config = Path("outputs/cmk_tpose2/dig/2024-05-14_142439/config.yml")
# config = Path("outputs/cal_bear/dig/2024-05-17_142920/config.yml")#ruilong v2
# config = Path("outputs/mac_charger/dig/2024-05-17_145312/config.yml")
# config = Path("outputs/mac_charger2/dig/2024-05-17_152545/config.yml")
# config = Path("outputs/glue_gun/dig/2024-05-17_161408/config.yml")
# config = Path("outputs/buddha_balls_poly/dig/2024-05-19_122050/config.yml")# reuilong v2, 32-dim gauss
# config = Path("outputs/mac_charger/dig/2024-05-19_125443/config.yml")
config = Path("outputs/mac_charger2/dig/2024-05-19_132100/config.yml")
OUTPUT_FOLDER = Path("renders/mac_charger")

assert OUTPUT_FOLDER.stem in str(config), "Output folder name does not match config name"
OUTPUT_FOLDER.mkdir(exist_ok=True)
train_config,pipeline,_,_ = eval_setup(config)
dino_loader = pipeline.datamanager.dino_dataloader
train_config.logging.local_writer.enable = False
# We need to set up the writer to track number of rays, otherwise the viewer will not calculate the resolution correctly
writer.setup_local_writer(train_config.logging, max_iter=train_config.max_num_iterations)
v = Viewer(ViewerConfig(default_composite_depth=False,num_rays_per_chunk=-1),config.parent,pipeline.datamanager.get_datapath(),pipeline,train_lock=Lock())

In [None]:
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from typing import Union
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
da_image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
da_model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")
da_model.to('cuda')
def get_depth(img: Union[torch.tensor,np.ndarray]):
    assert img.shape[2] == 3
    if isinstance(img,torch.Tensor):
        img = img.cpu().numpy()
    image = Image.fromarray(img)

    # prepare image for the model
    inputs = da_image_processor(images=image, return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].cuda()

    with torch.no_grad():
        outputs = da_model(**inputs)
        predicted_depth = outputs.predicted_depth

    # interpolate to original size
    prediction = torch.nn.functional.interpolate(
        predicted_depth.unsqueeze(1),
        size=image.size[::-1],
        mode="bicubic",
        align_corners=False,
    )
    return prediction.squeeze()
hand_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
hand_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
hand_model.to('cuda')
def get_hand_mask(img: Union[torch.tensor,np.ndarray]):
    assert img.shape[2] == 3
    if isinstance(img,torch.Tensor):
        img = img.cpu().numpy()
    image = Image.fromarray(img)

    # prepare image for the model
    inputs = hand_processor(images=image, return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].cuda()

    with torch.no_grad():
        outputs = hand_model(**inputs)

    # Perform post-processing to get panoptic segmentation map
    seg_ids = hand_processor.post_process_semantic_segmentation(
        outputs, target_sizes=[image.size[::-1]]
    )[0]
    hand_mask = (seg_ids == hand_model.config.label2id['person']).float()
    return hand_mask

In [None]:
"""
This cell defines a simple pose optimizer for learning a rigid transform offset given a gaussian model, star pose, and starting view
"""

def get_vid_frame(cap,timestamp):
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    # Calculate the frame number based on the timestamp and fps
    frame_number = min(int(timestamp * fps),int(cap.get(cv2.CAP_PROP_FRAME_COUNT)-1))
    
    # Set the video position to the calculated frame number
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
    
    # Read the frame
    success, frame = cap.read()
    # convert BGR to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return frame
        
def quatmul(q0:torch.Tensor,q1:torch.Tensor):
    w0, x0, y0, z0 = torch.unbind(q0, dim=-1)
    w1, x1, y1, z1 = torch.unbind(q1, dim=-1)
    return torch.stack(
            [
                -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,
                x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,
                -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,
                x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,
            ],
            dim = -1
        )

def depth_ranking_loss(rendered_depth, gt_depth):
    """
    Depth ranking loss as described in the SparseNeRF paper
    Assumes that the layout of the batch comes from a PairPixelSampler, so that adjacent samples in the gt_depth
    and rendered_depth are from pixels with a radius of each other
    """
    m = 1e-4
    if rendered_depth.shape[0] % 2 != 0:
        # chop off one index
        rendered_depth = rendered_depth[:-1, :]
        gt_depth = gt_depth[:-1, :]
    dpt_diff = gt_depth[::2, :] - gt_depth[1::2, :]
    out_diff = rendered_depth[::2, :] - rendered_depth[1::2, :] + m
    differing_signs = torch.sign(dpt_diff) != torch.sign(out_diff)
    loss = (out_diff[differing_signs] * torch.sign(out_diff[differing_signs]))
    med = loss.quantile(.8)
    return loss[loss < med].mean()

@wp.kernel
def apply_to_model(pose_deltas: wp.array(dtype = float, ndim = 2), means: wp.array(dtype = wp.vec3), quats: wp.array(dtype = float,ndim=2),
                    group_labels: wp.array(dtype = int), centroids: wp.array(dtype = wp.vec3),
                    means_out: wp.array(dtype = wp.vec3), quats_out: wp.array(dtype = float,ndim=2)):
    """
    Takes the current pose_deltas and applies them to each of the group masks
    """
    tid = wp.tid()
    group_id = group_labels[tid]
    position = wp.vector(pose_deltas[group_id,0],pose_deltas[group_id,1],pose_deltas[group_id,2])
    #pose_deltas are in w x y z, we need to flip
    quaternion = wp.quaternion(pose_deltas[group_id,4],pose_deltas[group_id,5],pose_deltas[group_id,6],pose_deltas[group_id,3])
    transform = wp.transformation(position,quaternion)
    means_out[tid] = wp.transform_point(transform,means[tid] - centroids[tid]) + centroids[tid]
    gauss_quaternion = wp.quaternion(quats[tid,1],quats[tid,2],quats[tid,3],quats[tid,0])
    newquat = quaternion*gauss_quaternion
    quats_out[tid,0] = newquat[3]
    quats_out[tid,1] = newquat[0]
    quats_out[tid,2] = newquat[1]
    quats_out[tid,3] = newquat[2]
    
@wp.kernel
def atap_loss(cur_means: wp.array(dtype = wp.vec3), dists: wp.array(dtype = float), ids: wp.array(dtype = int),
               match_ids: wp.array(dtype = int), group_ids1: wp.array(dtype = int), group_ids2: wp.array(dtype=int), 
               connectivity_weights: wp.array(dtype = float,ndim = 2), loss: wp.array(dtype = float)):
    tid = wp.tid()
    id1 = ids[tid]
    id2 = match_ids[tid]
    gid1 = group_ids1[tid]
    gid2 = group_ids2[tid]
    con_weight = connectivity_weights[gid1,gid2]
    curdist = wp.length(cur_means[id1] - cur_means[id2])
    loss[tid] = wp.abs(curdist - dists[tid]) * con_weight

class ATAPLoss:
    touch_radius: float = .01
    N: int = 10
    loss_mult: float = .05
    def __init__(self, dig_model: DiGModel, group_masks: List[torch.Tensor], group_labels: torch.Tensor):
        """
        Initializes the data structure to compute the loss between groups touching
        """
        self.dig_model = dig_model
        self.group_masks = group_masks
        self.group_labels = group_labels
        self.nn_info = []
        for grp in self.group_masks:
            with torch.no_grad():
                dists, ids, match_ids, group_ids1, group_ids = self._radius_nn(grp, self.touch_radius)
                self.nn_info.append((dists, ids, match_ids, group_ids1, group_ids))
                print(f"Group {len(self.nn_info)} has {len(ids)} neighbors")
        self.dists = torch.cat([x[0] for x in self.nn_info]).cuda()
        self.ids = torch.cat([x[1] for x in self.nn_info]).cuda().int()
        self.match_ids = torch.cat([x[2] for x in self.nn_info]).cuda().int()
        self.group_ids1 = torch.cat([x[3] for x in self.nn_info]).cuda().int()
        self.group_ids2 = torch.cat([x[4] for x in self.nn_info]).cuda().int()
        self.num_pairs = torch.cat([torch.tensor(len(x[1])).repeat(len(x[1])) for x in self.nn_info]).cuda().float()
        

    def __call__(self, connectivity_weights: torch.Tensor):
        """
        Computes the loss between groups touching
        connectivity_weights: a tensor of shape (num_groups,num_groups) representing the weights between each group

        returns: a differentiable loss
        """
        assert connectivity_weights.shape == (len(self.group_masks),len(self.group_masks)), "connectivity weights must be a square matrix of size num_groups"
        loss = wp.empty(self.dists.shape[0], dtype=wp.float32, requires_grad=True)
        wp.launch(
            dim = self.dists.shape[0],
            kernel = atap_loss,
            inputs = [wp.from_torch(self.dig_model.gauss_params['means'],dtype=wp.vec3),wp.from_torch(self.dists),
                      wp.from_torch(self.ids),wp.from_torch(self.match_ids),wp.from_torch(self.group_ids1),
                      wp.from_torch(self.group_ids2),wp.from_torch(connectivity_weights),loss]
        )
        return (wp.to_torch(loss)/self.num_pairs).sum()*self.loss_mult
        

    def _radius_nn(self, group_mask: torch.Tensor, r: float):
        """
        returns the nearest neighbors to gaussians in a group within a certain radius (and outside that group)
        returns -1 indices for neighbors outside the radius or within the same group
        """
        global_group_ids = torch.zeros(self.dig_model.num_points,dtype=torch.long,device='cuda')
        for i,grp in enumerate(self.group_masks):
            global_group_ids[grp] = i
        from cuml.neighbors import NearestNeighbors
        model = NearestNeighbors(n_neighbors=self.N)
        means = self.dig_model.means.detach().cpu().numpy()
        model.fit(means)
        dists, match_ids = model.kneighbors(means)
        dists, match_ids = torch.tensor(dists,dtype=torch.float32,device='cuda'),torch.tensor(match_ids,dtype=torch.long,device='cuda')
        dists, match_ids = dists[group_mask], match_ids[group_mask]
        # filter matches outside the radius
        match_ids[dists>r] = -1
        # filter out ones within same group mask
        match_ids[group_mask[match_ids]] = -1
        ids = torch.arange(self.dig_model.num_points,dtype=torch.long,device='cuda')[group_mask].unsqueeze(-1).repeat(1,self.N)
        #flatten all the ids/dists/match_ids
        ids = ids[match_ids!=-1].flatten()
        dists = dists[match_ids!=-1].flatten()
        match_ids = match_ids[match_ids!=-1].flatten()
        return dists, ids, match_ids, global_group_ids[ids], global_group_ids[match_ids]

try:
    loss_plt.remove()
except:
    pass
class RigidGroupOptimizer:
    use_depth: bool = True
    depth_ignore_threshold: float = 0.02 # in meters
    use_atap: bool = True
    pose_lr: float = .005
    pose_lr_final: float = .0005
    mask_hands: bool = False
    def __init__(self, dig_model: DiGModel, dino_loader: DinoDataloader, init_c2o: Cameras, group_masks: List[torch.Tensor], group_labels: torch.Tensor, render_lock = nullcontext()):
        """
        This one takes in a list of gaussian ID masks to optimize local poses for
        Each rigid group can be optimized independently, with no skeletal constraints
        """
        self.tape = None
        self.dig_model = dig_model
        #detach all the params to avoid retain_graph issue
        self.dig_model.gauss_params['means'] = self.dig_model.gauss_params['means'].detach()
        self.dig_model.gauss_params['quats'] = self.dig_model.gauss_params['quats'].detach()
        self.dino_loader = dino_loader
        self.group_labels = group_labels
        self.group_masks = group_masks
        self.init_c2o = deepcopy(init_c2o).to('cuda')
        #store a 7-vec of trans, rotation for each group
        self.pose_deltas = torch.zeros(len(group_masks),7,dtype=torch.float32,device='cuda')
        self.pose_deltas[:,3:] = torch.tensor([1,0,0,0],dtype=torch.float32,device='cuda')
        self.pose_deltas = torch.nn.Parameter(self.pose_deltas)
        k = 3
        s = 0.3 * ((k - 1) * 0.5 - 1) + 0.8
        self.blur = kornia.filters.GaussianBlur2d((k, k), (s, s))
        #NOT USED RN
        self.connectivity_weights = torch.nn.Parameter(-torch.ones(len(group_masks),len(group_masks),dtype=torch.float32,device='cuda'))
        self.optimizer = torch.optim.Adam([self.pose_deltas],lr=self.pose_lr)
        # self.weights_optimizer = torch.optim.Adam([self.connectivity_weights],lr=.001)
        self.init_means = dig_model.gauss_params['means'].detach().clone()
        self.init_quats = dig_model.gauss_params['quats'].detach().clone()
        self.keyframes = []
        # lock to prevent blocking the render thread if provided
        self.render_lock = render_lock
        if self.use_atap:
            self.atap = ATAPLoss(dig_model,group_masks,group_labels)
        self.centroids = torch.empty((self.dig_model.num_points,3),dtype=torch.float32,device='cuda',requires_grad=False)
        for i,mask in enumerate(self.group_masks):
            with torch.no_grad():
                self.centroids[mask] = self.dig_model.gauss_params['means'][mask].mean(dim=0)

    def step(self, niter = 1, use_depth = True, use_rgb = False, metric_depth = False):
        scheduler = ExponentialDecayScheduler(ExponentialDecaySchedulerConfig(lr_final = self.pose_lr_final, max_steps=niter)).get_scheduler(self.optimizer, self.pose_lr)
        for i in range(niter):
            # renormalize rotation representation
            with torch.no_grad():
                self.pose_deltas[:,3:] = self.pose_deltas[:,3:]/self.pose_deltas[:,3:].norm(dim=1,keepdim=True)
            tape = wp.Tape()
            self.optimizer.zero_grad()
            # self.weights_optimizer.zero_grad()
            with self.render_lock:
                self.dig_model.eval()
                with tape:
                    self.apply_to_model(self.pose_deltas)
                dig_outputs = self.dig_model.get_outputs(self.init_c2o)
            if 'dino' not in dig_outputs:
                self.reset_transforms()
                raise RuntimeError("Lost tracking")
            with torch.no_grad():
                object_mask = dig_outputs['accumulation']>.9
            dino_feats = self.blur(dig_outputs["dino"].permute(2,0,1)[None]).squeeze().permute(1,2,0)
            if self.mask_hands:
                pix_loss = (self.frame_pca_feats - dino_feats)[self.hand_mask]
            else:
                pix_loss = (self.frame_pca_feats - dino_feats)
            # THIS IS BAD WE NEED TO FIX THIS (because resizing makes the image very slightly misaligned)
            loss = pix_loss.norm(dim=-1).mean()
            if use_depth and self.use_depth:
                if metric_depth:
                    physical_depth = dig_outputs['depth']/pipeline.datamanager.train_dataset._dataparser_outputs.dataparser_scale
                    valids = object_mask & (~self.frame_depth.isnan())
                    if self.mask_hands:
                        valids = valids & self.hand_mask.unsqueeze(-1)
                    pix_loss = (physical_depth - self.frame_depth)**2
                    pix_loss = pix_loss[valids & (pix_loss<self.depth_ignore_threshold**2)]
                    loss = loss + 0.1*pix_loss.mean()
                else:
                    # This is ranking loss for monodepth (which is disparity)
                    disparity = 1.0 / dig_outputs['depth']
                    N = 20000
                    if self.mask_hands:
                        object_mask = object_mask & self.hand_mask.unsqueeze(-1)
                    valid_ids = torch.where(object_mask)
                    rand_samples = torch.randint(0,valid_ids[0].shape[0],(N,),device='cuda')
                    rand_samples = (valid_ids[0][rand_samples],valid_ids[1][rand_samples])
                    rend_samples = disparity[rand_samples]
                    mono_samples = self.frame_depth[rand_samples]
                    rank_loss = depth_ranking_loss(rend_samples,mono_samples)
                    loss = loss + 0.5*rank_loss
            if use_rgb:
                loss = loss + .05*(dig_outputs['rgb']-self.rgb_frame).abs().mean()
            if self.use_atap:
                null_weights = torch.ones_like(self.connectivity_weights)
                # null_weights = self.connectivity_weights.exp()
                weights = torch.clip(null_weights,0,1)
                with tape:
                    atap_loss = self.atap(weights)
                rigidity_loss = .02*(1-weights).mean()
                symmetric_loss = (weights - weights.T).abs().mean()
                #maximize the connectivity weights, as well as similarity
                loss = loss + atap_loss + symmetric_loss + rigidity_loss
            loss.backward()
            tape.backward()
            self.optimizer.step()
            # self.weights_optimizer.step()
            scheduler.step()
        #reset lr
        self.optimizer.param_groups[0]['lr'] = self.pose_lr
        return dig_outputs
    
    def apply_to_model(self,pose_deltas):
        """
        Takes the current pose_deltas and applies them to each of the group masks
        """
        self.reset_transforms()
        new_quats = torch.empty_like(self.dig_model.gauss_params['quats'],requires_grad=False)
        new_means = torch.empty_like(self.dig_model.gauss_params['means'],requires_grad=True)
        wp.launch(
            kernel = apply_to_model,
            dim = self.dig_model.num_points,
            inputs = [wp.from_torch(pose_deltas),wp.from_torch(self.dig_model.gauss_params['means'],dtype=wp.vec3),
                    wp.from_torch(self.dig_model.gauss_params['quats']),wp.from_torch(self.group_labels),
                    wp.from_torch(self.centroids,dtype=wp.vec3)],
            outputs = [wp.from_torch(new_means,dtype=wp.vec3),wp.from_torch(new_quats)]
        )
        self.dig_model.gauss_params['quats'] = new_quats
        self.dig_model.gauss_params['means'] = new_means


    def register_keyframe(self):
        """
        Saves the current pose_deltas as a keyframe
        """
        self.keyframes.append(self.pose_deltas.detach().clone())

    def apply_keyframe(self,i):
        """
        Applies the ith keyframe to the pose_deltas
        """
        with torch.no_grad():
            self.apply_to_model(self.keyframes[i])

    def reset_transforms(self):
        with torch.no_grad():
            self.dig_model.gauss_params['means'] = self.init_means.clone()
            self.dig_model.gauss_params['quats'] = self.init_quats.clone()

    def set_frame(self, rgb_frame: torch.Tensor, depth: torch.Tensor = None):
        """
        Sets the rgb_frame to optimize the pose for
        rgb_frame: HxWxC tensor image
        init_c2o: initial camera to object transform (given whatever coordinates the self.dig_model is in)
        """
        with torch.no_grad():
            self.rgb_frame = resize(rgb_frame.permute(2,0,1), (self.init_c2o.height,self.init_c2o.width),antialias = True).permute(1,2,0)
            self.frame_pca_feats = self.dino_loader.get_pca_feats(rgb_frame.permute(2,0,1).unsqueeze(0),keep_cuda=True).squeeze()
            self.frame_pca_feats = resize(self.frame_pca_feats.permute(2,0,1), (self.init_c2o.height,self.init_c2o.width),antialias = True).permute(1,2,0)
            if self.use_depth:
                if depth is None:
                    depth = get_depth((self.rgb_frame*255).to(torch.uint8))
                self.frame_depth = resize(depth.unsqueeze(0), (self.init_c2o.height,self.init_c2o.width),antialias = True).squeeze().unsqueeze(-1)
            if self.mask_hands:
                self.hand_mask = get_hand_mask((self.rgb_frame*255).to(torch.uint8))
                self.hand_mask = torch.nn.functional.max_pool2d(self.hand_mask[None,None],3,padding=1,stride=1).squeeze() == 0.0


MATCH_RESOLUTION = 500
camera_input = 'iphone' # ['train_cam', 'iphone','zed', 'iphone_vertical','zed_svo']
video_path = Path("motion_vids/mac_charger_fold3.MOV")
svo_path = Path("motion_vids/buddha_close_good.svo2")
start_time = 0.3


if camera_input == 'train_cam':
    init_cam,data = pipeline.datamanager.next_train(0)
    view_cam_pose = pipeline.viewer_control.get_camera(200,None,0)
    init_cam.camera_to_worlds = view_cam_pose.camera_to_worlds
    init_cam.rescale_output_resolution(MATCH_RESOLUTION/max(init_cam.width,init_cam.height))
elif camera_input == 'iphone':
    init_cam = Cameras(camera_to_worlds=pipeline.viewer_control.get_camera(200,None,0).camera_to_worlds,fx = 1137.0,fy = 1137.0,cx = 1280.0/2,cy = 720/2,width=1280,height=720)
    init_cam.rescale_output_resolution(MATCH_RESOLUTION/max(init_cam.width,init_cam.height))
elif camera_input == 'iphone_vertical':
    init_cam = Cameras(camera_to_worlds=pipeline.viewer_control.get_camera(200,None,0).camera_to_worlds,fy = 1137.0,fx = 1137.0,cy = 1280/2,cx = 720/2,height=1280,width=720)
    init_cam.rescale_output_resolution(MATCH_RESOLUTION/max(init_cam.width,init_cam.height))
elif camera_input in ['zed','zed_svo']:
    try:
        zed.cam.close()
        del zed
    except:
        pass
    finally:
        zed = Zed(recording_file=str(svo_path.absolute()) if camera_input == 'zed_svo' else None, start_time=start_time)
    fps = 30
    left_rgb,_,_ = zed.get_frame()
    K = zed.get_K()
    init_cam = Cameras(camera_to_worlds=pipeline.viewer_control.get_camera(200,None,0).camera_to_worlds,fx = K[0,0],fy = K[1,1],cx = K[0,2],cy = K[1,2],width=1920,height=1080)
    init_cam.rescale_output_resolution(MATCH_RESOLUTION/max(init_cam.width,init_cam.height))
outputs = pipeline.model.get_outputs_for_camera(init_cam)
if pipeline.cluster_labels is not None:
    labels = pipeline.cluster_labels.int().cuda()
    group_masks = [(cid == labels).cuda() for cid in range(labels.max() + 1)]
else:
    labels = torch.zeros(pipeline.model.num_points).int().cuda()
    group_masks = [torch.ones(pipeline.model.num_points).bool().cuda()]
optimizer = RigidGroupOptimizer(pipeline.model,dino_loader,init_cam,group_masks, group_labels = labels, render_lock = v.train_lock)
rgb_renders = [] 

In [None]:
if camera_input in ['zed','zed_svo']:
    left_rgb, right_rgb,depth = zed.get_frame()
    target_frame_rgb = (left_rgb/255)
    right_frame_rgb = (right_rgb/255)
    optimizer.set_frame(target_frame_rgb,depth=depth)
else:
    assert video_path.exists()
    motion_clip = cv2.VideoCapture(str(video_path.absolute()))
    start=1
    end=4
    fps = 30
    frame = get_vid_frame(motion_clip,start)
    target_frame_rgb = ToTensor()(Image.fromarray(frame)).permute(1,2,0).cuda()
    optimizer.set_frame(target_frame_rgb)
_,axs = plt.subplots(1,2,figsize=(10,4))
axs[0].imshow(outputs["rgb"].detach().cpu().numpy())
axs[1].imshow(target_frame_rgb.cpu().numpy())

In [None]:
from nerfstudio.utils.colormaps import apply_depth_colormap
import tqdm
import moviepy.editor as mpy
import plotly.express as px
def plotly_render(frame):
    fig = px.imshow(frame)
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),showlegend=False,yaxis_visible=False, yaxis_showticklabels=False,xaxis_visible=False, xaxis_showticklabels=False
    )
    return fig
fig = plotly_render(outputs['rgb'].detach().cpu().numpy())
try:
    frame_vis.remove()
except:
    pass
frame_vis = pipeline.viewer_control.viser_server.add_gui_plotly(fig, 9/16)
try:
    animate_button.remove()
    frame_slider.remove()
    reset_button.remove()
except:
    pass
def composite_vis_frame(target_frame_rgb,outputs):
    target_vis_frame = resize(target_frame_rgb.permute(2,0,1),(outputs["rgb"].shape[0],outputs["rgb"].shape[1])).permute(1,2,0)
    # composite the outputs['rgb'] on top of target_vis frame
    target_vis_frame = target_vis_frame*0.5 + outputs["rgb"]*0.5
    return target_vis_frame

try:
    render_button.remove()
    filename_input.remove()
    status_mkdown.remove()
except:
    pass
import viser
filename_input = v.viser_server.add_gui_text("File Name","render")
status_mkdown = v.viser_server.add_gui_markdown(" ")
render_button = v.viser_server.add_gui_button("Render Animation",color='green',icon=viser.Icon.MOVIE)
@render_button.on_click
def render(_):
    render_button.disabled = True
    render_frames = []
    camera = pipeline.viewer_control.get_camera(1080,1920,0)
    for i in tqdm.tqdm(range(len(optimizer.keyframes))):
        status_mkdown.content = f"Rendering...{i/len(optimizer.keyframes):.01f}"
        pipeline.model.eval()
        optimizer.apply_keyframe(i)
        with torch.no_grad():
            outputs = pipeline.model.get_outputs_for_camera(camera)
        render_frames.append(outputs["rgb"].detach().cpu().numpy()*255)
    status_mkdown.content = "Saving..."
    out_clip = mpy.ImageSequenceClip(render_frames, fps=fps)
    fname = filename_input.value
    (OUTPUT_FOLDER / 'posed_renders').mkdir(exist_ok=True)
    render_folder = OUTPUT_FOLDER / 'posed_renders'
    out_clip.write_videofile(f"{render_folder}/{fname}.mp4", fps=fps,codec='libx264')
    out_clip.write_videofile(f"{render_folder}/{fname}_mac.mp4", fps=fps,codec='mpeg4',bitrate='5000k')
    v.viser_server.send_file_download(f"{fname}_mac.mp4",open(f"{render_folder}/{fname}_mac.mp4",'rb').read())
    status_mkdown.content = "Done!"
    render_button.disabled = False


if camera_input in ['zed','zed_svo']:
    if len(rgb_renders)==0:
        for i in tqdm.tqdm(range(10)):
            target_vis_frame = composite_vis_frame(target_frame_rgb,outputs)
            vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1).detach().cpu().numpy()
            fig = plotly_render(target_vis_frame.detach().cpu().numpy())
            frame_vis.figure = fig
            rgb_renders.append(vis_frame*255)
            outputs = optimizer.step(50, use_depth=i>7, metric_depth=True)
    while True:
        # If input camera is the zed, just loop it indefinitely until no more frames
        left_rgb, _, depth = zed.get_frame()
        if left_rgb is None:
            break
        target_frame_rgb = left_rgb/255
        optimizer.set_frame(target_frame_rgb,depth=depth)
        outputs = optimizer.step(50, metric_depth=True)
        v._trigger_rerender()
        optimizer.register_keyframe()
        target_vis_frame = composite_vis_frame(target_frame_rgb,outputs)
        vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1).detach().cpu().numpy()
        rgb_renders.append(vis_frame*255)
        fig = plotly_render(target_vis_frame.detach().cpu().numpy())
        frame_vis.figure = fig
elif camera_input in ['iphone','iphone_vertical','train_cam']:
    # Otherwise procces the video
    if len(rgb_renders)==0:
        for i in tqdm.tqdm(range(10)):
            target_vis_frame = composite_vis_frame(target_frame_rgb,outputs)
            vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1).detach().cpu().numpy()
            fig = plotly_render(target_vis_frame.detach().cpu().numpy())
            frame_vis.figure = fig
            rgb_renders.append(vis_frame*255)
            outputs = optimizer.step(30, use_depth=i>7, metric_depth=False)

    for t in tqdm.tqdm(np.linspace(start,end,int((end-start)*fps))):
        frame = get_vid_frame(motion_clip,t)
        target_frame_rgb = ToTensor()(Image.fromarray(frame)).permute(1,2,0).cuda()
        optimizer.set_frame(target_frame_rgb)
        outputs = optimizer.step(50, metric_depth=False)
        optimizer.register_keyframe()
        v._trigger_rerender()
        target_vis_frame = composite_vis_frame(target_frame_rgb,outputs)
        vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1).detach().cpu().numpy()
        fig = plotly_render(target_vis_frame.detach().cpu().numpy())
        frame_vis.figure = fig
        rgb_renders.append(vis_frame*255)
#save as an mp4
out_clip = mpy.ImageSequenceClip(rgb_renders, fps=fps)  

fname = str(OUTPUT_FOLDER / "optimizer_out_dance.mp4")

out_clip.write_videofile(fname, fps=fps,codec='libx264')
out_clip.write_videofile(fname.replace('.mp4','_mac.mp4'),fps=fps,codec='mpeg4',bitrate='5000k')

# Populate some viewer elements to visualize the animation
animate_button = v.viser_server.add_gui_button("Play Animation")
frame_slider = v.viser_server.add_gui_slider("Frame",0,len(optimizer.keyframes)-1,1,0)
reset_button = v.viser_server.add_gui_button("Reset Transforms")

@animate_button.on_click
def play_animation(_):
    for i in range(len(optimizer.keyframes)):
        optimizer.apply_keyframe(i)
        v._trigger_rerender()
        time.sleep(1/fps)
@frame_slider.on_update
def apply_keyframe(_):
    optimizer.apply_keyframe(frame_slider.value)
    v._trigger_rerender()
@reset_button.on_click
def reset_transforms(_):
    optimizer.reset_transforms()
    v._trigger_rerender()