Skip to content

Commit

Permalink
Fix windows build (#1689)
Browse files Browse the repository at this point in the history
Summary:
Change the data type usage in the code to ensure cross-platform compatibility
long -> int64_t

<img width="628" alt="image" src="https://github.com/facebookresearch/pytorch3d/assets/6214316/40041f7f-3c09-4571-b9ff-676c625802e9">

Tested under
Win 11 and Ubuntu 22.04
with
CUDA 12.1.1 torch 2.1.1

Related issues & PR

#9

#1679

Pull Request resolved: #1689

Reviewed By: MichaelRamamonjisoa

Differential Revision: D51521562

Pulled By: bottler

fbshipit-source-id: d8ea81e223c642e0e9fb283f5d7efc9d6ac00d93
  • Loading branch information
eclipse0922 authored and facebook-github-bot committed Dec 5, 2023
1 parent 83bacda commit 7606854
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions pytorch3d/csrc/marching_cubes/marching_cubes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
compactedVoxelArray,
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
voxelOccupied,
const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
voxelOccupiedScan,
uint numVoxels) {
uint id = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -255,7 +255,8 @@ __global__ void GenerateFacesKernel(
at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
compactedVoxelArray,
at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
numVertsScanned,
const uint activeVoxels,
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
Expand Down Expand Up @@ -471,7 +472,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});

// number of active voxels
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();
int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int64_t>();

const int device_id = vol.device().index();
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
Expand All @@ -492,7 +493,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
numVoxels);
AT_CUDA_CHECK(cudaGetLastError());
cudaDeviceSynchronize();
Expand All @@ -502,7 +504,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});

// total number of vertices
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();
int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int64_t>();

// Execute "GenerateFacesKernel" kernel
// This runs only on the occupied voxels.
Expand All @@ -522,7 +524,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
activeVoxels,
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
Expand Down

0 comments on commit 7606854

Please sign in to comment.