In [1]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

import numpy as np
import os

import json
import gzip

from PIL import Image
import matplotlib.pyplot as plt

In [2]:
from network.model import NerFormer
from network.feature_network import FeatureNet
from positional_embedding import HarmonicEmbedding
from ray_sampling import RaySampler
from co3d_utils import *

In [3]:
from typing import Union, IO, Any, List, Tuple, Optional
import typing

In [6]:
u1 = torch.tensor([[-0.3, 1.0, 1.2], [1.6, -1.0, -1.1]])
u2 = torch.tensor([[1.0, 1.0, -1.0], [0.7, 0.5, 1.0]])
u3 = torch.tensor([[0.5, -0.8, 1.0], [-0.1, 0.3, 0.4]])
u4 = torch.tensor([[1.0, 0.1, -0.2], [-0.2, 1.3, -0.4]])

In [7]:
import torch.nn as nn

In [8]:
def step(x):
    x[x <= 0] = 0
    x[x > 0] = 1

    return x

In [19]:
x = torch.tensor([[1], [0]], dtype=torch.float)
x = torch.cat((torch.tensor([[1]]), x), dim=0)
print(x)

y1 = u1.matmul(x)
print(y1)
y1 = step(y1)
print(y1)

y1 = torch.cat((torch.tensor([[1]]), y1), dim=0)
print(y1)

y2 = u2.matmul(y1)
print(y2)
y2 = step(y2)
print(y2)

y2 = torch.cat((torch.tensor([[1]]), y2), dim=0)
print(y2)

y3 = u3.matmul(y2)
print(y3)
y3 = step(y3)
print(y3)

y3 = torch.cat((torch.tensor([[1]]), y3), dim=0)
print(y3)

y4 = u4.matmul(y3)
print(y4)
y4 = step(y4)
print(y4)

tensor([[1.],
        [1.],
        [0.]])
tensor([[0.7000],
        [0.6000]])
tensor([[1.],
        [1.]])
tensor([[1.],
        [1.],
        [1.]])
tensor([[1.0000],
        [2.2000]])
tensor([[1.],
        [1.]])
tensor([[1.],
        [1.],
        [1.]])
tensor([[0.7000],
        [0.6000]])
tensor([[1.],
        [1.]])
tensor([[1.],
        [1.],
        [1.]])
tensor([[0.9000],
        [0.7000]])
tensor([[1.],
        [1.]])


In [22]:
relu = nn.ReLU()

x = torch.tensor([[1], [0]], dtype=torch.float)
x = torch.cat((torch.tensor([[1]]), x), dim=0)
print(x)

y1 = relu(u1.matmul(x))
y1 = torch.cat((torch.tensor([[1]]), y1), dim=0)
print(y1)

y2 = relu(u1.matmul(y1))
y2 = torch.cat((torch.tensor([[1]]), y2), dim=0)
print(y2)

y3 = relu(u1.matmul(y2))
y3 = torch.cat((torch.tensor([[1]]), y3), dim=0)
print(y3)

y4 = relu(u1.matmul(y3))
print(y4)



tensor([[1.],
        [1.],
        [0.]])
tensor([[1.0000],
        [0.7000],
        [0.6000]])
tensor([[1.0000],
        [1.1200],
        [0.2400]])
tensor([[1.0000],
        [1.1080],
        [0.2160]])
tensor([[1.0672],
        [0.2544]])


In [None]:
u3 = torch.tensor([[0.5, -0.8, 1.0], [-0.1, 0.3, 0.4]])

In [23]:
seq_imgs, seq_masks, seq_c2ws, seq_intrinsics = read_seq_data("./test_dataset/38_1655_5016")

print(len(seq_imgs), len(seq_masks), len(seq_c2ws), len(seq_intrinsics))

print(seq_imgs[0])
print(seq_masks[0])
print(seq_c2ws[0])
print(seq_intrinsics[0])

102 102 102 102
teddybear/38_1655_5016/images/frame000001.jpg
teddybear/38_1655_5016/masks/frame000001.png
tensor([[-0.9966, -0.0065,  0.0825, -0.0459],
        [ 0.0041, -0.9996, -0.0289,  0.2112],
        [ 0.0826, -0.0285,  0.9962, -0.9091],
        [ 0.0000,  0.0000,  0.0000,  1.0000]])
