In [7]:
import os
gpu_id = 0
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

import cv2
import sys
import json
import time
import mmcv
import glob
import torch
import shutil
import random
import pickle
import hashlib
import numpy as np
import torch.nn as nn
from torch import optim
import mediapy as media
import matplotlib.pyplot as plt
from argparse import ArgumentParser

import torch.nn.functional as torch_F
from torchvision.ops import roi_align
device = torch.device('cuda:0')

from pytorch3d import io as py3d_io
from pytorch3d import ops as py3d_ops
from pytorch3d import loss as py3d_loss
from pytorch3d import utils as py3d_util
from pytorch3d import structures as py3d_struct
from pytorch3d import renderer as py3d_renderer
from pytorch3d import transforms as py3d_transform
from pytorch3d.vis import plotly_vis as py3d_vis
from pytorch3d.transforms import (matrix_to_euler_angles,
                                  euler_angles_to_matrix, 
                                  matrix_to_rotation_6d, 
                                  rotation_6d_to_matrix)

PROJ_ROOT = os.path.dirname(os.getcwd())
sys.path.append(PROJ_ROOT)

from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
L1Loss = torch.nn.L1Loss(reduction='mean')
MSELoss = torch.nn.MSELoss(reduction='mean')
SSIM_METRIC = SSIM(data_range=1, size_average=True, channel=3) # channel=1 for grayscale images
MS_SSIM_METRIC = MS_SSIM(data_range=1, size_average=True, channel=3)


from misc_utils import coseg_utils
from misc_utils.metric_utils import *
from config import inference_cfg as CFG
from misc_utils.warmup_lr import CosineAnnealingWarmupRestarts
from dataset.inference_datasets import datasetCallbacks

from model.network import model_arch as ModelNet
from inference import (gaussian_PipeP, gaussian_BG,
                       perform_segmentation_and_encoding, 
                       multiple_initial_pose_inference)


ckpt_file = os.path.join(PROJ_ROOT, 'checkpoints/model_weights.pth')

model_net = ModelNet().to(device)
model_net.load_state_dict(torch.load(ckpt_file, map_location=device))
model_net.eval()
print('Model weights are loaded!')

sys.path.append(os.path.join(PROJ_ROOT, 'gaussian_splatting'))

from argparse import ArgumentParser, Namespace
from gaussian_splatting.scene.cameras import Camera
from gaussian_splatting.gaussian_renderer import render
from gaussian_splatting.utils.graphics_utils import focal2fov
from gaussian_splatting.scene.gaussian_model import GaussianModel
from gaussian_splatting.arguments import ModelParams, PipelineParams, OptimizationParams
from gaussian_splatting.build_3DGaussianObject import training as build_3DGaussianObject



Model weights are loaded!


