Skip to content

Commit

Permalink
Single directional chamfer distance and non-absolute cosine similarity
Browse files Browse the repository at this point in the history
Summary: Single directional chamfer distance and option to use non-absolute cosine similarity

Reviewed By: bottler

Differential Revision: D46593980

fbshipit-source-id: b2e591706a0cdde1c2d361614cecebb84a581433
  • Loading branch information
Norman Mueller authored and facebook-github-bot committed Jun 13, 2023
1 parent 573a42c commit 5ffeb4d
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 105 deletions.
201 changes: 114 additions & 87 deletions pytorch3d/loss/chamfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,74 +68,28 @@ def _handle_pointcloud_input(
return X, lengths, normals


def chamfer_distance(
def _chamfer_distance_single_direction(
x,
y,
x_lengths=None,
y_lengths=None,
x_normals=None,
y_normals=None,
weights=None,
batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean",
norm: int = 2,
x_lengths,
y_lengths,
x_normals,
y_normals,
weights,
batch_reduction: Union[str, None],
point_reduction: str,
norm: int,
abs_cosine: bool,
):
"""
Chamfer distance between two pointclouds x and y.
Args:
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
a batch of point clouds with at most P1 points in each batch element,
batch size N and feature dimension D.
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
a batch of point clouds with at most P2 points in each batch element,
batch size N and feature dimension D.
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
cloud in x.
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
cloud in y.
x_normals: Optional FloatTensor of shape (N, P1, D).
y_normals: Optional FloatTensor of shape (N, P2, D).
weights: Optional FloatTensor of shape (N,) giving weights for
batch elements for reduction operation.
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"].
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
Returns:
2-element tuple containing
- **loss**: Tensor giving the reduced distance between the pointclouds
in x and the pointclouds in y.
- **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None.
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)

if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")

x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)

return_normals = x_normals is not None and y_normals is not None

N, P1, D = x.shape
P2 = y.shape[1]

# Check if inputs are heterogeneous and create a lengths mask.
is_x_heterogeneous = (x_lengths != P1).any()
is_y_heterogeneous = (y_lengths != P2).any()
x_mask = (
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
) # shape [N, P1]
y_mask = (
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
) # shape [N, P2]

if y.shape[0] != N or y.shape[2] != D:
raise ValueError("y does not have the correct shape.")
if weights is not None:
Expand All @@ -153,75 +107,148 @@ def chamfer_distance(
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)

cham_norm_x = x.new_zeros(())
cham_norm_y = x.new_zeros(())

x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1)

cham_x = x_nn.dists[..., 0] # (N, P1)
cham_y = y_nn.dists[..., 0] # (N, P2)

if is_x_heterogeneous:
cham_x[x_mask] = 0.0
if is_y_heterogeneous:
cham_y[y_mask] = 0.0

if weights is not None:
cham_x *= weights.view(N, 1)
cham_y *= weights.view(N, 1)

if return_normals:
# Gather the normals using the indices and keep only value for k=0
x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :]

cham_norm_x = 1 - torch.abs(
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
)
cham_norm_y = 1 - torch.abs(
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
)
cosine_sim = F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
# If abs_cosine, ignore orientation and take the absolute value of the cosine sim.
cham_norm_x = 1 - (torch.abs(cosine_sim) if abs_cosine else cosine_sim)

if is_x_heterogeneous:
cham_norm_x[x_mask] = 0.0
if is_y_heterogeneous:
cham_norm_y[y_mask] = 0.0

if weights is not None:
cham_norm_x *= weights.view(N, 1)
cham_norm_y *= weights.view(N, 1)
cham_norm_x = cham_norm_x.sum(1) # (N,)

# Apply point reduction
cham_x = cham_x.sum(1) # (N,)
cham_y = cham_y.sum(1) # (N,)
if return_normals:
cham_norm_x = cham_norm_x.sum(1) # (N,)
cham_norm_y = cham_norm_y.sum(1) # (N,)
if point_reduction == "mean":
x_lengths_clamped = x_lengths.clamp(min=1)
y_lengths_clamped = y_lengths.clamp(min=1)
cham_x /= x_lengths_clamped
cham_y /= y_lengths_clamped
if return_normals:
cham_norm_x /= x_lengths_clamped
cham_norm_y /= y_lengths_clamped

if batch_reduction is not None:
# batch_reduction == "sum"
cham_x = cham_x.sum()
cham_y = cham_y.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
cham_norm_y = cham_norm_y.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else max(N, 1)
cham_x /= div
cham_y /= div
if return_normals:
cham_norm_x /= div
cham_norm_y /= div

cham_dist = cham_x + cham_y
cham_normals = cham_norm_x + cham_norm_y if return_normals else None

cham_dist = cham_x
cham_normals = cham_norm_x if return_normals else None
return cham_dist, cham_normals


def chamfer_distance(
x,
y,
x_lengths=None,
y_lengths=None,
x_normals=None,
y_normals=None,
weights=None,
batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean",
norm: int = 2,
single_directional: bool = False,
abs_cosine: bool = True,
):
"""
Chamfer distance between two pointclouds x and y.
Args:
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
a batch of point clouds with at most P1 points in each batch element,
batch size N and feature dimension D.
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
a batch of point clouds with at most P2 points in each batch element,
batch size N and feature dimension D.
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
cloud in x.
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
cloud in y.
x_normals: Optional FloatTensor of shape (N, P1, D).
y_normals: Optional FloatTensor of shape (N, P2, D).
weights: Optional FloatTensor of shape (N,) giving weights for
batch elements for reduction operation.
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"].
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
single_directional: If False (default), loss comes from both the distance between
each point in x and its nearest neighbor in y and each point in y and its nearest
neighbor in x. If True, loss is the distance between each point in x and its
nearest neighbor in y.
abs_cosine: If False, loss_normals is from one minus the cosine similarity.
If True (default), loss_normals is from one minus the absolute value of the
cosine similarity, which means that exactly opposite normals are considered
equivalent to exactly matching normals, i.e. sign does not matter.
Returns:
2-element tuple containing
- **loss**: Tensor giving the reduced distance between the pointclouds
in x and the pointclouds in y.
- **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None.
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)

if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)

cham_x, cham_norm_x = _chamfer_distance_single_direction(
x,
y,
x_lengths,
y_lengths,
x_normals,
y_normals,
weights,
batch_reduction,
point_reduction,
norm,
abs_cosine,
)
if single_directional:
return cham_x, cham_norm_x
else:
cham_y, cham_norm_y = _chamfer_distance_single_direction(
y,
x,
y_lengths,
x_lengths,
y_normals,
x_normals,
weights,
batch_reduction,
point_reduction,
norm,
abs_cosine,
)
return (
cham_x + cham_y,
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
)
Loading

0 comments on commit 5ffeb4d

Please sign in to comment.