tensor([[523.4543,   0.0000, 239.5000,   0.0000],
        [  0.0000, 294.0334, 179.5000,   0.0000],
        [  0.0000,   0.0000,   1.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,   1.0000]])


In [24]:
print(seq_imgs[1])
print(seq_c2ws[1])
print(seq_intrinsics[1])

teddybear/38_1655_5016/images/frame000002.jpg
tensor([[-0.9966, -0.0056,  0.0828, -0.0463],
        [ 0.0034, -0.9996, -0.0270,  0.2104],
        [ 0.0829, -0.0266,  0.9962, -0.9069],
        [ 0.0000,  0.0000,  0.0000,  1.0000]])
tensor([[522.2864,   0.0000, 239.5000,   0.0000],
        [  0.0000, 293.3774, 179.5000,   0.0000],
        [  0.0000,   0.0000,   1.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,   1.0000]])


***

# c2w 계산 과정 확인
Rotation과 Translation 값을 가지고 [ R | t ] 형태로 만든 것을 w2c로 정의.

c2w를 계산하기 위해

- w2c를 바로 inverse 시킨 값
- R^-1 = R^T, t^-1 = R^T x (-t)  ------> [ R^(-1) | t^(-1) ]

두 값이 동일한지 확인

In [29]:
with open("./test_dataset/38_1655_5016/frame_annotations_file.json", 'r') as f:
    j = json.load(f)

frame = j[0]

for frame_key in frame.keys():
    print(frame_key, frame[frame_key])

frame_number 0
frame_timestamp -1.0
image {'path': 'teddybear/38_1655_5016/images/frame000001.jpg', 'size': [479, 359]}
depth {'path': 'teddybear/38_1655_5016/depths/frame000001.jpg.geometric.png', 'scale_adjustment': 1.262808918952942, 'mask_path': 'teddybear/38_1655_5016/depth_masks/frame000001.png'}
mask {'path': 'teddybear/38_1655_5016/masks/frame000001.png', 'mass': 35984.0}
viewpoint {'R': [[-0.9965706467628479, 0.004121924750506878, 0.08264364302158356], [-0.006493818014860153, -0.9995740652084351, -0.028452031314373016], [0.08249115943908691, -0.028891131281852722, 0.9961729049682617]], 'T': [0.20043912529945374, 1.2990046739578247, 6.429634094238281], 'focal_length': [2.1856130105871343, 1.6380690413377479], 'principal_point': [0.0, 0.0]}


In [44]:
c2w, k = get_c2w_intrinsic(frame["image"]["size"], frame["viewpoint"])

r = torch.eye(3)
r[:3, :3] = torch.tensor(frame["viewpoint"]["R"], dtype=torch.float)

t = torch.tensor(frame["viewpoint"]["T"], dtype=torch.float).reshape((3, 1))

In [45]:
print(c2w)
print(r)
print(t)

print(k)

c2w : 
 tensor([[-9.9657e-01, -6.4938e-03,  8.2491e-02, -3.2220e-01],
        [ 4.1219e-03, -9.9957e-01, -2.8891e-02,  1.4834e+00],
        [ 8.2644e-02, -2.8452e-02,  9.9617e-01, -6.3846e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]])

rotation : 
 tensor([[-0.9966,  0.0041,  0.0826],
        [-0.0065, -0.9996, -0.0285],
        [ 0.0825, -0.0289,  0.9962]])

translation : 
 tensor([[0.2004],
        [1.2990],
        [6.4296]])

intrinsic : 
 tensor([[523.4543,   0.0000, 239.5000,   0.0000],
        [  0.0000, 294.0334, 179.5000,   0.0000],
        [  0.0000,   0.0000,   1.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,   1.0000]])


In [35]:
r_T = r.T
print("inverse rotation (R^T) : \n", r_T)

t_T = r_T.matmul(-t)
print("\ninverse translation (R^T x -t) : \n", t_T, t_T.shape)

c2w_2 = torch.eye(4, dtype=torch.float)
c2w_2[:3, :3] = r_T
c2w_2[:3, 3] = t_T.T
print(c2w_2)

print("c2w ([R^T | t]): \n", c2w)

inverse rotation (R^T) : 
 tensor([[-0.9966, -0.0065,  0.0825],
        [ 0.0041, -0.9996, -0.0289],
        [ 0.0826, -0.0285,  0.9962]])

