In [2]:
import sys,os,imageio
root = '/host/home/ubuntu/mvsnerf'
os.chdir(root)
sys.path.append(root)

from opt import config_parser
from data import dataset_dict
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pathlib import Path

# # models
# from models import *
# from renderer import *
# from data.ray_utils import get_rays
# from scipy.spatial.transform import Rotation as R
from render_utils import *

from tqdm import tqdm
from skimage.metrics import structural_similarity

# pytorch-lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer, loggers


from data.ray_utils import ray_marcher

%load_ext autoreload
%autoreload 2

torch.cuda.set_device(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

  def _figure_formats_changed(self, name, old, new):


# Rendering video from finetuned ckpts

In [3]:
data_dir = Path('/host/data/')

In [4]:
nerf_root_dir = data_dir/'NeRF_Data'
zeiss_root_dir = data_dir/'ZEISS'

# Refactoring code

In [4]:
## Options
scene = 'HeadScan'
i_scene = 0
cmd = f'''  --datadir {zeiss_root_dir}/{scene}  \
                --dataset_name llff --imgScale_test {1.0}  --netwidth 128 --net_type v0 \
                --use_viewdirs \
                --N_samples 128 \
                --chunk 5120
                '''
is_finetuned = False # set False if rendering without finetuning
if is_finetuned:
    cmd += f'--ckpt ./runs_fine_tuning/{scene}-ft/ckpts/latest.tar --use_disp '
    # cmd += '--use_color_volume ' # add only if model was finetuned with this option
else:
    cmd += '--ckpt ./ckpts/mvsnerf-v0.tar'
    
args = config_parser(cmd.split())
# options not included in original option set
args.feat_dim = 8+3*4 

save_dir = f'results/video2'
os.makedirs(save_dir, exist_ok=True)

In [5]:
## Create models
if i_scene==0 or is_finetuned:
    render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
    filter_keys(render_kwargs_train)

    MVSNet = render_kwargs_train['network_mvs']
    render_kwargs_train.pop('network_mvs')

datadir = args.datadir
datatype = 'train'
pad = 24 #the padding value should be same as your finetuning ckpt


dataset = dataset_dict[args.dataset_name](args, split=datatype)
val_idx = dataset.img_idx

MVSNet.train()
MVSNet = MVSNet.cuda()    

Found ckpts ['./ckpts/mvsnerf-v0.tar']
Reloading from ./ckpts/mvsnerf-v0.tar
9 9 /host/data/ZEISS/HeadScan
===> training index: [0, 1, 2, 3, 4, 5, 6, 7, 8]


In [6]:
with torch.no_grad():
    c2ws_all = dataset.poses

    imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)

    if is_finetuned:   # large baseline
        volume_feature = torch.load(args.ckpt)['volume']['feat_volume']
        volume_feature = RefVolume(volume_feature.detach()).cuda()
    else:            
        # neighboring views with position distance
        volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad, lindisp=args.use_disp)

    pad *= args.imgScale_test
    w2cs, c2ws = pose_source['w2cs'], pose_source['c2ws']
    pair_idx = torch.load('configs/pairs.th')[f'{scene}_train']
    # pair_idx = [i for i in range(9)]
    # pdb.set_trace()
    c2ws_render = get_spiral(c2ws_all[pair_idx], near_far_source, rads_scale = 0.6, N_views=60)# you can enlarge the rads_scale if you want to render larger baseline        
    c2ws_render = c2ws_all # experimental
    c2ws_render = torch.from_numpy(np.stack(c2ws_render)).float().to(device)
    
    imgs_source = unpreprocess(imgs_source)

    try:
        tqdm._instances.clear() 
    except Exception:     
        pass

    frames = []
    img_directions = dataset.directions.to(device)
    for i, c2w in enumerate(tqdm(c2ws_render)):
        torch.cuda.empty_cache()
        
        rays_o, rays_d = get_rays(img_directions, c2w)  # both (h*w, 3)
        rays = torch.cat([rays_o, rays_d,
                    near_far_source[0] * torch.ones_like(rays_o[:, :1]),
                    near_far_source[1] * torch.ones_like(rays_o[:, :1])],
                1).to(device)  # (H*W, 3)
        
        
        N_rays_all = rays.shape[0]
        rgb_rays, depth_rays_preds = [],[]
        for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

            xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                N_samples=args.N_samples, lindisp=args.use_disp)

            # Converting world coordinate to ndc coordinate
            H, W = imgs_source.shape[-2:]
            inv_scale = torch.tensor([W - 1, H - 1]).to(device)
            w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
            xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                            near=near_far_source[0], far=near_far_source[1], pad=pad, lindisp=args.use_disp)


            # rendering
            rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                    xyz_NDC, z_vals, rays_o, rays_d,
                                                                    volume_feature,imgs_source, **render_kwargs_train)

            rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
            rgb_rays.append(rgb)
            depth_rays_preds.append(depth_pred)

        
        depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
        depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
        
        rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
        H_crop, W_crop = np.array(rgb_rays.shape[:2])//20
        rgb_rays = rgb_rays[H_crop:-H_crop,W_crop:-W_crop]
        depth_rays_preds = depth_rays_preds[H_crop:-H_crop,W_crop:-W_crop]
        img_vis = np.concatenate((rgb_rays*255,depth_rays_preds),axis=1)
        

        frames.append(img_vis.astype('uint8'))
