Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix compiling error on windows #766

Merged
merged 2 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions mmdet3d/ops/iou3d/src/iou3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All Rights Reserved 2019-2020.
#include <torch/extension.h>
#include <torch/serialize/tensor.h>

#include <cstdint>
#include <vector>

#define CHECK_CUDA(x) \
Expand Down Expand Up @@ -103,7 +104,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep,

int boxes_num = boxes.size(0);
const float *boxes_data = boxes.data_ptr<float>();
long *keep_data = keep.data_ptr<long>();
int64_t *keep_data = keep.data_ptr<int64_t>();

const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);

Expand All @@ -124,8 +125,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep,

cudaFree(mask_data);

unsigned long long remv_cpu[col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
unsigned long long *remv_cpu = new unsigned long long[col_blocks]();

int num_to_keep = 0;

Expand All @@ -141,6 +141,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep,
}
}
}
delete[] remv_cpu;
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");

return num_to_keep;
Expand All @@ -157,7 +158,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,

int boxes_num = boxes.size(0);
const float *boxes_data = boxes.data_ptr<float>();
long *keep_data = keep.data_ptr<long>();
int64_t *keep_data = keep.data_ptr<int64_t>();

const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);

Expand All @@ -178,8 +179,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,

cudaFree(mask_data);

unsigned long long remv_cpu[col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
unsigned long long *remv_cpu = new unsigned long long[col_blocks]();

int num_to_keep = 0;

Expand All @@ -195,6 +195,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,
}
}
}
delete[] remv_cpu;
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");

return num_to_keep;
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/ops/iou3d/src/iou3d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ All Rights Reserved 2019-2020.

//#define DEBUG
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
const float EPS = 1e-8;
__device__ const float EPS = 1e-8;
struct Point {
float x, y;
__device__ Point() {}
Expand Down
11 changes: 6 additions & 5 deletions mmdet3d/ops/paconv/src/assign_score_withk_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stdlib.h>
#include <assert.h>
#include <cmath>
#include <cstdint>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -49,7 +50,7 @@ __global__ void assign_score_withk_forward_kernel(const int B, const int N0, con
const float* points,
const float* centers,
const float* scores,
const long* knn_idx,
const int64_t* knn_idx,
float* output) {

// ----- parallel loop for B, N1, K and O ---------
Expand Down Expand Up @@ -82,7 +83,7 @@ __global__ void assign_score_withk_backward_points_kernel(const int B, const int
const int K, const int O, const int aggregate,
const float* grad_out,
const float* scores,
const long* knn_idx,
const int64_t* knn_idx,
float* grad_points,
float* grad_centers) {

Expand Down Expand Up @@ -116,7 +117,7 @@ __global__ void assign_score_withk_backward_scores_kernel(const int B, const int
const float* grad_out,
const float* points,
const float* centers,
const long* knn_idx,
const int64_t* knn_idx,
float* grad_scores) {

// ----- parallel loop for B, N, K, M ---------
Expand Down Expand Up @@ -156,7 +157,7 @@ void assign_score_withk_forward_wrapper(int B, int N0, int N1, int M, int K, int
const float* points_data = points.data_ptr<float>();
const float* centers_data = centers.data_ptr<float>();
const float* scores_data = scores.data_ptr<float>();
const long* knn_idx_data = knn_idx.data_ptr<long>();
const int64_t* knn_idx_data = knn_idx.data_ptr<int64_t>();
float* output_data = output.data_ptr<float>();

dim3 blocks(DIVUP(B*O*N1*K, THREADS_PER_BLOCK));
Expand Down Expand Up @@ -191,7 +192,7 @@ void assign_score_withk_backward_wrapper(int B, int N0, int N1, int M, int K, in
const float* points_data = points.data_ptr<float>();
const float* centers_data = centers.data_ptr<float>();
const float* scores_data = scores.data_ptr<float>();
const long* knn_idx_data = knn_idx.data_ptr<long>();
const int64_t* knn_idx_data = knn_idx.data_ptr<int64_t>();
float* grad_points_data = grad_points.data_ptr<float>();
float* grad_centers_data = grad_centers.data_ptr<float>();
float* grad_scores_data = grad_scores.data_ptr<float>();
Expand Down
4 changes: 3 additions & 1 deletion mmdet3d/ops/voxel/src/voxelization_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ void dynamic_voxelize_kernel(const torch::TensorAccessor<T, 2> points,
const int NDim) {
const int ndim_minus_1 = NDim - 1;
bool failed = false;
int coor[NDim];
// int coor[NDim];
int* coor = new int[NDim]();
int c;

for (int i = 0; i < num_points; ++i) {
Expand All @@ -37,6 +38,7 @@ void dynamic_voxelize_kernel(const torch::TensorAccessor<T, 2> points,
}
}

delete[] coor;
return;
}

Expand Down