inverse translation (R^T x -t) : 
 tensor([[-0.3222],
        [ 1.4834],
        [-6.3846]]) torch.Size([3, 1])
tensor([[-9.9657e-01, -6.4938e-03,  8.2491e-02, -3.2220e-01],
        [ 4.1219e-03, -9.9957e-01, -2.8891e-02,  1.4834e+00],
        [ 8.2644e-02, -2.8452e-02,  9.9617e-01, -6.3846e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]])
c2w ([R^T | t]): 
 tensor([[-9.9657e-01, -6.4938e-03,  8.2491e-02, -3.2220e-01],
        [ 4.1219e-03, -9.9957e-01, -2.8891e-02,  1.4834e+00],
        [ 8.2644e-02, -2.8452e-02,  9.9617e-01, -6.3846e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]])


In [36]:
w2c = torch.eye(4, dtype=torch.float)

w2c[:3, :3] = r
w2c[:3, 3] = t.T

print("w2c (= [R | t]) : \n", w2c)
print("\nc2w (= [R | t]^-1) : \n", torch.inverse(w2c))

w2c (= [R | t]) : 
 tensor([[-9.9657e-01,  4.1219e-03,  8.2644e-02,  2.0044e-01],
        [-6.4938e-03, -9.9957e-01, -2.8452e-02,  1.2990e+00],
        [ 8.2491e-02, -2.8891e-02,  9.9617e-01,  6.4296e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]])

c2w (= [R | t]^-1) : 
 tensor([[-9.9657e-01, -6.4938e-03,  8.2491e-02, -3.2220e-01],
        [ 4.1219e-03, -9.9957e-01, -2.8891e-02,  1.4834e+00],
        [ 8.2644e-02, -2.8452e-02,  9.9617e-01, -6.3846e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]])


In [40]:
P_w = torch.tensor([10, 20, 30, 1], dtype=torch.float)
P_c = w2c.matmul(P_w)
P_c_2_w = c2w.matmul(P_c)

print(f"P_world : \n{P_w} \n\nP_camera = w2c(P_world) : \n{P_c} \n\nP_camera convert back to wrold = c2w(P_camera) : \n{P_w}")

P_world : 
tensor([10., 20., 30.,  1.]) 

P_camera = w2c(P_world) : 
tensor([ -7.2035, -19.6110,  36.5619,   1.0000]) 

P_camera convert back to wrold = c2w(P_camera) : 
tensor([10., 20., 30.,  1.])


In [41]:
nerf_llff_bds = np.load("./poses_bounds.npy")

***

# NeRF llff 데이터셋에서의 bounds 값

In [42]:
len(nerf_llff_bds)
print(nerf_llff_bds[0])