imageio.mimwrite(f'{save_dir}/ft_{scene}_spiral{"" if is_finetuned else "_zeroshot"}.mp4', np.stack(frames), fps=10, quality=10)

====> ref idx: [0, 1, 2]


100%|██████████| 9/9 [05:43<00:00, 38.20s/it]


# Original Code

In [4]:
for i_scene, scene in enumerate(['HeadScan']):#'horns','flower','orchids', 'room','leaves','fern','trex','fortress'
    # add --use_color_volume if the ckpts are fintuned with this flag
    cmd = f'--datadir {zeiss_root_dir}/{scene}  \
     --dataset_name llff --imgScale_test {1.0}  --netwidth 128 --net_type v0 '

    is_finetuned = False # set False if rendering without finetuning
    if is_finetuned:
        cmd += f'--ckpt ./runs_fine_tuning/{scene}-ft/ckpts/latest.tar --use_disp '
    else:
        cmd += '--ckpt ./ckpts/mvsnerf-v0.tar'
        
    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim =  8+3*4
    # args.use_color_volume = False if not is_finetuned else args.use_color_volume

    # create models
    if i_scene==0 or is_finetuned:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'train'
    pad = 24 #the padding value should be same as your finetuning ckpt
    args.chunk = 5120


    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx

    save_dir = f'results/video2'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():

        c2ws_all = dataset.poses

        if is_finetuned:   
            # large baseline
            imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)

            volume_feature = torch.load(args.ckpt)['volume']['feat_volume']
            volume_feature = RefVolume(volume_feature.detach()).cuda()
            
            pad *= args.imgScale_test
            w2cs, c2ws = pose_source['w2cs'], pose_source['c2ws']
            pair_idx = torch.load('configs/pairs.th')[f'{scene}_train']
            # pdb.set_trace()
            c2ws_render = get_spiral(c2ws_all[pair_idx], near_far_source, rads_scale = 0.6, N_views=60)# you can enlarge the rads_scale if you want to render larger baseline
        else:            
            # neighboring views with position distance
            imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
            volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad, lindisp=args.use_disp)
            
            pad *= args.imgScale_test
            w2cs, c2ws = pose_source['w2cs'], pose_source['c2ws']
            pair_idx = torch.load('configs/pairs.th')[f'{scene}_train']
            # pdb.set_trace()
            c2ws_render = get_spiral(c2ws_all[pair_idx], near_far_source, rads_scale = 0.6, N_views=60)# you can enlarge the rads_scale if you want to render larger baseline
        c2ws_render = c2ws_all
        c2ws_render = torch.from_numpy(np.stack(c2ws_render)).float().to(device)

            
        imgs_source = unpreprocess(imgs_source)

        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        frames = []
        img_directions = dataset.directions.to(device)
        for i, c2w in enumerate(tqdm(c2ws_render)):
            torch.cuda.empty_cache()
            
            rays_o, rays_d = get_rays(img_directions, c2w)  # both (h*w, 3)
            rays = torch.cat([rays_o, rays_d,
                     near_far_source[0] * torch.ones_like(rays_o[:, :1]),
                     near_far_source[1] * torch.ones_like(rays_o[:, :1])],
                    1).to(device)  # (H*W, 3)
            
            
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples, lindisp=args.use_disp)

                # Converting world coordinate to ndc coordinate
                H, W = imgs_source.shape[-2:]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad, lindisp=args.use_disp)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)

            
            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
            H_crop, W_crop = np.array(rgb_rays.shape[:2])//20
            rgb_rays = rgb_rays[H_crop:-H_crop,W_crop:-W_crop]
            depth_rays_preds = depth_rays_preds[H_crop:-H_crop,W_crop:-W_crop]
            img_vis = np.concatenate((rgb_rays*255,depth_rays_preds),axis=1)
            

            frames.append(img_vis.astype('uint8'))
                
    imageio.mimwrite(f'{save_dir}/ft_{scene}_spiral{"" if is_finetuned else "_zeroshot"}.mp4', np.stack(frames), fps=10, quality=10)


