-
Notifications
You must be signed in to change notification settings - Fork 0
/
ray_utils.py
151 lines (116 loc) · 4.16 KB
/
ray_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import math
from typing import List, NamedTuple
import torch
import torch.nn.functional as F
from pytorch3d.renderer.cameras import CamerasBase
import numpy as np
# Convenience class wrapping several ray inputs:
# 1) Origins -- ray origins
# 2) Directions -- ray directions
# 3) Sample points -- sample points along ray direction from ray origin
# 4) Sample lengths -- distance of sample points from ray origin
class RayBundle(object):
def __init__(
self,
origins,
directions,
sample_points,
sample_lengths,
):
self.origins = origins
self.directions = directions
self.sample_points = sample_points
self.sample_lengths = sample_lengths
def __getitem__(self, idx):
return RayBundle(
self.origins[idx],
self.directions[idx],
self.sample_points[idx],
self.sample_lengths[idx],
)
@property
def shape(self):
return self.origins.shape[:-1]
@property
def sample_shape(self):
return self.sample_points.shape[:-1]
def reshape(self, *args):
return RayBundle(
self.origins.reshape(*args, 3),
self.directions.reshape(*args, 3),
self.sample_points.reshape(*args, self.sample_points.shape[-2], 3),
self.sample_lengths.reshape(*args, self.sample_lengths.shape[-2], 1),
)
def view(self, *args):
return RayBundle(
self.origins.view(*args, 3),
self.directions.view(*args, 3),
self.sample_points.view(*args, self.sample_points.shape[-2], 3),
self.sample_lengths.view(*args, self.sample_lengths.shape[-2], 1),
)
def _replace(self, **kwargs):
for key in kwargs.keys():
setattr(self, key, kwargs[key])
return self
# Sample image colors from pixel values
def sample_images_at_xy(
images: torch.Tensor,
xy_grid: torch.Tensor,
):
batch_size = images.shape[0]
spatial_size = images.shape[1:-1]
xy_grid = -xy_grid.view(batch_size, -1, 1, 2)
images_sampled = torch.nn.functional.grid_sample(
images.permute(0, 3, 1, 2),
xy_grid,
align_corners=True,
mode="bilinear",
)
return images_sampled.permute(0, 2, 3, 1).view(-1, images.shape[-1])
# Generate pixel coordinates from in NDC space (from [-1, 1])
def get_pixels_from_image(image_size, camera):
W, H = image_size[0], image_size[1]
# TODO (1.3): Generate pixel coordinates from [0, W] in x and [0, H] in y
x, y = torch.arange(0,W), torch.arange(0,H)
# TODO (1.3): Convert to the range [-1, 1] in both x and y
x, y = x / W * 2 - 1, y / H * 2 - 1
# Create grid of coordinates
xy_grid = torch.stack(
tuple( reversed( torch.meshgrid(y, x) ) ),
dim=-1,
).view(W * H, 2)
return -xy_grid
# Random subsampling of pixels from an image
def get_random_pixels_from_image(n_pixels, image_size, camera):
xy_grid = get_pixels_from_image(image_size, camera)
# TODO (2.1): Random subsampling of pixel coordinates
N = xy_grid.shape[0]
xy_grid_sub = xy_grid[np.random.choice(N, n_pixels)].to("cuda")
# Return
return xy_grid_sub.reshape(-1, 2)[:n_pixels]
# Get rays from pixel values
def get_rays_from_pixels(xy_grid, image_size, camera):
W, H = image_size[0], image_size[1]
# TODO (1.3): Map pixels to points on the image plane at Z=1
ndc_points = xy_grid.cuda()
ndc_points = torch.cat(
[
ndc_points,
torch.ones_like(ndc_points[..., -1:])
],
dim=-1
)
# TODO (1.3): Use camera.unproject to get world space points on the image plane from NDC space points
world_points = camera.unproject_points(ndc_points, from_ndc=True)
# TODO (1.3): Get ray origins from camera center
N = world_points.shape[0]
rays_o = camera.get_camera_center().expand(N, -1)
# TODO (1.3): Get normalized ray directions
rays_d = F.normalize(world_points - rays_o)
# Create and return RayBundle
return RayBundle(
rays_o,
rays_d,
torch.zeros_like(rays_o).unsqueeze(1),
torch.zeros_like(rays_o).unsqueeze(1),
)