llff_poses = nerf_llff_bds[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
print(llff_poses.shape)
llff_bds = nerf_llff_bds[:, -2:].transpose([1, 0])
print(llff_bds.shape)
print(llff_poses[..., 0], '\n')
print(llff_bds[..., 0], '\n')

[ 1.04872613e-02  9.98137190e-01 -6.01013048e-02 -3.37871546e+00
  3.02400000e+03  9.99658714e-01 -1.19034697e-02 -2.32542981e-02
 -3.09885180e+00  4.03200000e+03 -2.39263938e-02 -5.98369192e-02
 -9.97921375e-01  4.17063527e-02  3.32986996e+03  2.94571964e+01
  1.14827880e+02]
(3, 5, 55)
(2, 55)
[[ 1.04872613e-02  9.98137190e-01 -6.01013048e-02 -3.37871546e+00
   3.02400000e+03]
 [ 9.99658714e-01 -1.19034697e-02 -2.32542981e-02 -3.09885180e+00
   4.03200000e+03]
 [-2.39263938e-02 -5.98369192e-02 -9.97921375e-01  4.17063527e-02
   3.32986996e+03]] 

[ 29.45719638 114.82787963] 



***

# Ray sampling 테스트

In [57]:
# 배치 사이즈를 1로 설정하여 테스트
c2w_ = seq_c2ws[0].unsqueeze(0)
intrinsic_ = seq_intrinsics[0].unsqueeze(0)

In [58]:
# 테스트 이미지의 크기 지정
W = 10
H = 20
u, v = np.meshgrid(np.arange(W), np.arange(H))

# 이미지의 각 row들이 한 줄로 이어붙은 형태로 변환
# (H, W) --> (H*W)
u = u.reshape(-1).astype(dtype=np.float32)  # + 0.5    # add half pixel
v = v.reshape(-1).astype(dtype=np.float32)  # + 0.5

pixels = np.stack((u, v, np.ones_like(u)), axis=0)  # [3(x+y+z), H*W]
pixels = torch.from_numpy(pixels)
batched_pixels = pixels.unsqueeze(0).repeat(1, 1, 1)

# 각 픽셀로 향하는 rays의 방향을 구함
rays_d = (c2w_[:, :3, :3].bmm(torch.inverse(intrinsic_[:, :3, :3])).bmm(batched_pixels)).transpose(1, 2)
rays_d = rays_d.reshape(-1, 3)

# 각 픽셀로 향하는 rays의 원점을 구함
rays_o = c2w_[:, :3, 3].unsqueeze(1).repeat(1, rays_d.shape[0], 1).reshape(-1, 3)  # B x HW x 3

In [53]:
pixels[:, :10]

tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [59]:
print(f"rays_d : \n{rays_d[:10]} \nrays_o : \n{rays_o[:10]}")

rays_d : 
tensor([[0.5424, 0.5794, 0.9757],
        [0.5405, 0.5794, 0.9759],
        [0.5386, 0.5795, 0.9760],
        [0.5367, 0.5795, 0.9762],
        [0.5348, 0.5795, 0.9764],
        [0.5329, 0.5795, 0.9765],
        [0.5310, 0.5795, 0.9767],
        [0.5291, 0.5795, 0.9768],
        [0.5272, 0.5795, 0.9770],
        [0.5253, 0.5795, 0.9772]]) 
rays_o : 
tensor([[-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091],
        [-0.0459,  0.2112, -0.9091]])


***

# Positional Embedding 결과 확인

In [60]:
# (x, y, z) ---> 3(x, y, z) * 2(sin, cos) * 10 = 60으로 변환
pe = HarmonicEmbedding(10)

pe_test = pe(torch.tensor([0, 0, 1]))
print(pe_test)
print(pe_test.shape)

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0998,  0.1987,  0.3894,  0.7174,
         0.9996, -0.0584,  0.1165,  0.2315,  0.4504,  0.8043,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  0.9950,  0.9801,  0.9211,  0.6967, -0.0292, -0.9983,
         0.9932,  0.9728,  0.8928,  0.5942])
torch.Size([60])


In [None]:
from torch.utils.data import Dataset
import imageio


