From 74659aef26db47342d71df99f31b3f63eacd7182 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 15 Jun 2020 10:08:15 -0700 Subject: [PATCH] CPU implementation for point_mesh functions Summary: point_mesh functions were missing CPU implementations. The indices returned are not always matching, possibly due to numerical instability. Reviewed By: gkioxari Differential Revision: D21594264 fbshipit-source-id: 3016930e2a9a0f3cd8b3ac4c94a92c9411c0989d --- pytorch3d/csrc/point_mesh/point_mesh.cpp | 398 ++++++++++++++++++++ pytorch3d/csrc/point_mesh/point_mesh_edge.h | 49 ++- pytorch3d/csrc/point_mesh/point_mesh_face.h | 53 ++- pytorch3d/csrc/utils/geometry_utils.h | 286 ++++++++++++++ pytorch3d/csrc/utils/vec3.h | 5 + tests/test_point_mesh_distance.py | 118 +++++- 6 files changed, 878 insertions(+), 31 deletions(-) create mode 100644 pytorch3d/csrc/point_mesh/point_mesh.cpp diff --git a/pytorch3d/csrc/point_mesh/point_mesh.cpp b/pytorch3d/csrc/point_mesh/point_mesh.cpp new file mode 100644 index 000000000..59e5cefbd --- /dev/null +++ b/pytorch3d/csrc/point_mesh/point_mesh.cpp @@ -0,0 +1,398 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include +#include +#include "utils/geometry_utils.h" +#include "utils/vec3.h" + +// - We start with implementations of simple operations on points, edges and +// faces. The hull of H points is a point if H=1, an edge if H=2, a face if H=3. + +template +vec3 ExtractPoint(const at::TensorAccessor& t) { + return vec3(t[0], t[1], t[2]); +} + +template +struct ExtractHullHelper { + template + static std::array>, H> + get(const Accessor& t); + + template <> + static std::array>, 1> + get<1>(const Accessor& t) { + return {ExtractPoint(t)}; + } + + template <> + static std::array>, 2> + get<2>(const Accessor& t) { + return {ExtractPoint(t[0]), ExtractPoint(t[1])}; + } + + template <> + static std::array>, 3> + get<3>(const Accessor& t) { + return {ExtractPoint(t[0]), ExtractPoint(t[1]), ExtractPoint(t[2])}; + } +}; + +template +std::array>, H> +ExtractHull(const Accessor& t) { + return ExtractHullHelper::template get(t); +} + +template +void IncrementPoint(at::TensorAccessor&& t, const vec3& point) { + t[0] += point.x; + t[1] += point.y; + t[2] += point.z; +} + +// distance between the convex hull of A points and B points +// this could be done in c++17 with tuple_cat and invoke +template +T HullDistance( + const std::array, 1>& a, + const std::array, 2>& b) { + using std::get; + return PointLine3DistanceForward(get<0>(a), get<0>(b), get<1>(b)); +} +template +T HullDistance( + const std::array, 1>& a, + const std::array, 3>& b) { + using std::get; + return PointTriangle3DistanceForward( + get<0>(a), get<0>(b), get<1>(b), get<2>(b)); +} +template +T HullDistance( + const std::array, 2>& a, + const std::array, 1>& b) { + return HullDistance(b, a); +} +template +T HullDistance( + const std::array, 3>& a, + const std::array, 1>& b) { + return HullDistance(b, a); +} + +template +void HullHullDistanceBackward( + const std::array, 1>& a, + const std::array, 2>& b, + T grad_dist, + at::TensorAccessor&& grad_a, + at::TensorAccessor&& grad_b) { + using std::get; + auto res = + PointLine3DistanceBackward(get<0>(a), get<0>(b), get<1>(b), grad_dist); + IncrementPoint(std::move(grad_a), get<0>(res)); + IncrementPoint(grad_b[0], get<1>(res)); + IncrementPoint(grad_b[1], get<2>(res)); +} +template +void HullHullDistanceBackward( + const std::array, 1>& a, + const std::array, 3>& b, + T grad_dist, + at::TensorAccessor&& grad_a, + at::TensorAccessor&& grad_b) { + using std::get; + auto res = PointTriangle3DistanceBackward( + get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist); + IncrementPoint(std::move(grad_a), get<0>(res)); + IncrementPoint(grad_b[0], get<1>(res)); + IncrementPoint(grad_b[1], get<2>(res)); + IncrementPoint(grad_b[2], get<3>(res)); +} +template +void HullHullDistanceBackward( + const std::array, 3>& a, + const std::array, 1>& b, + T grad_dist, + at::TensorAccessor&& grad_a, + at::TensorAccessor&& grad_b) { + return HullHullDistanceBackward( + b, a, grad_dist, std::move(grad_b), std::move(grad_a)); +} +template +void HullHullDistanceBackward( + const std::array, 2>& a, + const std::array, 1>& b, + T grad_dist, + at::TensorAccessor&& grad_a, + at::TensorAccessor&& grad_b) { + return HullHullDistanceBackward( + b, a, grad_dist, std::move(grad_b), std::move(grad_a)); +} + +template +void ValidateShape(const at::Tensor& as) { + if (H == 1) { + TORCH_CHECK(as.size(1) == 3); + } else { + TORCH_CHECK(as.size(2) == 3 && as.size(1) == H); + } +} + +// ----------- Here begins the implementation of each top-level +// function using non-type template parameters to +// implement all the cases in one go. ----------- // + +template +std::tuple HullHullDistanceForwardCpu( + const at::Tensor& as, + const at::Tensor& as_first_idx, + const at::Tensor& bs, + const at::Tensor& bs_first_idx) { + const int64_t A_N = as.size(0); + const int64_t B_N = bs.size(0); + const int64_t BATCHES = as_first_idx.size(0); + + ValidateShape