In [8]:
def GS_Tracker(model_func, ref_database, frame, camK, prev_pose):
    zoom_outp = coseg_utils.zoom_in_and_crop_with_offset(image=frame, K=camK, t=prev_pose[:3, 3],
                                                         radius=ref_database['obj_diameter']/2,
                                                         target_size=CFG.zoom_image_scale, 
                                                         margin=CFG.zoom_image_margin)                
    zoom_camK = zoom_outp['zoom_camK']      
    zoom_image = zoom_outp['zoom_image']
    bbox_scale = zoom_outp['bbox_scale']   
    bbox_center = zoom_outp['bbox_center']
    zoom_offset = zoom_outp['zoom_offset']
    zoom_offsetX = zoom_offset[0]
    zoom_offsetY = zoom_offset[1]
    
    zoom_FovX = focal2fov(zoom_camK[0, 0], CFG.zoom_image_scale)
    zoom_FovY = focal2fov(zoom_camK[1, 1], CFG.zoom_image_scale)
    target_image = zoom_image.permute(2, 0, 1).to(device) # 3xSxS

    fg_trunc_mask = (target_image.sum(dim=0, keepdim=True) > 0).type(torch.float32) # 1xSxS
    
    with torch.no_grad():
        target_mask = model_func.query_cosegmentation(
            model_func.extract_DINOv2_feature(target_image[None]), 
            x_ref=ref_database['obj_fps_feats'], ref_mask=ref_database['obj_fps_masks'],
        ).sigmoid().squeeze(0)

    zoom_image_np = (target_image.detach().cpu().permute(1, 2, 0) * 255).numpy().astype(np.uint8)
    
    obj_gaussians = reference_database['3DGO']
    track_camera = Camera(T=prev_pose[:3, 3],
                          R=prev_pose[:3, :3].T, 
                          FoVx=zoom_FovX, FoVy=zoom_FovY,
                          cx_offset=zoom_offsetX, cy_offset=zoom_offsetY,
                          image=target_image, colmap_id=0, uid=0, image_name='', 
                          gt_alpha_mask=None, data_device=device)

    obj_gaussians.initialize_pose()
    
    optimizer = optim.AdamW([obj_gaussians._delta_R, obj_gaussians._delta_T])

    lr_scheduler = CosineAnnealingWarmupRestarts(optimizer, 
                                                 CFG.MAX_STEPS, 
                                                 warmup_steps=CFG.WARMUP, 
                                                 max_lr=CFG.START_LR, min_lr=CFG.END_LR)
    losses = list()
    target_image *= target_mask
    for iter_step in range(CFG.MAX_STEPS):
        optimizer.zero_grad()
        render_img = render(track_camera, obj_gaussians, gaussian_PipeP, gaussian_BG)['render'] * fg_trunc_mask
        loss = 0
        
        rgb_loss = L1Loss(render_img, target_image)
        loss += rgb_loss
    
        ssim_loss = 1 - SSIM_METRIC(render_img[None], target_image[None])
        loss += ssim_loss
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        
        losses.append(loss.item())
        if iter_step >= CFG.EARLY_STOP_MIN_STEPS:
            loss_grads = (torch.as_tensor(losses)[1:] - torch.as_tensor(losses)[:-1]).abs()
            loss_grad = loss_grads[-CFG.EARLY_STOP_MIN_STEPS:].mean() 
            if loss_grad < CFG.EARLY_STOP_LOSS_GRAD_NORM:
                break
    
    gs3d_delta_RT = obj_gaussians.get_delta_pose.squeeze(0).detach().cpu().numpy()
    curr_pose = prev_pose @ gs3d_delta_RT
        
    return{
        'track_pose': curr_pose,
        'render_img': render_img,
        'bbox_scale': bbox_scale,
        'bbox_center': bbox_center,
        'iter_step': iter_step,
    }


In [5]:
mode = 'test'
dataset_name = 'VideoTrack'
obj_name = 'waterbottle'

data_root = datasetCallbacks[dataset_name]['DATAROOT']
datasetObjects = datasetCallbacks[dataset_name]['OBJECTS']
datasetLoader = datasetCallbacks[dataset_name]['DATASETLOADER']
obj_dir_name = datasetObjects[obj_name]

obj_test_dataset = datasetLoader(data_root, obj_name, subset_mode=mode)
obj_train_dataset = datasetLoader(data_root, obj_name, subset_mode='train')

obj_pointcloud = obj_test_dataset.obj_pointcloud
bbox3d_diameter = obj_test_dataset.bbox3d_diameter

obj_camK = torch.as_tensor(obj_test_dataset.camK, dtype=torch.float32) 

database_dir = os.path.join(PROJ_ROOT, 'obj_database', dataset_name, 'objects')

obj_refer_database_dir = os.path.join(database_dir, f'{obj_name}')

ref_database_path = os.path.join(obj_refer_database_dir, 'object_reference_database.pkl')

with open(ref_database_path, 'rb') as df:
    reference_database = pickle.load(df)

for _key, _val in reference_database.items():
    reference_database[_key] = torch.as_tensor(_val, dtype=torch.float32).to(device)
# print('Object dtabase is loaded!')