class CO3Ddataset(Dataset):
    def __init__(self, args, mode, categories=[], **kwargs):
        self.folder_path = os.path.join(args.rootdir, 'CO3D/')
        self.rectify_inplane_rotation = args.rectify_inplane_rotation

        if mode == 'validation':
            mode = 'val'
        assert mode in ['train', 'val', 'test']
        self.mode = mode  # train / test / val w

        self.num_source_views = args.num_source_views

        total_category = os.listdir(self.folder_path)

        if len(categories) > 0:
            if isinstance(categories, str):
                categories = [categories]
        else:
            categories = total_category

        print("loading {} for {}".format(categories, mode))
        
        self.tgt_imgs = []
        self.tgt_poses = []
        self.tgt_intrinsics = []

        for category in categories:
            self.category_path = os.path.join(self.folder_path, category)     # ".../CO3D/teddybear"

            rgb_files, c2ws, intrinsics  = read_category_data(self.category_path)
            
            if self.mode != 'train':
                rgb_files = rgb_files[::self.testskip]
                intrinsics = intrinsics[::self.testskip]
                c2ws = c2ws[::self.testskip]
            self.tgt_imgs.extend(rgb_files)
            self.tgt_poses.extend(c2ws)
            self.tgt_intrinsics.extend(intrinsics)


    def __len__(self):
        return len(self.tgt_imgs)


    def __getitem__(self, idx):
        tgt_img = self.tgt_imgs[idx]
        tgt_pose = self.tgt_poses[idx]
        tgt_intrinsic = self.tgt_intrinsics[idx]
        
        # 선택한 데이터가 속한 category, sequence 이름을 가져온다.
        category, seq_name = tgt_img.split('/')[:-2]   # "teddybear", "38_1655_5016"
        seq_file_path = os.path.join(self.folder_path, category, seq_name, "frame_annotations_file.json")
        # 해당 sequence (=오브젝트)의 모든 img, c2w, intrinsic 정보를 읽어온다. --> 소스뷰로 이용
        src_imgs, src_c2ws, src_intrinsics = read_seq_data(seq_file_path)

        if self.mode == 'train':
            id_render = int(os.path.basename(tgt_img)[:-4].split('_')[1])
            subsample_factor = np.random.choice(np.arange(1, 4), p=[0.3, 0.5, 0.2])
        else:
            id_render = -1
            subsample_factor = 1

        rgb = imageio.imread(tgt_img).astype(np.float32) / 255.
        rgb = rgb[..., [-1]] * rgb[..., :3] + 1 - rgb[..., [-1]]
        img_size = rgb.shape[:2]
        camera = np.concatenate((list(img_size), tgt_intrinsic.flatten(),
                                 tgt_pose.flatten())).astype(np.float32)

        nearest_src_ids = get_nearest_src(tgt_pose,
                                                src_c2ws,
                                                int(self.num_source_views*subsample_factor),
                                                tar_id=id_render,
                                                angular_dist_method='vector')
        nearest_src_ids = np.random.choice(nearest_src_ids, self.num_source_views, replace=False)

        assert id_render not in nearest_src_ids
        # occasionally include input image
        if np.random.choice([0, 1], p=[0.995, 0.005]) and self.mode == 'train':
            nearest_src_ids[np.random.choice(len(nearest_src_ids))] = id_render

        src_rgbs = []
        src_cameras = []
        for id in nearest_src_ids:
            src_rgb = imageio.imread(src_imgs[id]).astype(np.float32) / 255.
            src_rgb = src_rgb[..., [-1]] * src_rgb[..., :3] + 1 - src_rgb[..., [-1]]
            train_pose = src_c2ws[id]
            src_intrinsics_ = src_intrinsics[id]
            if self.rectify_inplane_rotation:
                train_pose, src_rgb = rectify_inplane_rotation(train_pose, tgt_pose, src_rgb)

            src_rgbs.append(src_rgb)
            img_size = src_rgb.shape[:2]
            src_camera = np.concatenate((list(img_size), src_intrinsics_.flatten(),
                                              train_pose.flatten())).astype(np.float32)
            src_cameras.append(src_camera)

        src_rgbs = np.stack(src_rgbs, axis=0)
        src_cameras = np.stack(src_cameras, axis=0)

        near_depth = 2.
        far_depth = 6.

        depth_range = torch.tensor([near_depth, far_depth])

        return {'rgb': torch.from_numpy(rgb[..., :3]),
                'camera': torch.from_numpy(camera),
                'rgb_path': tgt_img,
                'src_rgbs': torch.from_numpy(src_rgbs[..., :3]),
                'src_cameras': torch.from_numpy(src_cameras),
                'depth_range': depth_range,
                }

In [None]:
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn.functional as F


