Skip to content

Commit

Permalink
updates to apply_pca_colormap (#3086)
Browse files Browse the repository at this point in the history
* improvements to pca_colormap: allow input pca matrix, optional ignore_zeros arg

* typo
  • Loading branch information
kerrj committed Apr 18, 2024
1 parent babf577 commit 45d8bb7
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions nerfstudio/utils/colormaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,35 +171,46 @@ def apply_boolean_colormap(
return colored_image


def apply_pca_colormap(image: Float[Tensor, "*bs dim"]) -> Float[Tensor, "*bs rgb=3"]:
def apply_pca_colormap(
image: Float[Tensor, "*bs dim"], pca_mat: Optional[Float[Tensor, "dim rgb=3"]] = None, ignore_zeros=True
) -> Float[Tensor, "*bs rgb=3"]:
"""Convert feature image to 3-channel RGB via PCA. The first three principle
components are used for the color channels, with outlier rejection per-channel
Args:
image: image of arbitrary vectors
pca_mat: an optional argument of the PCA matrix, shape (dim, 3)
ignore_zeros: whether to ignore zero values in the input image (they won't affect the PCA computation)
Returns:
Tensor: Colored image
"""
original_shape = image.shape
image = image.view(-1, image.shape[-1])
_, _, v = torch.pca_lowrank(image)
image = torch.matmul(image, v[..., :3])
d = torch.abs(image - torch.median(image, dim=0).values)
if ignore_zeros:
valids = (image.abs().amax(dim=-1)) > 0
else:
valids = torch.ones(image.shape[0], dtype=torch.bool)

if pca_mat is None:
_, _, pca_mat = torch.pca_lowrank(image[valids, :], q=3, niter=20)
assert pca_mat is not None
image = torch.matmul(image, pca_mat[..., :3])
d = torch.abs(image[valids, :] - torch.median(image[valids, :], dim=0).values)
mdev = torch.median(d, dim=0).values
s = d / mdev
m = 3.0 # this is a hyperparam controlling how many std dev outside for outliers
rins = image[s[:, 0] < m, 0]
gins = image[s[:, 1] < m, 1]
bins = image[s[:, 2] < m, 2]

image[:, 0] -= rins.min()
image[:, 1] -= gins.min()
image[:, 2] -= bins.min()

image[:, 0] /= rins.max() - rins.min()
image[:, 1] /= gins.max() - gins.min()
image[:, 2] /= bins.max() - bins.min()
m = 2.0 # this is a hyperparam controlling how many std dev outside for outliers
rins = image[valids, :][s[:, 0] < m, 0]
gins = image[valids, :][s[:, 1] < m, 1]
bins = image[valids, :][s[:, 2] < m, 2]

image[valids, 0] -= rins.min()
image[valids, 1] -= gins.min()
image[valids, 2] -= bins.min()

image[valids, 0] /= rins.max() - rins.min()
image[valids, 1] /= gins.max() - gins.min()
image[valids, 2] /= bins.max() - bins.min()

image = torch.clamp(image, 0, 1)
image_long = (image * 255).long()
Expand Down

0 comments on commit 45d8bb7

Please sign in to comment.