Skip to content

Commit

Permalink
Some updates
Browse files Browse the repository at this point in the history
  • Loading branch information
erikwijmans committed Jan 30, 2018
1 parent c4ddd6b commit 8bce353
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 26 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ __pycache__
runs
build
checkpoints
*.prof
.lvimrc
.vimtags
11 changes: 4 additions & 7 deletions models/Pointnet2Cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def __init__(self, num_classes, input_channels=9):
npoint=512,
radii=[0.1, 0.2, 0.4],
nsamples=[32, 64, 128],
mlps=[[input_channels, 32, 32,
64], [input_channels, 64, 64, 128],
[input_channels, 64, 96, 128]]
mlps=[[input_channels, 64], [input_channels, 128],
[input_channels, 128]]
)
)

Expand All @@ -100,9 +99,8 @@ def __init__(self, num_classes, input_channels=9):
npoint=128,
radii=[0.2, 0.4, 0.8],
nsamples=[16, 32, 64],
mlps=[[input_channels, 64, 64,
128], [input_channels, 128, 128, 256],
[input_channels, 128, 128, 256]]
mlps=[[input_channels, 128], [input_channels, 256],
[input_channels, 256]]
)
)
self.SA_modules.append(
Expand Down Expand Up @@ -136,7 +134,6 @@ def forward(self, xyz, points=None):
model = Pointnet2MSG(3)
model.cuda()


optimizer = optim.Adam(model.parameters(), lr=1e-2)

model_fn = model_fn_decorator(nn.CrossEntropyLoss())
Expand Down
18 changes: 3 additions & 15 deletions utils/cinclude/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H

#ifdef __cplusplus
extern "C" {
#endif
#include <cmath>

inline int opt_n_threads(int work_size) {
unsigned int n_threads = work_size;
n_threads--;
n_threads |= n_threads >> 1;
n_threads |= n_threads >> 2;
n_threads |= n_threads >> 4;
n_threads |= n_threads >> 8;
n_threads |= n_threads >> 16;
n_threads++;
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);

return max(min(n_threads / 2, 512), 2);
return max(min(1 << pow_2, 512), 32);
}

#ifdef __cplusplus
}
#endif
#endif
9 changes: 5 additions & 4 deletions utils/csrc/sampling_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ __global__ void gather_points_kernel(int b, int n, int c, int m,
for (int i = blockIdx.x; i < b; i += gridDim.x) {
for (int j = blockIdx.y * blockDim.x + threadIdx.x; j < m;
j += blockDim.x * gridDim.y) {
int a = idx[i * m + j];
memcpy(out + (i * m + j) * c, points + (i * n + a) * c,
sizeof(float) * c);
const int jj = idx[i * m + j];
for (int l = 0; l < c; ++l) {
out[(i * m + j) * c + l] = points[(i * n + jj) * c + l];
}
}
}
}
Expand All @@ -25,7 +26,7 @@ void gather_points_kernel_wrapper(int b, int n, int c, int npoints,
float *out, cudaStream_t stream) {

cudaError_t err;
gather_points_kernel<<<dim3(2, 8, 1), opt_n_threads(npoints) / 4, 0,
gather_points_kernel<<<dim3(b, 8, 1), opt_n_threads(npoints), 0,
stream>>>(b, n, c, npoints, points, idx, out);

err = cudaGetLastError();
Expand Down

0 comments on commit 8bce353

Please sign in to comment.