class Projector():
    def __init__(self, device):
        self.device = device

    def inbound(self, pixel_locations, h, w):
        '''
        픽셀이 이미지 상의 유효한 위치에 놓여 있는지 확인하는 함수.
        param pixel_locations: [..., 2]
        param h: height
        param w: weight
        
        return: mask, bool, [...]
        '''
        return (pixel_locations[..., 0] <= w - 1.) & \
               (pixel_locations[..., 0] >= 0) & \
               (pixel_locations[..., 1] <= h - 1.) &\
               (pixel_locations[..., 1] >= 0)


    def normalize(self, pixel_locations, h, w):
        resize_factor = torch.tensor([w-1., h-1.]).to(pixel_locations.device)[None, None, :]
        normalized_pixel_locations = 2 * pixel_locations / resize_factor - 1.  # [n_views, n_points, 2]
        
        return normalized_pixel_locations


    def compute_projections(self, xyz, train_cameras):
        '''
        3D 포인트를 train_cameras(소스 카메라)의 이미지 스페이스로 프로젝션하는 함수
        param xyz: [..., 3]
        param train_cameras: [n_views, 34], 34 = img_size(2) + intrinsics(16) + extrinsics(16)
        
        return: pixel locations [..., 2], mask [...]
        '''
        original_shape = xyz.shape[:2]
        xyz = xyz.reshape(-1, 3)
        num_views = len(train_cameras)

        # 타겟 카메라의 intrinsics
        train_intrinsics = train_cameras[:, 2:18].reshape(-1, 4, 4)  # [n_views, 4, 4]
        # 타겟 카메라의 extrinsics
        train_poses = train_cameras[:, -16:].reshape(-1, 4, 4)  # [n_views, 4, 4]
        
        # (x, y, z) ---> (x, y, z, 1) : Homogeneous Coordinate로 변환.
        xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1)  # [n_points, 4]

        # P_proj = K x [R|t]_inv x P
        projections = train_intrinsics.bmm(torch.inverse(train_poses)) \
            .bmm(xyz_h.t()[None, ...].repeat(num_views, 1, 1))  # [n_views, 4, n_points]

        projections = projections.permute(0, 2, 1)  # [n_views, n_points, 4]

        # 픽셀 위치 (x, y)만 추출
        pixel_locations = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8)  # [n_views, n_points, 2]
        pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6)

        # 카메라의 뒤로 projection된 포인트는 invalid로 판단.
        mask = projections[..., 2] > 0   # a point is invalid if behind the camera
        
        return pixel_locations.reshape((num_views, ) + original_shape + (2, )), \
               mask.reshape((num_views, ) + original_shape)


    def compute_angle(self, xyz, query_camera, train_cameras):
        '''
        param xyz: [..., 3]
        param query_camera: 타겟 카메라 [34, ] : img_size(2) + intrinsics(16) + extrinsics(16)
        param train_cameras: 소스 카메라 [n_views, 34]  : img_size(2) + intrinsics(16) + extrinsics(16)
        
        return: [n_views, ..., 4] : 앞 3개 채널 = query ray와 target ray 사이의 direction 차이를 나타내는 단위 벡터 / 마지막 채널 = 두 방향 벡터의 내적값
        '''
        original_shape = xyz.shape[:2]
        xyz = xyz.reshape(-1, 3)
        
        train_poses = train_cameras[:, -16:].reshape(-1, 4, 4)  # [n_views, 4, 4]
        
        num_views = len(train_poses)

        # target 카메라의 pose를 소스 개수만큼 복사.
        query_pose = query_camera[-16:].reshape(-1, 4, 4).repeat(num_views, 1, 1)  # [n_views, 4, 4]

        ray2tar_pose = (query_pose[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
        ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6)     # 단위벡터화

        ray2train_pose = (train_poses[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
        ray2train_pose /= (torch.norm(ray2train_pose, dim=-1, keepdim=True) + 1e-6)     # 단위벡터화

        # 두 단위 벡터의 차이값을 계산
        ray_diff = ray2tar_pose - ray2train_pose
        ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)  # 단위벡터화
        ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)

        # 두 단위 벡터를 내적
        ray_diff_dot = torch.sum(ray2tar_pose * ray2train_pose, dim=-1, keepdim=True)

        # [두 ray의 차이를 나타내는 방향 벡터(3채널), 두 ray의 내적값(1채널)]
        ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
        ray_diff = ray_diff.reshape((num_views, ) + original_shape + (4, ))

        return ray_diff


    # 메인으로 실행되는 함수
    def compute(self,  xyz, query_camera, train_imgs, train_cameras, featmaps):
        '''
        :param xyz: [n_rays, n_samples, 3]
        :param query_camera: [1, 34], 34 = img_size(2) + intrinsics(16) + extrinsics(16)
        :param train_imgs: [1, n_views, h, w, 3] = 소스 이미지 n개
        :param train_cameras: [1, n_views, 34] = 소스 카메라 n개 / img_size(2) + intrinsics(16) + extrinsics(16)
        :param featmaps: [n_views, d, h, w] = 소스 이미지 n개의 resnet feature map

        :return: rgb_feat_sampled: [n_rays, n_samples, 3+n_feat],
                 ray_diff: [n_rays, n_samples, 4],
                 mask: [n_rays, n_samples, 1]
        '''
        assert (train_imgs.shape[0] == 1) \
               and (train_cameras.shape[0] == 1) \
               and (query_camera.shape[0] == 1), 'only support batch_size=1 for now'

        # 소스 이미지, 소스 카메라의 1이었던 batch 차원을 삭제
        train_imgs = train_imgs.squeeze(0)  # [n_views, h, w, 3]
        train_cameras = train_cameras.squeeze(0)  # [n_views, 34]
        # 타겟 이미지의 1이었던 batch 차원을 삭제
        query_camera = query_camera.squeeze(0)  # [34, ]

        train_imgs = train_imgs.permute(0, 3, 1, 2)  # 채널 순서를 변경 ---> [n_views, 3, h, w]

        h, w = train_cameras[0][:2]

        # 쿼리 포인트 (x, y, z)가 각 소스 뷰들로 projection 되는 위치를 계산
        pixel_locations, mask_in_front = self.compute_projections(xyz, train_cameras)
        normalized_pixel_locations = self.normalize(pixel_locations, h, w)   # [n_views, n_rays, n_samples, 2]

        # 각 소스 이미지들로부터 프로젝션 된 위치의 RGB 컬러값을 샘플링
        rgbs_sampled = F.grid_sample(train_imgs, normalized_pixel_locations, align_corners=True)
        rgb_sampled = rgbs_sampled.permute(2, 3, 0, 1)  # [n_rays, n_samples, n_views, 3]

        # 각 소스 이미지의 faeture map들로부터 프로젝션 된 위치의 resnet feature 값을 샘플링
        feat_sampled = F.grid_sample(featmaps, normalized_pixel_locations, align_corners=True)
        feat_sampled = feat_sampled.permute(2, 3, 0, 1)  # [n_rays, n_samples, n_views, d]
        
        # [RGB 샘플링 값, resnet feature 샘플링 값]
        rgb_feat_sampled = torch.cat([rgb_sampled, feat_sampled], dim=-1)   # [n_rays, n_samples, n_views, d+3]

        # mask
        inbound = self.inbound(pixel_locations, h, w)
        ray_diff = self.compute_angle(xyz, query_camera, train_cameras)
        ray_diff = ray_diff.permute(1, 2, 0, 3)
        mask = (inbound * mask_in_front).float().permute(1, 2, 0)[..., None]   # [n_rays, n_samples, n_views, 1]

        return rgb_feat_sampled, ray_diff, mask

In [None]:
ray_sampler = RaySampler(train_data, device)

In [None]:
random_pixels = sample_random_pixel(500, 500, 800, "center")
print(len(random_pixels), '\n', random_pixels[:10])

In [None]:
rays_o, rays_d = get_rays(img1.height, img1.width, intrinsic.unsqueeze(0), c2w.unsqueeze(0))

print(rays_o.shape, rays_d.shape)

In [145]:
ray_d = rays_d.clone()
ray_o = rays_o.clone()

print(ray_d.shape, ray_o.shape)

torch.Size([171961, 3]) torch.Size([171961, 3])


In [154]:
depth_range = torch.tensor([[0.1, 3.0]])
N_samples = 20

pts, z_vals = sample_along_camera_ray(ray_o, ray_d, depth_range, N_samples)

In [155]:
print(pts.shape)
print(z_vals.shape)

torch.Size([171961, 20, 3])
torch.Size([171961, 20])


In [159]:
# 400번째 ray에서의 samples
pts[400]

tensor([[ -1.4153,  -0.0613,   0.1650],
        [ -1.5882,  -0.0688,   0.1851],
        [ -3.4004,  -0.1472,   0.3963],
        [ -5.2240,  -0.2261,   0.6089],
        [ -6.6318,  -0.2871,   0.7730],
        [ -7.8613,  -0.3403,   0.9163],
        [ -8.7938,  -0.3807,   1.0250],
        [-10.1739,  -0.4404,   1.1858],
        [-10.7546,  -0.4656,   1.2535],
        [-12.8424,  -0.5560,   1.4969],
        [-13.9064,  -0.6020,   1.6209],
        [-15.5691,  -0.6740,   1.8147],
        [-16.1829,  -0.7006,   1.8862],
        [-17.7843,  -0.7699,   2.0729],
        [-18.9395,  -0.8199,   2.2075],
        [-20.2334,  -0.8759,   2.3583],
        [-21.3119,  -0.9226,   2.4840],
        [-22.5580,  -0.9765,   2.6293],
        [-24.7361,  -1.0708,   2.8831],
        [-25.7066,  -1.1128,   2.9963]])