Found ckpts ['./ckpts/mvsnerf-v0.tar']
Reloading from ./ckpts/mvsnerf-v0.tar
9 9 /host/data/ZEISS/HeadScan
===> training index: [0, 1, 2, 3, 4, 5, 6, 7, 8]
====> ref idx: [0, 1, 2]


100%|██████████| 9/9 [05:45<00:00, 38.36s/it]


# Blender Code

In [6]:
for i_scene, scene in enumerate(['hotdog','lego', 'mic', 'ship', 'drums', 'chair']):#'ship','drums','ficus','materials',

    cmd = f'--datadir {nerf_root_dir}/nerf_synthetic/{scene}\
     --dataset_name blender --white_bkgd --imgScale_test {1.0} '

    is_finetuned = True # set True if rendering with finetuning
    if is_finetuned:
        cmd += f'--ckpt ./runs_fine_tuning/{scene}-ft/ckpts//latest.tar'
        pad = 0 #the padding value should be same as your finetuning ckpt
    else:
        cmd += '--ckpt ./ckpts//mvsnerf-v0.tar'
        pad = 24 #the padding value should be same as your finetuning ckpt
        
    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim = 8+3*4
#     args.use_color_volume = False if not is_finetuned else args.use_color_volume

    # create models
    if i_scene==0 or is_finetuned:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'val'
    args.chunk = 5120
    frames = 60


    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx
    
    save_as_image = False
    save_dir = f'results/video2'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():

        if is_finetuned:   
            # large baselien
            imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
            volume_feature = torch.load(args.ckpt)['volume']['feat_volume']
            volume_feature = RefVolume(volume_feature.detach()).cuda()
            c2ws_render = nerf_video_path(pose_source['c2ws'].cpu(), N_views=frames)
        else:            
            # neighboring views with angle distance
            c2ws_all = dataset.load_poses_all()
            random_selete = torch.randint(0,len(c2ws_all),(1,))     #!!!!!!!!!! you may change this line if rendering a specify view 
            dis = np.sum(c2ws_all[:,:3,2] * c2ws_all[[random_selete],:3,2], axis=-1)
            pair_idx = np.argsort(dis)[::-1][torch.randperm(5)[:3]]
            imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)#pair_idx=pair_idx, 
            volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad)
            
            #####
            c2ws_render = gen_render_path(c2ws_all[pair_idx], N_views=frames)
            c2ws_render = torch.from_numpy(np.stack(c2ws_render)).float().to(device)

            
        imgs_source = unpreprocess(imgs_source)
        

        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        frames = []
        img_directions = dataset.directions.to(device)
        for i, c2w in enumerate(tqdm(c2ws_render)):
            torch.cuda.empty_cache()
            
            rays_o, rays_d = get_rays(img_directions, c2w)  # both (h*w, 3)
            rays = torch.cat([rays_o, rays_d,
                     near_far_source[0] * torch.ones_like(rays_o[:, :1]),
                     near_far_source[1] * torch.ones_like(rays_o[:, :1])],
                    1).to(device)  # (H*W, 3)
            
        
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples)

                # Converting world coordinate to ndc coordinate
                H, W = imgs_source.shape[-2:]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad*args.imgScale_test)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)

            
            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
            H_crop, W_crop = np.array(rgb_rays.shape[:2])//20
