Skip to content

Commit

Permalink
remove torch from cuda
Browse files Browse the repository at this point in the history
Summary: Keep using at:: instead of torch:: so we don't need torch/extension.h and can keep other compilers happy.

Reviewed By: patricklabatut

Differential Revision: D31688436

fbshipit-source-id: 1825503da0104acaf1558d17300c02ef663bf538
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 18, 2021
1 parent 1a7442a commit 3953de4
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions pytorch3d/csrc/points_to_volumes/points_to_volumes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

using torch::PackedTensorAccessor64;
using torch::RestrictPtrTraits;
using at::PackedTensorAccessor64;
using at::RestrictPtrTraits;

// A chunk of work is blocksize-many points.
// There are N clouds in the batch, and P points in each cloud.
Expand Down Expand Up @@ -117,12 +116,12 @@ __global__ void PointsToVolumesForwardKernel(
}

void PointsToVolumesForwardCuda(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& volume_densities,
const torch::Tensor& volume_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
const at::Tensor& points_3d,
const at::Tensor& points_features,
const at::Tensor& volume_densities,
const at::Tensor& volume_features,
const at::Tensor& grid_sizes,
const at::Tensor& mask,
const float point_weight,
const bool align_corners,
const bool splat) {
Expand Down Expand Up @@ -285,17 +284,17 @@ __global__ void PointsToVolumesBackwardKernel(
}

void PointsToVolumesBackwardCuda(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
const at::Tensor& points_3d,
const at::Tensor& points_features,
const at::Tensor& grid_sizes,
const at::Tensor& mask,
const float point_weight,
const bool align_corners,
const bool splat,
const torch::Tensor& grad_volume_densities,
const torch::Tensor& grad_volume_features,
const torch::Tensor& grad_points_3d,
const torch::Tensor& grad_points_features) {
const at::Tensor& grad_volume_densities,
const at::Tensor& grad_volume_features,
const at::Tensor& grad_points_3d,
const at::Tensor& grad_points_features) {
// Check inputs are on the same device
at::TensorArg points_3d_t{points_3d, "points_3d", 1},
points_features_t{points_features, "points_features", 2},
Expand Down

0 comments on commit 3953de4

Please sign in to comment.