Skip to content

Commit

Permalink
Fast SH implementation (#165)
Browse files Browse the repository at this point in the history
* implement and profile new sh

* rewrite fast sh

* clean up

* fix naming of tmp variable

* reformat

* clean and add comments

* minor

* profile sh

* format

* fix typo

* auto diff check for fast sh

* minor

---------

Co-authored-by: Jianbo Ye <jianboye@amazon.com>
Co-authored-by: Ruilong Li <397653553@qq.com>
  • Loading branch information
3 people committed Apr 22, 2024
1 parent c54fe8b commit 0cef46c
Show file tree
Hide file tree
Showing 6 changed files with 458 additions and 38 deletions.
97 changes: 94 additions & 3 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,26 @@
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from typing import Tuple
from typing import Tuple, Literal


def compute_sh_color(
viewdirs: Float[Tensor, "*batch 3"], sh_coeffs: Float[Tensor, "*batch D C"]
viewdirs: Float[Tensor, "*batch 3"],
sh_coeffs: Float[Tensor, "*batch D C"],
method: Literal["poly", "fast"] = "fast",
):
"""
:param viewdirs (*, C)
:param sh_coeffs (*, D, C) sh coefficients for each color channel
return colors (*, C)
"""
*dims, dim_sh, C = sh_coeffs.shape
bases = eval_sh_bases(dim_sh, viewdirs) # (*, dim_sh)
if method == "poly":
bases = eval_sh_bases(dim_sh, viewdirs) # (*, dim_sh)
elif method == "fast":
bases = eval_sh_bases_fast(dim_sh, viewdirs) # (*, dim_sh)
else:
raise RuntimeError(f"Unknown mode: {method} for compute sh color.")
return (bases[..., None] * sh_coeffs).sum(dim=-2)


Expand Down Expand Up @@ -113,6 +120,90 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor):
return result


def eval_sh_bases_fast(basis_dim: int, dirs: torch.Tensor):
"""
Evaluate spherical harmonics bases at unit direction for high orders
using approach described by
Efficient Spherical Harmonic Evaluation, Peter-Pike Sloan, JCGT 2013
https://jcgt.org/published/0002/02/06/
:param basis_dim: int SH basis dim. Currently, only 1-25 square numbers supported
:param dirs: torch.Tensor (..., 3) unit directions
:return: torch.Tensor (..., basis_dim)
See reference C++ code in https://jcgt.org/published/0002/02/06/code.zip
"""
result = torch.empty(
(*dirs.shape[:-1], basis_dim), dtype=dirs.dtype, device=dirs.device
)

result[..., 0] = 0.2820947917738781

if basis_dim <= 1:
return

x, y, z = dirs.unbind(-1)

fTmpA = -0.48860251190292
result[..., 2] = 0.4886025119029199 * z
result[..., 3] = fTmpA * x
result[..., 1] = fTmpA * y

if basis_dim <= 4:
return

z2 = z * z
fTmpB = -1.092548430592079 * z
fTmpA = 0.5462742152960395
fC1 = x * x - y * y
fS1 = 2 * x * y
result[..., 6] = 0.9461746957575601 * z2 - 0.3153915652525201
result[..., 7] = fTmpB * x
result[..., 5] = fTmpB * y
result[..., 8] = fTmpA * fC1
result[..., 4] = fTmpA * fS1

if basis_dim <= 9:
return

fTmpC = -2.285228997322329 * z2 + 0.4570457994644658
fTmpB = 1.445305721320277 * z
fTmpA = -0.5900435899266435
fC2 = x * fC1 - y * fS1
fS2 = x * fS1 + y * fC1
result[..., 12] = z * (1.865881662950577 * z2 - 1.119528997770346)
result[..., 13] = fTmpC * x
result[..., 11] = fTmpC * y
result[..., 14] = fTmpB * fC1
result[..., 10] = fTmpB * fS1
result[..., 15] = fTmpA * fC2
result[..., 9] = fTmpA * fS2

if basis_dim <= 16:
return