#             rgb_rays = rgb_rays[H_crop:-H_crop,W_crop:-W_crop]
#             depth_rays_preds = depth_rays_preds[H_crop:-H_crop,W_crop:-W_crop]
            img_vis = np.concatenate((rgb_rays*255,depth_rays_preds),axis=1)

            frames.append(img_vis.astype('uint8'))
#             break
    imageio.mimwrite(f'{save_dir}/ft_{scene}_spiral.mp4', np.stack(frames), fps=10, quality=10)
# plt.imshow(rgb_rays)

Found ckpts ['./runs_fine_tuning/hotdog-ft/ckpts//latest.tar']
Reloading from ./runs_fine_tuning/hotdog-ft/ckpts//latest.tar
===> valing index: [26 60 13 47]
====> ref idx: [48 61  0]


100%|██████████| 60/60 [1:02:32<00:00, 62.54s/it]


Found ckpts ['./runs_fine_tuning/lego-ft/ckpts//latest.tar']
Reloading from ./runs_fine_tuning/lego-ft/ckpts//latest.tar
===> valing index: [63, 70, 18, 28]
====> ref idx: [6, 43, 33]


100%|██████████| 60/60 [1:02:27<00:00, 62.45s/it]


Found ckpts ['./runs_fine_tuning/mic-ft/ckpts//latest.tar']
Reloading from ./runs_fine_tuning/mic-ft/ckpts//latest.tar
===> valing index: [20 49 55 72]
====> ref idx: [61 80 64]


 80%|████████  | 48/60 [49:43<11:13, 56.09s/it]

# DTU

