Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface for depth in both forward rendering and backward propagation #5

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 19 additions & 2 deletions cuda_rasterizer/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,11 @@ renderCUDA(
const float2* __restrict__ points_xy_image,
const float4* __restrict__ conic_opacity,
const float* __restrict__ colors,
const float* __restrict__ depths,
const float* __restrict__ final_Ts,
const uint32_t* __restrict__ n_contrib,
const float* __restrict__ dL_dpixels,
const float* __restrict__ dL_depths,
float3* __restrict__ dL_dmean2D,
float4* __restrict__ dL_dconic2D,
float* __restrict__ dL_dopacity,
Expand All @@ -435,6 +437,7 @@ renderCUDA(
__shared__ float2 collected_xy[BLOCK_SIZE];
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
__shared__ float collected_colors[C * BLOCK_SIZE];
__shared__ float collected_depths[BLOCK_SIZE];

// In the forward, we stored the final value for T, the
// product of all (1 - alpha) factors.
Expand All @@ -448,12 +451,17 @@ renderCUDA(

float accum_rec[C] = { 0 };
float dL_dpixel[C];
if (inside)
float dL_depth;
float accum_depth_rec = 0;
if (inside){
for (int i = 0; i < C; i++)
dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
dL_depth = dL_depths[pix_id];
}

float last_alpha = 0;
float last_color[C] = { 0 };
float last_depth = 0;

// Gradient of pixel coordinate w.r.t. normalized
// screen-space viewport corrdinates (-1 to 1)
Expand All @@ -475,6 +483,7 @@ renderCUDA(
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
for (int i = 0; i < C; i++)
collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
collected_depths[block.thread_rank()] = depths[coll_id];
}
block.sync();

Expand Down Expand Up @@ -522,6 +531,10 @@ renderCUDA(
// many that were affected by this Gaussian.
atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
}
const float c_d = collected_depths[j];
accum_depth_rec = last_alpha * last_depth + (1.f - last_alpha) * accum_depth_rec;
last_depth = c_d;
dL_dalpha += (c_d - accum_depth_rec) * dL_depth;
dL_dalpha *= T;
// Update last alpha (to be used in the next iteration)
last_alpha = alpha;
Expand Down Expand Up @@ -630,9 +643,11 @@ void BACKWARD::render(
const float2* means2D,
const float4* conic_opacity,
const float* colors,
const float* depths,
const float* final_Ts,
const uint32_t* n_contrib,
const float* dL_dpixels,
const float* dL_depths,
float3* dL_dmean2D,
float4* dL_dconic2D,
float* dL_dopacity,
Expand All @@ -646,12 +661,14 @@ void BACKWARD::render(
means2D,
conic_opacity,
colors,
depths,
final_Ts,
n_contrib,
dL_dpixels,
dL_depths,
dL_dmean2D,
dL_dconic2D,
dL_dopacity,
dL_dcolors
);
}
}
4 changes: 3 additions & 1 deletion cuda_rasterizer/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ namespace BACKWARD
const float2* means2D,
const float4* conic_opacity,
const float* colors,
const float* depths,
const float* final_Ts,
const uint32_t* n_contrib,
const float* dL_dpixels,
const float* dL_depths,
float3* dL_dmean2D,
float4* dL_dconic2D,
float* dL_dopacity,
Expand Down Expand Up @@ -62,4 +64,4 @@ namespace BACKWARD
glm::vec4* dL_drot);
}

#endif
#endif
17 changes: 13 additions & 4 deletions cuda_rasterizer/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,13 @@ renderCUDA(
int W, int H,
const float2* __restrict__ points_xy_image,
const float* __restrict__ features,
const float* __restrict__ depths,
const float4* __restrict__ conic_opacity,
float* __restrict__ final_T,
uint32_t* __restrict__ n_contrib,
const float* __restrict__ bg_color,
float* __restrict__ out_color)
float* __restrict__ out_color,
float* __restrict__ out_depth)
{
// Identify current tile and associated min/max pixel range.
auto block = cg::this_thread_block();
Expand Down Expand Up @@ -301,6 +303,7 @@ renderCUDA(
uint32_t contributor = 0;
uint32_t last_contributor = 0;
float C[CHANNELS] = { 0 };
float D = { 0 };

// Iterate over batches until all done or range is complete
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
Expand Down Expand Up @@ -353,6 +356,7 @@ renderCUDA(
// Eq. (3) from 3D Gaussian splatting paper.
for (int ch = 0; ch < CHANNELS; ch++)
C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
D += depths[collected_id[j]] * alpha * T;

T = test_T;

Expand All @@ -370,6 +374,7 @@ renderCUDA(
n_contrib[pix_id] = last_contributor;
for (int ch = 0; ch < CHANNELS; ch++)
out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
out_depth[pix_id] = D;
}
}

Expand All @@ -380,23 +385,27 @@ void FORWARD::render(
int W, int H,
const float2* means2D,
const float* colors,
const float* depths,
const float4* conic_opacity,
float* final_T,
uint32_t* n_contrib,
const float* bg_color,
float* out_color)
float* out_color,
float* out_depth)
{
renderCUDA<NUM_CHANNELS> << <grid, block >> > (
ranges,
point_list,
W, H,
means2D,
colors,
depths,
conic_opacity,
final_T,
n_contrib,
bg_color,
out_color);
out_color,
out_depth);
}