(as); + ValidateShape

(bs); + + TORCH_CHECK(bs_first_idx.size(0) == BATCHES); + + // clang-format off + at::Tensor dists = at::zeros({A_N,}, as.options()); + at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options()); + // clang-format on + + auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > (); + auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > (); + auto as_first_idx_a = as_first_idx.accessor(); + auto bs_first_idx_a = bs_first_idx.accessor(); + auto dists_a = dists.accessor(); + auto idxs_a = idxs.accessor(); + int64_t a_batch_end = 0; + int64_t b_batch_start = 0, b_batch_end = 0; + int64_t batch_idx = 0; + for (int64_t a_n = 0; a_n < A_N; ++a_n) { + if (a_n == a_batch_end) { + ++batch_idx; + b_batch_start = b_batch_end; + if (batch_idx == BATCHES) { + a_batch_end = std::numeric_limits::max(); + b_batch_end = B_N; + } else { + a_batch_end = as_first_idx_a[batch_idx]; + b_batch_end = bs_first_idx_a[batch_idx]; + } + } + float min_dist = std::numeric_limits::max(); + size_t min_idx = 0; + auto a = ExtractHull

(as_a[a_n]); + for (int64_t b_n = b_batch_start; b_n < b_batch_end; ++b_n) { + float dist = HullDistance(a, ExtractHull

(bs_a[b_n])); + if (dist <= min_dist) { + min_dist = dist; + min_idx = b_n; + } + } + dists_a[a_n] = min_dist; + idxs_a[a_n] = min_idx; + } + + return std::make_tuple(dists, idxs); +} + +template +std::tuple HullHullDistanceBackwardCpu( + const at::Tensor& as, + const at::Tensor& bs, + const at::Tensor& idx_bs, + const at::Tensor& grad_dists) { + const int64_t A_N = as.size(0); + + TORCH_CHECK(idx_bs.size(0) == A_N); + TORCH_CHECK(grad_dists.size(0) == A_N); + ValidateShape

(as); + ValidateShape

(bs); + + at::Tensor grad_as = at::zeros_like(as); + at::Tensor grad_bs = at::zeros_like(bs); + + auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > (); + auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > (); + auto grad_as_a = grad_as.accessor < float, H1 == 1 ? 2 : 3 > (); + auto grad_bs_a = grad_bs.accessor < float, H2 == 1 ? 2 : 3 > (); + auto idx_bs_a = idx_bs.accessor(); + auto grad_dists_a = grad_dists.accessor(); + + for (int64_t a_n = 0; a_n < A_N; ++a_n) { + auto a = ExtractHull

(as_a[a_n]); + auto b = ExtractHull

(bs_a[idx_bs_a[a_n]]); + HullHullDistanceBackward( + a, b, grad_dists_a[a_n], grad_as_a[a_n], grad_bs_a[idx_bs_a[a_n]]); + } + return std::make_tuple(grad_as, grad_bs); +} + +template +torch::Tensor PointHullArrayDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& bs) { + const int64_t P = points.size(0); + const int64_t B_N = bs.size(0); + + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + ValidateShape(bs); + + at::Tensor dists = at::zeros({P, B_N}, points.options()); + auto points_a = points.accessor(); + auto bs_a = bs.accessor(); + auto dists_a = dists.accessor(); + for (int64_t p = 0; p < P; ++p) { + auto point = ExtractHull<1>(points_a[p]); + auto dest = dists_a[p]; + for (int64_t b_n = 0; b_n < B_N; ++b_n) { + auto b = ExtractHull(bs_a[b_n]); + dest[b_n] = HullDistance(point, b); + } + } + return dists; +} + +template +std::tuple PointHullArrayDistanceBackwardCpu( + const at::Tensor& points, + const at::Tensor& bs, + const at::Tensor& grad_dists) { + const int64_t P = points.size(0); + const int64_t B_N = bs.size(0); + + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + ValidateShape(bs); + TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == B_N)); + + at::Tensor grad_points = at::zeros({P, 3}, points.options()); + at::Tensor grad_bs = at::zeros({B_N, H, 3}, bs.options()); + + auto points_a = points.accessor(); + auto bs_a = bs.accessor(); + auto grad_dists_a = grad_dists.accessor(); + auto grad_points_a = grad_points.accessor(); + auto grad_bs_a = grad_bs.accessor(); + for (int64_t p = 0; p < P; ++p) { + auto point = ExtractHull<1>(points_a[p]); + auto grad_point = grad_points_a[p]; + auto grad_dist = grad_dists_a[p]; + for (int64_t b_n = 0; b_n < B_N; ++b_n) { + auto b = ExtractHull(bs_a[b_n]); + HullHullDistanceBackward( + point, b, grad_dist[b_n], std::move(grad_point), grad_bs_a[b_n]); + } + } + return std::make_tuple(grad_points, grad_bs); +} + +// ---------- Here begin the exported functions ------------ // + +std::tuple PointFaceDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& tris, + const torch::Tensor& tris_first_idx) { + return HullHullDistanceForwardCpu<1, 3>( + points, points_first_idx, tris, tris_first_idx); +} + +std::tuple PointFaceDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& tris, + const torch::Tensor& idx_points, + const torch::Tensor& grad_dists) { + return HullHullDistanceBackwardCpu<1, 3>( + points, tris, idx_points, grad_dists); +} + +std::tuple FacePointDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& tris, + const torch::Tensor& tris_first_idx) { + return HullHullDistanceForwardCpu<3, 1>( + tris, tris_first_idx, points, points_first_idx); +} + +std::tuple FacePointDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& tris, + const torch::Tensor& idx_tris, + const torch::Tensor& grad_dists) { + auto res = + HullHullDistanceBackwardCpu<3, 1>(tris, points, idx_tris, grad_dists); + return std::make_tuple(std::get<1>(res), std::get<0>(res)); +} + +torch::Tensor PointEdgeArrayDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms) { + return PointHullArrayDistanceForwardCpu<2>(points, segms); +} + +std::tuple PointFaceArrayDistanceBackwardCpu( + const at::Tensor& points, + const at::Tensor& tris, + const at::Tensor& grad_dists) { + return PointHullArrayDistanceBackwardCpu<3>(points, tris, grad_dists); +} + +torch::Tensor PointFaceArrayDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& tris) { + return PointHullArrayDistanceForwardCpu<3>(points, tris); +} + +std::tuple PointEdgeArrayDistanceBackwardCpu( + const at::Tensor& points, + const at::Tensor& segms, + const at::Tensor& grad_dists) { + return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists); +} + +std::tuple PointEdgeDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t /*max_points*/) { + return HullHullDistanceForwardCpu<1, 2>( + points, points_first_idx, segms, segms_first_idx); +} + +std::tuple PointEdgeDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_points, + const torch::Tensor& grad_dists) { + return HullHullDistanceBackwardCpu<1, 2>( + points, segms, idx_points, grad_dists); +} + +std::tuple EdgePointDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t /*max_segms*/) { + return HullHullDistanceForwardCpu<2, 1>( + segms, segms_first_idx, points, points_first_idx); +} + +std::tuple EdgePointDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_segms, + const torch::Tensor& grad_dists) { + auto res = + HullHullDistanceBackwardCpu<2, 1>(segms, points, idx_segms, grad_dists); + return std::make_tuple(std::get<1>(res), std::get<0>(res)); +} diff --git a/pytorch3d/csrc/point_mesh/point_mesh_edge.h b/pytorch3d/csrc/point_mesh/point_mesh_edge.h index 963820173..b775d2e09 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_edge.h +++ b/pytorch3d/csrc/point_mesh/point_mesh_edge.h @@ -46,6 +46,13 @@ std::tuple PointEdgeDistanceForwardCuda( const int64_t max_points); #endif +std::tuple PointEdgeDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t max_points); + std::tuple PointEdgeDistanceForward( const torch::Tensor& points, const torch::Tensor& points_first_idx, @@ -64,7 +71,8 @@ std::tuple PointEdgeDistanceForward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return PointEdgeDistanceForwardCpu( + points, points_first_idx, segms, segms_first_idx, max_points); } // Backward pass for PointEdgeDistance. @@ -91,6 +99,12 @@ std::tuple PointEdgeDistanceBackwardCuda( const torch::Tensor& grad_dists); #endif +std::tuple PointEdgeDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_points, + const torch::Tensor& grad_dists); + std::tuple PointEdgeDistanceBackward( const torch::Tensor& points, const torch::Tensor& segms, @@ -107,7 +121,7 @@ std::tuple PointEdgeDistanceBackward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists); } // **************************************************************************** @@ -150,6 +164,13 @@ std::tuple EdgePointDistanceForwardCuda( const int64_t max_segms); #endif +std::tuple EdgePointDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t max_segms); + std::tuple EdgePointDistanceForward( const torch::Tensor& points, const torch::Tensor& points_first_idx, @@ -168,7 +189,8 @@ std::tuple EdgePointDistanceForward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return EdgePointDistanceForwardCpu( + points, points_first_idx, segms, segms_first_idx, max_segms); } // Backward pass for EdgePointDistance. @@ -195,6 +217,12 @@ std::tuple EdgePointDistanceBackwardCuda( const torch::Tensor& grad_dists); #endif +std::tuple EdgePointDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_segms, + const torch::Tensor& grad_dists); + std::tuple EdgePointDistanceBackward( const torch::Tensor& points, const torch::Tensor& segms, @@ -211,7 +239,7 @@ std::tuple EdgePointDistanceBackward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists); } // **************************************************************************** @@ -242,6 +270,10 @@ torch::Tensor PointEdgeArrayDistanceForwardCuda( const torch::Tensor& segms); #endif +torch::Tensor PointEdgeArrayDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms); + torch::Tensor PointEdgeArrayDistanceForward( const torch::Tensor& points, const torch::Tensor& segms) { @@ -254,7 +286,7 @@ torch::Tensor PointEdgeArrayDistanceForward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return PointEdgeArrayDistanceForwardCpu(points, segms); } // Backward pass for PointEdgeArrayDistance. @@ -277,6 +309,11 @@ std::tuple PointEdgeArrayDistanceBackwardCuda( const torch::Tensor& grad_dists); #endif +std::tuple PointEdgeArrayDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& grad_dists); + std::tuple PointEdgeArrayDistanceBackward( const torch::Tensor& points, const torch::Tensor& segms, @@ -291,5 +328,5 @@ std::tuple PointEdgeArrayDistanceBackward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists); } diff --git a/pytorch3d/csrc/point_mesh/point_mesh_face.h b/pytorch3d/csrc/point_mesh/point_mesh_face.h index 00f5eb0a3..ec4bd3442 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_face.h +++ b/pytorch3d/csrc/point_mesh/point_mesh_face.h @@ -19,7 +19,7 @@ // points_first_idx: LongTensor of shape (N,) indicating the first point // index for each example in the batch // tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th -// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2]) +// triangular face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2]) // tris_first_idx: LongTensor of shape (N,) indicating the first face // index for each example in the batch // max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing @@ -48,6 +48,12 @@ std::tuple PointFaceDistanceForwardCuda( const int64_t max_points); #endif +std::tuple PointFaceDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& tris, + const torch::Tensor& tris_first_idx); + std::tuple PointFaceDistanceForward( const torch::Tensor& points, const torch::Tensor& points_first_idx, @@ -66,7 +72,8 @@ std::tuple PointFaceDistanceForward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return PointFaceDistanceForwardCpu( + points, points_first_idx, tris, tris_first_idx); } // Backward pass for PointFaceDistance. @@ -92,6 +99,11 @@ std::tuple PointFaceDistanceBackwardCuda( const torch::Tensor& idx_points, const torch::Tensor& grad_dists); #endif +std::tuple PointFaceDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& tris, + const torch::Tensor& idx_points, + const torch::Tensor& grad_dists); std::tuple PointFaceDistanceBackward( const torch::Tensor& points, @@ -109,7 +121,7 @@ std::tuple PointFaceDistanceBackward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return PointFaceDistanceBackwardCpu(points, tris, idx_points, grad_dists); } // **************************************************************************** @@ -124,7 +136,7 @@ std::tuple PointFaceDistanceBackward( // points_first_idx: LongTensor of shape (N,) indicating the first point // index for each example in the batch // tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th -// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2]) +// triangular face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2]) // tris_first_idx: LongTensor of shape (N,) indicating the first face // index for each example in the batch // max_tris: Scalar equal to max(T_i) for i in [0, N - 1] containing @@ -149,9 +161,15 @@ std::tuple FacePointDistanceForwardCuda( const torch::Tensor& points_first_idx, const torch::Tensor& tris, const torch::Tensor& tris_first_idx, - const int64_t max_tros); + const int64_t max_tris); #endif +std::tuple FacePointDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& tris, + const torch::Tensor& tris_first_idx); + std::tuple FacePointDistanceForward( const torch::Tensor& points, const torch::Tensor& points_first_idx, @@ -170,7 +188,8 @@ std::tuple FacePointDistanceForward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return FacePointDistanceForwardCpu( + points, points_first_idx, tris, tris_first_idx); } // Backward pass for FacePointDistance. @@ -197,6 +216,12 @@ std::tuple FacePointDistanceBackwardCuda( const torch::Tensor& grad_dists); #endif +std::tuple FacePointDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& tris, + const torch::Tensor& idx_tris, + const torch::Tensor& grad_dists); + std::tuple FacePointDistanceBackward( const torch::Tensor& points, const torch::Tensor& tris, @@ -213,7 +238,7 @@ std::tuple FacePointDistanceBackward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return FacePointDistanceBackwardCpu(points, tris, idx_tris, grad_dists); } // **************************************************************************** @@ -226,7 +251,7 @@ std::tuple FacePointDistanceBackward( // Args: // points: FloatTensor of shape (P, 3) // tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th -// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2]) +// triangular face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2]) // // Returns: // dists: FloatTensor of shape (P, T), where dists[p, t] is the squared @@ -245,6 +270,10 @@ torch::Tensor PointFaceArrayDistanceForwardCuda( const torch::Tensor& tris); #endif +torch::Tensor PointFaceArrayDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& tris); + torch::Tensor PointFaceArrayDistanceForward( const torch::Tensor& points, const torch::Tensor& tris) { @@ -257,7 +286,7 @@ torch::Tensor PointFaceArrayDistanceForward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return PointFaceArrayDistanceForwardCpu(points, tris); } // Backward pass for PointFaceArrayDistance. @@ -278,6 +307,10 @@ std::tuple PointFaceArrayDistanceBackwardCuda( const torch::Tensor& tris, const torch::Tensor& grad_dists); #endif +std::tuple PointFaceArrayDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& tris, + const torch::Tensor& grad_dists); std::tuple PointFaceArrayDistanceBackward( const torch::Tensor& points, @@ -293,5 +326,5 @@ std::tuple PointFaceArrayDistanceBackward( AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("No CPU implementation."); + return PointFaceArrayDistanceBackwardCpu(points, tris, grad_dists); } diff --git a/pytorch3d/csrc/utils/geometry_utils.h b/pytorch3d/csrc/utils/geometry_utils.h index ff603a66a..e283c8f42 100644 --- a/pytorch3d/csrc/utils/geometry_utils.h +++ b/pytorch3d/csrc/utils/geometry_utils.h @@ -2,6 +2,7 @@ #include #include +#include #include #include "vec2.h" #include "vec3.h" @@ -281,6 +282,23 @@ T PointLineDistanceForward( return dot(p - p_proj, p - p_proj); } +template +T PointLine3DistanceForward( + const vec3& p, + const vec3& v0, + const vec3& v1) { + const vec3 v1v0 = v1 - v0; + const T l2 = dot(v1v0, v1v0); + if (l2 <= kEpsilon) { + return dot(p - v1, p - v1); + } + + const T t = dot(v1v0, p - v0) / l2; + const T tt = std::min(std::max(t, 0.00f), 1.00f); + const vec3 p_proj = v0 + tt * v1v0; + return dot(p - p_proj, p - p_proj); +} + // Backward pass for point to line distance in 2D. // // Args: @@ -314,6 +332,51 @@ inline std::tuple, vec2, vec2> PointLineDistanceBackward( return std::make_tuple(grad_p, grad_v0, grad_v1); } +template +std::tuple, vec3, vec3> PointLine3DistanceBackward( + const vec3& p, + const vec3& v0, + const vec3& v1, + const T& grad_dist) { + const vec3 v1v0 = v1 - v0; + const vec3 pv0 = p - v0; + const T t_bot = dot(v1v0, v1v0); + const T t_top = dot(v1v0, pv0); + + vec3 grad_p{0.0f, 0.0f, 0.0f}; + vec3 grad_v0{0.0f, 0.0f, 0.0f}; + vec3 grad_v1{0.0f, 0.0f, 0.0f}; + + const T tt = t_top / t_bot; + + if (t_bot < kEpsilon) { + // if t_bot small, then v0 == v1, + // and dist = 0.5 * dot(pv0, pv0) + 0.5 * dot(pv1, pv1) + grad_p = grad_dist * 2.0f * pv0; + grad_v0 = -0.5f * grad_p; + grad_v1 = grad_v0; + } else if (tt < 0.0f) { + grad_p = grad_dist * 2.0f * pv0; + grad_v0 = -1.0f * grad_p; + // no gradients wrt v1 + } else if (tt > 1.0f) { + grad_p = grad_dist * 2.0f * (p - v1); + grad_v1 = -1.0f * grad_p; + // no gradients wrt v0 + } else { + const vec3 p_proj = v0 + tt * v1v0; + const vec3 diff = p - p_proj; + const vec3 grad_base = grad_dist * 2.0f * diff; + grad_p = grad_base - dot(grad_base, v1v0) * v1v0 / t_bot; + const vec3 dtt_v0 = (-1.0f * v1v0 - pv0 + 2.0f * tt * v1v0) / t_bot; + grad_v0 = (-1.0f + tt) * grad_base - dot(grad_base, v1v0) * dtt_v0; + const vec3 dtt_v1 = (pv0 - 2.0f * tt * v1v0) / t_bot; + grad_v1 = -dot(grad_base, v1v0) * dtt_v1 - tt * grad_base; + } + + return std::make_tuple(grad_p, grad_v0, grad_v1); +} + // The forward pass for calculating the shortest distance between a point // and a triangle. // Ref: https://www.randygaul.net/2014/07/23/distance-point-to-line-segment/ @@ -396,3 +459,226 @@ PointTriangleDistanceBackward( return std::make_tuple(grad_p, grad_v0, grad_v1, grad_v2); } + +// Computes the squared distance of a point p relative to a triangle (v0, v1, +// v2). If the point's projection p0 on the plane spanned by (v0, v1, v2) is +// inside the triangle with vertices (v0, v1, v2), then the returned value is +// the squared distance of p to its projection p0. Otherwise, the returned value +// is the smallest squared distance of p from the line segments (v0, v1), (v0, +// v2) and (v1, v2). +// +// Args: +// p: vec3 coordinates of a point +// v0, v1, v2: vec3 coordinates of the triangle vertices +// +// Returns: +// dist: Float of the squared distance +// + +const float vEpsilon = 1e-8; + +template +vec3 BarycentricCoords3Forward( + const vec3& p, + const vec3& v0, + const vec3& v1, + const vec3& v2) { + vec3 p0 = v1 - v0; + vec3 p1 = v2 - v0; + vec3 p2 = p - v0; + + const T d00 = dot(p0, p0); + const T d01 = dot(p0, p1); + const T d11 = dot(p1, p1); + const T d20 = dot(p2, p0); + const T d21 = dot(p2, p1); + + const T denom = d00 * d11 - d01 * d01 + kEpsilon; + const T w1 = (d11 * d20 - d01 * d21) / denom; + const T w2 = (d00 * d21 - d01 * d20) / denom; + const T w0 = 1.0f - w1 - w2; + + return vec3(w0, w1, w2); +} + +// Checks whether the point p is inside the triangle (v0, v1, v2). +// A point is inside the triangle, if all barycentric coordinates +// wrt the triangle are >= 0 & <= 1. +// +// NOTE that this function assumes that p lives on the space spanned +// by (v0, v1, v2). +// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2) +// and throw an error if check fails +// +// Args: +// p: vec3 coordinates of a point +// v0, v1, v2: vec3 coordinates of the triangle vertices +// +// Returns: +// inside: bool indicating wether p is inside triangle +// +template +static bool IsInsideTriangle( + const vec3& p, + const vec3& v0, + const vec3& v1, + const vec3& v2) { + vec3 bary = BarycentricCoords3Forward(p, v0, v1, v2); + bool x_in = 0.0f <= bary.x && bary.x <= 1.0f; + bool y_in = 0.0f <= bary.y && bary.y <= 1.0f; + bool z_in = 0.0f <= bary.z && bary.z <= 1.0f; + bool inside = x_in && y_in && z_in; + return inside; +} + +template +T PointTriangle3DistanceForward( + const vec3& p, + const vec3& v0, + const vec3& v1, + const vec3& v2) { + vec3 normal = cross(v2 - v0, v1 - v0); + const T norm_normal = norm(normal); + normal = normal / (norm_normal + vEpsilon); + + // p0 is the projection of p on the plane spanned by (v0, v1, v2) + // i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal + const T t = dot(v0 - p, normal); + const vec3 p0 = p + t * normal; + + bool is_inside = IsInsideTriangle(p0, v0, v1, v2); + T dist = 0.0f; + + if ((is_inside) && (norm_normal > kEpsilon)) { + // if projection p0 is inside triangle spanned by (v0, v1, v2) + // then distance is equal to norm(p0 - p)^2 + dist = t * t; + } else { + const float e01 = PointLine3DistanceForward(p, v0, v1); + const float e02 = PointLine3DistanceForward(p, v0, v2); + const float e12 = PointLine3DistanceForward(p, v1, v2); + + dist = (e01 > e02) ? e02 : e01; + dist = (dist > e12) ? e12 : dist; + } + + return dist; +} + +template +std::tuple, vec3> +cross_backward(const vec3& a, const vec3& b, const vec3& grad_cross) { + const float grad_ax = -grad_cross.y * b.z + grad_cross.z * b.y; + const float grad_ay = grad_cross.x * b.z - grad_cross.z * b.x; + const float grad_az = -grad_cross.x * b.y + grad_cross.y * b.x; + const vec3 grad_a = vec3(grad_ax, grad_ay, grad_az); + + const float grad_bx = grad_cross.y * a.z - grad_cross.z * a.y; + const float grad_by = -grad_cross.x * a.z + grad_cross.z * a.x; + const float grad_bz = grad_cross.x * a.y - grad_cross.y * a.x; + const vec3 grad_b = vec3(grad_bx, grad_by, grad_bz); + + return std::make_tuple(grad_a, grad_b); +} + +template +vec3 normalize_backward(const vec3& a, const vec3& grad_normz) { + const float a_norm = norm(a) + vEpsilon; + const vec3 out = a / a_norm; + + const float grad_ax = grad_normz.x * (1.0f - out.x * out.x) / a_norm + + grad_normz.y * (-out.x * out.y) / a_norm + + grad_normz.z * (-out.x * out.z) / a_norm; + const float grad_ay = grad_normz.x * (-out.x * out.y) / a_norm + + grad_normz.y * (1.0f - out.y * out.y) / a_norm + + grad_normz.z * (-out.y * out.z) / a_norm; + const float grad_az = grad_normz.x * (-out.x * out.z) / a_norm + + grad_normz.y * (-out.y * out.z) / a_norm + + grad_normz.z * (1.0f - out.z * out.z) / a_norm; + return vec3(grad_ax, grad_ay, grad_az); +} + +// The backward pass for computing the squared distance of a point +// to the triangle (v0, v1, v2). +// +// Args: +// p: xyz coordinates of a point +// v0, v1, v2: xyz coordinates of the triangle vertices +// grad_dist: Float of the gradient wrt dist +// +// Returns: +// tuple of gradients for the point and triangle: +// (float3 grad_p, float3 grad_v0, float3 grad_v1, float3 grad_v2) +// + +template +static std::tuple, vec3, vec3, vec3> +PointTriangle3DistanceBackward( + const vec3& p, + const vec3& v0, + const vec3& v1, + const vec3& v2, + const T& grad_dist) { + const vec3 v2v0 = v2 - v0; + const vec3 v1v0 = v1 - v0; + const vec3 v0p = v0 - p; + vec3 raw_normal = cross(v2v0, v1v0); + const T norm_normal = norm(raw_normal); + vec3 normal = raw_normal / (norm_normal + vEpsilon); + + // p0 is the projection of p on the plane spanned by (v0, v1, v2) + // i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal + const T t = dot(v0 - p, normal); + const vec3 p0 = p + t * normal; + const vec3 diff = t * normal; + + bool is_inside = IsInsideTriangle(p0, v0, v1, v2); + + vec3 grad_p(0.0f, 0.0f, 0.0f); + vec3 grad_v0(0.0f, 0.0f, 0.0f); + vec3 grad_v1(0.0f, 0.0f, 0.0f); + vec3 grad_v2(0.0f, 0.0f, 0.0f); + + if ((is_inside) && (norm_normal > kEpsilon)) { + // derivative of dist wrt p + grad_p = -2.0f * grad_dist * t * normal; + // derivative of dist wrt normal + const vec3 grad_normal = 2.0f * grad_dist * t * (v0p + diff); + // derivative of dist wrt raw_normal + const vec3 grad_raw_normal = normalize_backward(raw_normal, grad_normal); + // derivative of dist wrt v2v0 and v1v0 + const auto grad_cross = cross_backward(v2v0, v1v0, grad_raw_normal); + const vec3 grad_cross_v2v0 = std::get<0>(grad_cross); + const vec3 grad_cross_v1v0 = std::get<1>(grad_cross); + grad_v0 = + grad_dist * 2.0f * t * normal - (grad_cross_v2v0 + grad_cross_v1v0); + grad_v1 = grad_cross_v1v0; + grad_v2 = grad_cross_v2v0; + } else { + const T e01 = PointLine3DistanceForward(p, v0, v1); + const T e02 = PointLine3DistanceForward(p, v0, v2); + const T e12 = PointLine3DistanceForward(p, v1, v2); + + if ((e01 <= e02) && (e01 <= e12)) { + // e01 is smallest + const auto grads = PointLine3DistanceBackward(p, v0, v1, grad_dist); + grad_p = std::get<0>(grads); + grad_v0 = std::get<1>(grads); + grad_v1 = std::get<2>(grads); + } else if ((e02 <= e01) && (e02 <= e12)) { + // e02 is smallest + const auto grads = PointLine3DistanceBackward(p, v0, v2, grad_dist); + grad_p = std::get<0>(grads); + grad_v0 = std::get<1>(grads); + grad_v2 = std::get<2>(grads); + } else if ((e12 <= e01) && (e12 <= e02)) { + // e12 is smallest + const auto grads = PointLine3DistanceBackward(p, v1, v2, grad_dist); + grad_p = std::get<0>(grads); + grad_v1 = std::get<1>(grads); + grad_v2 = std::get<2>(grads); + } + } + + return std::make_tuple(grad_p, grad_v0, grad_v1, grad_v2); +} diff --git a/pytorch3d/csrc/utils/vec3.h b/pytorch3d/csrc/utils/vec3.h index 6ab43d131..7c696f7a9 100644 --- a/pytorch3d/csrc/utils/vec3.h +++ b/pytorch3d/csrc/utils/vec3.h @@ -56,6 +56,11 @@ inline vec3 cross(const vec3& a, const vec3& b) { a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x); } +template +inline T norm(const vec3& a) { + return sqrt(dot(a, a)); +} + template std::ostream& operator<<(std::ostream& os, const vec3& v) { os << "vec3(" << v.x << ", " << v.y << ", " << v.z << ")"; diff --git a/tests/test_point_mesh_distance.py b/tests/test_point_mesh_distance.py index d914dcb87..bef5d9af7 100644 --- a/tests/test_point_mesh_distance.py +++ b/tests/test_point_mesh_distance.py @@ -211,6 +211,9 @@ def test_point_edge_array_distance(self): same = torch.rand((E,), dtype=torch.float32, device=device) > 0.5 edges[same, 1] = edges[same, 0].clone().detach() + points_cpu = points.clone().cpu() + edges_cpu = edges.clone().cpu() + points.requires_grad = True edges.requires_grad = True grad_dists = torch.rand((P, E), dtype=torch.float32, device=device) @@ -224,22 +227,29 @@ def test_point_edge_array_distance(self): # Cuda Forward Implementation dists_cuda = _C.point_edge_array_dist_forward(points, edges) + dists_cpu = _C.point_edge_array_dist_forward(points_cpu, edges_cpu) # Compare self.assertClose(dists_naive.cpu(), dists_cuda.cpu()) + self.assertClose(dists_naive.cpu(), dists_cpu) # CUDA Bacwkard Implementation grad_points_cuda, grad_edges_cuda = _C.point_edge_array_dist_backward( points, edges, grad_dists ) + grad_points_cpu, grad_edges_cpu = _C.point_edge_array_dist_backward( + points_cpu, edges_cpu, grad_dists.cpu() + ) dists_naive.backward(grad_dists) - grad_points_naive = points.grad - grad_edges_naive = edges.grad + grad_points_naive = points.grad.cpu() + grad_edges_naive = edges.grad.cpu() # Compare - self.assertClose(grad_points_naive.cpu(), grad_points_cuda.cpu()) - self.assertClose(grad_edges_naive.cpu(), grad_edges_cuda.cpu()) + self.assertClose(grad_points_naive, grad_points_cuda.cpu()) + self.assertClose(grad_edges_naive, grad_edges_cuda.cpu()) + self.assertClose(grad_points_naive, grad_points_cpu) + self.assertClose(grad_edges_naive, grad_edges_cpu) def test_point_edge_distance(self): """ @@ -270,7 +280,7 @@ def test_point_edge_distance(self): (points_packed.shape[0],), dtype=torch.float32, device=device ) - # Cuda Implementation: forrward + # Cuda Implementation: forward dists_cuda, idx_cuda = _C.point_edge_dist_forward( points_packed, points_first_idx, edges_packed, edges_first_idx, max_p ) @@ -278,6 +288,20 @@ def test_point_edge_distance(self): grad_points_cuda, grad_edges_cuda = _C.point_edge_dist_backward( points_packed, edges_packed, idx_cuda, grad_dists ) + # Cpu Implementation: forward + dists_cpu, idx_cpu = _C.point_edge_dist_forward( + points_packed.cpu(), + points_first_idx.cpu(), + edges_packed.cpu(), + edges_first_idx.cpu(), + max_p, + ) + + # Cpu Implementation: backward + # Note that using idx_cpu doesn't pass - there seems to be a problem with tied results. + grad_points_cpu, grad_edges_cpu = _C.point_edge_dist_backward( + points_packed.cpu(), edges_packed.cpu(), idx_cuda.cpu(), grad_dists.cpu() + ) # Naive Implementation: forward edges_list = packed_to_list(edges_packed, meshes.num_edges_per_mesh().tolist()) @@ -312,15 +336,18 @@ def test_point_edge_distance(self): # Compare self.assertClose(dists_naive.cpu(), dists_cuda.cpu()) + self.assertClose(dists_naive.cpu(), dists_cpu) # Naive Implementation: backward dists_naive.backward(grad_dists) grad_points_naive = torch.cat([cloud.grad for cloud in pcls.points_list()]) - grad_edges_naive = edges_packed.grad + grad_edges_naive = edges_packed.grad.cpu() # Compare self.assertClose(grad_points_naive.cpu(), grad_points_cuda.cpu(), atol=1e-7) - self.assertClose(grad_edges_naive.cpu(), grad_edges_cuda.cpu(), atol=5e-7) + self.assertClose(grad_edges_naive, grad_edges_cuda.cpu(), atol=5e-7) + self.assertClose(grad_points_naive.cpu(), grad_points_cpu, atol=1e-7) + self.assertClose(grad_edges_naive, grad_edges_cpu, atol=5e-7) def test_edge_point_distance(self): """ @@ -361,6 +388,20 @@ def test_edge_point_distance(self): points_packed, edges_packed, idx_cuda, grad_dists ) + # Cpu Implementation: forward + dists_cpu, idx_cpu = _C.edge_point_dist_forward( + points_packed.cpu(), + points_first_idx.cpu(), + edges_packed.cpu(), + edges_first_idx.cpu(), + max_e, + ) + + # Cpu Implementation: backward + grad_points_cpu, grad_edges_cpu = _C.edge_point_dist_backward( + points_packed.cpu(), edges_packed.cpu(), idx_cpu, grad_dists.cpu() + ) + # Naive Implementation: forward edges_list = packed_to_list(edges_packed, meshes.num_edges_per_mesh().tolist()) dists_naive = [] @@ -395,15 +436,18 @@ def test_edge_point_distance(self): # Compare self.assertClose(dists_naive.cpu(), dists_cuda.cpu()) + self.assertClose(dists_naive.cpu(), dists_cpu) # Naive Implementation: backward dists_naive.backward(grad_dists) grad_points_naive = torch.cat([cloud.grad for cloud in pcls.points_list()]) - grad_edges_naive = edges_packed.grad + grad_edges_naive = edges_packed.grad.cpu() # Compare self.assertClose(grad_points_naive.cpu(), grad_points_cuda.cpu(), atol=1e-7) - self.assertClose(grad_edges_naive.cpu(), grad_edges_cuda.cpu(), atol=5e-7) + self.assertClose(grad_edges_naive, grad_edges_cuda.cpu(), atol=5e-7) + self.assertClose(grad_points_naive.cpu(), grad_points_cpu, atol=1e-7) + self.assertClose(grad_edges_naive, grad_edges_cpu, atol=5e-7) def test_point_mesh_edge_distance(self): """ @@ -483,6 +527,8 @@ def test_point_face_array_distance(self): device = get_random_cuda_device() points = torch.rand((P, 3), dtype=torch.float32, device=device) tris = torch.rand((T, 3, 3), dtype=torch.float32, device=device) + points_cpu = points.clone().cpu() + tris_cpu = tris.clone().cpu() points.requires_grad = True tris.requires_grad = True @@ -502,23 +548,30 @@ def test_point_face_array_distance(self): # Naive Backward dists_naive.backward(grad_dists) - grad_points_naive = points.grad - grad_tris_naive = tris.grad + grad_points_naive = points.grad.cpu() + grad_tris_naive = tris.grad.cpu() # Cuda Forward Implementation dists_cuda = _C.point_face_array_dist_forward(points, tris) + dists_cpu = _C.point_face_array_dist_forward(points_cpu, tris_cpu) # Compare self.assertClose(dists_naive.cpu(), dists_cuda.cpu()) + self.assertClose(dists_naive.cpu(), dists_cpu) # CUDA Backward Implementation grad_points_cuda, grad_tris_cuda = _C.point_face_array_dist_backward( points, tris, grad_dists ) + grad_points_cpu, grad_tris_cpu = _C.point_face_array_dist_backward( + points_cpu, tris_cpu, grad_dists.cpu() + ) # Compare - self.assertClose(grad_points_naive.cpu(), grad_points_cuda.cpu()) - self.assertClose(grad_tris_naive.cpu(), grad_tris_cuda.cpu(), atol=5e-6) + self.assertClose(grad_points_naive, grad_points_cuda.cpu()) + self.assertClose(grad_tris_naive, grad_tris_cuda.cpu(), atol=5e-6) + self.assertClose(grad_points_naive, grad_points_cpu) + self.assertClose(grad_tris_naive, grad_tris_cpu, atol=5e-6) def test_point_face_distance(self): """ @@ -559,6 +612,21 @@ def test_point_face_distance(self): points_packed, faces_packed, idx_cuda, grad_dists ) + # Cpu Implementation: forward + dists_cpu, idx_cpu = _C.point_face_dist_forward( + points_packed.cpu(), + points_first_idx.cpu(), + faces_packed.cpu(), + faces_first_idx.cpu(), + max_p, + ) + + # Cpu Implementation: backward + # Note that using idx_cpu doesn't pass - there seems to be a problem with tied results. + grad_points_cpu, grad_faces_cpu = _C.point_face_dist_backward( + points_packed.cpu(), faces_packed.cpu(), idx_cuda.cpu(), grad_dists.cpu() + ) + # Naive Implementation: forward faces_list = packed_to_list(faces_packed, meshes.num_faces_per_mesh().tolist()) dists_naive = [] @@ -593,15 +661,18 @@ def test_point_face_distance(self): # Compare self.assertClose(dists_naive.cpu(), dists_cuda.cpu()) + self.assertClose(dists_naive.cpu(), dists_cpu) # Naive Implementation: backward dists_naive.backward(grad_dists) grad_points_naive = torch.cat([cloud.grad for cloud in pcls.points_list()]) - grad_faces_naive = faces_packed.grad + grad_faces_naive = faces_packed.grad.cpu() # Compare self.assertClose(grad_points_naive.cpu(), grad_points_cuda.cpu(), atol=1e-7) - self.assertClose(grad_faces_naive.cpu(), grad_faces_cuda.cpu(), atol=5e-7) + self.assertClose(grad_faces_naive, grad_faces_cuda.cpu(), atol=5e-7) + self.assertClose(grad_points_naive.cpu(), grad_points_cpu, atol=1e-7) + self.assertClose(grad_faces_naive, grad_faces_cpu, atol=5e-7) def test_face_point_distance(self): """ @@ -642,6 +713,20 @@ def test_face_point_distance(self): points_packed, faces_packed, idx_cuda, grad_dists ) + # Cpu Implementation: forward + dists_cpu, idx_cpu = _C.face_point_dist_forward( + points_packed.cpu(), + points_first_idx.cpu(), + faces_packed.cpu(), + faces_first_idx.cpu(), + max_f, + ) + + # Cpu Implementation: backward + grad_points_cpu, grad_faces_cpu = _C.face_point_dist_backward( + points_packed.cpu(), faces_packed.cpu(), idx_cpu, grad_dists.cpu() + ) + # Naive Implementation: forward faces_list = packed_to_list(faces_packed, meshes.num_faces_per_mesh().tolist()) dists_naive = [] @@ -676,6 +761,7 @@ def test_face_point_distance(self): # Compare self.assertClose(dists_naive.cpu(), dists_cuda.cpu()) + self.assertClose(dists_naive.cpu(), dists_cpu) # Naive Implementation: backward dists_naive.backward(grad_dists) @@ -685,6 +771,8 @@ def test_face_point_distance(self): # Compare self.assertClose(grad_points_naive.cpu(), grad_points_cuda.cpu(), atol=1e-7) self.assertClose(grad_faces_naive.cpu(), grad_faces_cuda.cpu(), atol=5e-7) + self.assertClose(grad_points_naive.cpu(), grad_points_cpu, atol=1e-7) + self.assertClose(grad_faces_naive.cpu(), grad_faces_cpu, atol=5e-7) def test_point_mesh_face_distance(self): """