obj_3D_bbox = torch.as_tensor(obj_train_dataset.obj_bbox3d, dtype=torch.float32).squeeze()
cannon_3D_bbox = obj_3D_bbox.clone()
cannon_3D_bbox[0] = obj_3D_bbox[0]
cannon_3D_bbox[1] = obj_3D_bbox[4]
cannon_3D_bbox[2] = obj_3D_bbox[7]
cannon_3D_bbox[3] = obj_3D_bbox[3]

cannon_3D_bbox[4] = obj_3D_bbox[1]
cannon_3D_bbox[5] = obj_3D_bbox[5]
cannon_3D_bbox[6] = obj_3D_bbox[6]
cannon_3D_bbox[7] = obj_3D_bbox[2]


gaussians = GaussianModel(sh_degree=3)
ply_dirs = glob.glob(os.path.join(f'{obj_refer_database_dir}/point_cloud/iteration_*'))
gs_ply_dir = sorted(ply_dirs, key=lambda x: int(x.split('_')[-1]), reverse=True)[0]
gs_ply_path = f'{gs_ply_dir}/point_cloud.ply'
gaussians.load_ply(gs_ply_path)

# print('3D Gaussian object model is loaded!')

reference_database['3DGO'] = gaussians
reference_database['obj_diameter'] = bbox3d_diameter


In [9]:
CFG.USE_MS_SSIM = False
CFG.EARLY_STOP_LOSS_GRAD_NORM = 1e-4

start_idx = 0
que_data = obj_test_dataset[start_idx]
gt_pose = que_data['pose']
img_camK = que_data['camK']
que_image = que_data['image']      # HxWx3
que_image_ID = que_data['image_ID']
que_hei, que_wid = que_image.shape[:2]

target_size = CFG.zoom_image_scale
raw_hei, raw_wid = que_image.shape[:2]
raw_long_size = max(raw_hei, raw_wid)
raw_short_size = min(raw_hei, raw_wid)
raw_aspect_ratio = raw_short_size / raw_long_size
if raw_hei < raw_wid:
    new_wid = CFG.query_longside_scale
    new_hei = int(new_wid * raw_aspect_ratio)
else:
    new_hei = CFG.query_longside_scale
    new_wid = int(new_hei * raw_aspect_ratio)
query_rescaling_factor = CFG.query_longside_scale / raw_long_size
que_image = que_image[None, ...].permute(0, 3, 1, 2).to(device)
que_image = torch_F.interpolate(que_image, size=(new_hei, new_wid), mode='bilinear', align_corners=True)

obj_data = perform_segmentation_and_encoding(model_net, 
                                             que_image=que_image, 
                                             ref_database=reference_database, device=device)


obj_data['bbox_scale'] /= query_rescaling_factor  # back to the original image scale
obj_data['bbox_center'] /= query_rescaling_factor # back to the original image scale

obj_data['obj_camK'] = obj_camK
obj_data['img_camK'] = img_camK
obj_data['img_scale'] = max(que_hei, que_wid)

init_pose = multiple_initial_pose_inference(obj_data=obj_data, ref_database=reference_database, device=device)[0]

frame_idxes = list()
video_frames = list()
track_poses = list()
track_pose = init_pose.copy()
num_frames = len(obj_test_dataset)

