Skip to content

Commit

Permalink
Fixed windows MSVC build compatibility (#9)
Browse files Browse the repository at this point in the history
Summary:
Fixed a few MSVC compiler (visual studio 2019, MSVC 19.16.27034) compatibility issues
1. Replaced long with int64_t. aten::data_ptr\<long\> is not supported in MSVC
2. pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp, inline function is not correctly recognized by MSVC.
3. pytorch3d/csrc/rasterize_meshes/geometry_utils.cuh
const auto kEpsilon = 1e-30;
MSVC does not compile this const into both host and device, change to a MACRO.
4. pytorch3d/csrc/rasterize_meshes/geometry_utils.cuh,
const float area2 = pow(area, 2.0);
2.0 is considered as double by MSVC and raised an error
5. pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp
std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu() return type does not match the declaration in rasterize_points_cpu.h.
Pull Request resolved: #9

Reviewed By: nikhilaravi

Differential Revision: D19986567

Pulled By: yuanluxu

fbshipit-source-id: f4d98525d088c99c513b85193db6f0fc69c7f017
  • Loading branch information
yuanluxu authored and facebook-github-bot committed Feb 21, 2020
1 parent a3baa36 commit 9e21659
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 18 deletions.
40 changes: 39 additions & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

The core library is written in PyTorch. Several components have underlying implementation in CUDA for improved performance. A subset of these components have CPU implementations in C++/Pytorch. It is advised to use PyTorch3d with GPU support in order to use all the features.

- Linux or macOS
- Linux or macOS or Windows
- Python ≥ 3.6
- PyTorch 1.4
- torchvision that matches the PyTorch installation. You can install them together at pytorch.org to make sure of this.
Expand Down Expand Up @@ -72,3 +72,41 @@ To rebuild after installing from a local clone run, `rm -rf build/ **/*.so` then
```
MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install -e .
```

**Install from local clone on Windows:**

If you are using pre-compiled pytorch 1.4 and torchvision 0.5, you should make the following changes to the pytorch source code to successfully compile with Visual Studio 2019 (MSVC 19.16.27034) and CUDA 10.1.

Change python/Lib/site-packages/torch/include/csrc/jit/script/module.h

L466, 476, 493, 506, 536
```
-static constexpr *
+static const *
```
Change python/Lib/site-packages/torch/include/csrc/jit/argument_spec.h

L190
```
-static constexpr size_t DEPTH_LIMIT = 128;
+static const size_t DEPTH_LIMIT = 128;
```

Change python/Lib/site-packages/torch/include/pybind11/cast.h

L1449
```
-explicit operator type&() { return *(this->value); }
+explicit operator type& () { return *((type*)(this->value)); }
```

After patching, you can go to "x64 Native Tools Command Prompt for VS 2019" to compile and install
```
cd pytorch3d
python3 setup.py install
```
After installing, verify whether all unit tests have passed
```
cd tests
python3 -m unittest discover -p *.py
```
8 changes: 4 additions & 4 deletions pytorch3d/csrc/gather_scatter/gather_scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// TODO(T47953967) to make this cuda kernel support all datatypes.
__global__ void gather_scatter_kernel(
const float* __restrict__ input,
const long* __restrict__ edges,
const int64_t* __restrict__ edges,
float* __restrict__ output,
bool directed,
bool backward,
Expand All @@ -21,8 +21,8 @@ __global__ void gather_scatter_kernel(
// Edges are split evenly across the blocks.
for (int e = blockIdx.x; e < E; e += gridDim.x) {
// Get indices of vertices which form the edge.
const long v0 = edges[2 * e + v0_idx];
const long v1 = edges[2 * e + v1_idx];
const int64_t v0 = edges[2 * e + v0_idx];
const int64_t v1 = edges[2 * e + v1_idx];

// Split vertex features evenly across threads.
// This implementation will be quite wasteful when D<128 since there will be
Expand Down Expand Up @@ -57,7 +57,7 @@ at::Tensor gather_scatter_cuda(

gather_scatter_kernel<<<blocks, threads>>>(
input.data_ptr<float>(),
edges.data_ptr<long>(),
edges.data_ptr<int64_t>(),
output.data_ptr<float>(),
directed,
backward,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
template <typename scalar_t>
__device__ void WarpReduce(
volatile scalar_t* min_dists,
volatile long* min_idxs,
volatile int64_t* min_idxs,
const size_t tid) {
// s = 32
if (min_dists[tid] > min_dists[tid + 32]) {
Expand Down Expand Up @@ -57,7 +57,7 @@ template <typename scalar_t>
__global__ void NearestNeighborKernel(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
long* __restrict__ idx,
int64_t* __restrict__ idx,
const size_t N,
const size_t P1,
const size_t P2,
Expand All @@ -74,7 +74,7 @@ __global__ void NearestNeighborKernel(
extern __shared__ char shared_buf[];
scalar_t* x = (scalar_t*)shared_buf; // scalar_t[DD]
scalar_t* min_dists = &x[D_2]; // scalar_t[NUM_THREADS]
long* min_idxs = (long*)&min_dists[blockDim.x]; // long[NUM_THREADS]
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]

const size_t n = blockIdx.y; // index of batch element.
const size_t i = blockIdx.x; // index of point within batch element.
Expand Down Expand Up @@ -147,14 +147,14 @@ template <typename scalar_t>
__global__ void NearestNeighborKernelD3(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
long* __restrict__ idx,
int64_t* __restrict__ idx,
const size_t N,
const size_t P1,
const size_t P2) {
// Single shared memory buffer which is split and cast to different types.
extern __shared__ char shared_buf[];
scalar_t* min_dists = (scalar_t*)shared_buf; // scalar_t[NUM_THREADS]
long* min_idxs = (long*)&min_dists[blockDim.x]; // long[NUM_THREADS]
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]

const size_t D = 3;
const size_t n = blockIdx.y; // index of batch element.
Expand Down Expand Up @@ -230,12 +230,12 @@ at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
// Use the specialized kernel for D=3.
AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_v3_cuda", ([&] {
size_t shared_size = threads * sizeof(size_t) +
threads * sizeof(long);
threads * sizeof(int64_t);
NearestNeighborKernelD3<scalar_t>
<<<blocks, threads, shared_size>>>(
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
idx.data_ptr<long>(),
idx.data_ptr<int64_t>(),
N,
P1,
P2);
Expand All @@ -248,11 +248,11 @@ at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
// need to be rounded to the next even size.
size_t D_2 = D + (D % 2);
size_t shared_size = (D_2 + threads) * sizeof(size_t);
shared_size += threads * sizeof(long);
shared_size += threads * sizeof(int64_t);
NearestNeighborKernel<scalar_t><<<blocks, threads, shared_size>>>(
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
idx.data_ptr<long>(),
idx.data_ptr<int64_t>(),
N,
P1,
P2,
Expand Down
6 changes: 5 additions & 1 deletion pytorch3d/csrc/rasterize_meshes/geometry_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include "float_math.cuh"

// Set epsilon for preventing floating point errors and division by 0.
#ifdef _MSC_VER
#define kEpsilon 1e-30f
#else
const auto kEpsilon = 1e-30;
#endif

// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
Expand Down Expand Up @@ -93,7 +97,7 @@ BarycentricCoordsBackward(
const float2& v2,
const float3& grad_bary_upstream) {
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const float area2 = pow(area, 2.0);
const float area2 = pow(area, 2.0f);
const float e0 = EdgeFunctionForward(p, v1, v2);
const float e1 = EdgeFunctionForward(p, v2, v0);
const float e2 = EdgeFunctionForward(p, v0, v1);
Expand Down
6 changes: 3 additions & 3 deletions pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
inline float PixToNdc(const int i, const int S) {
static float PixToNdc(const int i, const int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
Expand Down Expand Up @@ -74,7 +74,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
return std::make_tuple(point_idxs, zbuf, pix_dists);
}

std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu(
torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
Expand Down Expand Up @@ -140,7 +140,7 @@ std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu(
bin_y_max = bin_y_min + bin_width;
}
}
return std::make_tuple(points_per_bin, bin_points);
return bin_points;
}

torch::Tensor RasterizePointsBackwardCpu(
Expand Down

0 comments on commit 9e21659

Please sign in to comment.