In [None]:
for i_scene, scene in enumerate([1]):# any scene index, like 1,2,3...,,8,21,103,114

    cmd = f'--datadir /mnt/data/new_disk/sungx/data/mvs_dataset/DTU/mvs_training/dtu/scan{scene} \
     --dataset_name dtu_ft --imgScale_test {1.0} ' #--use_color_volume
    
    is_finetuned = True # set False if rendering without finetuning
    if is_finetuned:
        cmd += f'--ckpt ./runs_fine_tuning/dtu_scan{scene}_2h/ckpts//latest.tar'
    else:
        cmd += '--ckpt ./ckpts/mvsnerf-v0.tar'

    args = config_parser(cmd.split())
    args.use_viewdirs = True

    args.N_samples = 128
    args.feat_dim =  8+3*4
    args.use_color_volume = False if not is_finetuned else args.use_color_volume


    # create models
    if i_scene==0 or is_finetuned:
        render_kwargs_train, render_kwargs_test, start, grad_vars = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
        filter_keys(render_kwargs_train)

        MVSNet = render_kwargs_train['network_mvs']
        render_kwargs_train.pop('network_mvs')


    datadir = args.datadir
    datatype = 'val'
    pad = 24 #the padding value should be same as your finetuning ckpt
    args.chunk = 5120
    frames = 60


    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx
    
    save_as_image = False
    save_dir = f'results/video2'
    os.makedirs(save_dir, exist_ok=True)
    MVSNet.train()
    MVSNet = MVSNet.cuda()
    
    with torch.no_grad():
        
        if is_finetuned:   
            # large baselien
            imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
            volume_feature = torch.load(args.ckpt)['volume']['feat_volume']
            volume_feature = RefVolume(volume_feature.detach()).cuda()
        else:            
            # neighboring views with angle distance
            c2ws_all = dataset.load_poses_all()
            random_selete = torch.randint(0,len(c2ws_all),(1,)) #!!!!!!!!!! you may change this line if rendering a specify view 
            dis = np.sum(c2ws_all[:,:3,2] * c2ws_all[[random_selete],:3,2], axis=-1)
            pair_idx = np.argsort(dis)[::-1][:3]#[25, 21, 33]#[14,15,24]#
            imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(pair_idx=pair_idx, device=device)
            volume_feature, _, _ = MVSNet(imgs_source, proj_mats, near_far_source, pad=pad)
            
        imgs_source = unpreprocess(imgs_source)

        c2ws_render = gen_render_path(pose_source['c2ws'].cpu().numpy(), N_views=frames)
        c2ws_render = torch.from_numpy(np.stack(c2ws_render)).float().to(device)
        
        
        try:
            tqdm._instances.clear() 
        except Exception:     
            pass
        
        frames = []
        img_directions = dataset.directions.to(device)
        for i, c2w in enumerate(tqdm(c2ws_render)):
            torch.cuda.empty_cache()
            
            rays_o, rays_d = get_rays(img_directions, c2w)  # both (h*w, 3)
            rays = torch.cat([rays_o, rays_d,
                     near_far_source[0] * torch.ones_like(rays_o[:, :1]),
                     near_far_source[1] * torch.ones_like(rays_o[:, :1])],
                    1).to(device)  # (H*W, 3)
            
        
            N_rays_all = rays.shape[0]
            rgb_rays, depth_rays_preds = [],[]
            for chunk_idx in range(N_rays_all//args.chunk + int(N_rays_all%args.chunk>0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(rays[chunk_idx*args.chunk:(chunk_idx+1)*args.chunk],
                                                    N_samples=args.N_samples)

                # Converting world coordinate to ndc coordinate
                H, W = imgs_source.shape[-2:]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = pose_source['w2cs'][0], pose_source['intrinsics'][0].clone()
                xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale,
                                             near=near_far_source[0], far=near_far_source[1], pad=pad*args.imgScale_test)


                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(args, pose_source, xyz_coarse_sampled,
                                                                       xyz_NDC, z_vals, rays_o, rays_d,
                                                                       volume_feature,imgs_source, **render_kwargs_train)
    
                
                rgb, depth_pred = torch.clamp(rgb.cpu(),0,1.0).numpy(), depth_pred.cpu().numpy()
                rgb_rays.append(rgb)
                depth_rays_preds.append(depth_pred)

            
            depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
            depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
            
            rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
#             H_crop, W_crop = np.array(rgb_rays.shape[:2])//20
#             rgb_rays = rgb_rays[H_crop:-H_crop,W_crop:-W_crop]
#             depth_rays_preds = depth_rays_preds[H_crop:-H_crop,W_crop:-W_crop]
            img_vis = np.concatenate((rgb_rays*255,depth_rays_preds),axis=1)
            frames.append(img_vis.astype('uint8'))
                
    imageio.mimwrite(f'{save_dir}/ft_scan{scene}_spiral2.mp4', np.stack(frames), fps=20, quality=10)
# plt.imshow(rgb_rays)

# render path generation

In [None]:
render_poses = {}
datatype = 'val'
for i_scene, scene in enumerate(['flower','orchids', 'room','leaves','fern','horns','trex','fortress']):
    # add --use_color_volume if the ckpts are fintuned with this flag
    cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/{scene}  \
     --dataset_name llff --imgScale_test {1.0} \
    --ckpt ./runs_new/runs_fine_tuning/{scene}/ckpts//latest.tar'

    args = config_parser(cmd.split())
    args.use_viewdirs = True


    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx
    

    imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)

    c2ws_all = dataset.poses
    w2cs, c2ws = pose_source['w2cs'], pose_source['c2ws']
    pair_idx = torch.load('configs/pairs.th')[f'{scene}_train']
    c2ws_render = get_spiral(c2ws_all[pair_idx], near_far_source, rads_scale = 0.5, N_views=60) 
    
    render_poses[f'{scene}_near_far_source'] = near_far_source
    render_poses[f'{scene}_c2ws_no_ft'] = c2ws_render
    render_poses[f'{scene}_intrinsic_no_ft'] = pose_source['intrinsics'][0].cpu().numpy()

