-
-
Notifications
You must be signed in to change notification settings - Fork 152
/
raymarching.cu
454 lines (377 loc) · 17.5 KB
/
raymarching.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
#include "helper_math.h"
#include "utils.h"
#define SQRT3 1.73205080757f
inline __host__ __device__ float signf(const float x) { return copysignf(1.0f, x); }
// exponentially step t if exp_step_factor>0 (larger step size when sample moves away from the camera)
// default exp_step_factor is 0 for synthetic scene, 1/256 for real scene
inline __host__ __device__ float calc_dt(float t, float exp_step_factor, int max_samples, int grid_size, float scale){
return clamp(t*exp_step_factor, SQRT3/max_samples, SQRT3*2*scale/grid_size);
}
// Example input range of |xyz| and return value of this function
// [0, 0.5) -> 0
// [0.5, 1) -> 1
// [1, 2) -> 2
inline __device__ int mip_from_pos(const float x, const float y, const float z, const int cascades) {
const float mx = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
int exponent; frexpf(mx, &exponent);
return min(cascades-1, max(0, exponent+1));
}
// Example input range of dt and return value of this function
// [0, 1/grid_size) -> 0
// [1/grid_size, 2/grid_size) -> 1
// [2/grid_size, 4/grid_size) -> 2
inline __device__ int mip_from_dt(float dt, int grid_size, int cascades) {
int exponent; frexpf(dt*grid_size, &exponent);
return min(cascades-1, max(0, exponent));
}
// morton utils
inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
{
v = (v * 0x00010001u) & 0xFF0000FFu;
v = (v * 0x00000101u) & 0x0F00F00Fu;
v = (v * 0x00000011u) & 0xC30C30C3u;
v = (v * 0x00000005u) & 0x49249249u;
return v;
}
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
{
uint32_t xx = __expand_bits(x);
uint32_t yy = __expand_bits(y);
uint32_t zz = __expand_bits(z);
return xx | (yy << 1) | (zz << 2);
}
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
{
x = x & 0x49249249;
x = (x | (x >> 2)) & 0xc30c30c3;
x = (x | (x >> 4)) & 0x0f00f00f;
x = (x | (x >> 8)) & 0xff0000ff;
x = (x | (x >> 16)) & 0x0000ffff;
return x;
}
__global__ void morton3D_kernel(
const torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> coords,
torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits> indices
){
const int n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= coords.size(0)) return;
indices[n] = __morton3D(coords[n][0], coords[n][1], coords[n][2]);
}
torch::Tensor morton3D_cu(const torch::Tensor coords){
int N = coords.size(0);
auto indices = torch::zeros({N}, coords.options());
const int threads = 256, blocks = (N+threads-1)/threads;
AT_DISPATCH_INTEGRAL_TYPES(coords.type(), "morton3D_cu",
([&] {
morton3D_kernel<<<blocks, threads>>>(
coords.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
indices.packed_accessor32<int, 1, torch::RestrictPtrTraits>()
);
}));
return indices;
}
__global__ void morton3D_invert_kernel(
const torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits> indices,
torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> coords
){
const int n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= coords.size(0)) return;
const int ind = indices[n];
coords[n][0] = __morton3D_invert(ind >> 0);
coords[n][1] = __morton3D_invert(ind >> 1);
coords[n][2] = __morton3D_invert(ind >> 2);
}
torch::Tensor morton3D_invert_cu(const torch::Tensor indices){
int N = indices.size(0);
auto coords = torch::zeros({N, 3}, indices.options());
const int threads = 256, blocks = (N+threads-1)/threads;
AT_DISPATCH_INTEGRAL_TYPES(indices.type(), "morton3D_invert_cu",
([&] {
morton3D_invert_kernel<<<blocks, threads>>>(
indices.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
coords.packed_accessor32<int, 2, torch::RestrictPtrTraits>()
);
}));
return coords;
}
// packbits utils
template <typename scalar_t>
__global__ void packbits_kernel(
const scalar_t* __restrict__ density_grid,
const int N,
const float density_threshold,
uint8_t* __restrict__ density_bitfield
){
// parallel per byte
const int n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
uint8_t bits = 0;
#pragma unroll 8
for (uint8_t i = 0; i < 8; i++) {
bits |= (density_grid[8*n+i]>density_threshold) ? ((uint8_t)1<<i) : 0;
}
density_bitfield[n] = bits;
}
void packbits_cu(
const torch::Tensor density_grid,
const float density_threshold,
torch::Tensor density_bitfield
){
const int N = density_bitfield.size(0);
const int threads = 256, blocks = (N+threads-1)/threads;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(density_grid.type(), "packbits_cu",
([&] {
packbits_kernel<scalar_t><<<blocks, threads>>>(
density_grid.data_ptr<scalar_t>(),
N,
density_threshold,
density_bitfield.data_ptr<uint8_t>()
);
}));
}
// ray marching utils
// below code is based on https://github.com/ashawkey/torch-ngp/blob/main/raymarching/src/raymarching.cu
__global__ void raymarching_train_kernel(
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> rays_o,
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> rays_d,
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> hits_t,
const uint8_t* __restrict__ density_bitfield,
const int cascades,
const int grid_size,
const float scale,
const float exp_step_factor,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> noise,
const int max_samples,
int* __restrict__ counter,
torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits> rays_a,
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> xyzs,
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> dirs,
torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> deltas,
torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> ts
){
const int r = blockIdx.x * blockDim.x + threadIdx.x;
if (r >= rays_o.size(0)) return;
const uint32_t grid_size3 = grid_size*grid_size*grid_size;
const float grid_size_inv = 1.0f/grid_size;
const float ox = rays_o[r][0], oy = rays_o[r][1], oz = rays_o[r][2];
const float dx = rays_d[r][0], dy = rays_d[r][1], dz = rays_d[r][2];
const float dx_inv = 1.0f/dx, dy_inv = 1.0f/dy, dz_inv = 1.0f/dz;
float t1 = hits_t[r][0], t2 = hits_t[r][1];
if (t1>=0) { // only perturb the starting t
const float dt = calc_dt(t1, exp_step_factor, max_samples, grid_size, scale);
t1 += dt*noise[r];
}
// first pass: compute the number of samples on the ray
float t = t1; int N_samples = 0;
// if t1 < 0 (no hit) this loop will be skipped (N_samples will be 0)
while (0<=t && t<t2 && N_samples<max_samples){
const float x = ox+t*dx, y = oy+t*dy, z = oz+t*dz;
const float dt = calc_dt(t, exp_step_factor, max_samples, grid_size, scale);
const int mip = max(mip_from_pos(x, y, z, cascades),
mip_from_dt(dt, grid_size, cascades));
const float mip_bound = fminf(scalbnf(1.0f, mip-1), scale);
const float mip_bound_inv = 1/mip_bound;
// round down to nearest grid position
const int nx = clamp(0.5f*(x*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const int ny = clamp(0.5f*(y*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const int nz = clamp(0.5f*(z*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const uint32_t idx = mip*grid_size3 + __morton3D(nx, ny, nz);
const bool occ = density_bitfield[idx/8] & (1<<(idx%8));
if (occ) {
t += dt; N_samples++;
} else { // skip until the next voxel
const float tx = (((nx+0.5f+0.5f*signf(dx))*grid_size_inv*2-1)*mip_bound-x)*dx_inv;
const float ty = (((ny+0.5f+0.5f*signf(dy))*grid_size_inv*2-1)*mip_bound-y)*dy_inv;
const float tz = (((nz+0.5f+0.5f*signf(dz))*grid_size_inv*2-1)*mip_bound-z)*dz_inv;
const float t_target = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
do {
t += calc_dt(t, exp_step_factor, max_samples, grid_size, scale);
} while (t < t_target);
}
}
// second pass: write to output
const int start_idx = atomicAdd(counter, N_samples);
const int ray_count = atomicAdd(counter+1, 1);
rays_a[ray_count][0] = r;
rays_a[ray_count][1] = start_idx; rays_a[ray_count][2] = N_samples;
t = t1; int samples = 0;
while (t<t2 && samples<N_samples){
const float x = ox+t*dx, y = oy+t*dy, z = oz+t*dz;
const float dt = calc_dt(t, exp_step_factor, max_samples, grid_size, scale);
const int mip = max(mip_from_pos(x, y, z, cascades),
mip_from_dt(dt, grid_size, cascades));
const float mip_bound = fminf(scalbnf(1.0f, mip-1), scale);
const float mip_bound_inv = 1/mip_bound;
// round down to nearest grid position
const int nx = clamp(0.5f*(x*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const int ny = clamp(0.5f*(y*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const int nz = clamp(0.5f*(z*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const uint32_t idx = mip*grid_size3 + __morton3D(nx, ny, nz);
const bool occ = density_bitfield[idx/8] & (1<<(idx%8));
if (occ) {
const int s = start_idx + samples;
xyzs[s][0] = x; xyzs[s][1] = y; xyzs[s][2] = z;
dirs[s][0] = dx; dirs[s][1] = dy; dirs[s][2] = dz;
ts[s] = t; deltas[s] = dt;
t += dt; samples++;
} else { // skip until the next voxel
const float tx = (((nx+0.5f+0.5f*signf(dx))*grid_size_inv*2-1)*mip_bound-x)*dx_inv;
const float ty = (((ny+0.5f+0.5f*signf(dy))*grid_size_inv*2-1)*mip_bound-y)*dy_inv;
const float tz = (((nz+0.5f+0.5f*signf(dz))*grid_size_inv*2-1)*mip_bound-z)*dz_inv;
const float t_target = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
do {
t += calc_dt(t, exp_step_factor, max_samples, grid_size, scale);
} while (t < t_target);
}
}
}
std::vector<torch::Tensor> raymarching_train_cu(
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor hits_t,
const torch::Tensor density_bitfield,
const int cascades,
const float scale,
const float exp_step_factor,
const torch::Tensor noise,
const int grid_size,
const int max_samples
){
const int N_rays = rays_o.size(0);
// count the number of samples and the number of rays processed
auto counter = torch::zeros({2}, torch::dtype(torch::kInt32).device(rays_o.device()));
// ray attributes: ray_idx, start_idx, N_samples
auto rays_a = torch::zeros({N_rays, 3},
torch::dtype(torch::kLong).device(rays_o.device()));
auto xyzs = torch::zeros({N_rays*max_samples, 3}, rays_o.options());
auto dirs = torch::zeros({N_rays*max_samples, 3}, rays_o.options());
auto deltas = torch::zeros({N_rays*max_samples}, rays_o.options());
auto ts = torch::zeros({N_rays*max_samples}, rays_o.options());
const int threads = 256, blocks = (N_rays+threads-1)/threads;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(rays_o.type(), "raymarching_train_cu",
([&] {
raymarching_train_kernel<<<blocks, threads>>>(
rays_o.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
rays_d.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
hits_t.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
density_bitfield.data_ptr<uint8_t>(),
cascades,
grid_size,
scale,
exp_step_factor,
noise.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
max_samples,
counter.data_ptr<int>(),
rays_a.packed_accessor64<int64_t, 2, torch::RestrictPtrTraits>(),
xyzs.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
dirs.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
deltas.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
ts.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
}));
return {rays_a, xyzs, dirs, deltas, ts, counter};
}
__global__ void raymarching_test_kernel(
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> rays_o,
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> rays_d,
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> hits_t,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> alive_indices,
const uint8_t* __restrict__ density_bitfield,
const int cascades,
const int grid_size,
const float scale,
const float exp_step_factor,
const int N_samples,
const int max_samples,
torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> xyzs,
torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> dirs,
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> deltas,
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> ts,
torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits> N_eff_samples
){
const int n = blockIdx.x * blockDim.x + threadIdx.x;
if (n >= alive_indices.size(0)) return;
const size_t r = alive_indices[n]; // ray index
const uint32_t grid_size3 = grid_size*grid_size*grid_size;
const float grid_size_inv = 1.0f/grid_size;
const float ox = rays_o[r][0], oy = rays_o[r][1], oz = rays_o[r][2];
const float dx = rays_d[r][0], dy = rays_d[r][1], dz = rays_d[r][2];
const float dx_inv = 1.0f/dx, dy_inv = 1.0f/dy, dz_inv = 1.0f/dz;
float t = hits_t[r][0], t2 = hits_t[r][1];
int s = 0;
while (t<t2 && s<N_samples){
const float x = ox+t*dx, y = oy+t*dy, z = oz+t*dz;
const float dt = calc_dt(t, exp_step_factor, max_samples, grid_size, cascades);
const int mip = max(mip_from_pos(x, y, z, cascades),
mip_from_dt(dt, grid_size, cascades));
const float mip_bound = fminf(scalbnf(1.0f, mip-1), scale);
const float mip_bound_inv = 1/mip_bound;
// round down to nearest grid position
const int nx = clamp(0.5f*(x*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const int ny = clamp(0.5f*(y*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const int nz = clamp(0.5f*(z*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
const uint32_t idx = mip*grid_size3 + __morton3D(nx, ny, nz);
const bool occ = density_bitfield[idx/8] & (1<<(idx%8));
if (occ) {
xyzs[n][s][0] = x; xyzs[n][s][1] = y; xyzs[n][s][2] = z;
dirs[n][s][0] = dx; dirs[n][s][1] = dy; dirs[n][s][2] = dz;
ts[n][s] = t; deltas[n][s] = dt;
t += dt;
hits_t[r][0] = t; // modify the starting point for the next marching
s++;
} else { // skip until the next voxel
const float tx = (((nx+0.5f+0.5f*signf(dx))*grid_size_inv*2-1)*mip_bound-x)*dx_inv;
const float ty = (((ny+0.5f+0.5f*signf(dy))*grid_size_inv*2-1)*mip_bound-y)*dy_inv;
const float tz = (((nz+0.5f+0.5f*signf(dz))*grid_size_inv*2-1)*mip_bound-z)*dz_inv;
const float t_target = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
do {
t += calc_dt(t, exp_step_factor, max_samples, grid_size, cascades);
} while (t < t_target);
}
}
N_eff_samples[n] = s; // effective samples that hit occupied region (<=N_samples)
}
std::vector<torch::Tensor> raymarching_test_cu(
const torch::Tensor rays_o,
const torch::Tensor rays_d,
torch::Tensor hits_t,
const torch::Tensor alive_indices,
const torch::Tensor density_bitfield,
const int cascades,
const float scale,
const float exp_step_factor,
const int grid_size,
const int max_samples,
const int N_samples
){
const int N_rays = alive_indices.size(0);
auto xyzs = torch::zeros({N_rays, N_samples, 3}, rays_o.options());
auto dirs = torch::zeros({N_rays, N_samples, 3}, rays_o.options());
auto deltas = torch::zeros({N_rays, N_samples}, rays_o.options());
auto ts = torch::zeros({N_rays, N_samples}, rays_o.options());
auto N_eff_samples = torch::zeros({N_rays},
torch::dtype(torch::kInt32).device(rays_o.device()));
const int threads = 256, blocks = (N_rays+threads-1)/threads;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(rays_o.type(), "raymarching_test_cu",
([&] {
raymarching_test_kernel<<<blocks, threads>>>(
rays_o.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
rays_d.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
hits_t.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
alive_indices.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
density_bitfield.data_ptr<uint8_t>(),
cascades,
grid_size,
scale,
exp_step_factor,
N_samples,
max_samples,
xyzs.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
dirs.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
deltas.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
ts.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
N_eff_samples.packed_accessor32<int, 1, torch::RestrictPtrTraits>()
);
}));
return {xyzs, dirs, deltas, ts, N_eff_samples};
}