fTmpD = z * (-4.683325804901025 * z2 + 2.007139630671868)
fTmpC = 3.31161143515146 * z2 - 0.47308734787878
fTmpB = -1.770130769779931 * z
fTmpA = 0.6258357354491763
fC3 = x * fC2 - y * fS2
fS3 = x * fS2 + y * fC2
result[..., 20] = (
1.984313483298443 * z * result[..., 12] + -1.006230589874905 * result[..., 6]
)
result[..., 21] = fTmpD * x
result[..., 19] = fTmpD * y
result[..., 22] = fTmpC * fC1
result[..., 18] = fTmpC * fS1
result[..., 23] = fTmpB * fC2
result[..., 17] = fTmpB * fS2
result[..., 24] = fTmpA * fC3
result[..., 16] = fTmpA * fS3
return result


def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:
assert quat.shape[-1] == 4, quat.shape
w, x, y, z = torch.unbind(quat, dim=-1)
Expand Down
74 changes: 53 additions & 21 deletions gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ compute_cov2d_bounds_tensor(const int num_pts, torch::Tensor &covs2d) {
}

torch::Tensor compute_sh_forward_tensor(
const std::string &method,
const unsigned num_points,
const unsigned degree,
const unsigned degrees_to_use,
Expand All @@ -72,21 +73,37 @@ torch::Tensor compute_sh_forward_tensor(
coeffs.size(1) != num_bases || coeffs.size(2) != 3) {
AT_ERROR("coeffs must have dimensions (N, D, 3)");
}
torch::Tensor colors = torch::empty({num_points, 3}, coeffs.options());
compute_sh_forward_kernel<<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
degree,
degrees_to_use,
(float3 *)viewdirs.contiguous().data_ptr<float>(),
coeffs.contiguous().data_ptr<float>(),
colors.contiguous().data_ptr<float>()
);
torch::Tensor colors = torch::empty({num_points, 3}, coeffs.options());
if (method == "poly") {
compute_sh_forward_kernel<SHType::Poly><<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
degree,
degrees_to_use,
(float3 *)viewdirs.contiguous().data_ptr<float>(),
coeffs.contiguous().data_ptr<float>(),
colors.contiguous().data_ptr<float>()
);
} else if (method == "fast") {
compute_sh_forward_kernel<SHType::Fast><<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
degree,
degrees_to_use,
(float3 *)viewdirs.contiguous().data_ptr<float>(),
coeffs.contiguous().data_ptr<float>(),
colors.contiguous().data_ptr<float>()
);
} else {
AT_ERROR("Invalid method: ", method);
}
return colors;
}

torch::Tensor compute_sh_backward_tensor(
const std::string &method,
const unsigned num_points,
const unsigned degree,
const unsigned degrees_to_use,
Expand All @@ -105,16 +122,31 @@ torch::Tensor compute_sh_backward_tensor(
unsigned num_bases = num_sh_bases(degree);
torch::Tensor v_coeffs =
torch::zeros({num_points, num_bases, 3}, v_colors.options());
compute_sh_backward_kernel<<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
degree,
degrees_to_use,
(float3 *)viewdirs.contiguous().data_ptr<float>(),
v_colors.contiguous().data_ptr<float>(),
v_coeffs.contiguous().data_ptr<float>()
);
if (method == "poly") {
compute_sh_backward_kernel<SHType::Poly><<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
degree,
degrees_to_use,
(float3 *)viewdirs.contiguous().data_ptr<float>(),
v_colors.contiguous().data_ptr<float>(),
v_coeffs.contiguous().data_ptr<float>()
);
} else if (method == "fast") {
compute_sh_backward_kernel<SHType::Fast><<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
degree,
degrees_to_use,
(float3 *)viewdirs.contiguous().data_ptr<float>(),
v_colors.contiguous().data_ptr<float>(),
v_coeffs.contiguous().data_ptr<float>()
);
} else {
AT_ERROR("Invalid method: ", method);
}
return v_coeffs;
}

Expand Down
2 changes: 2 additions & 0 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ std::tuple<
compute_cov2d_bounds_tensor(const int num_pts, torch::Tensor &A);

torch::Tensor compute_sh_forward_tensor(
const std::string &method,
unsigned num_points,
unsigned degree,
unsigned degrees_to_use,
Expand All @@ -30,6 +31,7 @@ torch::Tensor compute_sh_forward_tensor(
);

torch::Tensor compute_sh_backward_tensor(
const std::string &method,
unsigned num_points,
unsigned degree,
unsigned degrees_to_use,
Expand Down
Loading

0 comments on commit 0cef46c

Please sign in to comment.