for i_scene, scene in enumerate(['flower','orchids', 'room','leaves','fern','horns','trex','fortress']):#'flower','orchids', 'room','leaves','fern','horns','trex','fortress'
    # add --use_color_volume if the ckpts are fintuned with this flag
    cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/{scene}  \
     --dataset_name llff --imgScale_test {1.0} \
    --ckpt ./runs_new/runs_fine_tuning/{scene}/ckpts//latest.tar'

    args = config_parser(cmd.split())
    args.use_viewdirs = True


    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx
    

    imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)

    c2ws_all = dataset.poses
    w2cs, c2ws = pose_source['w2cs'], pose_source['c2ws']
    pair_idx = torch.load('configs/pairs.th')[f'{scene}_train']
    c2ws_render = get_spiral(c2ws_all[pair_idx], near_far_source, rads_scale = 0.6, N_views=60) 
    
    render_poses[f'{scene}_c2ws'] = c2ws_render
    render_poses[f'{scene}_intrinsic'] = pose_source['intrinsics'][0].cpu().numpy()
#######################################
for i_scene, scene in enumerate(['ship','mic','chair','lego','drums','ficus','materials','hotdog']):#

    cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/nerf_synthetic/{scene}  \
     --dataset_name blender --white_bkgd --imgScale_test {1.0}  \
    --ckpt /mnt/new_disk_2/anpei/code/MVS-NeRF/runs_fine_tuning/{scene}/ckpts//latest.tar '

    args = config_parser(cmd.split())


    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)


    imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
    c2ws_render = nerf_video_path(pose_source['c2ws'].cpu(), N_views=60)
    
    render_poses[f'{scene}_c2ws'] = c2ws_render.cpu().numpy()
    render_poses[f'{scene}_intrinsic'] = pose_source['intrinsics'][0].cpu().numpy()
    
##################################################
for i_scene, scene in enumerate([1]):

    cmd = f'--datadir /mnt/data/new_disk/sungx/data/mvs_dataset/DTU/mvs_training/dtu/scan{scene}  \
     --dataset_name dtu_ft --imgScale_test {1.0}   \
    --ckpt /mnt/new_disk_2/anpei/code/MVS-NeRF/runs_fine_tuning/dtu_scan{scene}/ckpts//latest.tar --netwidth 256 --net_type v0 --use_color_volume'

    args = config_parser(cmd.split())
    args.use_viewdirs = True


    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)
    val_idx = dataset.img_idx

    imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
    imgs_source = unpreprocess(imgs_source)

    c2ws_render = gen_render_path(pose_source['c2ws'].cpu().numpy(), N_views=60)
    render_poses[f'dtu_c2ws'] = c2ws_render
    render_poses[f'dtu_intrinsic'] = pose_source['intrinsics'][0].cpu().numpy()

    
torch.save(render_poses, './configs/video_path.th')
np.save('./configs/video_path.npy',render_poses)

In [None]:
rng = np.random.RandomState(234)
_EPS = np.finfo(float).eps * 4.0
TINY_NUMBER = 1e-6      # float32 only has 7 decimal digits precision

def angular_dist_between_2_vectors(vec1, vec2):
    vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER)
    vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER)
    angular_dists = np.arccos(np.clip(np.sum(vec1_unit*vec2_unit, axis=-1), -1.0, 1.0))
    return angular_dists