void FORWARD::preprocess(int P, int D, int M,
Expand Down Expand Up @@ -452,4 +461,4 @@ void FORWARD::preprocess(int P, int D, int M,
tiles_touched,
prefiltered
);
}
}
6 changes: 4 additions & 2 deletions cuda_rasterizer/forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ namespace FORWARD
int W, int H,
const float2* points_xy_image,
const float* features,
const float* depths,
const float4* conic_opacity,
float* final_T,
uint32_t* n_contrib,
const float* bg_color,
float* out_color);
float* out_color,
float* out_depth);
}


#endif
#endif
4 changes: 3 additions & 1 deletion cuda_rasterizer/rasterizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace CudaRasterizer
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
float* out_depth,
int* radii = nullptr,
bool debug = false);

Expand All @@ -72,6 +73,7 @@ namespace CudaRasterizer
char* binning_buffer,
char* image_buffer,
const float* dL_dpix,
const float* dL_depths,
float* dL_dmean2D,
float* dL_dconic,
float* dL_dopacity,
Expand All @@ -85,4 +87,4 @@ namespace CudaRasterizer
};
};

#endif
#endif
11 changes: 9 additions & 2 deletions cuda_rasterizer/rasterizer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ int CudaRasterizer::Rasterizer::forward(
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
float* out_depth,
int* radii,
bool debug)
{
Expand Down Expand Up @@ -326,11 +327,13 @@ int CudaRasterizer::Rasterizer::forward(
width, height,
geomState.means2D,
feature_ptr,
geomState.depths,
geomState.conic_opacity,
imgState.accum_alpha,
imgState.n_contrib,
background,
out_color), debug)
out_color,
out_depth), debug)

return num_rendered;
}
Expand All @@ -357,6 +360,7 @@ void CudaRasterizer::Rasterizer::backward(
char* binning_buffer,
char* img_buffer,
const float* dL_dpix,
const float* dL_depths,
float* dL_dmean2D,
float* dL_dconic,
float* dL_dopacity,
Expand Down Expand Up @@ -387,6 +391,7 @@ void CudaRasterizer::Rasterizer::backward(
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
const float* depth_ptr = geomState.depths;
CHECK_CUDA(BACKWARD::render(
tile_grid,
block,
Expand All @@ -397,9 +402,11 @@ void CudaRasterizer::Rasterizer::backward(
geomState.means2D,
geomState.conic_opacity,
color_ptr,
depth_ptr,
imgState.accum_alpha,
imgState.n_contrib,
dL_dpix,
dL_depths,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
dL_dopacity,
Expand Down Expand Up @@ -431,4 +438,4 @@ void CudaRasterizer::Rasterizer::backward(
dL_dsh,
(glm::vec3*)dL_dscale,
(glm::vec4*)dL_drot), debug)
}
}
11 changes: 6 additions & 5 deletions diff_gaussian_rasterization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,22 @@ def forward(
if raster_settings.debug:
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
try:
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
except Exception as ex:
torch.save(cpu_args, "snapshot_fw.dump")
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
raise ex
else:
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)

# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.num_rendered = num_rendered
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
return color, radii
return color, radii, depth

@staticmethod
def backward(ctx, grad_out_color, _):
def backward(ctx, grad_out_color, grad_radii, grad_depth):

# Restore necessary values from context
num_rendered = ctx.num_rendered
Expand All @@ -118,7 +118,8 @@ def backward(ctx, grad_out_color, _):
raster_settings.projmatrix,
raster_settings.tanfovx,
raster_settings.tanfovy,
grad_out_color,
grad_out_color,
grad_depth,
sh,
raster_settings.sh_degree,
raster_settings.campos,
Expand Down
10 changes: 7 additions & 3 deletions rasterize_points.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
return lambda;
}

std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
Expand Down Expand Up @@ -66,6 +66,7 @@ RasterizeGaussiansCUDA(
auto float_opts = means3D.options().dtype(torch::kFloat32);

torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));

torch::Device device(torch::kCUDA);
Expand Down Expand Up @@ -108,10 +109,11 @@ RasterizeGaussiansCUDA(
tan_fovy,
prefiltered,
out_color.contiguous().data<float>(),
out_depth.contiguous().data<float>(),
radii.contiguous().data<int>(),
debug);
}
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
return std::make_tuple(rendered, out_color, out_depth, radii, geomBuffer, binningBuffer, imgBuffer);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand All @@ -129,6 +131,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const float tan_fovx,
const float tan_fovy,
const torch::Tensor& dL_dout_color,
const torch::Tensor& dL_dout_depth,
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
Expand Down Expand Up @@ -180,6 +183,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
dL_dout_color.contiguous().data<float>(),
dL_dout_depth.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
dL_dopacity.contiguous().data<float>(),
Expand Down Expand Up @@ -214,4 +218,4 @@ torch::Tensor markVisible(
}

return present;
}
}
5 changes: 3 additions & 2 deletions rasterize_points.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <tuple>
#include <string>

std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
Expand Down Expand Up @@ -52,6 +52,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const float tan_fovx,
const float tan_fovy,
const torch::Tensor& dL_dout_color,
const torch::Tensor& dL_dout_depth,
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
Expand All @@ -64,4 +65,4 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
torch::Tensor markVisible(
torch::Tensor& means3D,
torch::Tensor& viewmatrix,
torch::Tensor& projmatrix);
torch::Tensor& projmatrix);