frame_interval = 1
track_accum_runtime = 0
for view_idx in range(start_idx, num_frames, frame_interval):    
    que_data = obj_test_dataset[view_idx]
    
    image = que_data['image']
    camK = que_data['camK'].clone()
    GT_pose = que_data['pose'].numpy().copy()
    image_hei, image_wid = image.shape[:2]
    
    track_timer = time.time()
    track_outp = GS_Tracker(model_net, 
                            frame=image, prev_pose=track_pose, 
                            camK=camK, ref_database=reference_database)
    frame_cost = time.time() - track_timer
    track_accum_runtime += frame_cost
    
    iter_step = track_outp['iter_step']
    render_img = track_outp['render_img']
    track_pose = track_outp['track_pose']
    bbox_scale = track_outp['bbox_scale']
    bbox_center = track_outp['bbox_center']

    track_poses.append(track_pose)
        
    query_img_np = (image * 255).numpy().astype(np.uint8)
    
    track_bbox3d_img = query_img_np.copy()
    
    
    render_full_img = coseg_utils.zoom_out_and_uncrop_image(render_img, # 3xSxS
                                                            bbox_scale=bbox_scale,
                                                            bbox_center=bbox_center,
                                                            orig_hei=image.shape[0],
                                                            orig_wid=image.shape[1],
                                                            ).detach().cpu().squeeze() # HxWx3
    render_full_img_np = (torch.clamp(render_full_img, 0, 1.0) * 255).numpy().astype(np.uint8)
    track_render_img = render_full_img_np.copy()
    track_render_img = cv2.addWeighted(track_render_img, 0.7, query_img_np, 0.3, 1)
        
    track_RT = torch.as_tensor(track_pose, dtype=torch.float32)
    track_bbox_KRT = torch.einsum('ij,kj->ki', track_RT[:3, :3], cannon_3D_bbox) + track_RT[:3, 3][None, :]
    track_bbox_KRT = torch.einsum('ij,kj->ki', camK, track_bbox_KRT)
    track_bbox_pts = (track_bbox_KRT[:, :2] / track_bbox_KRT[:, 2:3]).type(torch.int64)
    track_bbox_corner = track_bbox_pts.numpy()
    
    track_bbox3d_img = coseg_utils.draw_3d_bounding_box(track_bbox3d_img, track_bbox_corner, 
                                                        color=(0, 255, 0), linewidth=20)

    small_hei = image_hei // 3
    small_wid = image_wid // 3
    query_img_np = cv2.resize(query_img_np, (small_wid, small_hei))
    track_render_img = cv2.resize(track_render_img, (small_wid, small_hei))
    track_bbox3d_img = cv2.resize(track_bbox3d_img, (small_wid, small_hei))
        
    scale = 1.5
    thickness = 3
    color = (0, 255, 0)
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(query_img_np,     '  Input Video  ', (60, 50), font, scale, color, thickness=thickness)
    cv2.putText(track_bbox3d_img, 'Tracking Result', (60, 50), font, scale, color, thickness=thickness)
    cv2.putText(track_render_img, 'Rendered Object', (60, 50), font, scale, color, thickness=thickness)

    concat_images = np.concatenate([query_img_np, track_bbox3d_img, track_render_img], axis=1)
    video_frames.append(concat_images)
    
    if view_idx % 24 == 0:
        print('[{}/{}], \t{:.1f} FPS'.format(view_idx, num_frames, (view_idx - start_idx) / track_accum_runtime))


[0/233], 	0.0 FPS
[24/233], 	2.0 FPS
[48/233], 	2.1 FPS
[72/233], 	2.1 FPS
[96/233], 	2.1 FPS
[120/233], 	2.1 FPS
[144/233], 	2.1 FPS
[168/233], 	2.1 FPS
[192/233], 	2.1 FPS
[216/233], 	2.1 FPS


In [10]:
media.show_video(np.stack(video_frames, axis=0), fps=24, width=320*3)


0
This browser does not support the video tag.


# Visualize the coordinates of 3D Gaussians model

In [11]:
so3_sphere = py3d_struct.Pointclouds(
    points=[gaussians.get_xyz.squeeze().detach().cpu()],
    features=[gaussians._features_dc.squeeze().detach().cpu().sigmoid()]
)

fig = py3d_vis.plot_scene(
    {" ": 
        {
            'sphere': so3_sphere,
        }
    },
#     xaxis={"backgroundcolor":"rgb(200, 200, 230)"},
#     yaxis={"backgroundcolor":"rgb(230, 200, 200)"},
#     zaxis={"backgroundcolor":"rgb(200, 230, 200)"}, 
    
    xaxis={"backgroundcolor":"rgb(255, 255, 255)"},
    yaxis={"backgroundcolor":"rgb(255, 255, 255)"},
    zaxis={"backgroundcolor":"rgb(255, 255, 255)"}, 
    
    pointcloud_marker_size=2,
    pointcloud_max_points=30_000,
    axis_args=py3d_vis.AxisArgs(showgrid=True)
)

fig.update_layout(width=800, height=600)
fig.show()