def batched_angular_dist_rot_matrix(R1, R2):
    '''
    calculate the angular distance between two rotation matrices (batched)
    :param R1: the first rotation matrix [N, 3, 3]
    :param R2: the second rotation matrix [N, 3, 3]
    :return: angular distance in radiance [N, ]
    '''
    assert R1.shape[-1] == 3 and R2.shape[-1] == 3 and R1.shape[-2] == 3 and R2.shape[-2] == 3
    return np.arccos(np.clip((np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) / 2.,
                             a_min=-1 + TINY_NUMBER, a_max=1 - TINY_NUMBER))


def get_nearest_pose_ids(tar_pose, ref_poses, num_select, tar_id=-1, angular_dist_method='vector',
                         scene_center=(0, 0, 0)):
    '''
    Args:
        tar_pose: target pose [3, 3]
        ref_poses: reference poses [N, 3, 3]
        num_select: the number of nearest views to select
    Returns: the selected indices
    '''
    num_cams = len(ref_poses)
    # num_select = min(num_select, num_cams-1)
    batched_tar_pose = tar_pose[None].repeat(num_cams,axis=0)

    if angular_dist_method == 'matrix':
        dists = batched_angular_dist_rot_matrix(batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3])
    elif angular_dist_method == 'vector':
        tar_cam_locs = batched_tar_pose[:, :3, 3]
        ref_cam_locs = ref_poses[:, :3, 3]
        scene_center = np.array(scene_center)[None, ...]
        tar_vectors = tar_cam_locs - scene_center
        ref_vectors = ref_cam_locs - scene_center
        dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors)
    elif angular_dist_method == 'dist':
        tar_cam_locs = batched_tar_pose[:, :3, 3]
        ref_cam_locs = ref_poses[:, :3, 3]
        dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1)
    else:
        raise Exception('unknown angular distance calculation method!')

    if tar_id >= 0:
        assert tar_id < num_cams
        dists[tar_id] = 1e3  # make sure not to select the target id itself

    sorted_ids = np.argsort(dists)
    selected_ids = sorted_ids[:num_select]
    # print(angular_dists[selected_ids] * 180 / np.pi)
    return selected_ids


In [None]:
import glob
render_poses = {}
datatype = 'val'
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
# for i_scene, scene in enumerate(['flower']):#
#     # add --use_color_volume if the ckpts are fintuned with this flag
#     cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/{scene}  \
#      --dataset_name llff --imgScale_test {1.0} \
#     --ckpt ./runs_new/runs_fine_tuning/{scene}/ckpts//latest.tar'

#     args = config_parser(cmd.split())
#     args.use_viewdirs = True


#     print('============> rendering dataset <===================')
#     dataset = dataset_dict[args.dataset_name](args, split=datatype)
#     val_idx = dataset.img_idx
    

#     imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)

#     c2ws_all = dataset.poses
#     w2cs, c2ws = pose_source['w2cs'], pose_source['c2ws']
#     pair_idx = torch.load('configs/pairs.th')[f'{scene}_train']
#     c2ws_render = get_spiral(c2ws_all[pair_idx], near_far_source, rads_scale = 0.6, N_views=60) 
    
#     images = []
#     for i, c2w in enumerate(c2ws_render):
#         nearest_pose_ids = get_nearest_pose_ids(c2w,
#                                                 c2ws_all[pair_idx],
#                                                 3,
#                                                 angular_dist_method='vector')  
#         idxs = pair_idx[nearest_pose_ids]
        
#         im=[]
#         List = sorted(glob.glob(f'/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/{scene}/images_4/*'))
#         for idx in idxs:
#             im.append(cv2.resize(cv2.imread(List[idx]),None,fx=0.25,fy=0.25))
#         im = np.concatenate(im,axis=1)
#         images.append(im[...,::-1])
    
#     imageio.mimwrite(f'./results/test4/{scene}.mp4', np.stack(images), fps=20, quality=10)

# for i_scene, scene in enumerate(['flower','orchids', 'room','leaves','fern','horns','trex','fortress']):#'flower','orchids', 'room','leaves','fern','horns','trex','fortress'
#     # add --use_color_volume if the ckpts are fintuned with this flag
#     cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/{scene}  \
#      --dataset_name llff --imgScale_test {1.0} \
#     --ckpt ./runs_new/runs_fine_tuning/{scene}/ckpts//latest.tar'

#     args = config_parser(cmd.split())
#     args.use_viewdirs = True


#     print('============> rendering dataset <===================')
#     dataset = dataset_dict[args.dataset_name](args, split=datatype)
#     val_idx = dataset.img_idx
    

#     imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)

#     c2ws_all = dataset.poses
#     w2cs, c2ws = pose_source['w2cs'], pose_source['c2ws']
#     pair_idx = torch.load('configs/pairs.th')[f'{scene}_train']
#     c2ws_render = get_spiral(c2ws_all[pair_idx], near_far_source, rads_scale = 0.6, N_views=60) 
    
#     render_poses[f'{scene}_c2ws'] = c2ws_render
#     render_poses[f'{scene}_intrinsic'] = pose_source['intrinsics'][0].cpu().numpy()
# #######################################
for i_scene, scene in enumerate(['mic']):#

    cmd = f'--datadir /mnt/new_disk_2/anpei/Dataset/nerf_synthetic/{scene}  \
     --dataset_name blender --white_bkgd --imgScale_test {1.0}  \
    --ckpt /mnt/new_disk_2/anpei/code/MVS-NeRF/runs_fine_tuning/{scene}/ckpts//latest.tar '

    args = config_parser(cmd.split())

    print('============> rendering dataset <===================')
    dataset = dataset_dict[args.dataset_name](args, split=datatype)

    c2ws_all = dataset.load_poses_all()
    pair_idx = torch.load('configs/pairs.th')[f'{scene}_train']
    imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
    c2ws_render = nerf_video_path(pose_source['c2ws'].cpu(), N_views=60).cpu().numpy()
    
    
    images = []
    for i, c2w in enumerate(c2ws_render):
        nearest_pose_ids = get_nearest_pose_ids(c2w,
                                                c2ws_all[pair_idx],
                                                3,
                                                angular_dist_method='vector')  
        idxs = pair_idx[nearest_pose_ids]
        im=[]
        List = sorted(glob.glob(f'/mnt/new_disk2/anpei/Dataset/nerf_synthetic/{scene}/train/*'))
        for idx in idxs:
            temp = cv2.imread(f'/mnt/new_disk2/anpei/Dataset/nerf_synthetic/mic/train/r_{idx}.png',-1)
            im.append(cv2.resize(temp,None,fx=0.25,fy=0.25))
        im = np.concatenate(im,axis=1)
        images.append(im[...,[2,1,0,3]])
    
    imageio.mimwrite(f'./results/test4/{scene}.mp4', np.stack(images), fps=20, quality=10)
    
# ##################################################
# for i_scene, scene in enumerate([1]):

#     cmd = f'--datadir /mnt/data/new_disk/sungx/data/mvs_dataset/DTU/mvs_training/dtu/scan{scene}  \
#      --dataset_name dtu_ft --imgScale_test {1.0}   \
#     --ckpt /mnt/new_disk_2/anpei/code/MVS-NeRF/runs_fine_tuning/dtu_scan{scene}/ckpts//latest.tar --netwidth 256 --net_type v0 --use_color_volume'

#     args = config_parser(cmd.split())
#     args.use_viewdirs = True


#     print('============> rendering dataset <===================')
#     dataset = dataset_dict[args.dataset_name](args, split=datatype)
#     val_idx = dataset.img_idx

#     imgs_source, proj_mats, near_far_source, pose_source = dataset.read_source_views(device=device)
#     imgs_source = unpreprocess(imgs_source)

#     c2ws_render = gen_render_path(pose_source['c2ws'].cpu().numpy(), N_views=60)
#     render_poses[f'dtu_c2ws'] = c2ws_render
#     render_poses[f'dtu_intrinsic'] = pose_source['intrinsics'][0].cpu().numpy()

