From 0faa9261926221de2c1b05e60dde19885c0d95d3 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 17:06:08 -0400 Subject: [PATCH 01/12] Reformat codebase and add `pre-commit` --- .github/workflows/formatter.yml | 34 -- .github/workflows/pre-commit.yaml | 17 + .gitignore | 2 +- .pre-commit-config.yaml | 63 ++++ LICENSE | 30 +- README.md | 151 +++----- examples/example.py | 140 ++++---- examples/performance.py | 40 ++- setup.cfg | 15 + setup.py | 58 +--- torchsparse/__init__.py | 21 +- torchsparse/{src => backend}/common/gpu.cuh | 94 +++-- .../backend/convolution/convolution_cpu.cpp | 183 ++++++++++ .../backend/convolution/convolution_cpu.h | 15 + .../backend/convolution/convolution_cuda.cu | 278 +++++++++++++++ .../backend/convolution/convolution_cuda.h | 16 + .../backend/devoxelize/devoxelize_cpu.cpp | 59 ++++ .../backend/devoxelize/devoxelize_cpu.h | 14 + .../backend/devoxelize/devoxelize_cuda.cu | 98 ++++++ .../backend/devoxelize/devoxelize_cuda.h | 14 + torchsparse/backend/hash/hash_cpu.cpp | 58 ++++ torchsparse/backend/hash/hash_cpu.h | 11 + torchsparse/backend/hash/hash_cuda.cu | 84 +++++ torchsparse/backend/hash/hash_cuda.h | 11 + torchsparse/backend/hashmap/hashmap_cpu.cpp | 28 ++ torchsparse/backend/hashmap/hashmap_cpu.hpp | 27 ++ torchsparse/backend/hashmap/hashmap_cuda.cu | 214 ++++++++++++ torchsparse/backend/hashmap/hashmap_cuda.cuh | 146 ++++++++ torchsparse/backend/others/count_cpu.cpp | 23 ++ torchsparse/backend/others/count_cpu.h | 8 + torchsparse/backend/others/count_cuda.cu | 31 ++ torchsparse/backend/others/count_cuda.h | 8 + .../{src => backend}/others/query_cpu.cpp | 37 +- torchsparse/backend/others/query_cpu.h | 9 + torchsparse/backend/others/query_cuda.cu | 58 ++++ torchsparse/backend/others/query_cuda.h | 9 + torchsparse/backend/pybind_cpu.cpp | 23 ++ torchsparse/backend/pybind_cuda.cpp | 39 +++ torchsparse/backend/voxelize/voxelize_cpu.cpp | 43 +++ torchsparse/backend/voxelize/voxelize_cpu.h | 13 + torchsparse/backend/voxelize/voxelize_cuda.cu | 80 +++++ torchsparse/backend/voxelize/voxelize_cuda.h | 13 + torchsparse/nn/functional/__init__.py | 3 +- torchsparse/nn/functional/activation.py | 34 +- torchsparse/nn/functional/conv.py | 307 +++++++---------- torchsparse/nn/functional/count.py | 26 +- torchsparse/nn/functional/crop.py | 2 +- torchsparse/nn/functional/devox.py | 100 ------ torchsparse/nn/functional/devoxelize.py | 99 ++++++ torchsparse/nn/functional/downsample.py | 99 +++--- torchsparse/nn/functional/hash.py | 64 ++-- torchsparse/nn/functional/pooling.py | 6 +- torchsparse/nn/functional/query.py | 63 ++-- torchsparse/nn/functional/squeeze_nmap.py | 12 - torchsparse/nn/functional/voxelize.py | 60 +++- torchsparse/nn/modules/__init__.py | 1 + torchsparse/nn/modules/activation.py | 45 +-- torchsparse/nn/modules/bev.py | 225 ++++++++++++ torchsparse/nn/modules/conv.py | 323 +++--------------- torchsparse/nn/modules/crop.py | 7 +- torchsparse/nn/modules/norm.py | 59 ++-- torchsparse/nn/modules/pooling.py | 15 +- torchsparse/nn/utils/__init__.py | 2 + torchsparse/nn/utils/apply.py | 16 + torchsparse/nn/utils/kernel.py | 32 ++ torchsparse/operators.py | 16 + torchsparse/point_tensor.py | 41 --- torchsparse/sparse_tensor.py | 54 --- torchsparse/src/convolution/convolution.cu | 283 --------------- .../src/convolution/convolution_cpu.cpp | 195 ----------- .../src/convolution/convolution_cpu_header.h | 24 -- .../src/convolution/convolution_gpu.cu | 318 ----------------- torchsparse/src/convolution/convolution_gpu.h | 43 --- torchsparse/src/hash/hash.cpp | 36 -- torchsparse/src/hash/hash_cpu.cpp | 71 ---- torchsparse/src/hash/hash_cpu_header.h | 16 - torchsparse/src/hash/hash_gpu.cu | 61 ---- torchsparse/src/hash/hash_gpu.h | 16 - torchsparse/src/hashmap/hashmap.cu | 245 ------------- torchsparse/src/hashmap/hashmap.cuh | 178 ---------- torchsparse/src/hashmap/hashmap_cpu.cpp | 41 --- .../src/hashmap/hashmap_cpu_header.hpp | 30 -- torchsparse/src/interpolation/devox_cpu.cpp | 64 ---- .../src/interpolation/devox_cpu_header.h | 16 - .../src/interpolation/devox_deterministic.cpp | 35 -- .../interpolation/devox_deterministic_gpu.cu | 61 ---- torchsparse/src/interpolation/devox_gpu.cu | 108 ------ torchsparse/src/interpolation/devox_gpu.h | 27 -- torchsparse/src/others/count.cpp | 26 -- torchsparse/src/others/count_cpu.cpp | 25 -- torchsparse/src/others/count_cpu_header.h | 10 - torchsparse/src/others/count_gpu.cu | 17 - torchsparse/src/others/count_gpu.h | 12 - torchsparse/src/others/insertion_cpu.cpp | 52 --- torchsparse/src/others/insertion_cpu_header.h | 18 - torchsparse/src/others/insertion_gpu.cu | 85 ----- torchsparse/src/others/insertion_gpu.h | 20 -- torchsparse/src/others/query.cpp | 49 --- torchsparse/src/others/query_cpu_header.h | 13 - torchsparse/src/others/query_gpu.h | 13 - torchsparse/src/torchsparse_bindings.cpp | 25 -- torchsparse/src/torchsparse_bindings_gpu.cpp | 44 --- torchsparse/tensors.py | 99 ++++++ torchsparse/utils/__init__.py | 3 +- torchsparse/utils/collate.py | 59 ++++ torchsparse/utils/helpers.py | 265 -------------- torchsparse/utils/kernel.py | 82 ----- torchsparse/utils/quantize.py | 48 +++ torchsparse/utils/utils.py | 19 ++ torchsparse/version.py | 1 + 110 files changed, 2933 insertions(+), 3946 deletions(-) delete mode 100644 .github/workflows/formatter.yml create mode 100644 .github/workflows/pre-commit.yaml create mode 100644 .pre-commit-config.yaml create mode 100644 setup.cfg rename torchsparse/{src => backend}/common/gpu.cuh (69%) create mode 100644 torchsparse/backend/convolution/convolution_cpu.cpp create mode 100644 torchsparse/backend/convolution/convolution_cpu.h create mode 100644 torchsparse/backend/convolution/convolution_cuda.cu create mode 100644 torchsparse/backend/convolution/convolution_cuda.h create mode 100644 torchsparse/backend/devoxelize/devoxelize_cpu.cpp create mode 100644 torchsparse/backend/devoxelize/devoxelize_cpu.h create mode 100644 torchsparse/backend/devoxelize/devoxelize_cuda.cu create mode 100644 torchsparse/backend/devoxelize/devoxelize_cuda.h create mode 100644 torchsparse/backend/hash/hash_cpu.cpp create mode 100644 torchsparse/backend/hash/hash_cpu.h create mode 100644 torchsparse/backend/hash/hash_cuda.cu create mode 100644 torchsparse/backend/hash/hash_cuda.h create mode 100644 torchsparse/backend/hashmap/hashmap_cpu.cpp create mode 100644 torchsparse/backend/hashmap/hashmap_cpu.hpp create mode 100644 torchsparse/backend/hashmap/hashmap_cuda.cu create mode 100644 torchsparse/backend/hashmap/hashmap_cuda.cuh create mode 100644 torchsparse/backend/others/count_cpu.cpp create mode 100644 torchsparse/backend/others/count_cpu.h create mode 100644 torchsparse/backend/others/count_cuda.cu create mode 100644 torchsparse/backend/others/count_cuda.h rename torchsparse/{src => backend}/others/query_cpu.cpp (50%) create mode 100644 torchsparse/backend/others/query_cpu.h create mode 100644 torchsparse/backend/others/query_cuda.cu create mode 100644 torchsparse/backend/others/query_cuda.h create mode 100644 torchsparse/backend/pybind_cpu.cpp create mode 100644 torchsparse/backend/pybind_cuda.cpp create mode 100644 torchsparse/backend/voxelize/voxelize_cpu.cpp create mode 100644 torchsparse/backend/voxelize/voxelize_cpu.h create mode 100644 torchsparse/backend/voxelize/voxelize_cuda.cu create mode 100644 torchsparse/backend/voxelize/voxelize_cuda.h delete mode 100644 torchsparse/nn/functional/devox.py create mode 100644 torchsparse/nn/functional/devoxelize.py delete mode 100644 torchsparse/nn/functional/squeeze_nmap.py create mode 100644 torchsparse/nn/modules/bev.py create mode 100644 torchsparse/nn/utils/__init__.py create mode 100644 torchsparse/nn/utils/apply.py create mode 100644 torchsparse/nn/utils/kernel.py create mode 100644 torchsparse/operators.py delete mode 100644 torchsparse/point_tensor.py delete mode 100644 torchsparse/sparse_tensor.py delete mode 100644 torchsparse/src/convolution/convolution.cu delete mode 100644 torchsparse/src/convolution/convolution_cpu.cpp delete mode 100644 torchsparse/src/convolution/convolution_cpu_header.h delete mode 100644 torchsparse/src/convolution/convolution_gpu.cu delete mode 100644 torchsparse/src/convolution/convolution_gpu.h delete mode 100644 torchsparse/src/hash/hash.cpp delete mode 100644 torchsparse/src/hash/hash_cpu.cpp delete mode 100644 torchsparse/src/hash/hash_cpu_header.h delete mode 100644 torchsparse/src/hash/hash_gpu.cu delete mode 100644 torchsparse/src/hash/hash_gpu.h delete mode 100644 torchsparse/src/hashmap/hashmap.cu delete mode 100644 torchsparse/src/hashmap/hashmap.cuh delete mode 100644 torchsparse/src/hashmap/hashmap_cpu.cpp delete mode 100644 torchsparse/src/hashmap/hashmap_cpu_header.hpp delete mode 100644 torchsparse/src/interpolation/devox_cpu.cpp delete mode 100644 torchsparse/src/interpolation/devox_cpu_header.h delete mode 100644 torchsparse/src/interpolation/devox_deterministic.cpp delete mode 100644 torchsparse/src/interpolation/devox_deterministic_gpu.cu delete mode 100644 torchsparse/src/interpolation/devox_gpu.cu delete mode 100644 torchsparse/src/interpolation/devox_gpu.h delete mode 100644 torchsparse/src/others/count.cpp delete mode 100644 torchsparse/src/others/count_cpu.cpp delete mode 100644 torchsparse/src/others/count_cpu_header.h delete mode 100644 torchsparse/src/others/count_gpu.cu delete mode 100644 torchsparse/src/others/count_gpu.h delete mode 100644 torchsparse/src/others/insertion_cpu.cpp delete mode 100644 torchsparse/src/others/insertion_cpu_header.h delete mode 100644 torchsparse/src/others/insertion_gpu.cu delete mode 100644 torchsparse/src/others/insertion_gpu.h delete mode 100644 torchsparse/src/others/query.cpp delete mode 100644 torchsparse/src/others/query_cpu_header.h delete mode 100644 torchsparse/src/others/query_gpu.h delete mode 100644 torchsparse/src/torchsparse_bindings.cpp delete mode 100644 torchsparse/src/torchsparse_bindings_gpu.cpp create mode 100644 torchsparse/tensors.py create mode 100644 torchsparse/utils/collate.py delete mode 100644 torchsparse/utils/helpers.py delete mode 100644 torchsparse/utils/kernel.py create mode 100644 torchsparse/utils/quantize.py create mode 100644 torchsparse/utils/utils.py create mode 100644 torchsparse/version.py diff --git a/.github/workflows/formatter.yml b/.github/workflows/formatter.yml deleted file mode 100644 index 225bb62..0000000 --- a/.github/workflows/formatter.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Formatter - -on: - push: - branches: - -master - pull_request: - -jobs: - build: - name: Formatter - runs-on: ubuntu-latest - steps: - - name: Checkout repo - uses: actions/checkout@v2.3.4 - with: - repository: ${{ github.repository }} - token: ${{ github.token }} - ref: ${{ github.event.pull_request.head.ref }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install --upgrade yapf - echo $(yapf --version) - - name: Format with YAPF - run: | - yapf --verbose --recursive --in-place --parallel --style '{SPACES_AROUND_POWER_OPERATOR: True}' . - - name: Push commit - run: | - git config user.name github-actions - git config user.email github-actions@github.com - git add . - git commit -m "Automation: Formatter" --all | exit 0 - git push diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 0000000..14e7d69 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,17 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [master] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends clang-format + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - uses: pre-commit/action@v2.0.3 diff --git a/.gitignore b/.gitignore index 51d4ba0..e01e346 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ .vscode/ build/ -*.pyc \ No newline at end of file +*.pyc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8e45a40 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,63 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: trailing-whitespace + name: (Common) Remove trailing whitespaces + - id: mixed-line-ending + name: (Common) Fix mixed line ending + args: ['--fix=lf'] + - id: end-of-file-fixer + name: (Common) Remove extra EOF newlines + - id: check-merge-conflict + name: (Common) Check for merge conflicts + - id: requirements-txt-fixer + name: (Common) Sort "requirements.txt" + - id: fix-encoding-pragma + name: (Python) Remove encoding pragmas + args: ['--remove'] + - id: double-quote-string-fixer + name: (Python) Fix double-quoted strings + - id: debug-statements + name: (Python) Check for debugger imports + - id: check-json + name: (JSON) Check syntax + - id: check-yaml + name: (YAML) Check syntax + - id: check-toml + name: (TOML) Check syntax + - repo: https://github.com/asottile/pyupgrade + rev: v2.19.4 + hooks: + - id: pyupgrade + name: (Python) Update syntax for newer versions + args: ['--py36-plus'] + - repo: https://github.com/google/yapf + rev: v0.31.0 + hooks: + - id: yapf + name: (Python) Format with yapf + - repo: https://github.com/pycqa/isort + rev: 5.8.0 + hooks: + - id: isort + name: (Python) Sort imports with isort + - repo: https://github.com/pycqa/flake8 + rev: 3.9.2 + hooks: + - id: flake8 + name: (Python) Check with flake8 + additional_dependencies: [flake8-bugbear, flake8-comprehensions, flake8-docstrings, flake8-executable, flake8-quotes] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.902 + hooks: + - id: mypy + name: (Python) Check with mypy + additional_dependencies: [tokenize-rt] + - repo: local + hooks: + - id: clang-format + name: (C/C++/CUDA) Format with clang-format + entry: clang-format -style=google -i + language: system + files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ diff --git a/LICENSE b/LICENSE index 1f23cf0..b6edbf4 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2020 Haotian Tang, Zhijian Liu, Song Han +Copyright (c) 2020-2021 TorchSparse Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -19,31 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ---------------------------- LICENSE FOR MinkowskiEngine -------------------------------- -MIT License - -Copyright (c) 2020 NVIDIA CORPORATION. -Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu) - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural -Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part -of the code. \ No newline at end of file diff --git a/README.md b/README.md index a9c9789..929b49d 100644 --- a/README.md +++ b/README.md @@ -1,141 +1,92 @@ # TorchSparse -## News +TorchSparse is a high-performance neural network library for point cloud processing. -2020/09/20: We released `torchsparse` v1.1, which is significantly faster than our `torchsparse` v1.0 and is also achieves **1.9x** speedup over [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) v0.5 alpha when running MinkUNet18C! - -2020/08/30: We released `torchsparse` v1.0. - -## Overview +## Installation -We release `torchsparse`, a high-performance computing library for efficient 3D sparse convolution. This library aims at accelerating sparse computation in 3D, in particular the Sparse Convolution operation. +TorchSparse depends on the [Google Sparse Hash](https://github.com/sparsehash/sparsehash) library. - +* On Ubuntu, it can be installed by -The major advantage of this library is that we support all computation on the GPU, especially the kernel map construction (which is done on the CPU in latest [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) V0.4.3). + ```bash + sudo apt-get install libsparsehash-dev + ``` -## Installation +* On Mac OS, it can be installed by -You may run the following command to install torchsparse. + ```bash + brew install google-sparsehash + ``` -```bash -pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git -``` +* You can also compile the library locally (if you do not have the sudo permission) and add the library path to the environment variable `CPLUS_INCLUDE_PATH`. -Note that this library depends on Google's [sparse hash map project](https://github.com/sparsehash/sparsehash). In order to install this library, you may run +The latest released TorchSparse (v1.3.0) can then be installed by ```bash -sudo apt-get install libsparsehash-dev +pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.3.0 ``` -on Ubuntu servers. If you are not sudo, please clone Google's codebase, compile it and install locally. Finally, add the path to this library to your `CPLUS_INCLUDE_PATH` environmental variable. - -For GPU server users, we currently support PyTorch 1.6.0 + CUDA 10.2 + CUDNN 7.6.2. For CPU users, we support PyTorch 1.6.0 (CPU version), MKLDNN backend is optional. - -## Usage - -Our [SPVNAS](https://github.com/mit-han-lab/e3d) project (ECCV2020) is built with torchsparse. You may navigate to this project and follow the instructions in that codebase to play around. - -Here, we also provide a walk-through on some important concepts in torchsparse. - -### Sparse Tensor and Point Tensor - -In torchsparse, we have two data structures for point cloud storage, namely `torchsparse.SparseTensor` and `torchsparse.PointTensor`. Both structures has two data fields `C` (coordinates) and `F` (features). In `SparseTensor`, we assume that all coordinates are **integer** and **do not duplicate**. However, in `PointTensor`, all coordinates are **floating-point** and can duplicate. - -### Sparse Quantize and Sparse Collate - -The way to convert a point cloud to `SparseTensor` so that it can be consumed by networks built with Sparse Convolution or Sparse Point-Voxel Convolution is to use the function `torchsparse.utils.sparse_quantize`. An example is given here: - -```python -inds, labels, inverse_map = sparse_quantize(pc, feat, labels, return_index=True, return_invs=True) -``` +If you use TorchSparse in your code, please remember to specify the exact version as your dependencies. -where `pc`, `feat`, `labels` corresponds to point cloud (coordinates, should be integer), feature and ground-truth. The `inds` denotes unique indices in the point cloud coordinates, and `inverse_map` denotes the unique index each point is corresponding to. The `inverse map` is used to restore full point cloud prediction from downsampled prediction. +## Benchmark -To combine a list of `SparseTensor`s to a batch, you may want to use the `torchsparse.utils.sparse_collate_fn` function. +We compare TorchSparse with [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) (where the latency is measured on NVIDIA GTX 1080Ti): -Detailed results are given in [SemanticKITTI dataset preprocessing code](https://github.com/mit-han-lab/e3d/blob/master/spvnas/core/datasets/semantic_kitti.py) in our [SPVNAS](https://github.com/mit-han-lab/e3d) project. +| | MinkowskiEngine v0.4.3 | TorchSparse v1.0.0 | +| :----------------------- | :--------------------: | :----------------: | +| MinkUNet18C (MACs / 10) | 224.7 ms | 124.3 ms | +| MinkUNet18C (MACs / 4) | 244.3 ms | 160.9 ms | +| MinkUNet18C (MACs / 2.5) | 269.6 ms | 214.3 ms | +| MinkUNet18C | 323.5 ms | 294.0 ms | -### Computation API +## Getting Started -The computation interface in torchsparse is straightforward and very similar to original PyTorch. An example here defines a basic convolution block: +### Sparse Tensor -```python -class BasicConvolutionBlock(nn.Module): - def __init__(self, inc, outc, ks=3, stride=1, dilation=1): - super().__init__() - self.net = nn.Sequential( - spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), - spnn.BatchNorm(outc), - spnn.ReLU(True) - ) - - def forward(self, x): - out = self.net(x) - return out -``` +Sparse tensor (`SparseTensor`) is the main data structure for point cloud, which has two data fields: +* Coordinates (`coords`): a 2D integer tensor with a shape of N x 4, where the first three dimensions correspond to quantized x, y, z coordinates, and the last dimension denotes the batch index. +* Features (`feats`): a 2D tensor with a shape of N x C, where C is the number of feature channels. -where `spnn`denotes `torchsparse.nn`, and `spnn.Conv3d` means 3D sparse convolution operation, `spnn.BatchNorm` and `spnn.ReLU` denotes 3D sparse tensor batchnorm and activations, respectively. We also support direct convolution kernel call via `torchsparse.nn.functional`, for example: +Most existing datasets provide raw point cloud data with float coordinates. We can use `sparse_quantize` (provided in `torchsparse.utils.quantize`) to voxelize x, y, z coordinates and remove duplicates: ```python -outputs = torchsparse.nn.functional.conv3d(inputs, kernel, stride=1, dilation=1, transpose=False) +coords -= np.min(coords, axis=0, keepdims=True) +coords, indices = sparse_quantize(coords, voxel_size, return_index=True) +coords = torch.tensor(coords, dtype=torch.int) +feats = torch.tensor(feats[indices], dtype=torch.float) +tensor = SparseTensor(coords=coords, feats=feats) ``` -where we need to define `inputs`(SparseTensor), `kernel` (of shape k^3 x OC x IC when k > 1, or OC x IC when k = 1, where k denotes the kernel size and IC, OC means input / output channels). The `outputs` is still a SparseTensor. +We can then use `sparse_collate_fn` (provided in `torchsparse.utils.collate`) to assemble a batch of `SparseTensor`'s (and add the batch dimension to `coords`). Please refer to [this example](https://github.com/mit-han-lab/torchsparse/blob/dev/pre-commit/examples/example.py) for more details. -Detailed examples are given in [here](https://github.com/mit-han-lab/e3d/blob/master/spvnas/core/modules/dynamic_sparseop.py), where we use the `torchsparse.nn.functional` interfaces to implement weight-shared 3D-NAS modules. +### Sparse Neural Network -### Sparse Hashmap API - -Sparse hash map query is important in 3D sparse computation. It is mainly used to infer a point's memory location (*i.e.* index) given its coordinates. For example, we use this operation in kernel map construction part of 3D sparse convolution, and also sparse voxelization / devoxelization in [Sparse Point-Voxel Convolution](https://arxiv.org/abs/2007.16100). Here, we provide the following example for hash map API: +The neural network interface in TorchSparse is very similar to PyTorch: ```python -source_hash = torchsparse.nn.functional.sphash(torch.floor(source_coords).int()) -target_hash = torchsparse.nn.functional.sphash(torch.floor(target_coords).int()) -idx_query = torchsparse.nn.functional.sphashquery(source_hash, target_hash) +from torch import nn +from torchsparse import nn as spnn + +model = nn.Sequential( + spnn.Conv3d(in_channels, out_channels, kernel_size), + spnn.BatchNorm(out_channels), + spnn.ReLU(True), +) ``` -In this example, `sphash` is the function converting integer coordinates to hashing. The `sphashquery(source_hash, target_hash)` performs the hash table lookup. Here, the hash map has key `target_hash` and value corresponding to point indices in the target point cloud tensor. For each point in the `source_coords`, we find the point index in `target_coords` which has the same coordinate as it. - -### Dummy Training Example - -We here provides an entire training example with dummy input [here](examples/example.py). In this example, we cover - -- How we start from point cloud data and convert it to SparseTensor format; -- How we can implement SparseTensor batching; -- How to train a semantic segmentation SparseConvNet. - -You are also welcomed to check out our [SPVNAS](https://github.com/mit-han-lab/e3d) project to implement training / inference with real data. - -### Mixed Precision (float16) Support - -Mixed precision training is supported via `torch.cuda.amp.autocast` and `torch.cuda.amp.GradScaler`. Enabling mixed precision training can speed up training and reduce GPU memory usage. By wrapping your training code in a `torch.cuda.amp.autocast` block, feature tensors will automatically be converted to float16 if possible. See [here](examples/example.py) for a complete example. - -## Speed Comparison Between torchsparse and MinkowskiEngine - -We benchmark the performance of our torchsparse and latest [MinkowskiEngine V0.4.3](https://github.com/NVIDIA/MinkowskiEngine) here, latency is measured on NVIDIA GTX 1080Ti GPU: - -| Network | Latency (ME V0.4.3) | Latency (torchsparse V1.0.0) | -| :----------------------: | :-----------------: | :--------------------------: | -| MinkUNet18C (MACs / 10) | 224.7 | 124.3 | -| MinkUNet18C (MACs / 4) | 244.3 | 160.9 | -| MinkUNet18C (MACs / 2.5) | 269.6 | 214.3 | -| MinkUNet18C | 323.5 | 294.0 | - ## Citation -If you find this code useful, please consider citing: +If you use TorchSparse in your research, please use the following BibTeX entry: ```bibtex -@inproceedings{ - tang2020searching, - title = {Searching Efficient 3D Architectures with Sparse Point-Voxel Convolution}, - author = {Tang, Haotian* and Liu, Zhijian* and Zhao, Shengyu and Lin, Yujun and Lin, Ji and Wang, Hanrui and Han, Song}, - booktitle = {European Conference on Computer Vision}, +@inproceedings{tang2020searching, + title = {{Searching Efficient 3D Architectures with Sparse Point-Voxel Convolution}}, + author = {Tang, Haotian and Liu, Zhijian and Zhao, Shengyu and Lin, Yujun and Lin, Ji and Wang, Hanrui and Han, Song}, + booktitle = {European Conference on Computer Vision (ECCV)}, year = {2020} } ``` ## Acknowledgements -This library is inspired by [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine), [SECOND](https://github.com/traveller59/second.pytorch) and [SparseConvNet](https://github.com/facebookresearch/SparseConvNet). +TorchSparse is inspired by many existing open-source libraries, including (but not limited to) [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine), [SECOND](https://github.com/traveller59/second.pytorch) and [SparseConvNet](https://github.com/facebookresearch/SparseConvNet). diff --git a/examples/example.py b/examples/example.py index 2bc0ce1..6a4e475 100644 --- a/examples/example.py +++ b/examples/example.py @@ -1,82 +1,94 @@ -import numpy as np -import torch -import torch.nn as nn -import torchsparse -import torchsparse.nn as spnn -from torchsparse import SparseTensor -from torchsparse.utils import sparse_collate_fn, sparse_quantize import argparse +import random +from typing import Any, Dict +import numpy as np +import torch +import torch.utils.data +from torch import nn +from torch.cuda import amp -def generate_random_point_cloud(size=100000, voxel_size=0.2): - pc = np.random.randn(size, 4) - pc[:, :3] = pc[:, :3] * 10 - rounded_pc = np.round(pc[:, :3] / voxel_size).astype(np.int32) - labels = np.random.choice(10, size) - inds, _, inverse_map = sparse_quantize(rounded_pc, - pc, - labels, - return_index=True, - return_invs=True) +from torchsparse import SparseTensor +from torchsparse import nn as spnn +from torchsparse.utils.collate import sparse_collate_fn +from torchsparse.utils.quantize import sparse_quantize - voxel_pc = rounded_pc[inds] - voxel_feat = pc[inds] - voxel_labels = labels[inds] - sparse_tensor = SparseTensor(voxel_feat, voxel_pc) - label_tensor = SparseTensor(voxel_labels, voxel_pc) +class RandomDataset: - feed_dict = {'lidar': sparse_tensor, 'targets': label_tensor} + def __init__(self, input_size: int, voxel_size: float) -> None: + self.input_size = input_size + self.voxel_size = voxel_size - return feed_dict + def __getitem__(self, _: int) -> Dict[str, Any]: + lidar = np.random.uniform(-100, 100, size=(self.input_size, 4)) + labels = np.random.choice(10, size=self.input_size) + coords, feats = lidar[:, :3], lidar + coords -= np.min(coords, axis=0, keepdims=True) + coords, indices = sparse_quantize(coords, + self.voxel_size, + return_index=True) -def generate_batched_random_point_clouds(size=100000, - voxel_size=0.2, - batch_size=2): - batch = [] - for i in range(batch_size): - batch.append(generate_random_point_cloud(size, voxel_size)) - return sparse_collate_fn(batch) + coords = torch.tensor(coords, dtype=torch.int) + feats = torch.tensor(feats[indices], dtype=torch.float) + labels = torch.tensor(labels[indices], dtype=torch.long) + input = SparseTensor(coords=coords, feats=feats) + label = SparseTensor(coords=coords, feats=labels) + return {'input': input, 'label': label} -def dummy_train(device, mixed=False): - model = nn.Sequential( - spnn.Conv3d(4, 32, kernel_size=3, stride=1), spnn.BatchNorm(32), - spnn.ReLU(True), spnn.Conv3d(32, 64, kernel_size=2, stride=2), - spnn.BatchNorm(64), spnn.ReLU(True), - spnn.Conv3d(64, 64, kernel_size=2, stride=2, transpose=True), - spnn.BatchNorm(64), spnn.ReLU(True), - spnn.Conv3d(64, 32, kernel_size=3, stride=1), spnn.BatchNorm(32), - spnn.ReLU(True), spnn.Conv3d(32, 10, kernel_size=1)).to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss().to(device) - scaler = torch.cuda.amp.GradScaler(enabled=mixed) - - print('Starting dummy training...') - for i in range(10): - optimizer.zero_grad() - feed_dict = generate_batched_random_point_clouds() - inputs = feed_dict['lidar'].to(device) - targets = feed_dict['targets'].F.to(device).long() - with torch.cuda.amp.autocast(enabled=mixed): - outputs = model(inputs) - loss = criterion(outputs.F, targets) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - print('[step %d] loss = %f.' % (i, loss.item())) - print('Finished dummy training!') + def __len__(self): + return 100 if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--mixed", action="store_true") + parser.add_argument('--amp_enabled', action='store_true') args = parser.parse_args() - # set seeds for reproducibility - np.random.seed(2021) - torch.manual_seed(2021) + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) - device = 'cuda:0' if torch.cuda.is_available() else 'cpu' - dummy_train(device, args.mixed) \ No newline at end of file + dataset = RandomDataset(input_size=10000, voxel_size=0.2) + dataflow = torch.utils.data.DataLoader( + dataset, + batch_size=2, + collate_fn=sparse_collate_fn, + ) + + model = nn.Sequential( + spnn.Conv3d(4, 32, 3), + spnn.BatchNorm(32), + spnn.ReLU(True), + spnn.Conv3d(32, 64, 2, stride=2), + spnn.BatchNorm(64), + spnn.ReLU(True), + spnn.Conv3d(64, 64, 2, stride=2, transposed=True), + spnn.BatchNorm(64), + spnn.ReLU(True), + spnn.Conv3d(64, 32, 3), + spnn.BatchNorm(32), + spnn.ReLU(True), + spnn.Conv3d(32, 10, 1), + ).cuda() + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + scaler = amp.GradScaler(enabled=args.amp_enabled) + + for k, feed_dict in enumerate(dataflow): + inputs = feed_dict['input'].cuda() + targets = feed_dict['label'].cuda() + + with amp.autocast(enabled=args.amp_enabled): + outputs = model(inputs) + loss = criterion(outputs.F, targets.F) + + print(f'[step {k + 1}] loss = {loss.item()}.') + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() diff --git a/examples/performance.py b/examples/performance.py index 250c569..1b52ccf 100644 --- a/examples/performance.py +++ b/examples/performance.py @@ -3,10 +3,14 @@ import numpy as np import torch import torch.autograd.profiler as profiler +import torch.cuda import torch.nn as nn +import torch.optim + import torchsparse.nn as spnn from torchsparse import SparseTensor -from torchsparse.utils import sparse_collate_fn, sparse_quantize +from torchsparse.utils.collate import sparse_collate_fn +from torchsparse.utils.quantize import sparse_quantize def generate_random_point_cloud(size=100000, voxel_size=0.2): @@ -36,7 +40,7 @@ def generate_batched_random_point_clouds(size=100000, voxel_size=0.2, batch_size=2): batch = [] - for i in range(batch_size): + for _ in range(batch_size): batch.append(generate_random_point_cloud(size, voxel_size)) return sparse_collate_fn(batch) @@ -47,10 +51,10 @@ def dummy_train_3x3(device): spnn.Conv3d(32, 64, kernel_size=3, stride=1), spnn.Conv3d(64, 128, kernel_size=3, stride=1), spnn.Conv3d(128, 256, kernel_size=3, stride=1), - spnn.Conv3d(256, 128, kernel_size=3, stride=1, transpose=True), - spnn.Conv3d(128, 64, kernel_size=3, stride=1, transpose=True), - spnn.Conv3d(64, 32, kernel_size=3, stride=1, transpose=True), - spnn.Conv3d(32, 10, kernel_size=3, stride=1, transpose=True), + spnn.Conv3d(256, 128, kernel_size=3, stride=1, transposed=True), + spnn.Conv3d(128, 64, kernel_size=3, stride=1, transposed=True), + spnn.Conv3d(64, 32, kernel_size=3, stride=1, transposed=True), + spnn.Conv3d(32, 10, kernel_size=3, stride=1, transposed=True), ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss().to(device) @@ -58,8 +62,8 @@ def dummy_train_3x3(device): print('Starting dummy_train_3x3...') time = datetime.now() with profiler.profile(profile_memory=True, use_cuda=True) as prof: - with profiler.record_function("model_inference"): - for i in range(10): + with profiler.record_function('model_inference'): + for _ in range(10): feed_dict = generate_batched_random_point_clouds() inputs = feed_dict['lidar'].to(device) targets = feed_dict['targets'].F.to(device).long() @@ -69,8 +73,8 @@ def dummy_train_3x3(device): loss.backward() optimizer.step() # print('[step %d] loss = %f.'%(i, loss.item())) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - prof.export_chrome_trace("trace_dummy_3x3.json") + print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10)) + prof.export_chrome_trace('trace_dummy_3x3.json') time = datetime.now() - time print('Finished dummy_train_3x3 in ', time) @@ -82,10 +86,10 @@ def dummy_train_3x1(device): spnn.Conv3d(32, 64, kernel_size=(1, 3, 3), stride=1), spnn.Conv3d(64, 128, kernel_size=(3, 1, 3), stride=1), spnn.Conv3d(128, 256, kernel_size=(1, 3, 3), stride=1), - spnn.Conv3d(256, 128, kernel_size=(3, 1, 3), stride=1, transpose=True), - spnn.Conv3d(128, 64, kernel_size=(1, 3, 3), stride=1, transpose=True), - spnn.Conv3d(64, 32, kernel_size=(3, 1, 3), stride=1, transpose=True), - spnn.Conv3d(32, 10, kernel_size=(1, 3, 3), stride=1, transpose=True), + spnn.Conv3d(256, 128, kernel_size=(3, 1, 3), stride=1, transposed=True), + spnn.Conv3d(128, 64, kernel_size=(1, 3, 3), stride=1, transposed=True), + spnn.Conv3d(64, 32, kernel_size=(3, 1, 3), stride=1, transposed=True), + spnn.Conv3d(32, 10, kernel_size=(1, 3, 3), stride=1, transposed=True), ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss().to(device) @@ -93,8 +97,8 @@ def dummy_train_3x1(device): print('Starting dummy_train_3x1 ...') time = datetime.now() with profiler.profile(profile_memory=True, use_cuda=True) as prof: - with profiler.record_function("model_inference"): - for i in range(10): + with profiler.record_function('model_inference'): + for _ in range(10): feed_dict = generate_batched_random_point_clouds() inputs = feed_dict['lidar'].to(device) targets = feed_dict['targets'].F.to(device).long() @@ -104,8 +108,8 @@ def dummy_train_3x1(device): loss.backward() optimizer.step() # print('[step %d] loss = %f.'%(i, loss.item())) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - prof.export_chrome_trace("trace_dummy_3x1.json") + print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10)) + prof.export_chrome_trace('trace_dummy_3x1.json') time = datetime.now() - time print('Finished dummy_train_3x1 in ', time) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..ecaad42 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,15 @@ +[yapf] +based_on_style = google +spaces_around_power_operator = true +split_before_arithmetic_operator = true +split_before_logical_operator = true +split_before_bitwise_operator = true + +[isort] +known_first_party = torchsparse + +[flake8] +select = B, C, E, F, P, T4, W, B9 +ignore = E501, E722, W503 +per-file-ignores = + __init__.py: F401, F403 diff --git a/setup.py b/setup.py index 6177a69..aff6ee4 100644 --- a/setup.py +++ b/setup.py @@ -1,66 +1,40 @@ +import glob import os import torch +import torch.cuda from setuptools import find_packages, setup from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, CUDAExtension) -has_cuda = (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv( - 'FORCE_CUDA', '0') == '1' - from torchsparse import __version__ -# Notice that CUDA files, header files should not share names with CPP files. -# Otherwise, there will be "ninja: warning: multiple rules generate xxx.o", which leads to -# multiple definitions error! +if ((torch.cuda.is_available() and CUDA_HOME is not None) + or (os.getenv('FORCE_CUDA', '0') == '1')): + device = 'cuda' +else: + device = 'cpu' -file_lis = [ - 'torchsparse/src/torchsparse_bindings_gpu.cpp', - 'torchsparse/src/convolution/convolution_cpu.cpp', - 'torchsparse/src/convolution/convolution.cu', - 'torchsparse/src/convolution/convolution_gpu.cu', - 'torchsparse/src/hash/hash_cpu.cpp', - 'torchsparse/src/hash/hash.cpp', - 'torchsparse/src/hash/hash_gpu.cu', - 'torchsparse/src/hashmap/hashmap.cu', - 'torchsparse/src/hashmap/hashmap_cpu.cpp', - 'torchsparse/src/interpolation/devox_gpu.cu', - 'torchsparse/src/interpolation/devox_deterministic.cpp', - 'torchsparse/src/interpolation/devox_deterministic_gpu.cu', - 'torchsparse/src/interpolation/devox_cpu.cpp', - 'torchsparse/src/others/count.cpp', - 'torchsparse/src/others/count_gpu.cu', - 'torchsparse/src/others/count_cpu.cpp', - 'torchsparse/src/others/insertion_gpu.cu', - 'torchsparse/src/others/insertion_cpu.cpp', - 'torchsparse/src/others/query.cpp', - 'torchsparse/src/others/query_cpu.cpp', -] if has_cuda else [ - 'torchsparse/src/torchsparse_bindings.cpp', - 'torchsparse/src/convolution/convolution_cpu.cpp', - 'torchsparse/src/hash/hash_cpu.cpp', - 'torchsparse/src/hashmap/hashmap_cpu.cpp', - 'torchsparse/src/interpolation/devox_cpu.cpp', - 'torchsparse/src/others/insertion_cpu.cpp', - 'torchsparse/src/others/query_cpu.cpp', - 'torchsparse/src/others/count_cpu.cpp' -] +sources = [os.path.join('torchsparse', 'backend', f'pybind_{device}.cpp')] +for fpath in glob.glob(os.path.join('torchsparse', 'backend', '**', '*')): + if fpath.endswith('_cpu.cpp') and device in ['cpu', 'cuda']: + sources.append(fpath) + elif fpath.endswith('_cuda.cu') and device == 'cuda': + sources.append(fpath) +extension_type = CUDAExtension if device == 'cuda' else CppExtension extra_compile_args = { 'cxx': ['-g', '-O3', '-fopenmp', '-lgomp'], 'nvcc': ['-O3'] -} if has_cuda else { - 'cxx': ['-g', '-O3', '-fopenmp', '-lgomp'] } -extension_type = CUDAExtension if has_cuda else CppExtension setup( name='torchsparse', version=__version__, packages=find_packages(), ext_modules=[ - extension_type('torchsparse_backend', - file_lis, + extension_type('torchsparse.backend', + sources, extra_compile_args=extra_compile_args) ], cmdclass={'build_ext': BuildExtension}, diff --git a/torchsparse/__init__.py b/torchsparse/__init__.py index c046f68..e7ab36a 100644 --- a/torchsparse/__init__.py +++ b/torchsparse/__init__.py @@ -1,18 +1,3 @@ -import torch -from .sparse_tensor import * -from .point_tensor import * - -__version__ = '1.3.0' - - -def cat(input_list, dim=1): - assert len(input_list) > 0 - inputs = input_list[0] - features = inputs.F - coords = inputs.C - cur_stride = inputs.s - output_tensor = SparseTensor( - torch.cat([inputs.F for inputs in input_list], 1), coords, cur_stride) - output_tensor.coord_maps = inputs.coord_maps - output_tensor.kernel_maps = inputs.kernel_maps - return output_tensor +from .operators import * +from .tensors import * +from .version import __version__ diff --git a/torchsparse/src/common/gpu.cuh b/torchsparse/backend/common/gpu.cuh similarity index 69% rename from torchsparse/src/common/gpu.cuh rename to torchsparse/backend/common/gpu.cuh index 1ada73f..e36e753 100644 --- a/torchsparse/src/common/gpu.cuh +++ b/torchsparse/backend/common/gpu.cuh @@ -6,79 +6,77 @@ #include #include #include -#include // cuda driver types - +#include #include +#include #include #include #include -#include - // // CUDA macros // // CUDA: various checks for different function calls. -#define CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - { \ - cudaError_t error = condition; \ - if (error != cudaSuccess) { \ - throw std::runtime_error(cudaGetErrorString(error) << " at " \ - << __FILE__ << ":" << __LINE__); \ - } \ +#define CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + { \ + cudaError_t error = condition; \ + if (error != cudaSuccess) { \ + throw std::runtime_error(cudaGetErrorString(error) \ + << " at " << __FILE__ << ":" << __LINE__); \ + } \ } -#define CUBLAS_CHECK(condition) \ - { \ - cublasStatus_t status = condition; \ - if (status != CUBLAS_STATUS_SUCCESS) { \ - throw std::runtime_error(cublasGetErrorString(status) << " at " \ - << __FILE__ << ":" << __LINE__); \ - } \ +#define CUBLAS_CHECK(condition) \ + { \ + cublasStatus_t status = condition; \ + if (status != CUBLAS_STATUS_SUCCESS) { \ + throw std::runtime_error(cublasGetErrorString(status) \ + << " at " << __FILE__ << ":" << __LINE__); \ + } \ } -#define CUSPARSE_CHECK(call) \ - { \ - cusparseStatus_t err; \ - if ((err = (call)) != CUSPARSE_STATUS_SUCCESS) { \ - throw std::runtime_error(cusparseGetErrorString(err) << " at " \ - << __FILE__ << ":" << __LINE__); \ - } \ +#define CUSPARSE_CHECK(call) \ + { \ + cusparseStatus_t err; \ + if ((err = (call)) != CUSPARSE_STATUS_SUCCESS) { \ + throw std::runtime_error(cusparseGetErrorString(err) \ + << " at " << __FILE__ << ":" << __LINE__); \ + } \ } -#define CURAND_CHECK(condition) \ - { \ - curandStatus_t status = condition; \ - if (status != CURAND_STATUS_SUCCESS) { \ - throw std::runtime_error(curandGetErrorString(status) << " at " \ - << __FILE__ << ":" << __LINE__); \ - } \ +#define CURAND_CHECK(condition) \ + { \ + curandStatus_t status = condition; \ + if (status != CURAND_STATUS_SUCCESS) { \ + throw std::runtime_error(curandGetErrorString(status) \ + << " at " << __FILE__ << ":" << __LINE__); \ + } \ } // CUDA: grid stride looping -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) // CUDA: check for error after kernel execution and exit loudly if there is one. -#define CUDA_POST_KERNEL_CHECK \ - { \ - cudaError_t status = cudaPeekAtLastError(); \ - if (status != cudaSuccess) { \ - throw std::runtime_error(cudaGetErrorString(status) << " at " \ - << __FILE__ << ":" << __LINE__); \ - } \ +#define CUDA_POST_KERNEL_CHECK \ + { \ + cudaError_t status = cudaPeekAtLastError(); \ + if (status != cudaSuccess) { \ + throw std::runtime_error(cudaGetErrorString(status) \ + << " at " << __FILE__ << ":" << __LINE__); \ + } \ } #define THRUST_CHECK(condition) \ try { \ condition; \ } catch (thrust::system_error e) { \ - throw std::runtime_error("Thrust error: " << e.what() << " at " \ - << __FILE__ << ":" << __LINE__); \ + throw std::runtime_error("Thrust error: " << e.what() << " at " \ + << __FILE__ << ":" << __LINE__); \ } // CUDA: library error reporting. @@ -93,13 +91,12 @@ constexpr int SHARED_BLOCK_SIZE = 32; constexpr int BLOCK_SIZE = 32; - - inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } -template void print(const thrust::device_vector &v); +template +void print(const thrust::device_vector &v); template void print(const thrust::device_vector &v1, const thrust::device_vector &v2); @@ -119,5 +116,4 @@ __device__ double atomicAdd(double *address, double val) { } #endif -#endif // GPU_H_ - +#endif // GPU_H_ diff --git a/torchsparse/backend/convolution/convolution_cpu.cpp b/torchsparse/backend/convolution/convolution_cpu.cpp new file mode 100644 index 0000000..b0b7925 --- /dev/null +++ b/torchsparse/backend/convolution/convolution_cpu.cpp @@ -0,0 +1,183 @@ +#include "convolution_cpu.h" + +#include + +#include +#include + +void scatter_cpu(const int n_in, const int n_out, const int c, + const float *in_feat, float *out_feat, const int *kmap, + const bool transpose) { + for (int i = 0; i < n_in; i++) { + int out_pos = kmap[2 * i + 1 - transpose]; + if (out_pos < 0) { + continue; + } +#pragma omp parallel for + for (int j = 0; j < c; j++) { + out_feat[out_pos * c + j] += in_feat[i * c + j]; + } + } +} + +void gather_cpu(const int n_k, const int n_in, const int c, + const float *in_feat, float *out_feat, const int *kmap, + const bool transpose) { + for (int i = 0; i < n_k; i++) { + int in_pos = kmap[2 * i + transpose]; + if (in_pos < 0) { + continue; + } +#pragma omp parallel for + for (int j = 0; j < c; j++) { + out_feat[i * c + j] = in_feat[in_pos * c + j]; + } + } +} + +void convolution_forward_cpu(at::Tensor in_feat, at::Tensor out_feat, + at::Tensor kernel, at::Tensor neighbor_map, + at::Tensor neighbor_offset, const bool transpose) { + if (in_feat.size(1) != kernel.size(1)) { + throw std::invalid_argument("Input feature size and kernel size mismatch"); + } + + int out_nrows = out_feat.size(0); + out_feat.resize_({out_nrows, kernel.size(2)}); + out_feat.zero_(); + + int kernel_volume = kernel.size(0); + int in_buffer_size = 1; + bool flag = false; + // memory optimization + if (kernel_volume % 2 && out_nrows == in_feat.size(0)) { + flag = true; + in_buffer_size = + *std::max_element(neighbor_offset.data_ptr(), + neighbor_offset.data_ptr() + kernel_volume / 2); + in_buffer_size = + std::max(in_buffer_size, + *std::max_element( + neighbor_offset.data_ptr() + kernel_volume / 2 + 1, + neighbor_offset.data_ptr() + kernel_volume)); + in_buffer_size = std::max(in_buffer_size, 1); + + torch::mm_out(out_feat, in_feat, kernel[kernel_volume / 2]); + } else { + in_buffer_size = + *std::max_element(neighbor_offset.data_ptr(), + neighbor_offset.data_ptr() + kernel_volume); + } + + auto options = + torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device()); + auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options); + auto out_buffer = torch::zeros({in_buffer_size, kernel.size(2)}, options); + int cur_offset = 0; + for (int i = 0; i < kernel_volume; i++) { + if (flag && (i == kernel_volume / 2)) { + cur_offset += 2 * neighbor_offset.data_ptr()[i]; + continue; + } + + if (neighbor_offset.data_ptr()[i] == 0) { + continue; + } + + auto out_buffer_activated = torch::from_blob( + out_buffer.data_ptr(), + {neighbor_offset.data_ptr()[i], kernel.size(2)}, options); + auto in_buffer_activated = torch::from_blob( + in_buffer.data_ptr(), + {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options); + + // gather + gather_cpu(in_buffer_activated.size(0), in_feat.size(0), kernel.size(1), + in_feat.data_ptr(), in_buffer_activated.data_ptr(), + neighbor_map.data_ptr() + cur_offset, transpose); + + // matmul + torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]); + + // scatter + scatter_cpu(neighbor_offset.data_ptr()[i], out_nrows, kernel.size(2), + out_buffer_activated.data_ptr(), + out_feat.data_ptr(), + neighbor_map.data_ptr() + cur_offset, transpose); + cur_offset += 2 * neighbor_offset.data_ptr()[i]; + } +} + +void convolution_backward_cpu(at::Tensor in_feat, at::Tensor grad_in_feat, + at::Tensor grad_out_feat, at::Tensor kernel, + at::Tensor grad_kernel, at::Tensor neighbor_map, + at::Tensor neighbor_offset, + const bool transpose) { + grad_in_feat.resize_as_(in_feat); + grad_in_feat.zero_(); + grad_kernel.resize_as_(kernel); + grad_kernel.zero_(); + + int kernel_volume = kernel.size(0); + bool flag = false; + int in_buffer_size; + in_buffer_size = + *std::max_element(neighbor_offset.data_ptr(), + neighbor_offset.data_ptr() + kernel_volume); + + auto options = + torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device()); + auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options); + auto in_grad_buffer = + torch::zeros({in_buffer_size, in_feat.size(1)}, options); + auto out_grad_buffer = + torch::zeros({in_buffer_size, kernel.size(2)}, options); + + int cur_offset = 0; + for (int i = 0; i < kernel_volume; i++) { + auto kernel_grad_buffer = grad_kernel[i]; + if (flag && (i == kernel_volume / 2)) { + cur_offset += 2 * neighbor_offset.data_ptr()[i]; + continue; + } + + if (neighbor_offset.data_ptr()[i] == 0) { + continue; + } + + auto out_grad_buffer_activated = torch::from_blob( + out_grad_buffer.data_ptr(), + {neighbor_offset.data_ptr()[i], kernel.size(2)}, options); + auto in_grad_buffer_activated = torch::from_blob( + in_grad_buffer.data_ptr(), + {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options); + auto in_buffer_activated = torch::from_blob( + in_buffer.data_ptr(), + {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options); + + // gather + gather_cpu(out_grad_buffer_activated.size(0), grad_out_feat.size(0), + kernel.size(2), grad_out_feat.data_ptr(), + out_grad_buffer_activated.data_ptr(), + neighbor_map.data_ptr() + cur_offset, !transpose); + + gather_cpu(in_buffer_activated.size(0), in_feat.size(0), kernel.size(1), + in_feat.data_ptr(), in_buffer_activated.data_ptr(), + neighbor_map.data_ptr() + cur_offset, transpose); + + // matmul + torch::mm_out(in_grad_buffer_activated, out_grad_buffer_activated, + torch::transpose(kernel[i], 0, 1)); + torch::mm_out(kernel_grad_buffer, + torch::transpose(in_buffer_activated, 0, 1), + out_grad_buffer_activated); + + // scatter + scatter_cpu(neighbor_offset.data_ptr()[i], in_feat.size(0), + kernel.size(1), in_grad_buffer_activated.data_ptr(), + grad_in_feat.data_ptr(), + neighbor_map.data_ptr() + cur_offset, !transpose); + + cur_offset += 2 * neighbor_offset.data_ptr()[i]; + } +} diff --git a/torchsparse/backend/convolution/convolution_cpu.h b/torchsparse/backend/convolution/convolution_cpu.h new file mode 100644 index 0000000..4e37bcc --- /dev/null +++ b/torchsparse/backend/convolution/convolution_cpu.h @@ -0,0 +1,15 @@ +#ifndef TORCHSPARSE_CONVOLUTION_CPU +#define TORCHSPARSE_CONVOLUTION_CPU + +#include + +void convolution_forward_cpu(at::Tensor in_feat, at::Tensor out_feat, + at::Tensor kernel, at::Tensor neighbor_map, + at::Tensor neighbor_offset, const bool transpose); + +void convolution_backward_cpu(at::Tensor in_feat, at::Tensor grad_in_feat, + at::Tensor grad_out_feat, at::Tensor kernel, + at::Tensor grad_kernel, at::Tensor neighbor_map, + at::Tensor neighbor_offset, const bool transpose); + +#endif diff --git a/torchsparse/backend/convolution/convolution_cuda.cu b/torchsparse/backend/convolution/convolution_cuda.cu new file mode 100644 index 0000000..026f0e7 --- /dev/null +++ b/torchsparse/backend/convolution/convolution_cuda.cu @@ -0,0 +1,278 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "convolution_cuda.h" + +template +__global__ void gather_kernel(const int n_k, const int n_in, const int c, + const scalar_t *in_feat, scalar_t *out_feat, + const int *kmap, const bool transpose) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int i = index / c; + int j = index % c; + if (i >= n_k) return; + int in_pos = kmap[2 * i + transpose]; + if (in_pos < 0) return; + out_feat[i * c + j] = in_feat[in_pos * c + j]; +} + +template +__global__ void scatter_kernel(const int n_in, const int n_out, const int c, + const scalar_t *in_feat, scalar_t *out_feat, + const int *kmap, const bool transpose) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int i = index / c; + int j = index % c; + if (i >= n_in) return; + int out_pos = kmap[2 * i + 1 - transpose]; + if (out_pos < 0) return; + out_feat[out_pos * c + j] += in_feat[i * c + j]; +} + +// in_feat: (N, c) N=# of input points, c = input channels +// out_feat: (M, o) M=# of output points, o = output channels +// for stride=1, M=N. For stride>1, the N input coords +// are requantized to M points with grid size (stride * +// cur_stride) +// kernel: (k^3, c, o) for a 3D convolution of length k +// neighbor_map: (a, 2) the hash table query results from out_coords to +// in_coords +// where neighbor_map[:,0] is the index of the output +// feature and neighbor_map[:,1] is the index of the input +// feature +// neighbor_offset: (k^3) count of active weights based on neighbor_map +// with unused weights having 0 and neighbor_offset[k^3/2] +// holding w[0,0]. +void convolution_forward_cuda(at::Tensor in_feat, at::Tensor out_feat, + at::Tensor kernel, at::Tensor neighbor_map, + at::Tensor neighbor_offset, + const bool transpose) { + if (in_feat.size(1) != kernel.size(1)) { + throw std::invalid_argument("Input feature size and kernel size mismatch"); + } + + bool is_half = in_feat.scalar_type() == at::ScalarType::Half; + + int n_in_feats = in_feat.size(0); + int n_in_channels = in_feat.size(1); + int n_out_feats = out_feat.size(0); + int n_out_channels = out_feat.size(1); + ; + + int kernel_volume = kernel.size(0); + + // memory optimization + bool precompute_mid = false; + int mid_kernel = kernel_volume / 2; + int in_buffer_size = 1; + // we can precompute features for w[0,0] which avoids gather/scatter + if (kernel_volume % 2 == 1 && n_in_feats == n_out_feats) { + precompute_mid = true; + in_buffer_size = + *std::max_element(neighbor_offset.data_ptr(), + neighbor_offset.data_ptr() + mid_kernel); + in_buffer_size = std::max( + in_buffer_size, + *std::max_element(neighbor_offset.data_ptr() + mid_kernel + 1, + neighbor_offset.data_ptr() + kernel_volume)); + in_buffer_size = std::max(in_buffer_size, 1); + + // (N, c) X (c, o) = (N, o) + torch::mm_out(out_feat, in_feat, kernel[mid_kernel]); + } else { + in_buffer_size = + *std::max_element(neighbor_offset.data_ptr(), + neighbor_offset.data_ptr() + kernel_volume); + } + + auto options = + torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device()); + auto in_buffer = torch::zeros({in_buffer_size, n_in_channels}, options); + auto out_buffer = torch::zeros({in_buffer_size, n_out_channels}, options); + int cur_offset = 0; + // gather/gemm/scatter on each weight + for (int i = 0; i < kernel_volume; i++) { + int n_active_feats = neighbor_offset.data_ptr()[i]; + // if there's no active features for this weight, skip it + if (n_active_feats == 0) { + continue; + } + + // if w[0,0] was precomputed above, skip it + if ((i == mid_kernel) && precompute_mid) { + cur_offset += 2 * n_active_feats; + continue; + } + + // in_buffer_activated (i, c) holds the dense input features from gather + // for i = n_active_feats (# of features in the activated kernel from + // neighbor_offset) out_buffer_activated (i, o) holds the dense output + // features to scatter + at::Tensor out_buffer_activated; + at::Tensor in_buffer_activated; + if (is_half) { + out_buffer_activated = + torch::from_blob(out_buffer.data_ptr(), + {n_active_feats, n_out_channels}, options); + in_buffer_activated = + torch::from_blob(in_buffer.data_ptr(), + {n_active_feats, n_in_channels}, options); + } else { + out_buffer_activated = + torch::from_blob(out_buffer.data_ptr(), + {n_active_feats, n_out_channels}, options); + in_buffer_activated = + torch::from_blob(in_buffer.data_ptr(), + {n_active_feats, n_in_channels}, options); + } + + // gather n_active_feats dense features from N sparse input features with c + // feature dimensions + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + in_feat.type(), "convolution_forward_cuda", ([&] { + gather_kernel + <<>>( + n_active_feats, n_in_feats, n_in_channels, + in_feat.data_ptr(), + in_buffer_activated.data_ptr(), + neighbor_map.data_ptr() + cur_offset, transpose); + })); + + // gemm: (i, c) X (c, o) = (i, o) + torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]); + + // scatter n_active_feats dense features into n_out_feats output features of + // dimension n_out_channels + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + in_feat.type(), "convolution_forward_cuda", ([&] { + scatter_kernel + <<>>( + n_active_feats, n_out_feats, n_out_channels, + out_buffer_activated.data_ptr(), + out_feat.data_ptr(), + neighbor_map.data_ptr() + cur_offset, transpose); + })); + + cur_offset += 2 * n_active_feats; + } +} + +void convolution_backward_cuda(at::Tensor in_feat, at::Tensor grad_in_feat, + at::Tensor grad_out_feat, at::Tensor kernel, + at::Tensor grad_kernel, at::Tensor neighbor_map, + at::Tensor neighbor_offset, + const bool transpose) { + grad_in_feat.resize_as_(in_feat); + grad_in_feat.zero_(); + grad_kernel.resize_as_(kernel); + grad_kernel.zero_(); + + bool is_half = in_feat.scalar_type() == at::ScalarType::Half; + int n_in_feats = in_feat.size(0); + int n_in_channels = in_feat.size(1); + int n_out_feats = grad_out_feat.size(0); + int n_out_channels = kernel.size(-1); + + int kernel_volume = kernel.size(0); + bool flag = false; + int in_buffer_size; + in_buffer_size = + *std::max_element(neighbor_offset.data_ptr(), + neighbor_offset.data_ptr() + kernel_volume); + + auto options = + torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device()); + auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options); + auto in_grad_buffer = + torch::zeros({in_buffer_size, in_feat.size(1)}, options); + auto out_grad_buffer = + torch::zeros({in_buffer_size, kernel.size(2)}, options); + + int cur_offset = 0; + for (int i = 0; i < kernel_volume; i++) { + auto kernel_grad_buffer = grad_kernel[i]; + int n_active_feats = neighbor_offset.data_ptr()[i]; + if (flag && (i == kernel_volume / 2)) { + cur_offset += 2 * n_active_feats; + continue; + } + + if (n_active_feats == 0) { + continue; + } + + // Can't figure out a cleaner way to do this + at::Tensor out_grad_buffer_activated; + at::Tensor in_grad_buffer_activated; + at::Tensor in_buffer_activated; + if (is_half) { + out_grad_buffer_activated = + torch::from_blob(out_grad_buffer.data_ptr(), + {n_active_feats, kernel.size(2)}, options); + in_grad_buffer_activated = + torch::from_blob(in_grad_buffer.data_ptr(), + {n_active_feats, in_feat.size(1)}, options); + in_buffer_activated = + torch::from_blob(in_buffer.data_ptr(), + {n_active_feats, in_feat.size(1)}, options); + } else { + out_grad_buffer_activated = + torch::from_blob(out_grad_buffer.data_ptr(), + {n_active_feats, kernel.size(2)}, options); + in_grad_buffer_activated = + torch::from_blob(in_grad_buffer.data_ptr(), + {n_active_feats, in_feat.size(1)}, options); + in_buffer_activated = + torch::from_blob(in_buffer.data_ptr(), + {n_active_feats, in_feat.size(1)}, options); + } + + // gather + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + in_feat.type(), "convolution_forward_cuda", ([&] { + gather_kernel + <<>>( + n_active_feats, n_out_feats, n_out_channels, + grad_out_feat.data_ptr(), + out_grad_buffer_activated.data_ptr(), + neighbor_map.data_ptr() + cur_offset, !transpose); + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + in_feat.type(), "convolution_forward_cuda", ([&] { + gather_kernel + <<>>( + n_active_feats, n_in_feats, n_in_channels, + in_feat.data_ptr(), + in_buffer_activated.data_ptr(), + neighbor_map.data_ptr() + cur_offset, transpose); + })); + + // gemm + torch::mm_out(in_grad_buffer_activated, out_grad_buffer_activated, + torch::transpose(kernel[i], 0, 1)); + torch::mm_out(kernel_grad_buffer, + torch::transpose(in_buffer_activated, 0, 1), + out_grad_buffer_activated); + + // scatter + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + in_feat.type(), "convolution_forward_cuda", ([&] { + scatter_kernel + <<>>( + n_active_feats, n_in_feats, n_in_channels, + in_grad_buffer_activated.data_ptr(), + grad_in_feat.data_ptr(), + neighbor_map.data_ptr() + cur_offset, !transpose); + })); + + cur_offset += 2 * n_active_feats; + } +} diff --git a/torchsparse/backend/convolution/convolution_cuda.h b/torchsparse/backend/convolution/convolution_cuda.h new file mode 100644 index 0000000..52880e1 --- /dev/null +++ b/torchsparse/backend/convolution/convolution_cuda.h @@ -0,0 +1,16 @@ +#ifndef TORCHSPARSE_CONVOLUTION_CUDA +#define TORCHSPARSE_CONVOLUTION_CUDA + +#include + +void convolution_forward_cuda(at::Tensor in_feat, at::Tensor out_feat, + at::Tensor kernel, at::Tensor neighbor_map, + at::Tensor neighbor_offset, const bool transpose); + +void convolution_backward_cuda(at::Tensor in_feat, at::Tensor grad_in_feat, + at::Tensor grad_out_feat, at::Tensor kernel, + at::Tensor grad_kernel, at::Tensor neighbor_map, + at::Tensor neighbor_offset, + const bool transpose); + +#endif diff --git a/torchsparse/backend/devoxelize/devoxelize_cpu.cpp b/torchsparse/backend/devoxelize/devoxelize_cpu.cpp new file mode 100644 index 0000000..dc2afd5 --- /dev/null +++ b/torchsparse/backend/devoxelize/devoxelize_cpu.cpp @@ -0,0 +1,59 @@ +#include "devoxelize_cpu.h" + +#include + +#include + +// make sure indices is int type +// feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c) +at::Tensor devoxelize_forward_cpu(const at::Tensor feat, + const at::Tensor indices, + const at::Tensor weight) { + int c = feat.size(1); + int N = indices.size(0); + + at::Tensor out = torch::zeros( + {N, c}, at::device(feat.device()).dtype(at::ScalarType::Float)); +#pragma omp parallel for + for (int i = 0; i < N; i++) { + int *indices_ = indices.data_ptr() + i * 8; + float *weight_ = weight.data_ptr() + i * 8; + for (int j = 0; j < c; j++) { + float *feat_ = feat.data_ptr() + j; + float cur_feat; + for (int k = 0; k < 8; k++) { + cur_feat = (indices_[k] >= 0) ? feat_[indices_[k] * c] : 0; + *(out.data_ptr() + i * c + j) += weight_[k] * cur_feat; + } + } + } + return out; +} + +// top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad: +// (b,c,s), s=r^3 +at::Tensor devoxelize_backward_cpu(const at::Tensor top_grad, + const at::Tensor indices, + const at::Tensor weight, int n) { + int c = top_grad.size(1); + int N = top_grad.size(0); + at::Tensor bottom_grad = torch::zeros( + {n, c}, at::device(top_grad.device()).dtype(at::ScalarType::Float)); + + for (int i = 0; i < N; i++) { + int *indices_ = indices.data_ptr() + i * 8; + float *weight_ = weight.data_ptr() + i * 8; +#pragma omp parallel for + for (int j = 0; j < c; j++) { + float *top_grad_ = top_grad.data_ptr() + j; + float cur_top_grad; + for (int k = 0; k < 8; k++) { + cur_top_grad = (indices_[k] >= 0) ? top_grad_[indices_[k] * c] : 0; + *(bottom_grad.data_ptr() + indices_[k] * c + j) += + weight_[k] * cur_top_grad; + } + } + } + + return bottom_grad; +} diff --git a/torchsparse/backend/devoxelize/devoxelize_cpu.h b/torchsparse/backend/devoxelize/devoxelize_cpu.h new file mode 100644 index 0000000..38dce96 --- /dev/null +++ b/torchsparse/backend/devoxelize/devoxelize_cpu.h @@ -0,0 +1,14 @@ +#ifndef TORCHSPARSE_DEVOXELIZE_CPU +#define TORCHSPARSE_DEVOXELIZE_CPU + +#include + +at::Tensor devoxelize_forward_cpu(const at::Tensor feat, + const at::Tensor indices, + const at::Tensor weight); + +at::Tensor devoxelize_backward_cpu(const at::Tensor top_grad, + const at::Tensor indices, + const at::Tensor weight, int n); + +#endif diff --git a/torchsparse/backend/devoxelize/devoxelize_cuda.cu b/torchsparse/backend/devoxelize/devoxelize_cuda.cu new file mode 100644 index 0000000..c2c0423 --- /dev/null +++ b/torchsparse/backend/devoxelize/devoxelize_cuda.cu @@ -0,0 +1,98 @@ +#include +#include +#include +#include + +#include + +// input features (n, c), indices (N, 8), weight (N, 8) -> output features (N, +// c) +template +__global__ void devoxelize_forward_kernel(int N, int c, + const int *__restrict__ indices, + const scalar_t *__restrict__ weight, + const scalar_t *__restrict__ feat, + scalar_t *__restrict__ out) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int i = index / c; + int j = index % c; + + if (i < N) { + const int *indices_ = indices + 8 * i; + const scalar_t *weight_ = weight + 8 * i; + const scalar_t *feat_ = feat + j; + + scalar_t cur_feat; + for (int k = 0; k < 8; k++) { + cur_feat = 0; + if (indices_[k] >= 0) cur_feat = feat_[indices_[k] * c]; + + out[i * c + j] += weight_[k] * cur_feat; + } + } +} + +// input weight (N, 8), indices (N, 8), top_grad (N, c) -> bottom grad (n, c) +template +__global__ void devoxelize_backward_kernel( + int N, int n, int c, const int *__restrict__ indices, + const scalar_t *__restrict__ weight, const scalar_t *__restrict__ top_grad, + scalar_t *__restrict__ bottom_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int i = index / c; + int j = index % c; + + if (i < N) { + const int *indices_ = indices + 8 * i; + const scalar_t *weight_ = weight + 8 * i; + + scalar_t cur_top_grad = top_grad[i * c + j]; + +#pragma unroll + for (int k = 0; k < 8; k++) { + if (indices_[k] >= 0) + atomicAdd(&bottom_grad[indices_[k] * c + j], weight_[k] * cur_top_grad); + } + } +} + +// make sure indices is int type +// feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c) +at::Tensor devoxelize_forward_cuda(const at::Tensor feat, + const at::Tensor indices, + const at::Tensor weight) { + int c = feat.size(1); + int N = indices.size(0); + + at::Tensor out = + torch::zeros({N, c}, at::device(feat.device()).dtype(feat.dtype())); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + feat.type(), "devoxelize_forward_cuda", ([&] { + devoxelize_forward_kernel<<>>( + N, c, indices.data_ptr(), weight.data_ptr(), + feat.data_ptr(), out.data_ptr()); + })); + + return out; +} + +// top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad: +// (b,c,s), s=r^3 +at::Tensor devoxelize_backward_cuda(const at::Tensor top_grad, + const at::Tensor indices, + const at::Tensor weight, int n) { + int c = top_grad.size(1); + int N = top_grad.size(0); + at::Tensor bottom_grad = torch::zeros( + {n, c}, at::device(top_grad.device()).dtype(top_grad.dtype())); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.type(), "devoxelize_backward_cuda", ([&] { + devoxelize_backward_kernel<<>>( + N, n, c, indices.data_ptr(), weight.data_ptr(), + top_grad.data_ptr(), bottom_grad.data_ptr()); + })); + + return bottom_grad; +} diff --git a/torchsparse/backend/devoxelize/devoxelize_cuda.h b/torchsparse/backend/devoxelize/devoxelize_cuda.h new file mode 100644 index 0000000..ade14dc --- /dev/null +++ b/torchsparse/backend/devoxelize/devoxelize_cuda.h @@ -0,0 +1,14 @@ +#ifndef TORCHSPARSE_DEVOXELIZE_CUDA +#define TORCHSPARSE_DEVOXELIZE_CUDA + +#include + +at::Tensor devoxelize_forward_cuda(const at::Tensor feat, + const at::Tensor indices, + const at::Tensor weight); + +at::Tensor devoxelize_backward_cuda(const at::Tensor top_grad, + const at::Tensor indices, + const at::Tensor weight, int n); + +#endif diff --git a/torchsparse/backend/hash/hash_cpu.cpp b/torchsparse/backend/hash/hash_cpu.cpp new file mode 100644 index 0000000..a017214 --- /dev/null +++ b/torchsparse/backend/hash/hash_cpu.cpp @@ -0,0 +1,58 @@ +#include "hash_cpu.h" + +#include + +#include + +void cpu_hash_wrapper(int N, const int *data, long *out) { +#pragma omp parallel for + for (int i = 0; i < N; i++) { + unsigned long long hash = 14695981039346656037UL; + for (int j = 0; j < 4; j++) { + hash ^= (unsigned int)data[4 * i + j]; + hash *= 1099511628211UL; + } + hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); + out[i] = hash; + } +} + +void cpu_kernel_hash_wrapper(int N, int K, const int *data, + const int *kernel_offset, long int *out) { + for (int k = 0; k < K; k++) { +#pragma omp parallel for + for (int i = 0; i < N; i++) { + int cur_coord[4]; + for (int j = 0; j < 3; j++) { + cur_coord[j] = data[i * 4 + j] + kernel_offset[k * 3 + j]; + } + cur_coord[3] = data[3]; + unsigned long long hash = 14695981039346656037UL; + for (int j = 0; j < 4; j++) { + hash ^= (unsigned int)cur_coord[j]; + hash *= 1099511628211UL; + } + hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); + out[k * N + i] = hash; + } + } +} + +at::Tensor hash_cpu(const at::Tensor idx) { + int N = idx.size(0); + at::Tensor out = + torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long)); + cpu_hash_wrapper(N, idx.data_ptr(), out.data_ptr()); + return out; +} + +at::Tensor kernel_hash_cpu(const at::Tensor idx, + const at::Tensor kernel_offset) { + int N = idx.size(0); + int K = kernel_offset.size(0); + at::Tensor out = torch::zeros( + {K, N}, at::device(idx.device()).dtype(at::ScalarType::Long)); + cpu_kernel_hash_wrapper(N, K, idx.data_ptr(), + kernel_offset.data_ptr(), out.data_ptr()); + return out; +} diff --git a/torchsparse/backend/hash/hash_cpu.h b/torchsparse/backend/hash/hash_cpu.h new file mode 100644 index 0000000..6367480 --- /dev/null +++ b/torchsparse/backend/hash/hash_cpu.h @@ -0,0 +1,11 @@ +#ifndef _SPARSE_HASH_CPU +#define _SPARSE_HASH_CPU + +#include + +at::Tensor hash_cpu(const at::Tensor idx); + +at::Tensor kernel_hash_cpu(const at::Tensor idx, + const at::Tensor kernel_offset); + +#endif diff --git a/torchsparse/backend/hash/hash_cuda.cu b/torchsparse/backend/hash/hash_cuda.cu new file mode 100644 index 0000000..7193da7 --- /dev/null +++ b/torchsparse/backend/hash/hash_cuda.cu @@ -0,0 +1,84 @@ +#include +#include +#include + +#include +#include + +// hashing +// input N*4 int32 tensor output N*1 int64 tensor +__global__ void hash_kernel(int N, const int *__restrict__ data, + long int *__restrict__ out) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < N) { + data += i * 4; + unsigned long long hash = 14695981039346656037UL; + for (int j = 0; j < 4; j++) { + hash ^= (unsigned int)data[j]; + hash *= 1099511628211UL; + } + hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); + out[i] = hash; + } +} + +// kernel hashing: given data D and offset map K, generate D x K +// input N*4 int32 tensor, |K|*3 int32 tensor, output |K|*N int64 tensor +__global__ void kernel_hash_kernel(int N, int K, const int *__restrict__ data, + const int *__restrict__ kernel_offset, + long int *__restrict__ out) { + extern __shared__ int kernel_offset_local[]; + + for (int i = 0; i < K * 3; i++) { + kernel_offset_local[i] = kernel_offset[i]; + } + __syncthreads(); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int k = idx % K; + int i = idx / K; + int cur_coord[4]; + if (i < N) { + data += i * 4; + for (int j = 0; j < 3; j++) { + cur_coord[j] = data[j] + kernel_offset[k * 3 + j]; + } + cur_coord[3] = data[3]; + unsigned long long hash = 14695981039346656037UL; + for (int j = 0; j < 4; j++) { + hash ^= (unsigned int)cur_coord[j]; + hash *= 1099511628211UL; + } + hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); + out[k * N + i] = hash; + } +} + +void kernel_hash_wrapper(int N, int K, const int *data, + const int *kernel_offset, long int *out) { + kernel_hash_kernel<<>>( + N, K, data, kernel_offset, out); +} + +void hash_wrapper(int N, const int *data, long int *out) { + hash_kernel<<>>(N, data, out); +} + +at::Tensor hash_cuda(const at::Tensor idx) { + int N = idx.size(0); + at::Tensor out = + torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long)); + hash_wrapper(N, idx.data_ptr(), out.data_ptr()); + return out; +} + +at::Tensor kernel_hash_cuda(const at::Tensor idx, + const at::Tensor kernel_offset) { + int N = idx.size(0); + int K = kernel_offset.size(0); + at::Tensor out = torch::zeros( + {K, N}, at::device(idx.device()).dtype(at::ScalarType::Long)); + kernel_hash_wrapper(N, K, idx.data_ptr(), kernel_offset.data_ptr(), + out.data_ptr()); + return out; +} diff --git a/torchsparse/backend/hash/hash_cuda.h b/torchsparse/backend/hash/hash_cuda.h new file mode 100644 index 0000000..83b5807 --- /dev/null +++ b/torchsparse/backend/hash/hash_cuda.h @@ -0,0 +1,11 @@ +#ifndef TORCHSPARSE_HASH_CUDA +#define TORCHSPARSE_HASH_CUDA + +#include + +at::Tensor hash_cuda(const at::Tensor idx); + +at::Tensor kernel_hash_cuda(const at::Tensor idx, + const at::Tensor kernel_offset); + +#endif diff --git a/torchsparse/backend/hashmap/hashmap_cpu.cpp b/torchsparse/backend/hashmap/hashmap_cpu.cpp new file mode 100644 index 0000000..1cdce23 --- /dev/null +++ b/torchsparse/backend/hashmap/hashmap_cpu.cpp @@ -0,0 +1,28 @@ +#include "hashmap_cpu.hpp" + +#include +#include +#include +#include + +void HashTableCPU::lookup_vals(const int64_t* const keys, + int64_t* const results, const int n) { +#pragma omp parallel for + for (int idx = 0; idx < n; idx++) { + int64_t key = keys[idx]; + google::dense_hash_map::iterator iter = hashmap.find(key); + if (iter != hashmap.end()) { + results[idx] = iter->second; + } else { + results[idx] = 0; + } + } +} + +void HashTableCPU::insert_vals(const int64_t* const keys, + const int64_t* const vals, const int n) { + for (int i = 0; i < 10; i++) { + printf("%d, %d, %d, %d\n", i, i < n, n, i < 10); + // hashmap[(int)keys[idx]] = (int)vals[idx]+1; + } +} diff --git a/torchsparse/backend/hashmap/hashmap_cpu.hpp b/torchsparse/backend/hashmap/hashmap_cpu.hpp new file mode 100644 index 0000000..d1edeff --- /dev/null +++ b/torchsparse/backend/hashmap/hashmap_cpu.hpp @@ -0,0 +1,27 @@ +#ifndef _CUCKOO_MULTI_CPU_HPP_ +#define _CUCKOO_MULTI_CPU_HPP_ + +#include +#include +#include +#include +#include +#include + +class HashTableCPU { + private: + google::dense_hash_map hashmap; + + public: + HashTableCPU() {} + + ~HashTableCPU() {} + + void insert_vals(const int64_t* const keys, const int64_t* const vals, + const int n); + + void lookup_vals(const int64_t* const keys, int64_t* const results, + const int n); +}; + +#endif diff --git a/torchsparse/backend/hashmap/hashmap_cuda.cu b/torchsparse/backend/hashmap/hashmap_cuda.cu new file mode 100644 index 0000000..5b6db6a --- /dev/null +++ b/torchsparse/backend/hashmap/hashmap_cuda.cu @@ -0,0 +1,214 @@ +#include +#include +#include + +#include "hashmap_cuda.cuh" + +typedef unsigned long long int VTYPE; + +__global__ void cuckooBucketKernel_Multi(VTYPE *const key_buf, + VTYPE *const val_buf, const int size, + const VTYPE *const keys, + const VTYPE *const vals, const int n, + int *const counters, + const int num_buckets) { + // Get thread index. + int idx = threadIdx.x + blockIdx.x * blockDim.x; + + // Only threads within range are active. + if (idx < n) { + // Do 1st-level hashing to get bucket id, then do atomic add to get index + // inside the bucket. + VTYPE key = keys[idx]; + VTYPE val = vals[idx]; + + int bucket_num = do_1st_hash(key, num_buckets); + int bucket_ofs = atomicAdd(&counters[bucket_num], 1); + + // Directly write the key into the table buffer. + if (bucket_ofs >= BUCKET_SIZE) { + printf("%d/%d ERROR: bucket overflow! (n=%d, bucket_num=%d/%d, key=%d)\n", + bucket_ofs, BUCKET_SIZE, n, bucket_num, num_buckets, key); + } else { + key_buf[bucket_num * BUCKET_SIZE + bucket_ofs] = key; + val_buf[bucket_num * BUCKET_SIZE + bucket_ofs] = val; + } + } +} + +__global__ void cuckooInsertKernel_Multi( + VTYPE *const key, VTYPE *const val, const VTYPE *const key_buf, + const VTYPE *const val_buf, const int size, + const FuncConfig *const hash_func_configs, const int num_funcs, + const int *const counters, const int num_buckets, const int evict_bound, + const int pos_width, int *const rehash_requests) { + // Create local cuckoo table in shared memory. Size passed in as the third + // kernel parameter. + extern __shared__ VTYPE local_key[]; + for (int i = 0; i < num_funcs; ++i) { + local_key[i * BUCKET_SIZE + threadIdx.x] = EMPTY_CELL; + } + + // might be useful + __syncthreads(); + + // Get thread index. + int idx = threadIdx.x + blockIdx.x * blockDim.x; + VTYPE cur_idx = idx; + + // Only threads within local bucket range are active. + if (threadIdx.x < counters[blockIdx.x]) { + // Set initial conditions. + VTYPE cur_key = key_buf[cur_idx]; + int cur_func = 0; + int evict_count = 0; + + // Start the test-kick-and-reinsert loops. + do { + int pos = do_2nd_hash(cur_key, hash_func_configs, cur_func, BUCKET_SIZE); + + VTYPE new_data = make_data(cur_idx + 1, cur_func, pos_width); + + VTYPE old_idx = + atomicExch(&local_key[cur_func * BUCKET_SIZE + pos], new_data); + + if (old_idx != EMPTY_CELL) { + cur_idx = fetch_val(old_idx, pos_width) - 1; + // potential overflow here. It seems that cur_idx < 0 is possible! + cur_key = key_buf[cur_idx]; + cur_func = (fetch_func(old_idx, pos_width) + 1) % num_funcs; + evict_count++; + } else { + break; + } + + } while (evict_count < num_funcs * evict_bound); + + // If exceeds eviction bound, then needs rehashing. + if (evict_count >= num_funcs * evict_bound) { + atomicAdd(rehash_requests, 1); + } + } + + // Every thread write its responsible local slot into the global data table. + __syncthreads(); + for (int i = 0; i < num_funcs; ++i) { + VTYPE cur_idx = local_key[i * BUCKET_SIZE + threadIdx.x]; + if (cur_idx == EMPTY_CELL) { + continue; + } + int cur_func = fetch_func(cur_idx, pos_width); + cur_idx = fetch_val(cur_idx, pos_width) - 1; + key[i * size + idx] = key_buf[cur_idx]; + val[i * size + idx] = val_buf[cur_idx]; + } +} + +__global__ void cuckooLookupKernel_Multi( + const VTYPE *const keys, VTYPE *const results, const int n, + const VTYPE *const all_keys, const VTYPE *const all_vals, const int size, + const FuncConfig *const hash_func_configs, const int num_funcs, + const int num_buckets, const int pos_width) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + + // Only threads within range are active. + if (idx < n) { + VTYPE key = keys[idx]; + int bucket_num = do_1st_hash(key, num_buckets); + for (int i = 0; i < num_funcs; ++i) { + int pos = bucket_num * BUCKET_SIZE + + do_2nd_hash(key, hash_func_configs, i, BUCKET_SIZE); + if (all_keys[i * size + pos] == key) { + results[idx] = all_vals[i * size + pos] + 1; + return; + } + } + + // TODO(Haotian): should be a value that will not be encountered. + results[idx] = EMPTY_CELL; + } +} + +void CuckooHashTableCuda_Multi::lookup_vals(const VTYPE *const keys, + VTYPE *d_key, VTYPE *d_val, + VTYPE *const results, const int n) { + // Launch the lookup kernel. + cuckooLookupKernel_Multi<<>>( + keys, results, n, d_key, d_val, _size, _d_hash_func_configs, _num_funcs, + _num_buckets, _pos_width); +} + +int CuckooHashTableCuda_Multi::insert_vals(const VTYPE *const keys, + const VTYPE *const vals, + VTYPE *d_key_buf, VTYPE *d_val_buf, + VTYPE *d_key, VTYPE *d_val, + const int n) { + // + // Phase 1: Distribute keys into buckets. + // + + // Allocate GPU memory. + + int *d_counters = NULL; + + cudaMalloc((void **)&d_counters, _num_buckets * sizeof(int)); + + cudaMemset(d_counters, 0, _num_buckets * sizeof(int)); + + // Invoke bucket kernel. + cuckooBucketKernel_Multi<<>>( + d_key_buf, d_val_buf, _size, keys, vals, n, d_counters, _num_buckets); + + // + // Phase 2: Local cuckoo hashing. + // + + // Allocate GPU memory. + + cudaDeviceSynchronize(); + int *d_rehash_requests = NULL; + + cudaMalloc((void **)&d_rehash_requests, sizeof(int)); + + // Copy values onto GPU memory. + cudaMemcpy(_d_hash_func_configs, _hash_func_configs, + _num_funcs * sizeof(FuncConfig), cudaMemcpyHostToDevice); + + // Invoke insert kernel. Passes shared memory table size by the third + // argument. Loops until no rehashing needed. + + int rehash_count = 0; + do { + int rehash_requests = 0; + cudaMemset(d_rehash_requests, 0, sizeof(int)); + cuckooInsertKernel_Multi<<>>( + d_key, d_val, d_key_buf, d_val_buf, _size, _d_hash_func_configs, + _num_funcs, d_counters, _num_buckets, _evict_bound, _pos_width, + d_rehash_requests); + cudaMemcpy(&rehash_requests, d_rehash_requests, sizeof(int), + cudaMemcpyDeviceToHost); + + if (rehash_requests == 0) { + break; + } else { + rehash_count++; + gen_hash_funcs(); + cudaMemcpy(_d_hash_func_configs, _hash_func_configs, + _num_funcs * sizeof(FuncConfig), cudaMemcpyHostToDevice); + } + } while (rehash_count < MAX_DEPTH); + + cudaDeviceSynchronize(); + + // Free GPU resources. + + if (d_counters != NULL) { + cudaFree(d_counters); + } + if (d_rehash_requests != NULL) { + cudaFree(d_rehash_requests); + } + + return (rehash_count < MAX_DEPTH) ? rehash_count : ERR_DEPTH; +} diff --git a/torchsparse/backend/hashmap/hashmap_cuda.cuh b/torchsparse/backend/hashmap/hashmap_cuda.cuh new file mode 100644 index 0000000..cdaca02 --- /dev/null +++ b/torchsparse/backend/hashmap/hashmap_cuda.cuh @@ -0,0 +1,146 @@ +#ifndef _CUCKOO_CUDA_MULTI_HPP_ +#define _CUCKOO_CUDA_MULTI_HPP_ + +#include +#include +#include +#include +#include + +#include "cuda_runtime.h" + +/** Reserved value for indicating "empty". */ +#define EMPTY_CELL (0) +/** Max rehashing depth, and error depth. */ +#define MAX_DEPTH (100) +#define ERR_DEPTH (-1) +/** CUDA naive thread block size. */ +#define BLOCK_SIZE (256) +/** CUDA multi-level thread block size = bucket size. */ +#define BUCKET_SIZE (512) + +typedef unsigned long long int VTYPE; + +/** Struct of a hash function config. */ +typedef struct { + int rv; // Randomized XOR value. + int ss; // Randomized shift filter start position. +} FuncConfig; + +/** Hard code hash functions and all inline helper functions for CUDA kernels' + * use. */ +inline __device__ int do_1st_hash(const VTYPE val, const int num_buckets) { + return val % num_buckets; +} + +inline __device__ int do_2nd_hash(const VTYPE val, + const FuncConfig *const hash_func_configs, + const int func_idx, const int size) { + FuncConfig fc = hash_func_configs[func_idx]; + return ((val ^ fc.rv) >> fc.ss) % size; // XOR function as 2nd-level hashing. +} + +// trying to ignore EMPTY_CELL by adding 1 at make_data. +inline __device__ VTYPE fetch_val(const VTYPE data, const int pos_width) { + return data >> pos_width; +} + +inline __device__ int fetch_func(const VTYPE data, const int pos_width) { + return data & ((0x1 << pos_width) - 1); +} + +inline __device__ VTYPE make_data(const VTYPE val, const int func, + const int pos_width) { + return (val << pos_width) ^ func; +} + +class CuckooHashTableCuda_Multi { + private: + const int _size; + const int _evict_bound; + const int _num_funcs; + const int _pos_width; + const int _num_buckets; + + FuncConfig *_d_hash_func_configs; + + /** Cuckoo hash function set. */ + FuncConfig *_hash_func_configs; + + /** Private operations. */ + void gen_hash_funcs() { + // Calculate bit width of value range and table size. + int val_width = 8 * sizeof(VTYPE) - ceil(log2((double)_num_funcs)); + int bucket_width = ceil(log2((double)_num_buckets)); + int size_width = ceil(log2((double)BUCKET_SIZE)); + // Generate randomized configurations. + for (int i = 0; i < _num_funcs; ++i) { // At index 0 is a dummy function. + if (val_width - bucket_width <= size_width) + _hash_func_configs[i] = {rand(), 0}; + else { + _hash_func_configs[i] = { + rand(), rand() % (val_width - bucket_width - size_width + 1) + + bucket_width}; + } + } + }; + + inline VTYPE fetch_val(const VTYPE data) { return data >> _pos_width; } + inline int fetch_func(const VTYPE data) { + return data & ((0x1 << _pos_width) - 1); + } + + public: + CuckooHashTableCuda_Multi(const int size, const int evict_bound, + const int num_funcs) + : _size(size), + _evict_bound(evict_bound), + _num_funcs(num_funcs), + _pos_width(ceil(log2((double)_num_funcs))), + _num_buckets(ceil((double)_size / BUCKET_SIZE)) { + srand(time(NULL)); + _d_hash_func_configs = NULL; + _hash_func_configs = NULL; + _hash_func_configs = new FuncConfig[num_funcs]; + + gen_hash_funcs(); + + cudaMalloc((void **)&_d_hash_func_configs, _num_funcs * sizeof(FuncConfig)); + cudaMemcpy(_d_hash_func_configs, _hash_func_configs, + _num_funcs * sizeof(FuncConfig), cudaMemcpyHostToDevice); + }; + ~CuckooHashTableCuda_Multi() { + if (_hash_func_configs != NULL) delete[] _hash_func_configs; + + if (_d_hash_func_configs != NULL) cudaFree(_d_hash_func_configs); + }; + + int insert_vals(const VTYPE *const keys, const VTYPE *const vals, + VTYPE *d_key_buf, VTYPE *d_val_buf, VTYPE *d_key, + VTYPE *d_val, const int n); + + void lookup_vals(const VTYPE *const keys, VTYPE *const results, VTYPE *d_key, + VTYPE *d_val, const int n); +}; + +__global__ void cuckooBucketKernel_Multi(VTYPE *const key_buf, + VTYPE *const val_buf, const int size, + const VTYPE *const keys, + const VTYPE *const vals, const int n, + int *const counters, + const int num_buckets); + +__global__ void cuckooInsertKernel_Multi( + VTYPE *const key, VTYPE *const val, const VTYPE *const key_buf, + const VTYPE *const val_buf, const int size, + const FuncConfig *const hash_func_configs, const int num_funcs, + const int *const counters, const int num_buckets, const int evict_bound, + const int pos_width, int *const rehash_requests); + +__global__ void cuckooLookupKernel_Multi( + const VTYPE *const keys, VTYPE *const results, const int n, + const VTYPE *const all_keys, const VTYPE *const all_vals, const int size, + const FuncConfig *const hash_func_configs, const int num_funcs, + const int num_buckets, const int pos_width); + +#endif diff --git a/torchsparse/backend/others/count_cpu.cpp b/torchsparse/backend/others/count_cpu.cpp new file mode 100644 index 0000000..ba0611c --- /dev/null +++ b/torchsparse/backend/others/count_cpu.cpp @@ -0,0 +1,23 @@ +#include "count_cpu.h" + +#include + +#include + +at::Tensor count_cpu(const at::Tensor idx, const int s) { + int N = idx.size(0); + at::Tensor out = + torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int)); + int *idx_ = idx.data_ptr(); + int *out_ = out.data_ptr(); +#pragma omp parallel for + for (int i = 0; i < N; i++) { + int cur_idx = idx_[i]; + if (cur_idx < 0) { + continue; + } +#pragma omp atomic + out_[cur_idx]++; + } + return out; +} diff --git a/torchsparse/backend/others/count_cpu.h b/torchsparse/backend/others/count_cpu.h new file mode 100644 index 0000000..f2a0ab3 --- /dev/null +++ b/torchsparse/backend/others/count_cpu.h @@ -0,0 +1,8 @@ +#ifndef _SPARSE_COUNT_CPU +#define _SPARSE_COUNT_CPU + +#include + +at::Tensor count_cpu(const at::Tensor idx, const int s); + +#endif diff --git a/torchsparse/backend/others/count_cuda.cu b/torchsparse/backend/others/count_cuda.cu new file mode 100644 index 0000000..4860422 --- /dev/null +++ b/torchsparse/backend/others/count_cuda.cu @@ -0,0 +1,31 @@ +#include +#include +#include + +#include +#include + +// counting +// input N*3 int32 tensor output N*1 int64 tensor +__global__ void count_kernel(int N, const int *__restrict__ data, + int *__restrict__ out) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < N && data[i] >= 0) { + atomicAdd(&out[data[i]], 1); + } +} + +void count_wrapper(int N, const int *data, int *out) { + count_kernel<<>>(N, data, out); +} + +// make sure indices is int type +// feat: (b,c,n) indices: (b,n) -> out: (b,c,s), out_indices: (b,n) +// (preprocessed indices) +at::Tensor count_cuda(const at::Tensor idx, const int s) { + int N = idx.size(0); + at::Tensor out = + torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int)); + count_wrapper(N, idx.data_ptr(), out.data_ptr()); + return out; +} diff --git a/torchsparse/backend/others/count_cuda.h b/torchsparse/backend/others/count_cuda.h new file mode 100644 index 0000000..2bb64f6 --- /dev/null +++ b/torchsparse/backend/others/count_cuda.h @@ -0,0 +1,8 @@ +#ifndef _SPARSE_COUNT +#define _SPARSE_COUNT + +#include + +at::Tensor count_cuda(const at::Tensor idx, const int s); + +#endif diff --git a/torchsparse/src/others/query_cpu.cpp b/torchsparse/backend/others/query_cpu.cpp similarity index 50% rename from torchsparse/src/others/query_cpu.cpp rename to torchsparse/backend/others/query_cpu.cpp index cfba240..915aa83 100644 --- a/torchsparse/src/others/query_cpu.cpp +++ b/torchsparse/backend/others/query_cpu.cpp @@ -1,41 +1,34 @@ +#include "query_cpu.h" + #include -#include "../hashmap/hashmap_cpu_header.hpp" -#include + #include -#include -#include "query_cpu_header.h" #include +#include +#include + +#include "../hashmap/hashmap_cpu.hpp" -at::Tensor cpu_query_forward( - const at::Tensor hash_query, - const at::Tensor hash_target, - const at::Tensor idx_target) -{ - //return group_point_forward_gpu(points, indices); +at::Tensor hash_query_cpu(const at::Tensor hash_query, + const at::Tensor hash_target, + const at::Tensor idx_target) { int n = hash_target.size(0); int n1 = hash_query.size(0); google::dense_hash_map hashmap; hashmap.set_empty_key(0); - /* - HashTableCPU in_hash_table; - printf("inserting %d %d...\n", n, n1); - in_hash_table.insert_vals(hash_target.data_ptr(), idx_target.data_ptr(), n); - */ - at::Tensor out = torch::zeros({n1}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); - for (int idx = 0; idx < n; idx++) - { + at::Tensor out = torch::zeros( + {n1}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); + for (int idx = 0; idx < n; idx++) { int64_t key = *(hash_target.data_ptr() + idx); int64_t val = *(idx_target.data_ptr() + idx) + 1; hashmap.insert(std::make_pair(key, val)); } #pragma omp parallel for - for (int idx = 0; idx < n1; idx++) - { + for (int idx = 0; idx < n1; idx++) { int64_t key = *(hash_query.data_ptr() + idx); google::dense_hash_map::iterator iter = hashmap.find(key); - if (iter != hashmap.end()) - { + if (iter != hashmap.end()) { *(out.data_ptr() + idx) = iter->second; } } diff --git a/torchsparse/backend/others/query_cpu.h b/torchsparse/backend/others/query_cpu.h new file mode 100644 index 0000000..a63f219 --- /dev/null +++ b/torchsparse/backend/others/query_cpu.h @@ -0,0 +1,9 @@ +#ifndef _SPARSE_QUERY_CPU +#define _SPARSE_QUERY_CPU + +#include + +at::Tensor hash_query_cpu(const at::Tensor hash_query, + const at::Tensor hash_target, + const at::Tensor idx_target); +#endif diff --git a/torchsparse/backend/others/query_cuda.cu b/torchsparse/backend/others/query_cuda.cu new file mode 100644 index 0000000..0209c00 --- /dev/null +++ b/torchsparse/backend/others/query_cuda.cu @@ -0,0 +1,58 @@ +#include + +#include +#include +#include + +#include "../hashmap/hashmap_cuda.cuh" + +at::Tensor hash_query_cuda(const at::Tensor hash_query, + const at::Tensor hash_target, + const at::Tensor idx_target) { + // return group_point_forward_gpu(points, indices); + int n = hash_target.size(0); + int n1 = hash_query.size(0); + const int nextPow2 = pow(2, ceil(log2((double)n))); + // When n is large, the hash values tend to be more evenly distrubuted and + // choosing table_size to be 2 * nextPow2 typically suffices. For smaller n, + // the effect of uneven distribution of hash values is more pronounced and + // hence we choose table_size to be 4 * nextPow2 to reduce the chance of + // bucket overflow. + int table_size = (n < 2048) ? 4 * nextPow2 : 2 * nextPow2; + if (table_size < 512) { + table_size = 512; + } + int num_funcs = 3; + CuckooHashTableCuda_Multi in_hash_table(table_size, 8 * ceil(log2((double)n)), + num_funcs); + at::Tensor key_buf = + torch::zeros({table_size}, + at::device(hash_query.device()).dtype(at::ScalarType::Long)); + at::Tensor val_buf = + torch::zeros({table_size}, + at::device(hash_query.device()).dtype(at::ScalarType::Long)); + at::Tensor key = + torch::zeros({num_funcs * table_size}, + at::device(hash_query.device()).dtype(at::ScalarType::Long)); + at::Tensor val = + torch::zeros({num_funcs * table_size}, + at::device(hash_query.device()).dtype(at::ScalarType::Long)); + + in_hash_table.insert_vals( + (unsigned long long int *)(hash_target.data_ptr()), + (unsigned long long int *)(idx_target.data_ptr()), + (unsigned long long int *)(key_buf.data_ptr()), + (unsigned long long int *)(val_buf.data_ptr()), + (unsigned long long int *)(key.data_ptr()), + (unsigned long long int *)(val.data_ptr()), n); + + at::Tensor out = torch::zeros( + {n1}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); + + in_hash_table.lookup_vals( + (unsigned long long int *)(hash_query.data_ptr()), + (unsigned long long int *)(key.data_ptr()), + (unsigned long long int *)(val.data_ptr()), + (unsigned long long int *)(out.data_ptr()), n1); + return out; +} diff --git a/torchsparse/backend/others/query_cuda.h b/torchsparse/backend/others/query_cuda.h new file mode 100644 index 0000000..175c527 --- /dev/null +++ b/torchsparse/backend/others/query_cuda.h @@ -0,0 +1,9 @@ +#ifndef _SPARSE_QUERY +#define _SPARSE_QUERY + +#include + +at::Tensor hash_query_cuda(const at::Tensor hash_query, + const at::Tensor hash_target, + const at::Tensor idx_target); +#endif diff --git a/torchsparse/backend/pybind_cpu.cpp b/torchsparse/backend/pybind_cpu.cpp new file mode 100644 index 0000000..d7ab41c --- /dev/null +++ b/torchsparse/backend/pybind_cpu.cpp @@ -0,0 +1,23 @@ +#include +#include +#include + +#include "convolution/convolution_cpu.h" +#include "devoxelize/devoxelize_cpu.h" +#include "hash/hash_cpu.h" +#include "others/count_cpu.h" +#include "others/query_cpu.h" +#include "voxelize/voxelize_cpu.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("convolution_forward_cpu", &convolution_forward_cpu); + m.def("convolution_backward_cpu", &convolution_backward_cpu); + m.def("voxelize_forward_cpu", &voxelize_forward_cpu); + m.def("voxelize_backward_cpu", &voxelize_backward_cpu); + m.def("devoxelize_forward_cpu", &devoxelize_forward_cpu); + m.def("devoxelize_backward_cpu", &devoxelize_backward_cpu); + m.def("hash_cpu", &hash_cpu); + m.def("kernel_hash_cpu", &kernel_hash_cpu); + m.def("hash_query_cpu", &hash_query_cpu); + m.def("count_cpu", &count_cpu); +} diff --git a/torchsparse/backend/pybind_cuda.cpp b/torchsparse/backend/pybind_cuda.cpp new file mode 100644 index 0000000..be0e78d --- /dev/null +++ b/torchsparse/backend/pybind_cuda.cpp @@ -0,0 +1,39 @@ +#include +#include +#include + +#include "convolution/convolution_cpu.h" +#include "convolution/convolution_cuda.h" +#include "devoxelize/devoxelize_cpu.h" +#include "devoxelize/devoxelize_cuda.h" +#include "hash/hash_cpu.h" +#include "hash/hash_cuda.h" +#include "others/count_cpu.h" +#include "others/count_cuda.h" +#include "others/query_cpu.h" +#include "others/query_cuda.h" +#include "voxelize/voxelize_cpu.h" +#include "voxelize/voxelize_cuda.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("convolution_forward_cpu", &convolution_forward_cpu); + m.def("convolution_forward_cuda", &convolution_forward_cuda); + m.def("convolution_backward_cpu", &convolution_backward_cpu); + m.def("convolution_backward_cuda", &convolution_backward_cuda); + m.def("voxelize_forward_cpu", &voxelize_forward_cpu); + m.def("voxelize_forward_cuda", &voxelize_forward_cuda); + m.def("voxelize_backward_cpu", &voxelize_backward_cpu); + m.def("voxelize_backward_cuda", &voxelize_backward_cuda); + m.def("devoxelize_forward_cpu", &devoxelize_forward_cpu); + m.def("devoxelize_forward_cuda", &devoxelize_forward_cuda); + m.def("devoxelize_backward_cpu", &devoxelize_backward_cpu); + m.def("devoxelize_backward_cuda", &devoxelize_backward_cuda); + m.def("hash_cpu", &hash_cpu); + m.def("hash_cuda", &hash_cuda); + m.def("kernel_hash_cpu", &kernel_hash_cpu); + m.def("kernel_hash_cuda", &kernel_hash_cuda); + m.def("hash_query_cpu", &hash_query_cpu); + m.def("hash_query_cuda", &hash_query_cuda); + m.def("count_cpu", &count_cpu); + m.def("count_cuda", &count_cuda); +} diff --git a/torchsparse/backend/voxelize/voxelize_cpu.cpp b/torchsparse/backend/voxelize/voxelize_cpu.cpp new file mode 100644 index 0000000..938b7bc --- /dev/null +++ b/torchsparse/backend/voxelize/voxelize_cpu.cpp @@ -0,0 +1,43 @@ +#include "voxelize_cpu.h" + +#include + +#include + +at::Tensor voxelize_forward_cpu(const at::Tensor inputs, const at::Tensor idx, + const at::Tensor counts) { + int N = inputs.size(0); + int c = inputs.size(1); + int N1 = counts.size(0); + at::Tensor out = torch::zeros( + {N1, c}, at::device(idx.device()).dtype(at::ScalarType::Float)); + for (int i = 0; i < N; i++) { + int pos = *(idx.data_ptr() + i); + if (*(counts.data_ptr() + pos) == 0) continue; +#pragma omp parallel for + for (int j = 0; j < c; j++) { + *(out.data_ptr() + pos * c + j) += + *(inputs.data_ptr() + i * c + j) / + (float)(*(counts.data_ptr() + pos)); + } + } + return out; +} + +at::Tensor voxelize_backward_cpu(const at::Tensor top_grad, + const at::Tensor idx, const at::Tensor counts, + const int N) { + int c = top_grad.size(1); + at::Tensor bottom_grad = torch::zeros( + {N, c}, at::device(idx.device()).dtype(at::ScalarType::Float)); + for (int i = 0; i < N; i++) { + if (*(counts.data_ptr() + *(idx.data_ptr() + i)) == 0) continue; +#pragma omp parallel for + for (int j = 0; j < c; j++) { + *(bottom_grad.data_ptr() + i * c + j) = + *(top_grad.data_ptr() + *(idx.data_ptr() + i) * c + j) / + (float)(*(counts.data_ptr() + *(idx.data_ptr() + i))); + } + } + return bottom_grad; +} diff --git a/torchsparse/backend/voxelize/voxelize_cpu.h b/torchsparse/backend/voxelize/voxelize_cpu.h new file mode 100644 index 0000000..bed480e --- /dev/null +++ b/torchsparse/backend/voxelize/voxelize_cpu.h @@ -0,0 +1,13 @@ +#ifndef TORCHSPARSE_VOXELIZE_CPU +#define TORCHSPARSE_VOXELIZE_CPU + +#include + +at::Tensor voxelize_forward_cpu(const at::Tensor inputs, const at::Tensor idx, + const at::Tensor counts); + +at::Tensor voxelize_backward_cpu(const at::Tensor top_grad, + const at::Tensor idx, const at::Tensor counts, + const int N); + +#endif diff --git a/torchsparse/backend/voxelize/voxelize_cuda.cu b/torchsparse/backend/voxelize/voxelize_cuda.cu new file mode 100644 index 0000000..a47f605 --- /dev/null +++ b/torchsparse/backend/voxelize/voxelize_cuda.cu @@ -0,0 +1,80 @@ +#include +#include +#include + +#include +#include + +// hashing +// input N*F float tensor, pointer to output N'*F int64 tensor, N*1 count +// tensor, N*1 index tensor +template +__global__ void voxelize_forward_kernel(int N, int c, int s, + const scalar_t *__restrict__ data, + const int *__restrict__ idx, + const int *__restrict__ counts, + scalar_t *__restrict__ out) { + int index = blockDim.x * blockIdx.x + threadIdx.x; + int i = index / c; + int j = index % c; + if (i < N) { + int pos = idx[i]; + if (pos < 0 || pos >= s || counts[pos] == 0) return; + atomicAdd(&out[pos * c + j], data[i * c + j] / float(counts[pos])); + } +} + +template +__global__ void voxelize_backward_kernel(int N, int c, int s, + const scalar_t *__restrict__ top_grad, + const int *__restrict__ idx, + const int *__restrict__ counts, + scalar_t *__restrict__ bottom_grad) { + int index = blockDim.x * blockIdx.x + threadIdx.x; + int i = index / c; + int j = index % c; + if (i < N) { + int pos = idx[i]; + if (pos < 0 || pos >= s || counts[pos] == 0) return; + atomicAdd(&bottom_grad[i * c + j], + top_grad[pos * c + j] / float(counts[pos])); + } +} + +at::Tensor voxelize_forward_cuda(const at::Tensor inputs, const at::Tensor idx, + const at::Tensor counts) { + int N = inputs.size(0); + int c = inputs.size(1); + int N1 = counts.size(0); + + at::Tensor out = + torch::zeros({N1, c}, at::device(idx.device()).dtype(inputs.dtype())); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputs.type(), "voxelize_forward_cuda", ([&] { + voxelize_forward_kernel<<>>( + N, c, N1, inputs.data_ptr(), idx.data_ptr(), + counts.data_ptr(), out.data_ptr()); + })); + + return out; +} + +at::Tensor voxelize_backward_cuda(const at::Tensor top_grad, + const at::Tensor idx, const at::Tensor counts, + const int N) { + int c = top_grad.size(1); + int N1 = counts.size(0); + + at::Tensor bottom_grad = + torch::zeros({N, c}, at::device(idx.device()).dtype(top_grad.dtype())); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.type(), "voxelize_backward_cuda", ([&] { + voxelize_backward_kernel<<>>( + N, c, N1, top_grad.data_ptr(), idx.data_ptr(), + counts.data_ptr(), bottom_grad.data_ptr()); + })); + + return bottom_grad; +} diff --git a/torchsparse/backend/voxelize/voxelize_cuda.h b/torchsparse/backend/voxelize/voxelize_cuda.h new file mode 100644 index 0000000..4b4cb0e --- /dev/null +++ b/torchsparse/backend/voxelize/voxelize_cuda.h @@ -0,0 +1,13 @@ +#ifndef TORCHSPARSE_VOXELIZE_CUDA +#define TORCHSPARSE_VOXELIZE_CUDA + +#include + +at::Tensor voxelize_forward_cuda(const at::Tensor inputs, const at::Tensor idx, + const at::Tensor counts); + +at::Tensor voxelize_backward_cuda(const at::Tensor top_grad, + const at::Tensor idx, const at::Tensor counts, + const int N); + +#endif diff --git a/torchsparse/nn/functional/__init__.py b/torchsparse/nn/functional/__init__.py index 94e6d43..5c0121d 100644 --- a/torchsparse/nn/functional/__init__.py +++ b/torchsparse/nn/functional/__init__.py @@ -1,9 +1,8 @@ from .activation import * from .conv import * -from .squeeze_nmap import * from .count import * from .crop import * -from .devox import * +from .devoxelize import * from .downsample import * from .hash import * from .pooling import * diff --git a/torchsparse/nn/functional/activation.py b/torchsparse/nn/functional/activation.py index 5f6aa64..b0cab37 100644 --- a/torchsparse/nn/functional/activation.py +++ b/torchsparse/nn/functional/activation.py @@ -1,29 +1,19 @@ -import functools - from torch.nn import functional as F -from torchsparse.sparse_tensor import SparseTensor - -__all__ = ['spact', 'sprelu', 'spleaky_relu'] +from torchsparse import SparseTensor +from torchsparse.nn.utils import fapply -def spact(inputs, act_funct=F.relu): - feats = inputs.F - coords = inputs.C - stride = inputs.s - output_features = act_funct(feats) - outputs = SparseTensor(output_features, coords, stride) - outputs.coord_maps = inputs.coord_maps - outputs.kernel_maps = inputs.kernel_maps - return outputs +__all__ = ['relu', 'leaky_relu'] -def sprelu(inputs, inplace=True): - return spact(inputs, functools.partial(F.relu, inplace=inplace)) +def relu(input: SparseTensor, inplace: bool = True) -> SparseTensor: + return fapply(input, F.relu, inplace=inplace) -def spleaky_relu(inputs, negative_slope=0.1, inplace=True): - return spact( - inputs, - functools.partial(F.leaky_relu, - inplace=inplace, - negative_slope=negative_slope)) +def leaky_relu(input: SparseTensor, + negative_slope: float = 0.1, + inplace: bool = True) -> SparseTensor: + return fapply(input, + F.leaky_relu, + negative_slope=negative_slope, + inplace=inplace) diff --git a/torchsparse/nn/functional/conv.py b/torchsparse/nn/functional/conv.py index 807efc0..2e9b73d 100644 --- a/torchsparse/nn/functional/conv.py +++ b/torchsparse/nn/functional/conv.py @@ -1,218 +1,147 @@ -import copy +from typing import Optional, Tuple, Union import torch -import torchsparse_backend from torch.autograd import Function -from torch.cuda.amp import custom_fwd, custom_bwd -from torchsparse import SparseTensor -from torchsparse.nn import functional as spF -from torchsparse.utils.helpers import make_tuple -from torchsparse.utils.kernel import KernelRegion, KernelMapKey +from torch.cuda.amp import custom_bwd, custom_fwd -from typing import Union, List, Tuple, Optional +import torchsparse.backend +from torchsparse import SparseTensor +from torchsparse.nn import functional as F +from torchsparse.nn.utils import get_kernel_offsets +from torchsparse.utils import make_ntuple __all__ = ['conv3d'] -class SpConvolution(Function): +class ConvolutionFunction(Function): + @staticmethod @custom_fwd(cast_inputs=torch.half) def forward(ctx, - features, - kernel, - neighbor_map, - neighbor_offset, - sizes, - transpose=False): - features = features.contiguous() - kernel = kernel.contiguous() - if not transpose: - out = torch.zeros(sizes[1], - kernel.size(-1), - dtype=features.dtype, - device=features.device) + input: torch.Tensor, + weight: torch.Tensor, + nbmaps: torch.Tensor, + nbsizes: torch.Tensor, + sizes: Tuple[int, int], + transposed: bool = False) -> torch.Tensor: + input = input.contiguous() + weight = weight.contiguous() + nbmaps = nbmaps.int().contiguous() + nbsizes = nbsizes.int().contiguous() + + if not transposed: + output = torch.zeros(sizes[1], + weight.size(-1), + dtype=input.dtype, + device=input.device) else: - # tbd: ensure the original, upsampled size to be the same. - out = torch.zeros(sizes[0], - kernel.size(-1), - dtype=features.dtype, - device=features.device) - - if 'cuda' in str(features.device): - torchsparse_backend.sparseconv_forward(features, out, kernel, - neighbor_map, - neighbor_offset, transpose) + # TODO(Haotian): ensure the original, upsampled size to be the same. + output = torch.zeros(sizes[0], + weight.size(-1), + dtype=input.dtype, + device=input.device) + + if input.device.type == 'cuda': + torchsparse.backend.convolution_forward_cuda( + input, output, weight, nbmaps, nbsizes.cpu(), transposed) else: # use the native pytorch XLA APIs for the TPU. cur_st = 0 - for kernel_idx in range(kernel.shape[0]): - cur_ed = cur_st + neighbor_offset[kernel_idx] - in_map = neighbor_map[cur_st:cur_ed, 0].long() - out_map = neighbor_map[cur_st:cur_ed, 1].long() - cur_st += neighbor_offset[kernel_idx] + for kernel_idx in range(weight.shape[0]): + cur_ed = cur_st + nbsizes[kernel_idx] + in_map = nbmaps[cur_st:cur_ed, 0].long() + out_map = nbmaps[cur_st:cur_ed, 1].long() + cur_st += nbsizes[kernel_idx] - if transpose: + if transposed: in_map, out_map = out_map, in_map - # gather - cur_feat = features[in_map] - # gemm - cur_feat = torch.mm(cur_feat, kernel[kernel_idx]) - # scatter - out[out_map] += cur_feat - ctx.for_backwards = (features, kernel, neighbor_map, neighbor_offset, - transpose) - return out + cur_feat = input[in_map] + cur_feat = torch.mm(cur_feat, weight[kernel_idx]) + output[out_map] += cur_feat + + ctx.for_backwards = (input, weight, nbmaps, nbsizes, transposed) + return output @staticmethod @custom_bwd - def backward(ctx, grad_out): - features, kernel, neighbor_map, neighbor_offset, transpose = ctx.for_backwards - K, c_in, c_out = kernel.size() - N_in = features.size(0) - grad_features = torch.zeros(N_in, - c_in, - device=features.device, - dtype=features.dtype) - grad_kernel = torch.zeros(K, - c_in, - c_out, - device=kernel.device, - dtype=features.dtype) - - if 'cuda' in str(features.device): - torchsparse_backend.sparseconv_backward(features, grad_features, - grad_out.contiguous(), - kernel, grad_kernel, - neighbor_map, - neighbor_offset, transpose) - else: - raise NotImplementedError - return grad_features, grad_kernel, None, None, None, None + def backward(ctx, grad_output: torch.Tensor): + input, weight, nbmaps, nbsizes, transposed = ctx.for_backwards + grad_input = torch.zeros_like(input) + grad_weight = torch.zeros_like(weight) -sparseconv_op = SpConvolution.apply + if input.device.type == 'cuda': + torchsparse.backend.convolution_backward_cuda( + input, grad_input, grad_output.contiguous(), weight, + grad_weight, nbmaps, nbsizes.cpu(), transposed) + else: + raise NotImplementedError + return grad_input, grad_weight, None, None, None, None -def conv3d(inputs: SparseTensor, - kernel: torch.Tensor, - kernel_size: Union[int, List[int], Tuple[int, int, int]], +def conv3d(input: SparseTensor, + weight: torch.Tensor, + kernel_size: Union[int, Tuple[int, ...]], bias: Optional[torch.Tensor] = None, - stride: Union[int, List[int], Tuple[int, int, int]] = 1, - dilation: Union[int, List[int], Tuple[int, int, int]] = 1, - transpose: bool = False) -> SparseTensor: - features = inputs.F - coords = inputs.C - cur_stride = inputs.s - - # convert to hashable types - kernel_size = make_tuple(kernel_size) - stride = make_tuple(stride) - dilation = make_tuple(dilation) - - if kernel_size == (1, 1, 1) and stride == (1, 1, 1) and dilation == (1, 1, - 1): - output_features = features.matmul(kernel) + stride: Union[int, Tuple[int, ...]] = 1, + dilation: Union[int, Tuple[int, ...]] = 1, + transposed: bool = False) -> SparseTensor: + feats, coords = input.feats, input.coords + + kernel_size = make_ntuple(kernel_size, ndim=3) + stride = make_ntuple(stride, ndim=3) + dilation = make_ntuple(dilation, ndim=3) + + if (kernel_size == (1, 1, 1) and stride == (1, 1, 1) + and dilation == (1, 1, 1)): + feats = feats.matmul(weight) if bias is not None: - output_features += bias - output_tensor = SparseTensor(output_features, coords, cur_stride) - output_tensor.coord_maps = inputs.coord_maps - output_tensor.kernel_maps = inputs.kernel_maps - output_tensor.check() - elif not transpose: - kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride, - dilation) - kernel_map = inputs.kernel_maps.get(kernel_map_key, None) - - if any(x > 1 for x in stride): - # do downsample - kRegion = KernelRegion(kernel_size=kernel_size, - tensor_stride=cur_stride) - kOffset = kRegion.get_kernel_offset().to(features.device) - new_coords = spF.spdownsample(coords, stride, kernel_size, - cur_stride) - hash_query = spF.sphash(new_coords, kOffset) - hash_target = spF.sphash(coords) - idx_query = spF.sphashquery(hash_query, hash_target) - idx_query = list(spF.squeeze_nmap(idx_query)) - idx_query[1] = idx_query[1].to('cpu') - sizes = (features.shape[0], new_coords.shape[0]) - output_features = sparseconv_op(features, kernel, idx_query[0], - idx_query[1], sizes, transpose) - if bias is not None: - output_features += bias - output_tensor = SparseTensor( - output_features, new_coords, - [a * b for a, b in zip(cur_stride, stride)]) - output_tensor.coord_maps = copy.deepcopy(inputs.coord_maps) - output_tensor.check() - - kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride, - dilation) - output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps) - output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes] - - else: - if kernel_map is None: - kRegion = KernelRegion(kernel_size=kernel_size, - tensor_stride=cur_stride) - try: - kOffset = kRegion.get_kernel_offset().to(features.device) - except: - raise - hash_query = spF.sphash(coords, kOffset) - hash_target = spF.sphash(coords) - idx_query = spF.sphashquery(hash_query, hash_target) - idx_query = list(spF.squeeze_nmap(idx_query)) - idx_query[1] = idx_query[1].to('cpu') - sizes = (features.shape[0], features.shape[0]) - output_features = sparseconv_op(features, kernel, idx_query[0], - idx_query[1], sizes, transpose) - if bias is not None: - output_features += bias - output_tensor = SparseTensor(output_features, coords, - cur_stride) - output_tensor.coord_maps = inputs.coord_maps - output_tensor.check() - output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps) - kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride, - dilation) - output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes] - else: - output_features = sparseconv_op(features, kernel, - kernel_map[0], kernel_map[1], - kernel_map[2], transpose) - if bias is not None: - output_features += bias - output_tensor = SparseTensor(output_features, coords, - cur_stride) - output_tensor.coord_maps = inputs.coord_maps - output_tensor.check() - output_tensor.kernel_maps = inputs.kernel_maps - - else: - # do upsample - - original_stride = tuple( - [int(a / b) for a, b in zip(cur_stride, stride)]) - - kernel_map_key = KernelMapKey(kernel_size, original_stride, stride, - dilation) - kernel_map = inputs.kernel_maps.get(kernel_map_key, None) - assert kernel_map is not None, f'{kernel_map_key} does not exist.' - output_features = sparseconv_op(features, kernel, kernel_map[0], - kernel_map[1], kernel_map[2], - transpose) + feats += bias + output = SparseTensor(coords=coords, feats=feats, stride=input.stride) + elif not transposed: + kmap = input.kmaps.get((input.stride, kernel_size, stride, dilation)) + if kmap is None: + offsets = get_kernel_offsets(kernel_size, + stride=input.stride, + device=feats.device) + + references = F.sphash(coords) + if any(s > 1 for s in stride): + coords = F.spdownsample(coords, stride, kernel_size, + input.stride) + queries = F.sphash(coords, offsets) + results = F.sphashquery(queries, references) + + nbsizes = torch.sum(results != -1, dim=1) + nbmaps = torch.nonzero(results != -1) + nbmaps[:, 0] = results.view(-1)[nbmaps[:, 0] * results.size(1) + + nbmaps[:, 1]] + + kmap = [nbmaps, nbsizes, (feats.shape[0], coords.shape[0])] + input.kmaps[(input.stride, kernel_size, stride, dilation)] = kmap + + feats = ConvolutionFunction.apply(feats, weight, kmap[0], kmap[1], + kmap[2], transposed) if bias is not None: - output_features += bias - - cur_coords = inputs.coord_maps.get(original_stride, None) - assert cur_coords is not None, f'{original_stride} not in coord maps.' - - output_tensor = SparseTensor(output_features, cur_coords, - original_stride) - output_tensor.coord_maps = inputs.coord_maps - output_tensor.check() - output_tensor.kernel_maps = inputs.kernel_maps + feats += bias + output = SparseTensor( + coords=coords, + feats=feats, + stride=tuple(input.stride[k] * stride[k] for k in range(3))) + else: + tensor_stride = tuple(input.stride[k] // stride[k] for k in range(3)) + kmap = input.kmaps[(tensor_stride, kernel_size, stride, dilation)] - return output_tensor + feats = ConvolutionFunction.apply(feats, weight, kmap[0], kmap[1], + kmap[2], transposed) + if bias is not None: + feats += bias + output = SparseTensor(coords=input.cmaps[tensor_stride], + feats=feats, + stride=tensor_stride) + + output.cmaps = input.cmaps + output.cmaps.setdefault(output.stride, output.coords) + output.kmaps = input.kmaps + return output diff --git a/torchsparse/nn/functional/count.py b/torchsparse/nn/functional/count.py index 434e078..4c0fb6b 100644 --- a/torchsparse/nn/functional/count.py +++ b/torchsparse/nn/functional/count.py @@ -1,18 +1,16 @@ -import torchsparse_backend -from torch.autograd import Function - -__all__ = ['spcount'] +import torch +import torchsparse.backend -class CountGPU(Function): - @staticmethod - def forward(ctx, idx, num): - if 'cuda' in str(idx.device): - outs = torchsparse_backend.count_forward(idx.contiguous(), num) - else: - outs = torchsparse_backend.cpu_count_forward(idx.contiguous(), num) - return outs +__all__ = ['spcount'] -def spcount(idx, num): - return CountGPU.apply(idx, num) +def spcount(coords: torch.Tensor, num: torch.Tensor) -> torch.Tensor: + coords = coords.contiguous() + if coords.device.type == 'cuda': + return torchsparse.backend.count_cuda(coords, num) + elif coords.device.type == 'cpu': + return torchsparse.backend.count_cpu(coords, num) + else: + device = coords.device + return torchsparse.backend.count_cpu(coords.cpu(), num).to(device) diff --git a/torchsparse/nn/functional/crop.py b/torchsparse/nn/functional/crop.py index feebb0f..819e88f 100644 --- a/torchsparse/nn/functional/crop.py +++ b/torchsparse/nn/functional/crop.py @@ -1,4 +1,4 @@ -from torchsparse.sparse_tensor import SparseTensor +from torchsparse import SparseTensor __all__ = ['spcrop'] diff --git a/torchsparse/nn/functional/devox.py b/torchsparse/nn/functional/devox.py deleted file mode 100644 index 6c364f8..0000000 --- a/torchsparse/nn/functional/devox.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import torchsparse_backend -from torch.autograd import Function -from torch.cuda.amp import custom_fwd, custom_bwd - -__all__ = ['spdevoxelize', 'calc_ti_weights'] - - -def calc_ti_weights(pc, idx_query, scale=1.0): - # TBD: normalize the weights to a probability distribution. Note that some indices are "-1". - with torch.no_grad(): - # don't want points to lie exactly on grid - pc_grid = pc - # don't use np.floor then convert to torch. numerical errors. - if scale != 1.: - pc_floor = torch.floor(pc / scale) * scale - else: - pc_floor = torch.floor(pc) - pc_ceil = pc_floor + scale - pc_gridx = pc_grid[:, 0].view(-1, 1) - pc_gridy = pc_grid[:, 1].view(-1, 1) - pc_gridz = pc_grid[:, 2].view(-1, 1) - pc_floorx = pc_floor[:, 0].view(-1, 1) - pc_floory = pc_floor[:, 1].view(-1, 1) - pc_floorz = pc_floor[:, 2].view(-1, 1) - pc_ceilx = pc_ceil[:, 0].view(-1, 1) - pc_ceily = pc_ceil[:, 1].view(-1, 1) - pc_ceilz = pc_ceil[:, 2].view(-1, 1) - pc_floorx = pc_floorx.float() - pc_floory = pc_floory.float() - pc_floorz = pc_floorz.float() - pc_ceilx = pc_ceilx.float() - pc_ceily = pc_ceily.float() - pc_ceilz = pc_ceilz.float() - weight000 = (pc_ceilx - pc_gridx) * (pc_ceily - pc_gridy) * (pc_ceilz - - pc_gridz) - weight001 = (pc_ceilx - pc_gridx) * (pc_ceily - pc_gridy) * (pc_gridz - - pc_floorz) - weight010 = (pc_ceilx - pc_gridx) * (pc_gridy - - pc_floory) * (pc_ceilz - pc_gridz) - weight011 = (pc_ceilx - pc_gridx) * (pc_gridy - pc_floory) * ( - pc_gridz - pc_floorz) - weight100 = (pc_gridx - pc_floorx) * (pc_ceily - - pc_gridy) * (pc_ceilz - pc_gridz) - weight101 = (pc_gridx - pc_floorx) * (pc_ceily - pc_gridy) * ( - pc_gridz - pc_floorz) - weight110 = (pc_gridx - pc_floorx) * (pc_gridy - pc_floory) * ( - pc_ceilz - pc_gridz) - weight111 = (pc_gridx - pc_floorx) * (pc_gridy - pc_floory) * ( - pc_gridz - pc_floorz) - - all_weights = torch.cat([ - weight000, weight001, weight010, weight011, weight100, weight101, - weight110, weight111 - ], 1).transpose(1, 0).contiguous() - if scale != 1: - all_weights /= scale ** 3 - all_weights[idx_query == -1] = 0 - all_weights /= all_weights.sum(0) + 1e-8 - return all_weights - - -class DevoxelizationGPU(Function): - @staticmethod - @custom_fwd(cast_inputs=torch.half) - def forward(ctx, feat, indices, weights): - if 'cuda' in str(feat.device): - out = torchsparse_backend.devoxelize_forward( - feat.contiguous(), - indices.contiguous().int(), weights.contiguous()) - else: - out = torchsparse_backend.cpu_devoxelize_forward( - feat.contiguous(), - indices.contiguous().int(), weights.contiguous()) - - ctx.for_backwards = (indices.contiguous().int(), weights, - feat.shape[0]) - - return out - - @staticmethod - @custom_bwd - def backward(ctx, grad_out): - indices, weights, n = ctx.for_backwards - - if 'cuda' in str(grad_out.device): - grad_features = torchsparse_backend.devoxelize_backward( - grad_out.contiguous(), indices, weights, n) - else: - grad_features = torchsparse_backend.cpu_devoxelize_backward( - grad_out.contiguous(), indices, weights, n) - - return grad_features, None, None - - -devoxelize = DevoxelizationGPU.apply - - -def spdevoxelize(feat, indices, weights): - return devoxelize(feat, indices, weights) diff --git a/torchsparse/nn/functional/devoxelize.py b/torchsparse/nn/functional/devoxelize.py new file mode 100644 index 0000000..4f17d2c --- /dev/null +++ b/torchsparse/nn/functional/devoxelize.py @@ -0,0 +1,99 @@ +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +import torchsparse.backend + +__all__ = ['spdevoxelize', 'calc_ti_weights'] + + +def calc_ti_weights(coords, idx_query, scale: float = 1): + # TODO(Haotian): normalize the weights to a probability distribution. + with torch.no_grad(): + # don't want points to lie exactly on grid + p = coords + # don't use np.floor then convert to torch. numerical errors. + if scale != 1: + pf = torch.floor(coords / scale) * scale + else: + pf = torch.floor(coords) + pc = pf + scale + + x = p[:, 0].view(-1, 1) + y = p[:, 1].view(-1, 1) + z = p[:, 2].view(-1, 1) + + xf = pf[:, 0].view(-1, 1).float() + yf = pf[:, 1].view(-1, 1).float() + zf = pf[:, 2].view(-1, 1).float() + + xc = pc[:, 0].view(-1, 1).float() + yc = pc[:, 1].view(-1, 1).float() + zc = pc[:, 2].view(-1, 1).float() + + w0 = (xc - x) * (yc - y) * (zc - z) + w1 = (xc - x) * (yc - y) * (z - zf) + w2 = (xc - x) * (y - yf) * (zc - z) + w3 = (xc - x) * (y - yf) * (z - zf) + w4 = (x - xf) * (yc - y) * (zc - z) + w5 = (x - xf) * (yc - y) * (z - zf) + w6 = (x - xf) * (y - yf) * (zc - z) + w7 = (x - xf) * (y - yf) * (z - zf) + + w = torch.cat([w0, w1, w2, w3, w4, w5, w6, w7], dim=1) + w = w.transpose(1, 0).contiguous() + if scale != 1: + w /= scale ** 3 + w[idx_query == -1] = 0 + w /= w.sum(0) + 1e-8 + return w + + +class DevoxelizeFunction(Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.half) + def forward(ctx, feats: torch.Tensor, coords: torch.Tensor, + weights: torch.Tensor) -> torch.Tensor: + feats = feats.contiguous() + coords = coords.contiguous().int() + weights = weights.contiguous() + + if feats.device.type == 'cuda': + output = torchsparse.backend.devoxelize_forward_cuda( + feats, coords, weights) + elif feats.device.type == 'cpu': + output = torchsparse.backend.devoxelize_forward_cpu( + feats, coords, weights) + else: + device = feats.device + output = torchsparse.backend.devoxelize_forward_cpu( + feats.cpu(), coords.cpu(), weights.cpu()).to(device) + + ctx.for_backwards = (coords, weights, feats.shape[0]) + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output: torch.Tensor): + coords, weights, input_size = ctx.for_backwards + grad_output = grad_output.contiguous() + + if grad_output.device.type == 'cuda': + grad_feats = torchsparse.backend.devoxelize_backward_cuda( + grad_output, coords, weights, input_size) + elif grad_output.device.type == 'cpu': + grad_feats = torchsparse.backend.devoxelize_backward_cpu( + grad_output, coords, weights, input_size) + else: + device = grad_output.device + grad_feats = torchsparse.backend.devoxelize_backward_cpu( + grad_output.cpu(), coords.cpu(), weights.cpu(), + input_size).to(device) + + return grad_feats, None, None + + +def spdevoxelize(feats: torch.Tensor, coords: torch.Tensor, + weights: torch.Tensor) -> torch.Tensor: + return DevoxelizeFunction.apply(feats, coords, weights) diff --git a/torchsparse/nn/functional/downsample.py b/torchsparse/nn/functional/downsample.py index 7d7c31b..e1893ed 100644 --- a/torchsparse/nn/functional/downsample.py +++ b/torchsparse/nn/functional/downsample.py @@ -1,70 +1,49 @@ +from typing import Tuple, Union + import torch -import torchsparse_backend -from torch.autograd import Function -from torchsparse.nn.functional.hash import * -from torchsparse.utils.kernel import KernelRegion -from typing import Tuple, List, Union + +from torchsparse.nn.utils import get_kernel_offsets +from torchsparse.utils import make_ntuple __all__ = ['spdownsample'] def spdownsample( - coords: torch.Tensor, - ratio: Union[int, List[int], Tuple[int, int, int]] = 2, - kernel_size: Union[int, List[int], Tuple[int, int, int]] = 2, - tensor_stride: Union[int, List[int], Tuple[int, int, int]] = 1 -) -> torch.Tensor: + coords: torch.Tensor, + stride: Union[int, Tuple[int, ...]] = 2, + kernel_size: Union[int, Tuple[int, ...]] = 2, + tensor_stride: Union[int, Tuple[int, ...]] = 1) -> torch.Tensor: + stride = make_ntuple(stride, ndim=3) + kernel_size = make_ntuple(kernel_size, ndim=3) + tensor_stride = make_ntuple(tensor_stride, ndim=3) - if not isinstance(ratio, int): - ratio = torch.IntTensor(ratio).to(coords.device).unsqueeze(0) - if not isinstance(tensor_stride, int): - tensor_stride = torch.IntTensor(tensor_stride).to( - coords.device).unsqueeze(0) + sample_stride = [stride[k] * tensor_stride[k] for k in range(3)] + sample_stride = torch.tensor(sample_stride, + dtype=torch.int, + device=coords.device).unsqueeze(0) - if isinstance(kernel_size, int) and isinstance(ratio, int): - direct_downsample = kernel_size == ratio + if all(stride[k] in [1, kernel_size[k]] for k in range(3)): + coords = coords.clone() + coords[:, :3] = coords[:, :3] // sample_stride * sample_stride else: - if isinstance(kernel_size, int): - # ratio is a permutation of [1, 1, kernel_size] - direct_downsample = (kernel_size == ratio.prod().item()) & \ - (torch.sum(ratio == kernel_size) == 1).item() - else: - direct_downsample = False + offsets = get_kernel_offsets(kernel_size, + tensor_stride, + device=coords.device) - if direct_downsample: - _ratio = ratio * tensor_stride - new_coords = torch.cat( - [coords[:, :3] // _ratio * _ratio, coords[:, 3:]], 1) - return torch.unique(new_coords, dim=0) - else: - kernel_region = KernelRegion(kernel_size, tensor_stride, dilation=1) - # kernel volume x 3 - kernel_offset = kernel_region.get_kernel_offset().to(coords.device) - new_coords = coords[:, :3].unsqueeze(1).repeat( - 1, kernel_offset.size(0), 1) + kernel_offset - # (N x kernel volume) x 4 - new_coords = torch.cat([ - coords[:, 3:].repeat(1, kernel_offset.size(0)).view(-1, 1), - new_coords.view(-1, 3) - ], - dim=1) - new_ts = tensor_stride * ratio - # only keep these coordinates that is multiple of new_ts. - if isinstance(new_ts, torch.Tensor): - new_ts = new_ts[0] - new_coords = new_coords[ - (new_coords[:, 1] % new_ts[0].item() == 0) & (new_coords[:, 2] % new_ts[1].item() == 0) & \ - (new_coords[:, 3] % new_ts[2].item() == 0) - ] - else: - new_coords = new_coords[ - (new_coords[:, 1] % new_ts == 0) & (new_coords[:, 2] % new_ts == 0) & \ - (new_coords[:, 3] % new_ts == 0) - ] - new_coords = new_coords[(new_coords[:, 1] >= 0) - & (new_coords[:, 2] >= 0) & - (new_coords[:, 3] >= 0)] - # filter out duplicates - new_coords = torch.unique(new_coords, dim=0) - new_coords = new_coords[:, [1, 2, 3, 0]] - return new_coords + coords_min = torch.min(coords[:, :3], dim=0, keepdim=True).values + + xyz = coords[:, :3].unsqueeze(1).repeat(1, offsets.size(0), 1) + offsets + b = coords[:, 3:].repeat(1, offsets.size(0)) + coords = torch.cat([xyz.view(-1, 3), b.view(-1, 1)], dim=1) + + # TODO(Zhijian): We need to also filter `coords` based on `coords_max`. + mask = (coords[:, :3] % sample_stride == 0) + mask &= (coords[:, :3] >= coords_min) + coords = coords[torch.sum(mask, dim=1) == 3, :] + + # This makes sure that the points will be ordered with respect to the batch + # index, but this will not affect the correctness of the result. + coords = coords[:, [3, 0, 1, 2]] + coords = torch.unique(coords, dim=0) + coords = coords[:, [1, 2, 3, 0]] + return coords diff --git a/torchsparse/nn/functional/hash.py b/torchsparse/nn/functional/hash.py index ad5580c..9499b96 100644 --- a/torchsparse/nn/functional/hash.py +++ b/torchsparse/nn/functional/hash.py @@ -1,41 +1,37 @@ -import torchsparse_backend -from torch.autograd import Function - -__all__ = ['sphash'] +from typing import Optional +import torch -class HashGPU(Function): - @staticmethod - def forward(ctx, idx): - if 'cuda' in str(idx.device): - return torchsparse_backend.hash_forward(idx.contiguous()) - elif 'cpu' in str(idx.device): - return torchsparse_backend.cpu_hash_forward(idx.int().contiguous()) - else: - device = idx.device - return torchsparse_backend.cpu_hash_forward( - idx.int().contiguous().cpu()).to(device) +import torchsparse.backend +__all__ = ['sphash'] -class KernelHashGPU(Function): - @staticmethod - def forward(ctx, idx, koffset): - if 'cuda' in str(idx.device): - return torchsparse_backend.kernel_hash_forward( - idx.contiguous(), koffset.contiguous()) - elif 'cpu' in str(idx.device): - return torchsparse_backend.cpu_kernel_hash_forward( - idx.int().contiguous(), - koffset.int().contiguous()) - else: - device = idx.device - return torchsparse_backend.cpu_kernel_hash_forward( - idx.int().contiguous().cpu(), - koffset.int().contiguous().cpu()).to(device) +def sphash(coords: torch.Tensor, + offsets: Optional[torch.Tensor] = None) -> torch.Tensor: + assert coords.dtype == torch.int, coords.dtype + assert coords.ndim == 2 and coords.shape[1] == 4, coords.shape + coords = coords.contiguous() -def sphash(idx, koffset=None): - if koffset is None: - return HashGPU.apply(idx) + # TODO(Zhijian): We might be able to merge `hash_kernel` and `hash`. + if offsets is None: + if coords.device.type == 'cuda': + return torchsparse.backend.hash_cuda(coords) + elif coords.device.type == 'cpu': + return torchsparse.backend.hash_cpu(coords) + else: + device = coords.device + return torchsparse.backend.hash_cpu(coords.cpu()).to(device) else: - return KernelHashGPU.apply(idx, koffset) + assert offsets.dtype == torch.int, offsets.dtype + assert offsets.ndim == 2 and offsets.shape[1] == 3, offsets.shape + offsets = offsets.contiguous() + + if coords.device.type == 'cuda': + return torchsparse.backend.kernel_hash_cuda(coords, offsets) + elif coords.device.type == 'cpu': + return torchsparse.backend.kernel_hash_cpu(coords, offsets) + else: + device = coords.device + return torchsparse.backend.kernel_hash_cpu(coords.cpu(), + offsets.cpu()).to(device) diff --git a/torchsparse/nn/functional/pooling.py b/torchsparse/nn/functional/pooling.py index cc314f2..0d70d8b 100644 --- a/torchsparse/nn/functional/pooling.py +++ b/torchsparse/nn/functional/pooling.py @@ -1,9 +1,11 @@ import torch +from torchsparse import SparseTensor + __all__ = ['global_avg_pool', 'global_max_pool'] -def global_avg_pool(inputs): +def global_avg_pool(inputs: SparseTensor) -> torch.Tensor: batch_index = inputs.C[:, -1] max_index = torch.max(batch_index).item() outputs = [] @@ -16,7 +18,7 @@ def global_avg_pool(inputs): return outputs -def global_max_pool(inputs): +def global_max_pool(inputs: SparseTensor) -> torch.Tensor: batch_index = inputs.C[:, -1] max_index = torch.max(batch_index).item() outputs = [] diff --git a/torchsparse/nn/functional/query.py b/torchsparse/nn/functional/query.py index 79e8b9f..913047d 100644 --- a/torchsparse/nn/functional/query.py +++ b/torchsparse/nn/functional/query.py @@ -1,40 +1,33 @@ import torch -import torchsparse_backend -from torch.autograd import Function + +import torchsparse.backend __all__ = ['sphashquery'] -class SparseQuery(Function): - @staticmethod - def forward(ctx, hash_query, hash_target): - if len(hash_query.size()) == 2: - C = hash_query.size(1) - else: - C = 1 - - idx_target = torch.arange(len(hash_target), - device=hash_query.device, - dtype=torch.long) - - if 'cuda' in str(hash_query.device): - out, key_buf, val_buf, key = torchsparse_backend.query_forward( - hash_query.view(-1).contiguous(), hash_target.contiguous(), - idx_target) - elif 'cpu' in str(hash_query.device): - out = torchsparse_backend.cpu_query_forward( - hash_query.view(-1).contiguous(), hash_target.contiguous(), - idx_target) - else: - device = hash_query.device - out = torchsparse_backend.cpu_query_forward( - hash_query.view(-1).contiguous().cpu(), - hash_target.contiguous().cpu(), idx_target.cpu()).to(device) - - if C > 1: - out = out.view(-1, C) - return (out - 1) - - -def sphashquery(hash_query, hash_target): - return SparseQuery.apply(hash_query, hash_target) +def sphashquery(queries: torch.Tensor, + references: torch.Tensor) -> torch.Tensor: + queries = queries.contiguous() + references = references.contiguous() + + sizes = queries.size() + queries = queries.view(-1) + + indices = torch.arange(len(references), + device=queries.device, + dtype=torch.long) + + if queries.device.type == 'cuda': + output = torchsparse.backend.hash_query_cuda(queries, references, + indices) + elif queries.device.type == 'cpu': + output = torchsparse.backend.hash_query_cpu(queries, references, + indices) + else: + device = queries.device + output = torchsparse.backend.hash_query_cpu(queries.cpu(), + references.cpu(), + indices.cpu()).to(device) + + output = (output - 1).view(*sizes) + return output diff --git a/torchsparse/nn/functional/squeeze_nmap.py b/torchsparse/nn/functional/squeeze_nmap.py deleted file mode 100644 index 2bd3f17..0000000 --- a/torchsparse/nn/functional/squeeze_nmap.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch - -__all__ = ['squeeze_nmap'] - - -def squeeze_nmap(neighbor_map: torch.Tensor) -> torch.Tensor: - idx_batch, idx_point = torch.where(neighbor_map != -1) - map_converted = neighbor_map.view(-1)[idx_batch * neighbor_map.size(1) + - idx_point] - map_converted = torch.stack([map_converted, idx_point], dim=1) - nmap_offset = torch.sum(neighbor_map != -1, 1) - return map_converted.int().contiguous(), nmap_offset.int().contiguous() diff --git a/torchsparse/nn/functional/voxelize.py b/torchsparse/nn/functional/voxelize.py index 12cc305..12537dc 100644 --- a/torchsparse/nn/functional/voxelize.py +++ b/torchsparse/nn/functional/voxelize.py @@ -1,30 +1,56 @@ import torch -import torchsparse_backend from torch.autograd import Function -from torch.cuda.amp import custom_fwd, custom_bwd -from torchsparse.nn.functional.hash import * +from torch.cuda.amp import custom_bwd, custom_fwd + +import torchsparse.backend __all__ = ['spvoxelize'] -class VoxelizeGPU(Function): +class VoxelizeFunction(Function): + @staticmethod @custom_fwd(cast_inputs=torch.half) - def forward(ctx, feat, idx, cnt): - out = torchsparse_backend.insertion_forward(feat.contiguous(), - idx.int().contiguous(), - cnt) - ctx.for_backwards = (idx.int().contiguous(), cnt, feat.shape[0]) - return out + def forward(ctx, feats: torch.Tensor, coords: torch.Tensor, + counts: torch.Tensor) -> torch.Tensor: + feats = feats.contiguous() + coords = coords.contiguous().int() + + if feats.device.type == 'cuda': + output = torchsparse.backend.voxelize_forward_cuda( + feats, coords, counts) + elif feats.device.type == 'cpu': + output = torchsparse.backend.voxelize_forward_cpu( + feats, coords, counts) + else: + device = feats.device + output = torchsparse.backend.voxelize_forward_cpu( + feats.cpu(), coords.cpu(), counts.cpu()).to(device) + + ctx.for_backwards = (coords, counts, feats.shape[0]) + return output @staticmethod @custom_bwd - def backward(ctx, top_grad): - idx, cnt, N = ctx.for_backwards - bottom_grad = torchsparse_backend.insertion_backward( - top_grad.contiguous(), idx, cnt, N) - return bottom_grad, None, None + def backward(ctx, grad_output: torch.Tensor): + coords, counts, input_size = ctx.for_backwards + grad_output = grad_output.contiguous() + + if grad_output.device.type == 'cuda': + grad_feats = torchsparse.backend.voxelize_backward_cuda( + grad_output, coords, counts, input_size) + elif grad_output.device.type == 'cpu': + grad_feats = torchsparse.backend.voxelize_backward_cpu( + grad_output, coords, counts, input_size) + else: + device = grad_output.device + grad_feats = torchsparse.backend.voxelize_backward_cpu( + grad_output.cpu(), coords.cpu(), counts.cpu(), + input_size).to(device) + + return grad_feats, None, None -def spvoxelize(feat, idx, cnt): - return VoxelizeGPU.apply(feat, idx, cnt) +def spvoxelize(feats: torch.Tensor, coords: torch.Tensor, + counts: torch.Tensor) -> torch.Tensor: + return VoxelizeFunction.apply(feats, coords, counts) diff --git a/torchsparse/nn/modules/__init__.py b/torchsparse/nn/modules/__init__.py index 2e7eda7..cb0e53e 100644 --- a/torchsparse/nn/modules/__init__.py +++ b/torchsparse/nn/modules/__init__.py @@ -1,4 +1,5 @@ from .activation import * +from .bev import * from .conv import * from .crop import * from .norm import * diff --git a/torchsparse/nn/modules/activation.py b/torchsparse/nn/modules/activation.py index cdce13d..41ec040 100644 --- a/torchsparse/nn/modules/activation.py +++ b/torchsparse/nn/modules/activation.py @@ -1,47 +1,18 @@ -import functools - from torch import nn -from torchsparse.sparse_tensor import SparseTensor -from torchsparse.nn import functional as spF +from torchsparse import SparseTensor +from torchsparse.nn.utils import fapply __all__ = ['ReLU', 'LeakyReLU'] -class Activation(nn.Module): - def __init__(self, inplace: bool = True) -> None: - super().__init__() - self.activation = spF.spact - self.inplace = inplace - - def forward(self, inputs): - return self.activation(inputs) - - -class ReLU(Activation): - def __init__(self, inplace: bool = True) -> None: - super().__init__() - self.activation = functools.partial(spF.sprelu, inplace=inplace) +class ReLU(nn.ReLU): - def __repr__(self): - if self.inplace: - return 'ReLU(inplace=True)' - else: - return 'ReLU(inplace=False)' + def forward(self, input: SparseTensor) -> SparseTensor: + return fapply(input, super().forward) -class LeakyReLU(Activation): - def __init__(self, - negative_slope: float = 0.1, - inplace: bool = True) -> None: - super().__init__() - self.activation = functools.partial(spF.spleaky_relu, - negative_slope=negative_slope, - inplace=inplace) - self.negative_slope = negative_slope +class LeakyReLU(nn.LeakyReLU): - def __repr__(self): - if self.inplace: - return 'LeakyReLU(negative_slope=%f, inplace=True)' % self.negative_slope - else: - return 'LeakyReLU(negative_slope=%f, inplace=False)' % self.negative_slope + def forward(self, input: SparseTensor) -> SparseTensor: + return fapply(input, super().forward) diff --git a/torchsparse/nn/modules/bev.py b/torchsparse/nn/modules/bev.py new file mode 100644 index 0000000..6e87aa7 --- /dev/null +++ b/torchsparse/nn/modules/bev.py @@ -0,0 +1,225 @@ +import math +from typing import List, Tuple, Union + +import torch +from torch import nn + +from torchsparse import SparseTensor + +__all__ = [ + 'ToBEVConvolution', 'ToBEVReduction', 'ToDenseBEVConvolution', + 'ToBEVHeightCompression' +] + + +class ToBEVReduction(nn.Module): + + def __init__(self, dim: int = 1) -> None: + super().__init__() + self.dim = dim + + def extra_repr(self): + return f'dim = {self.dim}' + + def forward(self, input: SparseTensor) -> SparseTensor: + coords, feats, stride = input.C, input.F, input.s + + coords = coords.clone() + coords[:, self.dim] = 0 + feats = torch.cat([torch.ones_like(feats[:, :1]), feats], axis=1) + tensor = torch.cuda.sparse.FloatTensor(coords.t().long(), + feats).coalesce() + coords = tensor.indices().t().int() + feats = tensor.values()[:, 1:] / tensor.values()[:, :1] + return SparseTensor(coords=coords, feats=feats, stride=stride) + + +class ToDenseBEVConvolution(nn.Module): + """ + + Converts a torchsparse.SparseTensor to a BEV feature map. + Group points with the same z value together and apply the same FC kernel. + Aggregate the results by summing up all features within one BEV grid. + + in_channels: input channels + out_channels: output channels + shape: shape of BEV map. + dim: dimension index for z. (default: 1 for KITTI coords) + bias: whether to use bias. + + Warning: usually larger memory consumption than ToBEVHeightCompression. + + + """ + + def __init__(self, + in_channels: int, + out_channels: int, + shape: Union[List[int], Tuple[int, int, int], torch.Tensor], + offset: Tuple[int, int, int] = (0, 0, 0), + dim: int = 1, + bias: bool = False) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.register_buffer('offset', torch.IntTensor([list(offset) + [0]])) + if isinstance(shape, torch.Tensor): + self.register_buffer('shape', shape.int()) + else: + self.register_buffer('shape', torch.IntTensor(shape)) + self.dim = dim + self.n_kernels = int(self.shape[self.dim]) + self.bev_dims = [i for i in range(3) if i != self.dim] + self.bev_shape = self.shape[self.bev_dims] + self.kernel = nn.Parameter( + torch.zeros(self.n_kernels, in_channels, out_channels)) + self.bias = nn.Parameter(torch.zeros(1, out_channels)) if bias else 0 + self.reset_parameters() + + def extra_repr(self): + return 'in_channels={}, out_channels={}, n_kernels={}'.format( + self.in_channels, self.out_channels, self.n_kernels) + + def reset_parameters(self): + std = 1. / math.sqrt(self.in_channels) + self.kernel.data.uniform_(-std, std) + + def forward(self, input: SparseTensor) -> torch.Tensor: + coords, feats, stride = input.C, input.F, input.s + if isinstance(stride, tuple): + stride = torch.Tensor(stride).unsqueeze(0).to(feats)[:, self.dim] + + kernel = torch.index_select(self.kernel, 0, + (coords[:, self.dim] // stride).long()) + feats = (feats.unsqueeze(-1) * kernel).sum(1) + self.bias + coords = (coords - self.offset).t()[[3] + self.bev_dims].long() + coords[1:] = (coords[1:] // stride).long() + indices = coords[0] * int(self.bev_shape.prod()) + coords[1] * int( + self.bev_shape[1]) + coords[2] + batch_size = coords[0].max().item() + 1 + output = torch.sparse_coo_tensor( + indices.unsqueeze(0), + feats, + torch.Size( + [batch_size * int(self.bev_shape.prod()), + feats.size(-1)]), + ).to_dense() + output = output.view(batch_size, *self.bev_shape, -1) + output = output.permute(0, 3, 1, 2).contiguous() + return output + + +class ToBEVConvolution(nn.Module): + """ Sparse version of ToDenseBEVConvolution. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + n_kernels: int, + stride: int = 1, + dim: int = 1, + bias: bool = False) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.n_kernels = n_kernels + self.stride = stride + self.dim = dim + self.kernel = nn.Parameter( + torch.zeros(n_kernels, in_channels, out_channels)) + self.bias = nn.Parameter(torch.zeros(1, out_channels)) if bias else 0 + self.reset_parameters() + + def reset_parameters(self): + std = 1. / math.sqrt(self.in_channels) + self.kernel.data.uniform_(-std, std) + + def extra_repr(self): + return 'in_channels={}, out_channels={}, n_kernels={}, stride={}'.format( + self.in_channels, self.out_channels, self.n_kernels, self.stride) + + def forward(self, input: SparseTensor) -> torch.Tensor: + coords, feats, stride = input.C, input.F, input.s + ratio = stride * self.stride + if isinstance(stride, tuple): + stride = torch.Tensor(stride).unsqueeze(0).to(feats)[:, self.dim] + + kernels = torch.index_select(self.kernel, 0, + coords[:, self.dim].long() // stride) + feats = (feats.unsqueeze(-1) * kernels).sum(1) + self.bias + coords = coords.t().long() + coords[self.dim, :] = 0 + if self.stride > 1: + coords[:3] /= ratio + coords[:3] *= ratio + flatten = torch.cuda.sparse.FloatTensor(coords, feats).coalesce() + return SparseTensor(flatten.values(), + flatten.indices().t().int(), ratio) + + +class ToBEVHeightCompression(nn.Module): + """ + + Converts a torchsparse.SparseTensor to a dense volumetric tensor, + then flatten the z dimension. + E.g. input [N, C] (assume batch_size=1), spatial size [128,2,128] + then output will be [1, 2C, 128, 128] + + channels: input channels + (Note: output channels = channels x #unique z values) + shape: shape of BEV map. + dim: dimension index for z. (default: 1 for KITTI coords) + bias: whether to use bias. + + + """ + + def __init__(self, + channels: int, + shape: Union[List[int], Tuple[int, int, int], torch.Tensor], + offset: Tuple[int, int, int] = (0, 0, 0), + dim: int = 1, + bias: bool = False) -> None: + super().__init__() + self.channels = channels + self.register_buffer('offset', torch.IntTensor([list(offset) + [0]])) + if isinstance(shape, torch.Tensor): + self.register_buffer('shape', shape.int()) + else: + self.register_buffer('shape', torch.IntTensor(shape)) + self.dim = dim + self.bev_dims = [i for i in range(3) if i != self.dim] + self.bev_shape = self.shape[self.bev_dims] + + def extra_repr(self) -> str: + return f'channels={self.channels}' + + def forward(self, input: SparseTensor) -> torch.Tensor: + coords, feats, stride = input.C, input.F, input.s + if isinstance(stride, tuple): + stride = torch.Tensor(stride).unsqueeze(0).to(feats) + assert isinstance(stride, torch.Tensor) + + # [b, x, y, z] + coords = (coords - self.offset).t()[[3] + self.bev_dims + + [self.dim]].long() + shape = self.shape[self.bev_dims + [self.dim]] + + # now stride must be torch.Tensor since input.s is tuple. + stride = stride[:, self.bev_dims + [self.dim]].t() + + coords[1:] = (coords[1:] // stride).long() + coords[-1] = torch.clamp(coords[-1], 0, shape[-1] - 1) + indices = coords[0] * int(shape.prod()) + coords[1] * int( + shape[1:].prod()) + coords[2] * int(shape[2]) + coords[3] + batch_size = coords[0].max().item() + 1 + output = torch.sparse_coo_tensor( + indices.unsqueeze(0), + feats, + torch.Size([batch_size * int(self.shape.prod()), + feats.size(-1)]), + ).to_dense() + output = output.view(batch_size, *self.bev_shape.cpu().numpy(), -1) + output = output.permute(0, 3, 1, 2).contiguous() + return output diff --git a/torchsparse/nn/modules/conv.py b/torchsparse/nn/modules/conv.py index 3dcb2af..407a2ac 100644 --- a/torchsparse/nn/modules/conv.py +++ b/torchsparse/nn/modules/conv.py @@ -1,301 +1,72 @@ import math +from typing import Tuple, Union +import numpy as np import torch from torch import nn -from torchsparse.sparse_tensor import SparseTensor -from torchsparse.nn import functional as spF -from torchsparse.utils.helpers import make_tuple -from typing import Union, List, Tuple +from torchsparse import SparseTensor +from torchsparse.nn import functional as F +from torchsparse.utils import make_ntuple -__all__ = [ - 'Conv3d', 'ToBEVConvolution', 'ToBEVReduction', 'ToDenseBEVConvolution', - 'ToBEVHeightCompression' -] +__all__ = ['Conv3d'] class Conv3d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, - kernel_size: Union[int, List[int], Tuple[int, int, int]] = 3, - stride: Union[int, List[int], Tuple[int, int, int]] = 1, + kernel_size: Union[int, Tuple[int, ...]] = 3, + stride: Union[int, Tuple[int, ...]] = 1, dilation: int = 1, bias: bool = False, - transpose: bool = False) -> None: + transposed: bool = False) -> None: super().__init__() - self.in_channels = inc = in_channels - self.out_channels = outc = out_channels - if isinstance(kernel_size, list): - self.kernel_size = tuple(kernel_size) - else: - self.kernel_size = kernel_size - if isinstance(stride, list): - self.stride = tuple(stride) - else: - self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = make_ntuple(kernel_size, ndim=3) + self.stride = make_ntuple(stride, ndim=3) self.dilation = dilation + self.transposed = transposed - if not isinstance(kernel_size, (list, tuple)): - self.kernel_volume = self.kernel_size ** 3 + self.kernel_volume = int(np.prod(self.kernel_size)) + if self.kernel_volume > 1: self.kernel = nn.Parameter( - torch.zeros(self.kernel_volume, inc, - outc)) if self.kernel_size > 1 else nn.Parameter( - torch.zeros(inc, outc)) + torch.zeros(self.kernel_volume, in_channels, out_channels)) else: - if len(self.kernel_size) == 3: - self.kernel_volume = self.kernel_size[0] * self.kernel_size[ - 1] * self.kernel_size[2] - self.kernel = nn.Parameter( - torch.zeros(self.kernel_volume, inc, outc)) - else: - raise ValueError( - "kernel_size must be either an integer of a 3 dimensional tuple" - ) - - self.bias = None if not bias else nn.Parameter(torch.zeros(outc)) - self.t = transpose - self.reset_parameters() - - if kernel_size == 1: - assert not transpose - - def __repr__(self): - if not self.t: - return 'Conv3d(in_channels={}, out_channels={}, kernel_size={}, stride={}, dilation={})'.format( - self.in_channels, self.out_channels, self.kernel_size, - self.stride, self.dilation) + self.kernel = nn.Parameter(torch.zeros(in_channels, out_channels)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) else: - return 'Conv3dTranspose(in_channels={}, out_channels={}, kernel_size={}, stride={}, dilation={})'.format( - self.in_channels, self.out_channels, self.kernel_size, - self.stride, self.dilation) + self.register_parameter('bias', None) + self.reset_parameters() - def reset_parameters(self): - std = 1. / math.sqrt( - self.out_channels if self.t else self.in_channels * - (self.kernel_volume)) + def extra_repr(self) -> str: + s = '{in_channels}, {out_channels}, kernel_size={kernel_size}' + if self.stride != (1,) * len(self.stride): + s += ', stride={stride}' + if self.dilation != 1: + s += ', dilation={dilation}' + if self.bias is None: + s += ', bias=False' + if self.transposed: + s += ', transposed=True' + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + std = 1 / math.sqrt( + (self.out_channels if self.transposed else self.in_channels) + * self.kernel_volume) self.kernel.data.uniform_(-std, std) if self.bias is not None: self.bias.data.uniform_(-std, std) - def forward(self, inputs: SparseTensor) -> SparseTensor: - return spF.conv3d(inputs, - self.kernel, - kernel_size=self.kernel_size, - bias=self.bias, - stride=self.stride, - dilation=self.dilation, - transpose=self.t) - - -class ToBEVReduction(nn.Module): - def __init__(self, dim: int = 1) -> None: - super().__init__() - self.dim = dim - - def extra_repr(self): - return 'dim = {}'.format(self.dim) - - def forward(self, inputs: SparseTensor) -> SparseTensor: - coords, feats, stride = inputs.C, inputs.F, inputs.s - - coords = coords.clone() - coords[:, self.dim] = 0 - feats = torch.cat([torch.ones_like(feats[:, :1]), feats], axis=1) - tensor = torch.cuda.sparse.FloatTensor(coords.t().long(), - feats).coalesce() - coords = tensor.indices().t().int() - feats = tensor.values()[:, 1:] / tensor.values()[:, :1] - return SparseTensor(coords=coords, feats=feats, stride=stride) - - -class ToDenseBEVConvolution(nn.Module): - """ - - Converts a torchsparse.SparseTensor to a BEV feature map. - Group points with the same z value together and apply the same FC kernel. - Aggregate the results by summing up all features within one BEV grid. - - in_channels: input channels - out_channels: output channels - shape: shape of BEV map. - dim: dimension index for z. (default: 1 for KITTI coords) - bias: whether to use bias. - - Warning: usually larger memory consumption than ToBEVHeightCompression. - - - """ - def __init__(self, - in_channels: int, - out_channels: int, - shape: Union[List[int], Tuple[int, int, int], torch.Tensor], - offset: List[int] = [0, 0, 0], - dim: int = 1, - bias: bool = False) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.register_buffer('offset', torch.IntTensor([list(offset) + [0]])) - if isinstance(shape, torch.Tensor): - self.register_buffer('shape', shape.int()) - else: - self.register_buffer('shape', torch.IntTensor(shape)) - self.dim = dim - self.n_kernels = int(self.shape[self.dim]) - self.bev_dims = [i for i in range(3) if i != self.dim] - self.bev_shape = self.shape[self.bev_dims] - self.kernel = nn.Parameter( - torch.zeros(self.n_kernels, in_channels, out_channels)) - self.bias = nn.Parameter(torch.zeros(1, out_channels)) if bias else 0 - self.reset_parameters() - - def extra_repr(self): - return 'in_channels={}, out_channels={}, n_kernels={}'.format( - self.in_channels, self.out_channels, self.n_kernels) - - def reset_parameters(self): - std = 1. / math.sqrt(self.in_channels) - self.kernel.data.uniform_(-std, std) - - def forward(self, inputs: SparseTensor) -> torch.Tensor: - coords, feats, stride = inputs.C, inputs.F, inputs.s - if isinstance(stride, tuple): - stride = torch.Tensor(stride).unsqueeze(0).to(feats)[:, self.dim] - - kernel = torch.index_select(self.kernel, 0, - (coords[:, self.dim] // stride).long()) - feats = (feats.unsqueeze(-1) * kernel).sum(1) + self.bias - coords = (coords - self.offset).t()[[3] + self.bev_dims].long() - coords[1:] = (coords[1:] // stride).long() - indices = coords[0] * int(self.bev_shape.prod()) + coords[1] * int( - self.bev_shape[1]) + coords[2] - batch_size = coords[0].max().item() + 1 - outputs = torch.sparse_coo_tensor( - indices.unsqueeze(0), - feats, - torch.Size( - [batch_size * int(self.bev_shape.prod()), - feats.size(-1)]), - ).to_dense() - outputs = outputs.view(batch_size, *self.bev_shape, -1) - outputs = outputs.permute(0, 3, 1, 2).contiguous() - return outputs - - -class ToBEVConvolution(nn.Module): - """ - - Sparse version of ToDenseBEVConvolution. - - """ - def __init__(self, - in_channels: int, - out_channels: int, - n_kernels: int, - stride: int = 1, - dim: int = 1, - bias: bool = False) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.n_kernels = n_kernels - self.stride = stride - self.dim = dim - self.kernel = nn.Parameter( - torch.zeros(n_kernels, in_channels, out_channels)) - self.bias = nn.Parameter(torch.zeros(1, out_channels)) if bias else 0 - self.reset_parameters() - - def reset_parameters(self): - std = 1. / math.sqrt(self.in_channels) - self.kernel.data.uniform_(-std, std) - - def extra_repr(self): - return 'in_channels={}, out_channels={}, n_kernels={}, stride={}'.format( - self.in_channels, self.out_channels, self.n_kernels, self.stride) - - def forward(self, inputs: SparseTensor) -> torch.Tensor: - coords, feats, stride = inputs.C, inputs.F, inputs.s - ratio = stride * self.stride - if isinstance(stride, tuple): - stride = torch.Tensor(stride).unsqueeze(0).to(feats)[:, self.dim] - - kernels = torch.index_select(self.kernel, 0, - coords[:, self.dim].long() // stride) - feats = (feats.unsqueeze(-1) * kernels).sum(1) + self.bias - coords = coords.t().long() - coords[self.dim, :] = 0 - if self.stride > 1: - coords[:3] /= ratio - coords[:3] *= ratio - flatten = torch.cuda.sparse.FloatTensor(coords, feats).coalesce() - return SparseTensor(flatten.values(), - flatten.indices().t().int(), ratio) - - -class ToBEVHeightCompression(nn.Module): - """ - - Converts a torchsparse.SparseTensor to a dense volumetric tensor, - then flatten the z dimension. - E.g. input [N, C] (assume batch_size=1), spatial size [128,2,128] - then output will be [1, 2C, 128, 128] - - channels: input channels - (Note: output channels = channels x #unique z values) - shape: shape of BEV map. - dim: dimension index for z. (default: 1 for KITTI coords) - bias: whether to use bias. - - - """ - def __init__(self, - channels: int, - shape: Union[List[int], Tuple[int, int, int], torch.Tensor], - offset: List[int] = [0, 0, 0], - dim: int = 1, - bias: bool = False) -> None: - super().__init__() - self.channels = channels - self.register_buffer('offset', torch.IntTensor([list(offset) + [0]])) - if isinstance(shape, torch.Tensor): - self.register_buffer('shape', shape.int()) - else: - self.register_buffer('shape', torch.IntTensor(shape)) - self.dim = dim - self.bev_dims = [i for i in range(3) if i != self.dim] - self.bev_shape = self.shape[self.bev_dims] - - def extra_repr(self): - return 'channels={}'.format(self.channels) - - def forward(self, inputs: SparseTensor) -> torch.Tensor: - coords, feats, stride = inputs.C, inputs.F, inputs.s - if isinstance(stride, tuple): - stride = torch.Tensor(stride).unsqueeze(0).to(feats) - - # [b, x, y, z] - coords = (coords - self.offset).t()[[3] + self.bev_dims + - [self.dim]].long() - shape = self.shape[self.bev_dims + [self.dim]] - - # now stride must be torch.Tensor since inputs.s is tuple. - dim = self.dim - stride = stride[:, self.bev_dims + [self.dim]] - stride = stride.t() - - coords[1:] = (coords[1:] // stride).long() - coords[-1] = torch.clamp(coords[-1], 0, shape[-1] - 1) - indices = coords[0] * int(shape.prod()) + coords[1] * int( - shape[1:].prod()) + coords[2] * int(shape[2]) + coords[3] - batch_size = coords[0].max().item() + 1 - outputs = torch.sparse_coo_tensor( - indices.unsqueeze(0), - feats, - torch.Size([batch_size * int(self.shape.prod()), - feats.size(-1)]), - ).to_dense() - outputs = outputs.view(batch_size, *self.bev_shape.cpu().numpy(), -1) - outputs = outputs.permute(0, 3, 1, 2).contiguous() - return outputs + def forward(self, input: SparseTensor) -> SparseTensor: + return F.conv3d(input, + self.kernel, + kernel_size=self.kernel_size, + bias=self.bias, + stride=self.stride, + dilation=self.dilation, + transposed=self.transposed) diff --git a/torchsparse/nn/modules/crop.py b/torchsparse/nn/modules/crop.py index f73bf27..9b58421 100644 --- a/torchsparse/nn/modules/crop.py +++ b/torchsparse/nn/modules/crop.py @@ -1,15 +1,18 @@ import torch from torch import nn + +from torchsparse import SparseTensor from torchsparse.nn.functional import spcrop __all__ = ['SparseCrop'] class SparseCrop(nn.Module): + def __init__(self, loc_min, loc_max): super().__init__() self.loc_min = torch.cuda.IntTensor([list(loc_min)]) self.loc_max = torch.cuda.IntTensor([list(loc_max)]) - def forward(self, inputs): - return spcrop(inputs, self.loc_min, self.loc_max) + def forward(self, input: SparseTensor) -> SparseTensor: + return spcrop(input, self.loc_min, self.loc_max) diff --git a/torchsparse/nn/modules/norm.py b/torchsparse/nn/modules/norm.py index 7590114..1c22246 100644 --- a/torchsparse/nn/modules/norm.py +++ b/torchsparse/nn/modules/norm.py @@ -1,45 +1,29 @@ import torch from torch import nn -from torchsparse.sparse_tensor import SparseTensor + +from torchsparse import SparseTensor +from torchsparse.nn.utils import fapply __all__ = ['BatchNorm', 'GroupNorm'] class BatchNorm(nn.BatchNorm1d): - def __init__(self, - num_features: int, - *, - eps: float = 1e-5, - momentum: float = 0.1) -> None: - super().__init__(num_features=num_features, eps=eps, momentum=momentum) - - def forward(self, inputs): - feats = inputs.F - coords = inputs.C - stride = inputs.s - feats = super().forward(feats) - outputs = SparseTensor(coords=coords, feats=feats, stride=stride) - outputs.coord_maps = inputs.coord_maps - outputs.kernel_maps = inputs.kernel_maps - return outputs + + def forward(self, input: SparseTensor) -> SparseTensor: + return fapply(input, super().forward) class GroupNorm(nn.GroupNorm): - def __init__(self, - num_groups: int, - num_channels: int, - eps: float = 1e-5, - affine: bool = True) -> None: - super().__init__(num_groups, num_channels, eps=eps, affine=affine) - - def forward(self, inputs): - feats = inputs.F - coords = inputs.C - stride = inputs.s - # PyTorch's GroupNorm function expects the input to be in (N, C, *) format where - # N is batch size, and C is number of channels. "feats" is not in that format. - # So, we extract the feats corresponding to each sample, bring it to the format - # expected by PyTorch's GroupNorm function, and invoke it. + + def forward(self, input: SparseTensor) -> SparseTensor: + feats = input.F + coords = input.C + stride = input.s + # PyTorch's GroupNorm function expects the input to be in (N, C, *) + # format where N is batch size, and C is number of channels. "feats" + # is not in that format. So, we extract the feats corresponding to + # each sample, bring it to the format expected by PyTorch's GroupNorm + # function, and invoke it. batch_size = coords[-1][-1] + 1 num_channels = feats.shape[1] new_feats = torch.zeros_like(feats) @@ -47,13 +31,12 @@ def forward(self, inputs): indices = coords[:, -1] == sample_idx sample_feats = feats[indices] sample_feats = torch.transpose(sample_feats, 0, 1) - sample_feats = sample_feats.reshape( - 1, num_channels, -1) # N=1. since we have a single sample here + sample_feats = sample_feats.reshape(1, num_channels, -1) normalized_feats = super().forward(sample_feats) normalized_feats = normalized_feats.reshape(num_channels, -1) normalized_feats = torch.transpose(normalized_feats, 0, 1) new_feats[indices] = normalized_feats - outputs = SparseTensor(coords=coords, feats=new_feats, stride=stride) - outputs.coord_maps = inputs.coord_maps - outputs.kernel_maps = inputs.kernel_maps - return outputs + output = SparseTensor(coords=coords, feats=new_feats, stride=stride) + output.cmaps = input.cmaps + output.kmaps = input.kmaps + return output diff --git a/torchsparse/nn/modules/pooling.py b/torchsparse/nn/modules/pooling.py index 947f5e2..d58aa6b 100644 --- a/torchsparse/nn/modules/pooling.py +++ b/torchsparse/nn/modules/pooling.py @@ -1,17 +1,18 @@ from torch import nn -from torchsparse.sparse_tensor import SparseTensor - -from torchsparse.nn import functional as spF +from torchsparse import SparseTensor +from torchsparse.nn import functional as F __all__ = ['GlobalAveragePooling', 'GlobalMaxPooling'] class GlobalAveragePooling(nn.Module): - def forward(self, inputs): - return spF.global_avg_pool(inputs) + + def forward(self, input: SparseTensor) -> SparseTensor: + return F.global_avg_pool(input) class GlobalMaxPooling(nn.Module): - def forward(self, inputs): - return spF.global_max_pool(inputs) + + def forward(self, input: SparseTensor) -> SparseTensor: + return F.global_max_pool(input) diff --git a/torchsparse/nn/utils/__init__.py b/torchsparse/nn/utils/__init__.py new file mode 100644 index 0000000..0630f8b --- /dev/null +++ b/torchsparse/nn/utils/__init__.py @@ -0,0 +1,2 @@ +from .apply import * +from .kernel import * diff --git a/torchsparse/nn/utils/apply.py b/torchsparse/nn/utils/apply.py new file mode 100644 index 0000000..cc38b18 --- /dev/null +++ b/torchsparse/nn/utils/apply.py @@ -0,0 +1,16 @@ +from typing import Callable + +import torch + +from torchsparse import SparseTensor + +__all__ = ['fapply'] + + +def fapply(input: SparseTensor, fn: Callable[..., torch.Tensor], *args, + **kwargs) -> SparseTensor: + feats = fn(input.feats, *args, **kwargs) + output = SparseTensor(coords=input.coords, feats=feats, stride=input.stride) + output.cmaps = input.cmaps + output.kmaps = input.kmaps + return output diff --git a/torchsparse/nn/utils/kernel.py b/torchsparse/nn/utils/kernel.py new file mode 100644 index 0000000..0973e8f --- /dev/null +++ b/torchsparse/nn/utils/kernel.py @@ -0,0 +1,32 @@ +from typing import Tuple, Union + +import numpy as np +import torch + +from torchsparse.utils import make_ntuple + +__all__ = ['get_kernel_offsets'] + + +def get_kernel_offsets(size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + dilation: Union[int, Tuple[int, ...]] = 1, + device: str = 'cpu') -> torch.Tensor: + size = make_ntuple(size, ndim=3) + stride = make_ntuple(stride, ndim=3) + dilation = make_ntuple(dilation, ndim=3) + + offsets = [(np.arange(-size[k] // 2 + 1, size[k] // 2 + 1) * stride[k] + * dilation[k]) for k in range(3)] + + # This condition check is only to make sure that our weight layout is + # compatible with `MinkowskiEngine`. + if np.prod(size) % 2 == 1: + offsets = [[x, y, z] for z in offsets[2] for y in offsets[1] + for x in offsets[0]] + else: + offsets = [[x, y, z] for x in offsets[0] for y in offsets[1] + for z in offsets[2]] + + offsets = torch.tensor(offsets, dtype=torch.int, device=device) + return offsets diff --git a/torchsparse/operators.py b/torchsparse/operators.py new file mode 100644 index 0000000..d03a379 --- /dev/null +++ b/torchsparse/operators.py @@ -0,0 +1,16 @@ +from typing import List + +import torch + +from torchsparse.tensors import SparseTensor + +__all__ = ['cat'] + + +def cat(inputs: List[SparseTensor]) -> SparseTensor: + coords, stride = inputs[0].coords, inputs[0].stride + feats = torch.cat([inputs.feats for inputs in inputs], dim=1) + outputs = SparseTensor(feats, coords, stride) + outputs.cmaps = inputs[0].cmaps + outputs.kmaps = inputs[0].kmaps + return outputs diff --git a/torchsparse/point_tensor.py b/torchsparse/point_tensor.py deleted file mode 100644 index 99f7c05..0000000 --- a/torchsparse/point_tensor.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch - -__all__ = ['PointTensor'] - - -class PointTensor: - def __init__(self, feat, coords, idx_query=None, weights=None): - self.F = feat - self.C = coords - self.idx_query = idx_query if idx_query is not None else {} - self.weights = weights if weights is not None else {} - self.additional_features = {} - self.additional_features['idx_query'] = {} - self.additional_features['counts'] = {} - - def cuda(self): - assert type(self.F) == torch.Tensor - assert type(self.C) == torch.Tensor - self.F = self.F.cuda() - self.C = self.C.cuda() - return self - - def detach(self): - assert type(self.F) == torch.Tensor - assert type(self.C) == torch.Tensor - self.F = self.F.detach() - self.C = self.C.detach() - return self - - def to(self, device, non_blocking=True): - assert type(self.F) == torch.Tensor - assert type(self.C) == torch.Tensor - self.F = self.F.to(device, non_blocking=non_blocking) - self.C = self.C.to(device, non_blocking=non_blocking) - return self - - def __add__(self, other): - tensor = PointTensor(self.F + other.F, self.C, self.idx_query, - self.weights) - tensor.additional_features = self.additional_features - return tensor diff --git a/torchsparse/sparse_tensor.py b/torchsparse/sparse_tensor.py deleted file mode 100644 index 095d4a3..0000000 --- a/torchsparse/sparse_tensor.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import torch -from typing import Union, List, Tuple - -__all__ = ['SparseTensor'] - - -class SparseTensor: - def __init__( - self, - feats: Union[np.ndarray, torch.Tensor], - coords: Union[np.ndarray, torch.Tensor], - stride: Union[int, List[int], Tuple[int, int, int]] = 1) -> None: - self.F = feats - self.C = coords - if isinstance(stride, int): - self.s = (stride, stride, stride) - elif isinstance(stride, list): - self.s = tuple(stride) - else: - self.s = stride - self.coord_maps = {} - self.kernel_maps = {} - - def check(self): - if self.s not in self.coord_maps: - self.coord_maps[self.s] = self.C - - def cuda(self): - assert type(self.F) == torch.Tensor - assert type(self.C) == torch.Tensor - self.F = self.F.cuda() - self.C = self.C.cuda() - return self - - def detach(self): - assert type(self.F) == torch.Tensor - assert type(self.C) == torch.Tensor - self.F = self.F.detach() - self.C = self.C.detach() - return self - - def to(self, device, non_blocking=True): - assert type(self.F) == torch.Tensor - assert type(self.C) == torch.Tensor - self.F = self.F.to(device, non_blocking=non_blocking) - self.C = self.C.to(device, non_blocking=non_blocking) - return self - - def __add__(self, other): - tensor = SparseTensor(self.F + other.F, self.C, self.s) - tensor.coord_maps = self.coord_maps - tensor.kernel_maps = self.kernel_maps - return tensor diff --git a/torchsparse/src/convolution/convolution.cu b/torchsparse/src/convolution/convolution.cu deleted file mode 100644 index 3701886..0000000 --- a/torchsparse/src/convolution/convolution.cu +++ /dev/null @@ -1,283 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include "convolution_gpu.h" - -template -__global__ void gather_kernel(const int n_k, const int n_in, const int c, - const scalar_t *in_feat, scalar_t *out_feat, const int *kmap, - const bool transpose){ - - int index = blockIdx.x * blockDim.x + threadIdx.x; - int i = index / c; - int j = index % c; - if(i >= n_k) return; - int in_pos = kmap[2 * i + transpose]; - if(in_pos < 0) return; - out_feat[i * c + j] = in_feat[in_pos * c + j]; -} - -template -__global__ void scatter_kernel(const int n_in, const int n_out, const int c, - const scalar_t *in_feat, scalar_t *out_feat, const int *kmap, - const bool transpose){ - - int index = blockIdx.x * blockDim.x + threadIdx.x; - int i = index / c; - int j = index % c; - if(i >= n_in) return; - int out_pos = kmap[2 * i + 1 - transpose]; - if(out_pos < 0) return; - out_feat[out_pos * c + j] += in_feat[i * c + j]; -} - -// in_feat: (N, c) N=# of input points, c = input channels -// out_feat: (M, o) M=# of output points, o = output channels -// for stride=1, M=N. For stride>1, the N input coords -// are requantized to M points with grid size (stride * cur_stride) -// kernel: (k^3, c, o) for a 3D convolution of length k -// neighbor_map: (a, 2) the hash table query results from out_coords to in_coords -// where neighbor_map[:,0] is the index of the output feature -// and neighbor_map[:,1] is the index of the input feature -// neighbor_offset: (k^3) count of active weights based on neighbor_map -// with unused weights having 0 and neighbor_offset[k^3/2] -// holding w[0,0]. -void ConvolutionForwardGPU(at::Tensor in_feat, at::Tensor out_feat, - at::Tensor kernel, at::Tensor neighbor_map, - at::Tensor neighbor_offset, const bool transpose) -{ - if (in_feat.size(1) != kernel.size(1)) - { - throw std::invalid_argument("Input feature size and kernel size mismatch"); - } - - bool is_half = in_feat.scalar_type() == at::ScalarType::Half; - - int n_in_feats = in_feat.size(0); - int n_in_channels = in_feat.size(1); - int n_out_feats = out_feat.size(0); - int n_out_channels = out_feat.size(1);; - - int kernel_volume = kernel.size(0); - - // memory optimization - bool precompute_mid = false; - int mid_kernel = kernel_volume / 2; - int in_buffer_size = 1; - // we can precompute features for w[0,0] which avoids gather/scatter - if (kernel_volume % 2 == 1 && n_in_feats == n_out_feats) - { - precompute_mid = true; - in_buffer_size = *std::max_element(neighbor_offset.data_ptr(), - neighbor_offset.data_ptr() + mid_kernel); - in_buffer_size = std::max(in_buffer_size, - *std::max_element(neighbor_offset.data_ptr() + mid_kernel + 1, - neighbor_offset.data_ptr() + kernel_volume)); - in_buffer_size = std::max(in_buffer_size, 1); - - // (N, c) X (c, o) = (N, o) - torch::mm_out(out_feat, in_feat, kernel[mid_kernel]); - } - else - { - in_buffer_size = *std::max_element(neighbor_offset.data_ptr(), - neighbor_offset.data_ptr() + kernel_volume); - } - - auto options = - torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device()); - auto in_buffer = torch::zeros({in_buffer_size, n_in_channels}, options); - auto out_buffer = torch::zeros({in_buffer_size, n_out_channels}, options); - int cur_offset = 0; - // gather/gemm/scatter on each weight - for (int i = 0; i < kernel_volume; i++) - { - int n_active_feats = neighbor_offset.data_ptr()[i]; - // if there's no active features for this weight, skip it - if (n_active_feats == 0) - { - continue; - } - - // if w[0,0] was precomputed above, skip it - if ((i == mid_kernel) && precompute_mid) - { - cur_offset += 2 * n_active_feats; - continue; - } - - // in_buffer_activated (i, c) holds the dense input features from gather - // for i = n_active_feats (# of features in the activated kernel from neighbor_offset) - // out_buffer_activated (i, o) holds the dense output features to scatter - at::Tensor out_buffer_activated; - at::Tensor in_buffer_activated; - if (is_half) - { - out_buffer_activated = - torch::from_blob(out_buffer.data_ptr(), - {n_active_feats, n_out_channels}, options); - in_buffer_activated = - torch::from_blob(in_buffer.data_ptr(), - {n_active_feats, n_in_channels}, options); - } - else - { - out_buffer_activated = - torch::from_blob(out_buffer.data_ptr(), - {n_active_feats, n_out_channels}, options); - in_buffer_activated = - torch::from_blob(in_buffer.data_ptr(), - {n_active_feats, n_in_channels}, options); - } - - // gather n_active_feats dense features from N sparse input features with c feature dimensions - AT_DISPATCH_FLOATING_TYPES_AND_HALF(in_feat.type(), "ConvolutionForwardGPU", ([&] { - gather_kernel<<>>( - n_active_feats, - n_in_feats, - n_in_channels, - in_feat.data_ptr(), - in_buffer_activated.data_ptr(), - neighbor_map.data_ptr() + cur_offset, - transpose); - })); - - // gemm: (i, c) X (c, o) = (i, o) - torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]); - - // scatter n_active_feats dense features into n_out_feats output features of dimension n_out_channels - AT_DISPATCH_FLOATING_TYPES_AND_HALF(in_feat.type(), "ConvolutionForwardGPU", ([&] { - scatter_kernel<<>>( - n_active_feats, - n_out_feats, - n_out_channels, - out_buffer_activated.data_ptr(), - out_feat.data_ptr(), - neighbor_map.data_ptr() + cur_offset, - transpose); - })); - - cur_offset += 2 * n_active_feats; - } -} - -void ConvolutionBackwardGPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor kernel, at::Tensor grad_kernel, at::Tensor neighbor_map, - at::Tensor neighbor_offset, const bool transpose) -{ - grad_in_feat.resize_as_(in_feat); - grad_in_feat.zero_(); - grad_kernel.resize_as_(kernel); - grad_kernel.zero_(); - - bool is_half = in_feat.scalar_type() == at::ScalarType::Half; - int n_in_feats = in_feat.size(0); - int n_in_channels = in_feat.size(1); - int n_out_feats = grad_out_feat.size(0); - int n_out_channels = kernel.size(-1); - - int kernel_volume = kernel.size(0); - bool flag = false; - int in_buffer_size; - in_buffer_size = *std::max_element(neighbor_offset.data_ptr(), - neighbor_offset.data_ptr() + kernel_volume); - - auto options = - torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device()); - auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options); - auto in_grad_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options); - auto out_grad_buffer = torch::zeros({in_buffer_size, kernel.size(2)}, options); - - int cur_offset = 0; - for (int i = 0; i < kernel_volume; i++) - { - auto kernel_grad_buffer = grad_kernel[i]; - int n_active_feats = neighbor_offset.data_ptr()[i]; - if (flag && (i == kernel_volume / 2)) - { - cur_offset += 2 * n_active_feats; - continue; - } - - if (n_active_feats == 0) - { - continue; - } - - // Can't figure out a cleaner way to do this - at::Tensor out_grad_buffer_activated; - at::Tensor in_grad_buffer_activated; - at::Tensor in_buffer_activated; - if (is_half) - { - out_grad_buffer_activated = - torch::from_blob(out_grad_buffer.data_ptr(), - {n_active_feats, kernel.size(2)}, options); - in_grad_buffer_activated = - torch::from_blob(in_grad_buffer.data_ptr(), - {n_active_feats, in_feat.size(1)}, options); - in_buffer_activated = - torch::from_blob(in_buffer.data_ptr(), - {n_active_feats, in_feat.size(1)}, options); - } - else - { - out_grad_buffer_activated = - torch::from_blob(out_grad_buffer.data_ptr(), - {n_active_feats, kernel.size(2)}, options); - in_grad_buffer_activated = - torch::from_blob(in_grad_buffer.data_ptr(), - {n_active_feats, in_feat.size(1)}, options); - in_buffer_activated = - torch::from_blob(in_buffer.data_ptr(), - {n_active_feats, in_feat.size(1)}, options); - } - - // gather - AT_DISPATCH_FLOATING_TYPES_AND_HALF(in_feat.type(), "ConvolutionForwardGPU", ([&] { - gather_kernel<<>>( - n_active_feats, - n_out_feats, - n_out_channels, - grad_out_feat.data_ptr(), - out_grad_buffer_activated.data_ptr(), - neighbor_map.data_ptr() + cur_offset, - !transpose); - })); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(in_feat.type(), "ConvolutionForwardGPU", ([&] { - gather_kernel<<>>( - n_active_feats, - n_in_feats, - n_in_channels, - in_feat.data_ptr(), - in_buffer_activated.data_ptr(), - neighbor_map.data_ptr() + cur_offset, - transpose); - })); - - // gemm - torch::mm_out(in_grad_buffer_activated, out_grad_buffer_activated, torch::transpose(kernel[i], 0, 1)); - torch::mm_out(kernel_grad_buffer, torch::transpose(in_buffer_activated, 0, 1), out_grad_buffer_activated); - - // scatter - AT_DISPATCH_FLOATING_TYPES_AND_HALF(in_feat.type(), "ConvolutionForwardGPU", ([&] { - scatter_kernel<<>>( - n_active_feats, - n_in_feats, - n_in_channels, - in_grad_buffer_activated.data_ptr(), - grad_in_feat.data_ptr(), - neighbor_map.data_ptr() + cur_offset, - !transpose); - })); - - cur_offset += 2 * n_active_feats; - } -} diff --git a/torchsparse/src/convolution/convolution_cpu.cpp b/torchsparse/src/convolution/convolution_cpu.cpp deleted file mode 100644 index 8aabf15..0000000 --- a/torchsparse/src/convolution/convolution_cpu.cpp +++ /dev/null @@ -1,195 +0,0 @@ -#include -#include -#include -#include "convolution_cpu_header.h" - -void cpu_scatter_launch(const int n_in, const int n_out, const int c, - const float *in_feat, float *out_feat, - const int *kmap, const bool transpose) -{ - for (int i = 0; i < n_in; i++) - { - int out_pos = kmap[2 * i + 1 - transpose]; - if (out_pos < 0) - continue; -#pragma omp parallel for - for (int j = 0; j < c; j++) - { - out_feat[out_pos * c + j] += in_feat[i * c + j]; - } - } -} - -void cpu_gather_launch(const int n_k, const int n_in, const int c, - const float *in_feat, float *out_feat, - const int *kmap, const bool transpose) -{ - for (int i = 0; i < n_k; i++) - { - int in_pos = kmap[2 * i + transpose]; - if (in_pos < 0) - continue; -#pragma omp parallel for - for (int j = 0; j < c; j++) - { - out_feat[i * c + j] = in_feat[in_pos * c + j]; - } - } -} - -void ConvolutionForwardCPU(at::Tensor in_feat, at::Tensor out_feat, - at::Tensor kernel, at::Tensor neighbor_map, - at::Tensor neighbor_offset, const bool transpose) -{ - - if (in_feat.size(1) != kernel.size(1)) - { - throw std::invalid_argument("Input feature size and kernel size mismatch"); - } - - int out_nrows = out_feat.size(0); - out_feat.resize_({out_nrows, kernel.size(2)}); - out_feat.zero_(); - - int kernel_volume = kernel.size(0); - int in_buffer_size = 1; - bool flag = false; - // memory optimization - if (kernel_volume % 2 && out_nrows == in_feat.size(0)) - { - flag = true; - in_buffer_size = *std::max_element(neighbor_offset.data_ptr(), - neighbor_offset.data_ptr() + kernel_volume / 2); - in_buffer_size = std::max(in_buffer_size, - *std::max_element(neighbor_offset.data_ptr() + kernel_volume / 2 + 1, - neighbor_offset.data_ptr() + kernel_volume)); - in_buffer_size = std::max(in_buffer_size, 1); - - torch::mm_out(out_feat, in_feat, kernel[kernel_volume / 2]); - } - else - { - in_buffer_size = *std::max_element(neighbor_offset.data_ptr(), - neighbor_offset.data_ptr() + kernel_volume); - } - - auto options = - torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device()); - auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options); - auto out_buffer = torch::zeros({in_buffer_size, kernel.size(2)}, options); - int cur_offset = 0; - for (int i = 0; i < kernel_volume; i++) - { - if (flag && (i == kernel_volume / 2)) - { - cur_offset += 2 * neighbor_offset.data_ptr()[i]; - continue; - } - - if (neighbor_offset.data_ptr()[i] == 0) - { - continue; - } - - auto out_buffer_activated = - torch::from_blob(out_buffer.data_ptr(), - {neighbor_offset.data_ptr()[i], kernel.size(2)}, options); - auto in_buffer_activated = - torch::from_blob(in_buffer.data_ptr(), - {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options); - // gather - cpu_gather_launch(in_buffer_activated.size(0), in_feat.size(0), kernel.size(1), - in_feat.data_ptr(), in_buffer_activated.data_ptr(), - neighbor_map.data_ptr() + cur_offset, transpose); - // GEMM - torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]); - // scatter - cpu_scatter_launch(neighbor_offset.data_ptr()[i], out_nrows, kernel.size(2), out_buffer_activated.data_ptr(), - out_feat.data_ptr(), neighbor_map.data_ptr() + cur_offset, transpose); - cur_offset += 2 * neighbor_offset.data_ptr()[i]; - } - - /* - cublasHandle_t handle = - //THCState_getCurrentBlasHandle(at::globalContext().getTHCState()); - at::cuda::getCurrentCUDABlasHandle(); - - ConvolutionForwardKernelGPU( - in_feat.data_ptr(), in_feat.size(1), out_feat.data_ptr(), - out_feat.size(1), kernel.data_ptr(), neighbor_map.data_ptr(), - neighbor_offset.data_ptr(), in_feat.size(0), out_feat.size(0), - kernel.size(0), transpose, handle, - at::cuda::getCurrentCUDAStream()); - */ -} - -void ConvolutionBackwardCPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor kernel, at::Tensor grad_kernel, at::Tensor neighbor_map, - at::Tensor neighbor_offset, const bool transpose) -{ - - grad_in_feat.resize_as_(in_feat); - grad_in_feat.zero_(); - grad_kernel.resize_as_(kernel); - grad_kernel.zero_(); - - int kernel_volume = kernel.size(0); - bool flag = false; - int in_buffer_size; - in_buffer_size = *std::max_element(neighbor_offset.data_ptr(), - neighbor_offset.data_ptr() + kernel_volume); - - auto options = - torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device()); - auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options); - auto in_grad_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options); - auto out_grad_buffer = torch::zeros({in_buffer_size, kernel.size(2)}, options); - - int cur_offset = 0; - for (int i = 0; i < kernel_volume; i++) - { - auto kernel_grad_buffer = grad_kernel[i]; - if (flag && (i == kernel_volume / 2)) - { - cur_offset += 2 * neighbor_offset.data_ptr()[i]; - continue; - } - - if (neighbor_offset.data_ptr()[i] == 0) - { - continue; - } - - auto out_grad_buffer_activated = - torch::from_blob(out_grad_buffer.data_ptr(), - {neighbor_offset.data_ptr()[i], kernel.size(2)}, options); - auto in_grad_buffer_activated = - torch::from_blob(in_grad_buffer.data_ptr(), - {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options); - auto in_buffer_activated = - torch::from_blob(in_buffer.data_ptr(), - {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options); - // gather - - cpu_gather_launch(out_grad_buffer_activated.size(0), grad_out_feat.size(0), kernel.size(2), - grad_out_feat.data_ptr(), out_grad_buffer_activated.data_ptr(), - neighbor_map.data_ptr() + cur_offset, !transpose); - - cpu_gather_launch(in_buffer_activated.size(0), in_feat.size(0), kernel.size(1), - in_feat.data_ptr(), in_buffer_activated.data_ptr(), - neighbor_map.data_ptr() + cur_offset, transpose); - - // GEMM - //torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]); - torch::mm_out(in_grad_buffer_activated, out_grad_buffer_activated, torch::transpose(kernel[i], 0, 1)); - torch::mm_out(kernel_grad_buffer, torch::transpose(in_buffer_activated, 0, 1), out_grad_buffer_activated); - // scatter - //grad_kernel[i] = kernel_grad_buffer; - - cpu_scatter_launch(neighbor_offset.data_ptr()[i], in_feat.size(0), kernel.size(1), in_grad_buffer_activated.data_ptr(), - grad_in_feat.data_ptr(), neighbor_map.data_ptr() + cur_offset, !transpose); - - cur_offset += 2 * neighbor_offset.data_ptr()[i]; - } -} diff --git a/torchsparse/src/convolution/convolution_cpu_header.h b/torchsparse/src/convolution/convolution_cpu_header.h deleted file mode 100644 index be59f6d..0000000 --- a/torchsparse/src/convolution/convolution_cpu_header.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _SPARSE_CONVOLUTION_CPU -#define _SPARSE_CONVOLUTION_CPU -#include -#include -#include - -void cpu_scatter_launch(const int n_in, const int n_out, const int c, - const float *in_feat, float *out_feat, - const int *kmap, const bool transpose); - -void cpu_gather_launch(const int n_k, const int n_in, const int c, - const float *in_feat, float *out_feat, - const int *kmap, const bool transpose); - -void ConvolutionForwardCPU(at::Tensor in_feat, at::Tensor out_feat, - at::Tensor kernel, at::Tensor neighbor_map, - at::Tensor neighbor_offset, const bool transpose); - -void ConvolutionBackwardCPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor kernel, at::Tensor grad_kernel, at::Tensor neighbor_map, - at::Tensor neighbor_offset, const bool transpose); - -#endif \ No newline at end of file diff --git a/torchsparse/src/convolution/convolution_gpu.cu b/torchsparse/src/convolution/convolution_gpu.cu deleted file mode 100644 index fd538f3..0000000 --- a/torchsparse/src/convolution/convolution_gpu.cu +++ /dev/null @@ -1,318 +0,0 @@ -#ifndef GPU_CONVOLUTION -#define GPU_CONVOLUTION -#include "../common/gpu.cuh" -#include -#include -#include - -// Given each output, get an input feature for each corresponding kernel weight -// and add the output in place -__global__ void inplace_convolution(const int n, const float *in_feat, - const int in_nchannel, float *out_feat, - const int out_nchannel, const float *kernel, - const int *neighbor_map) { - // n = out_nchannel * out_nrows - // The kernel computes one output scalar for each output index and each output - // channel. - CUDA_KERNEL_LOOP(index, n) { - const int out_ch = index % out_nchannel; - const int out_row = index / out_nchannel; - // Pytorch tensors in C-ordering with in_nchannels x out_nchannels - float tmp = 0.0; - const float *curr_kernel = kernel + out_ch; - const float *curr_in_feat = in_feat + out_row * in_nchannel; - for (int in_ch = 0; in_ch < in_nchannel; in_ch++) { - tmp += (*curr_kernel) * (*curr_in_feat); - curr_kernel += out_nchannel; - curr_in_feat += 1; - } - // Done independently, no need for atomicAdd - out_feat[neighbor_map[out_row] * out_nchannel + out_ch] += tmp; - } -} - -/** - * Matrix multiplication (CUDA Kernel) on the device: C = A * B - * wA is A's width and wB is B's width - */ -__global__ void matmul(const float *A, const int wA, const int hA, - const float *B, const int wB, const int hB, float *C, - const int *neighbor_map, const int nmap_size, - const bool transpose) { - // Use in_feat as A and kernel as B - - // Block index - const int bx = blockIdx.x; - const int by = blockIdx.y; - - // Thread index - const int tx = threadIdx.x; - const int ty = threadIdx.y; - - // Coordinate. x is for rows, y is for columns. - const int x = BLOCK_SIZE * bx + tx; - const int y = BLOCK_SIZE * by + ty; - - // Csub is used to store the element of the block sub-matrix - // that is computed by the thread - float Csub = 0; - - // out_npoints is the output size. - // conv: out_npoints <= hA; deconv: out_npoints >= hA. - // be careful about in_row_! - const int out_row_ = y < nmap_size ? neighbor_map[2 * y + 1]: -1; - const int in_row_ = y < nmap_size ? neighbor_map[2 * y] : -1; - const int out_row = transpose ? in_row_ : out_row_; - const int in_row = transpose ? out_row_ : in_row_; - - - - // Loop over all the sub-matrices of A and B - // required to compute the block sub-matrix - for (int s = 0; s < wA; s += BLOCK_SIZE) { - // Declaration of the shared memory array As used to - // store the sub-matrix of A - __shared__ float As[BLOCK_SIZE][BLOCK_SIZE]; - - // Declaration of the shared memory array Bs used to - // store the sub-matrix of B - __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE]; - - // Load the matrices from device memory - // to shared memory; each thread loads - // one element of each matrix - - - As[ty][tx] = ((s + tx) < wA && y < hA && in_row >= 0) ? A[wA * in_row + s + tx] : 0; - Bs[ty][tx] = ((s + ty) < hB && x < wB) ? B[wB * (s + ty) + x] : 0; - - // Synchronize to make sure the matrices are loaded - __syncthreads(); - - if(in_row >= 0 && out_row >= 0){ - // Multiply the two matrices together; - // each thread computes one element - // of the block sub-matrix - #pragma unroll - for (int k = 0; k < BLOCK_SIZE; ++k) { - Csub += As[ty][k] * Bs[k][tx]; - } - } - - // Synchronize to make sure that the preceding - // computation is done before loading two new - // sub-matrices of A and B in the next iteration - __syncthreads(); - } - - // Write the block sub-matrix to device memory; - // each thread writes one element - - if (out_row >= 0 && y < hA && x < wB){ - C[wB * out_row + x] += Csub; - } - // TODO: atomicAdd(&C[wB * out_row + x], Csub); // For conv transpose, it - // might fail due to overlapping outputs -} - -/** - * Matrix multiplication (CUDA Kernel) on the device: C = A * B^T, E = D^T * A - * wA is A's width and wB is B's width - * - * +---+ - * |B^T| - * +-------+ - * | | | - * | A | C | - * | | | - * | | | - * +------------------+ - * | D^T | E | - * +----------+---+ - * - */ -__global__ void matmul2(const float *A, const int wA, const int hA, - const float *B, const int wB, const int hB, - const float *D, const int wD, const int hD, float *C, - float *E, const int *neighbor_map, const int nmap_size, - const bool transpose) { - // Use grad_out_feat as A, transposed kernel weight as B, and in_feat as D - - // Block index - const int bx = blockIdx.x; - const int by = blockIdx.y; - - // Thread index - const int tx = threadIdx.x; - const int ty = threadIdx.y; - - // Coordinate. x is for rows, y is for columns. - const int x = BLOCK_SIZE * bx + tx; - const int y = BLOCK_SIZE * by + ty; - - - const int out_row_ = y < nmap_size ? neighbor_map[2 * y + 1]: -1; - const int in_row_ = y < nmap_size ? neighbor_map[2 * y] : -1; - const int out_row = transpose ? in_row_ : out_row_; - const int in_row = transpose ? out_row_ : in_row_; - - - // Csub is used to store the element of the block sub-matrix - // that is computed by the thread - float Csub = 0; - float Esub = 0; - - // Declaration of the shared memory array As used to - // store the sub-matrix of A - __shared__ float As[BLOCK_SIZE][BLOCK_SIZE]; - - // Declaration of the shared memory array Bs used to - // store the sub-matrix of B - __shared__ float BTs[BLOCK_SIZE][BLOCK_SIZE]; - - // Declaration of the shared memory array Ds used to - // store the sub-matrix of D - __shared__ float DTs[BLOCK_SIZE][BLOCK_SIZE]; - - // For Ds = D^T[...:..., ...:...], use the transposed grid dimension for A - DTs[ty][tx] = (x < wD && y < hD && in_row >= 0) ? D[wD * in_row + x] : 0; - - // Loop over all the sub-matrices of A and B - // required to compute the block sub-matrix - for (int s = 0; s < wA; s += BLOCK_SIZE) { - // Load the matrices from device memory - // to shared memory; each thread loads - // one element of each matrix - As[ty][tx] = ((s + tx) < wA && y < hA && out_row >= 0) ? A[wA * out_row + s + tx] : 0; - - // Transposed kernel - BTs[ty][tx] = ((s + ty) < wB && x < hB) ? B[wB * x + s + ty] : 0; - - // Synchronize to make sure the matrices are loaded - __syncthreads(); - - - // Multiply the two matrices together; - // each thread computes one element - // of the block sub-matrix -#pragma unroll - for (int k = 0; k < BLOCK_SIZE; ++k) { - Csub += As[ty][k] * BTs[k][tx]; - } - - Esub = 0; - - // For Esub, reset to 0 -#pragma unroll - for (int k = 0; k < BLOCK_SIZE; ++k) { - Esub += DTs[k][ty] * As[k][tx]; - } - - // Synchronize to make sure that the preceding - // computation is done before loading two new - // sub-matrices of A and B in the next iteration - __syncthreads(); - - // For the E matrix which requires accmulation of multiple blocks, use - // atomic addition. This can be replaced with a more sophisticaed reduction - // algorithm. - if ((bx * BLOCK_SIZE + ty) < wD && (s + tx) < wA) - atomicAdd(&E[wA * (bx * BLOCK_SIZE + ty) + (s + tx)], Esub); - } - - - // Write the block sub-matrix to device memory; - // each thread writes one element - if (y < hA && x < hB && in_row >= 0) - atomicAdd(&C[hB * in_row + x], Csub); -} - -void ConvolutionForwardKernelGPU( - const float *d_in_feat, int in_nchannel, float *d_out_feat, - int out_nchannel, const float *d_kernel, - const int* neighbor_map, - const int* neighbor_offset, - const int in_npoints, - const int out_npoints, - const int n_neighbors, - const bool transpose, - cublasHandle_t cuhandle, cudaStream_t stream) { - // For the in out buffer, use the pre allocated GPU memory space as thrust - // resize gives segfault. Also initializing it with torch allows us to - // allocate memory faster and efficiently. - - - int kernel_volume=n_neighbors, n_active_in_volume, num_kernels, - neighbor_step=min(out_npoints, in_npoints); - int cur_offset = 0; - - //printf("%d %d\n", in_buffer_size, in_npoints); - - // Iterate through each spatial kernel and get indices for in_map and out_map - - for (int k = 0; k < kernel_volume; k++) { - - n_active_in_volume = in_npoints; - if (n_active_in_volume / SHARED_BLOCK_SIZE < 65536) { - dim3 threads(SHARED_BLOCK_SIZE, SHARED_BLOCK_SIZE); - dim3 grid((out_nchannel + threads.x - 1) / threads.x, - (n_active_in_volume + threads.y - 1) / threads.y); - matmul<<>>( - d_in_feat, in_nchannel, n_active_in_volume, - &d_kernel[k * in_nchannel * out_nchannel], out_nchannel, in_nchannel, - d_out_feat, &neighbor_map[cur_offset], neighbor_offset[k], transpose); - } else { - printf("call2\n"); - num_kernels = out_nchannel * n_active_in_volume; - inplace_convolution - <<>>( - num_kernels, d_in_feat, in_nchannel, d_out_feat, out_nchannel, - &d_kernel[k * in_nchannel * out_nchannel], neighbor_map + cur_offset); - } - cur_offset += 2 * neighbor_offset[k]; - - } - -} - -void ConvolutionBackwardKernelGPU( - const float *d_in_feat, float *d_grad_in_feat, int in_nchannel, - const float *d_grad_out_feat, int out_nchannel, float *d_kernel, - float *d_grad_kernel, const int * neighbor_map, - const int * neighbor_offset, - const int in_npoints, - const int out_npoints, - const int n_neighbors, - const bool transpose, - cublasHandle_t cuhandle, cudaStream_t stream) { - int kernel_volume=n_neighbors, n_active_in_volume; - int neighbor_step=min(in_npoints, out_npoints); - int cur_offset = 0; - // Assume that old kernel will never be used. - for (int k = 0; k < kernel_volume; k++) { - // acceleration by setting good n_active_in_volume. - n_active_in_volume = neighbor_offset[k]; - if (n_active_in_volume == 0) - continue; - - - dim3 threads(SHARED_BLOCK_SIZE, SHARED_BLOCK_SIZE); - dim3 grid((in_nchannel + threads.x - 1) / threads.x, - (n_active_in_volume + threads.y - 1) / threads.y); - - matmul2<<>>( - d_grad_out_feat, out_nchannel, n_active_in_volume, // A - &d_kernel[k * in_nchannel * out_nchannel], out_nchannel, - in_nchannel, // B - d_in_feat, in_nchannel, n_active_in_volume, // D - d_grad_in_feat, // C - &d_grad_kernel[k * in_nchannel * out_nchannel], // E - neighbor_map + cur_offset, neighbor_offset[k], transpose); - - cur_offset += 2 * neighbor_offset[k]; - } - - -} - -#endif diff --git a/torchsparse/src/convolution/convolution_gpu.h b/torchsparse/src/convolution/convolution_gpu.h deleted file mode 100644 index 5ca82f4..0000000 --- a/torchsparse/src/convolution/convolution_gpu.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef _SPARSE_CONVOLUTION -#define _SPARSE_CONVOLUTION -#include -#include -#include -#include -#include -#include -#include -#include - -void ConvolutionForwardKernelGPU( - const float *d_in_feat, int in_nchannel, float *d_out_feat, - int out_nchannel, const float *d_kernel, - const int *neighbor_map, - const int *neighbor_offset, - const int in_npoints, - const int out_npoints, - const int n_neighbors, - const bool transpose, - cublasHandle_t cuhandle, cudaStream_t stream); - -void ConvolutionBackwardKernelGPU( - const float *d_in_feat, float *d_grad_in_feat, int in_nchannel, - const float *d_grad_out_feat, int out_nchannel, float *d_kernel, - float *d_grad_kernel, const int *neighbor_map, - const int *neighbor_offset, - const int in_npoints, - const int out_npoints, - const int n_neighbors, - const bool transpose, - cublasHandle_t cuhandle, cudaStream_t stream); - -void ConvolutionForwardGPU(at::Tensor in_feat, at::Tensor out_feat, - at::Tensor kernel, at::Tensor neighbor_map, - at::Tensor neighbor_offset, const bool transpose); - -void ConvolutionBackwardGPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor kernel, at::Tensor grad_kernel, at::Tensor neighbor_map, - at::Tensor neighbor_offset, const bool transpose); - -#endif \ No newline at end of file diff --git a/torchsparse/src/hash/hash.cpp b/torchsparse/src/hash/hash.cpp deleted file mode 100644 index d6f33e8..0000000 --- a/torchsparse/src/hash/hash.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include -#include -#include "hash_gpu.h" - -at::Tensor hash_forward( - const at::Tensor idx -) -{ - int N = idx.size(0); - at::Tensor out = torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long)); - hash_wrapper(N, idx.data_ptr(), out.data_ptr()); - return out; -} - - -at::Tensor kernel_hash_forward( - const at::Tensor idx, - const at::Tensor kernel_offset -) -{ - int N = idx.size(0); - int K = kernel_offset.size(0); - at::Tensor out = torch::zeros({K, N}, at::device(idx.device()).dtype(at::ScalarType::Long)); - kernel_hash_wrapper(N, K, idx.data_ptr(), kernel_offset.data_ptr(), out.data_ptr()); - return out; -} - - - -/* -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("hash_forward", &hash_forward, "Hashing forward (CUDA)"); - m.def("kernel_hash_forward", &kernel_hash_forward, "Kernel Hashing forward (CUDA)"); -} -*/ - diff --git a/torchsparse/src/hash/hash_cpu.cpp b/torchsparse/src/hash/hash_cpu.cpp deleted file mode 100644 index 1ba4503..0000000 --- a/torchsparse/src/hash/hash_cpu.cpp +++ /dev/null @@ -1,71 +0,0 @@ -#include -#include -#include "hash_cpu_header.h" - -void cpu_hash_wrapper(int N, const int * data, long * out){ - #pragma omp parallel for - for(int i = 0; i < N; i++){ - unsigned long long hash = 14695981039346656037UL; - for(int j = 0; j < 4; j++){ - hash ^= (unsigned int)data[4 * i + j]; - hash *= 1099511628211UL; - } - hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); - out[i] = hash; - } -} - - -void cpu_kernel_hash_wrapper(int N, int K, const int * data, const int *kernel_offset, long int * out){ - for(int k = 0; k < K; k++){ - #pragma omp parallel for - for(int i = 0; i < N; i++){ - int cur_coord[4]; - for(int j = 0; j < 3; j++){ - cur_coord[j] = data[i * 4 + j]+kernel_offset[k*3+j]; - } - cur_coord[3] = data[3]; - unsigned long long hash = 14695981039346656037UL; - for(int j = 0; j < 4; j++){ - hash ^= (unsigned int)cur_coord[j]; - hash *= 1099511628211UL; - } - hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); - out[k*N+i] = hash; - } - } -} - - -at::Tensor cpu_hash_forward( - const at::Tensor idx -) -{ - int N = idx.size(0); - at::Tensor out = torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long)); - cpu_hash_wrapper(N, idx.data_ptr(), out.data_ptr()); - return out; -} - - -at::Tensor cpu_kernel_hash_forward( - const at::Tensor idx, - const at::Tensor kernel_offset -) -{ - int N = idx.size(0); - int K = kernel_offset.size(0); - at::Tensor out = torch::zeros({K, N}, at::device(idx.device()).dtype(at::ScalarType::Long)); - cpu_kernel_hash_wrapper(N, K, idx.data_ptr(), kernel_offset.data_ptr(), out.data_ptr()); - return out; -} - - - -/* -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("hash_forward", &hash_forward, "Hashing forward (CUDA)"); - m.def("kernel_hash_forward", &kernel_hash_forward, "Kernel Hashing forward (CUDA)"); -} -*/ - diff --git a/torchsparse/src/hash/hash_cpu_header.h b/torchsparse/src/hash/hash_cpu_header.h deleted file mode 100644 index ece965d..0000000 --- a/torchsparse/src/hash/hash_cpu_header.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _SPARSE_HASH_CPU -#define _SPARSE_HASH_CPU -#include -#include - -//CUDA forward declarations -void cpu_hash_wrapper(int N, const int * data, long * out); -void cpu_kernel_hash_wrapper(int N, int K, const int * data, const int *kernel_offset, long int * out); -at::Tensor cpu_hash_forward( - const at::Tensor idx -); -at::Tensor cpu_kernel_hash_forward( - const at::Tensor idx, - const at::Tensor kernel_offset -); -#endif \ No newline at end of file diff --git a/torchsparse/src/hash/hash_gpu.cu b/torchsparse/src/hash/hash_gpu.cu deleted file mode 100644 index c28c80d..0000000 --- a/torchsparse/src/hash/hash_gpu.cu +++ /dev/null @@ -1,61 +0,0 @@ -#include -#include -#include - -//hashing -//input N*4 int32 tensor output N*1 int64 tensor -__global__ void hash_kernel(int N, const int *__restrict__ data, long int *__restrict__ out){ - int i = blockDim.x * blockIdx.x + threadIdx.x; - if(i < N){ - data += i * 4; - unsigned long long hash = 14695981039346656037UL; - for(int j = 0; j < 4; j++){ - hash ^= (unsigned int)data[j]; - hash *= 1099511628211UL; - } - hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); - out[i] = hash; - } -} - - -//kernel hashing: given data D and offset map K, generate D x K -//input N*4 int32 tensor, |K|*3 int32 tensor, output |K|*N int64 tensor -__global__ void kernel_hash_kernel(int N, int K, const int *__restrict__ data, const int * __restrict__ kernel_offset, long int *__restrict__ out){ - - extern __shared__ int kernel_offset_local[]; - - for(int i = 0; i < K * 3; i++){ - kernel_offset_local[i] = kernel_offset[i]; - } - __syncthreads(); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int k = idx % K; - int i = idx / K; - int cur_coord[4]; - if(i < N){ - data += i * 4; - for(int j = 0; j < 3; j++){ - cur_coord[j] = data[j]+kernel_offset[k*3+j]; - } - cur_coord[3] = data[3]; - unsigned long long hash = 14695981039346656037UL; - for(int j = 0; j < 4; j++){ - hash ^= (unsigned int)cur_coord[j]; - hash *= 1099511628211UL; - } - hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); - out[k*N+i] = hash; - } -} - - -void kernel_hash_wrapper(int N, int K, const int * data, const int *kernel_offset, long int * out){ - kernel_hash_kernel<<>>(N, K, data, kernel_offset, out); -} - - -void hash_wrapper(int N, const int * data, long int * out){ - hash_kernel<<>>(N, data, out); -} diff --git a/torchsparse/src/hash/hash_gpu.h b/torchsparse/src/hash/hash_gpu.h deleted file mode 100644 index 12d35d3..0000000 --- a/torchsparse/src/hash/hash_gpu.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _SPARSE_HASH -#define _SPARSE_HASH -#include -#include - -//CUDA forward declarations -void hash_wrapper(int N, const int * data, long * out); -void kernel_hash_wrapper(int N, int K, const int * data, const int *kernel_offset, long int * out); -at::Tensor hash_forward( - const at::Tensor idx -); -at::Tensor kernel_hash_forward( - const at::Tensor idx, - const at::Tensor kernel_offset -); -#endif \ No newline at end of file diff --git a/torchsparse/src/hashmap/hashmap.cu b/torchsparse/src/hashmap/hashmap.cu deleted file mode 100644 index 0b2a4f4..0000000 --- a/torchsparse/src/hashmap/hashmap.cu +++ /dev/null @@ -1,245 +0,0 @@ -#ifndef GPU_HASHMAP -#define GPU_HASHMAP -#include "hashmap.cuh" -#include -#include -#include - -typedef unsigned long long int VTYPE; - -__global__ void -cuckooBucketKernel_Multi(VTYPE * const key_buf, VTYPE * const val_buf, - const int size, const VTYPE * const keys, const VTYPE * const vals, - const int n, int * const counters, const int num_buckets) { - - // Get thread index. - int idx = threadIdx.x + blockIdx.x * blockDim.x; - - // Only threads within range are active. - if (idx < n) { - - // Do 1st-level hashing to get bucket id, then do atomic add to get index inside the bucket. - VTYPE key = keys[idx]; - VTYPE val = vals[idx]; - - int bucket_num = do_1st_hash(key, num_buckets); - int bucket_ofs = atomicAdd(&counters[bucket_num], 1); - - // Directly write the key into the table buffer. - if (bucket_ofs >= BUCKET_SIZE) - printf("%d/%d ERROR: bucket overflow! (n=%d, bucket_num=%d/%d, key=%d)\n", bucket_ofs, BUCKET_SIZE, n, bucket_num, num_buckets, key); - else { - //printf("good!\n"); - key_buf[bucket_num * BUCKET_SIZE + bucket_ofs] = key; - val_buf[bucket_num * BUCKET_SIZE + bucket_ofs] = val; - } - - } -} - - -__global__ void -cuckooInsertKernel_Multi(VTYPE * const key, VTYPE * const val, const VTYPE * const key_buf, const VTYPE * const val_buf, - const int size, const FuncConfig * const hash_func_configs, const int num_funcs, - const int * const counters, const int num_buckets, const int evict_bound, const int pos_width, - int * const rehash_requests) { - - // Create local cuckoo table in shared memory. Size passed in as the third kernel parameter. - extern __shared__ VTYPE local_key[]; - for (int i = 0; i < num_funcs; ++i){ - local_key[i * BUCKET_SIZE + threadIdx.x] = EMPTY_CELL; - } - - // might be useful - __syncthreads(); - - // Get thread index. - int idx = threadIdx.x + blockIdx.x * blockDim.x; - VTYPE cur_idx = idx; - - // Only threads within local bucket range are active. - if (threadIdx.x < counters[blockIdx.x]) { - // Set initial conditions. - VTYPE cur_key = key_buf[cur_idx]; - int cur_func = 0; - int evict_count = 0; - - // Start the test-kick-and-reinsert loops. - do { - int pos = do_2nd_hash(cur_key, hash_func_configs, cur_func, BUCKET_SIZE); - - VTYPE new_data = make_data(cur_idx+1, cur_func, pos_width); - - VTYPE old_idx = atomicExch(&local_key[cur_func * BUCKET_SIZE + pos], new_data); - /* - if(new_data == EMPTY_CELL){ - printf("WARNING, new data is 0, this should not happen! (%d %llu %llu)\n", pos, cur_idx+1, cur_key); - } - */ - - if (old_idx != EMPTY_CELL) { - cur_idx = fetch_val(old_idx, pos_width)-1; - //fixme - /* - if((int)cur_idx >= size || (int)cur_idx < 0){ - printf("WARNING, there's an overflow %d %d.\n", (int)cur_idx, size); - break; - } - */ - // potential overflow here. It seems that cur_idx < 0 is possible! - cur_key = key_buf[cur_idx]; - cur_func = (fetch_func(old_idx, pos_width) + 1) % num_funcs; - evict_count++; - } - - else{ - break; - } - - } while (evict_count < num_funcs * evict_bound); - - - // If exceeds eviction bound, then needs rehashing. - if (evict_count >= num_funcs * evict_bound) - atomicAdd(rehash_requests, 1); - } - - // Every thread write its responsible local slot into the global data table. - __syncthreads(); - for (int i = 0; i < num_funcs; ++i){ - VTYPE cur_idx = local_key[i * BUCKET_SIZE + threadIdx.x]; - if(cur_idx == EMPTY_CELL) { - continue; - } - int cur_func = fetch_func(cur_idx, pos_width); - cur_idx = fetch_val(cur_idx, pos_width)-1; - /* - if(cur_idx < 0 || cur_idx >= size){ - printf("WARNING %d\n", cur_idx); - } - */ - key[i * size + idx] = key_buf[cur_idx];//make_data(key_buf[cur_idx], cur_func, pos_width); - val[i * size + idx] = val_buf[cur_idx]; - } - -} - - - -__global__ void -cuckooLookupKernel_Multi(const VTYPE * const keys, VTYPE * const results, const int n, - const VTYPE * const all_keys, const VTYPE * const all_vals, const int size, - const FuncConfig * const hash_func_configs, const int num_funcs, - const int num_buckets, const int pos_width) { - - // Get thread index. - int idx = threadIdx.x + blockIdx.x * blockDim.x; - - // Only threads within range are active. - if (idx < n) { - VTYPE key = keys[idx]; - //int bucket_num = key % num_buckets; - int bucket_num = do_1st_hash(key, num_buckets); - for (int i = 0; i < num_funcs; ++i) { - int pos = bucket_num * BUCKET_SIZE + do_2nd_hash(key, hash_func_configs, i, BUCKET_SIZE); - //int pos = bucket_num * BUCKET_SIZE + do_2nd_hash(key, hash_func_configs, i, BUCKET_SIZE-1); - if(all_keys[i * size + pos] == key){ - //if (fetch_val(all_keys[i * size + pos], pos_width) == key) { - results[idx] = all_vals[i * size + pos] + 1; - return; - } - } - - // fixme: should be a value that will not be encountered. - results[idx] = EMPTY_CELL; - } -} - - -void -CuckooHashTableCuda_Multi::lookup_vals(const VTYPE * const keys, VTYPE *d_key, VTYPE *d_val, VTYPE * const results, const int n) { - // Launch the lookup kernel. - cuckooLookupKernel_Multi<<>>(keys, results, - n, d_key, d_val, _size, - _d_hash_func_configs, _num_funcs, - _num_buckets, _pos_width); - -} - - -int -CuckooHashTableCuda_Multi::insert_vals(const VTYPE * const keys, const VTYPE * const vals, VTYPE * d_key_buf, VTYPE * d_val_buf, VTYPE * d_key, VTYPE * d_val, const int n) { - - // - // Phase 1: Distribute keys into buckets. - // - - // Allocate GPU memory. - - int *d_counters = NULL; - - - cudaMalloc((void **) &d_counters, _num_buckets * sizeof(int)); - - cudaMemset(d_counters, 0, _num_buckets * sizeof(int)); - - // Invoke bucket kernel. - cuckooBucketKernel_Multi<<>>(d_key_buf, d_val_buf, - _size, keys, vals, - n, d_counters, _num_buckets); - - // - // Phase 2: Local cuckoo hashing. - // - - // Allocate GPU memory. - - cudaDeviceSynchronize(); - int *d_rehash_requests=NULL; - - cudaMalloc((void **) &d_rehash_requests, sizeof(int)); - - // Copy values onto GPU memory. - cudaMemcpy(_d_hash_func_configs, _hash_func_configs, _num_funcs * sizeof(FuncConfig), - cudaMemcpyHostToDevice); - - // Invoke insert kernel. Passes shared memory table size by the third argument. - // Loops until no rehashing needed. - - int rehash_count = 0; - do { - int rehash_requests = 0; - cudaMemset(d_rehash_requests, 0, sizeof(int)); - cuckooInsertKernel_Multi<<>>(d_key, d_val, d_key_buf, - d_val_buf, _size, _d_hash_func_configs, - _num_funcs, d_counters, _num_buckets, - _evict_bound, _pos_width, d_rehash_requests); - cudaMemcpy(&rehash_requests, d_rehash_requests, sizeof(int), cudaMemcpyDeviceToHost); - - if (rehash_requests == 0) - break; - else { - //printf("rehash %d %d\n", rehash_count, rehash_requests); - rehash_count++; - gen_hash_funcs(); - cudaMemcpy(_d_hash_func_configs, _hash_func_configs, _num_funcs * sizeof(FuncConfig), - cudaMemcpyHostToDevice); - } - } while (rehash_count < MAX_DEPTH); - - cudaDeviceSynchronize(); - - // Free GPU resources. - - if(d_counters!=NULL) cudaFree(d_counters); - if(d_rehash_requests!=NULL) cudaFree(d_rehash_requests); - - - //printf("%d\n", rehash_count); - return (rehash_count < MAX_DEPTH) ? rehash_count : ERR_DEPTH; -} - - -#endif - diff --git a/torchsparse/src/hashmap/hashmap.cuh b/torchsparse/src/hashmap/hashmap.cuh deleted file mode 100644 index 54624ab..0000000 --- a/torchsparse/src/hashmap/hashmap.cuh +++ /dev/null @@ -1,178 +0,0 @@ -#ifndef _CUCKOO_CUDA_MULTI_HPP_ -#define _CUCKOO_CUDA_MULTI_HPP_ -#include -#include -#include -#include -#include -#include -#include "cuda_runtime.h" - -/** Reserved value for indicating "empty". */ -#define EMPTY_CELL (0) -/** Max rehashing depth, and error depth. */ -#define MAX_DEPTH (100) -#define ERR_DEPTH (-1) -/** CUDA naive thread block size. */ -#define BLOCK_SIZE (256) -/** CUDA multi-level thread block size = bucket size. */ -#define BUCKET_SIZE (512) - -typedef unsigned long long int VTYPE; - -/** Struct of a hash function config. */ -typedef struct { - int rv; // Randomized XOR value. - int ss; // Randomized shift filter start position. -} FuncConfig; - - - -/** Hard code hash functions and all inline helper functions for CUDA kernels' use. */ -inline __device__ int -do_1st_hash(const VTYPE val, const int num_buckets) { - /* - if(num_buckets % 2) return val % num_buckets; - return val % (num_buckets-1); // Simply using modulo as 1st-level hashing. - */ - - //return val % (num_buckets-1); - return val % num_buckets; -} - -inline __device__ int -do_2nd_hash(const VTYPE val, const FuncConfig * const hash_func_configs, const int func_idx, - const int size) { - FuncConfig fc = hash_func_configs[func_idx]; - return ((val ^ fc.rv) >> fc.ss) % size; // XOR function as 2nd-level hashing. -} - -// trying to ignore EMPTY_CELL by adding 1 at make_data. -inline __device__ VTYPE -fetch_val(const VTYPE data, const int pos_width) { - return data >> pos_width; -} - -inline __device__ int -fetch_func(const VTYPE data, const int pos_width) { - return data & ((0x1 << pos_width) - 1); -} - -inline __device__ VTYPE -make_data(const VTYPE val, const int func, const int pos_width) { - return (val << pos_width) ^ func; // VTYPE CANNOT handle signed values currently! -} - - -/** - * - * Cuckoo hash table generic class. - * - */ -class CuckooHashTableCuda_Multi { - -private: - - /** Input parameters. */ - const int _size; - const int _evict_bound; - const int _num_funcs; - const int _pos_width; - const int _num_buckets; - - - //int *_d_out_coords; - //VTYPE *_d_results; - FuncConfig *_d_hash_func_configs; - - - /** Cuckoo hash function set. */ - FuncConfig *_hash_func_configs; - - /** Private operations. */ - void gen_hash_funcs() { - - // Calculate bit width of value range and table size. - int val_width = 8 * sizeof(VTYPE) - ceil(log2((double) _num_funcs)); - int bucket_width = ceil(log2((double) _num_buckets)); - int size_width = ceil(log2((double) BUCKET_SIZE)); - // Generate randomized configurations. - for (int i = 0; i < _num_funcs; ++i) { // At index 0 is a dummy function. - if (val_width - bucket_width <= size_width) - _hash_func_configs[i] = {rand(), 0}; - else{ - _hash_func_configs[i] = {rand(), rand() % (val_width - bucket_width - size_width + 1) - + bucket_width}; - } - } - }; - - /** Inline helper functions. */ - inline VTYPE fetch_val(const VTYPE data) { - return data >> _pos_width; - } - inline int fetch_func(const VTYPE data) { - return data & ((0x1 << _pos_width) - 1); - } - - - -public: - - /** Constructor & Destructor. */ - CuckooHashTableCuda_Multi(const int size, const int evict_bound, const int num_funcs) - : _size(size), _evict_bound(evict_bound), _num_funcs(num_funcs), - _pos_width(ceil(log2((double) _num_funcs))), - _num_buckets(ceil((double) _size / BUCKET_SIZE)) { - - srand(time(NULL)); - //_d_out_coords = NULL; - //_d_results = NULL; - _d_hash_func_configs = NULL; - _hash_func_configs = NULL; - _hash_func_configs = new FuncConfig[num_funcs]; - - // Generate initial hash function configs. - - gen_hash_funcs(); - - - cudaMalloc((void **) &_d_hash_func_configs, _num_funcs * sizeof(FuncConfig)); - cudaMemcpy(_d_hash_func_configs, _hash_func_configs, _num_funcs * sizeof(FuncConfig), - cudaMemcpyHostToDevice); - - - }; - ~CuckooHashTableCuda_Multi() { - if(_hash_func_configs!=NULL) delete[] _hash_func_configs; - - if(_d_hash_func_configs!=NULL) cudaFree(_d_hash_func_configs); - }; - - /** Supported operations. */ - int insert_vals(const VTYPE * const keys, const VTYPE * const vals, VTYPE * d_key_buf, VTYPE * d_val_buf, VTYPE * d_key, VTYPE * d_val, const int n); - // delete is not supported. - void lookup_vals(const VTYPE * const keys, VTYPE * const results, VTYPE *d_key, VTYPE *d_val, const int n); -}; - - - -__global__ void -cuckooBucketKernel_Multi(VTYPE * const key_buf, VTYPE * const val_buf, - const int size, const VTYPE * const keys, const VTYPE * const vals, - const int n, int * const counters, const int num_buckets); - - -__global__ void -cuckooInsertKernel_Multi(VTYPE * const key, VTYPE * const val, const VTYPE * const key_buf, const VTYPE * const val_buf, - const int size, const FuncConfig * const hash_func_configs, const int num_funcs, - const int * const counters, const int num_buckets, const int evict_bound, const int pos_width, - int * const rehash_requests); - -__global__ void -cuckooLookupKernel_Multi(const VTYPE * const keys, VTYPE * const results, const int n, - const VTYPE * const all_keys, const VTYPE * const all_vals, const int size, - const FuncConfig * const hash_func_configs, const int num_funcs, - const int num_buckets, const int pos_width); - -#endif diff --git a/torchsparse/src/hashmap/hashmap_cpu.cpp b/torchsparse/src/hashmap/hashmap_cpu.cpp deleted file mode 100644 index 5e6dff2..0000000 --- a/torchsparse/src/hashmap/hashmap_cpu.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include -#include "hashmap_cpu_header.hpp" -#include -#include -#include - - - - -void -HashTableCPU::lookup_vals(const int64_t * const keys, int64_t * const results, const int n) { - #pragma omp parallel for - for(int idx = 0; idx< n; idx++) { - int64_t key = keys[idx]; - //int bucket_num = key % num_buckets; - google::dense_hash_map::iterator iter = hashmap.find(key); - if(iter != hashmap.end()){ - results[idx] = iter->second; - } - else{ - results[idx] = 0; - } - } - -} - - - -void -HashTableCPU::insert_vals(const int64_t * const keys, const int64_t * const vals, const int n) { - for(int i = 0; i < 10; i++){ - printf("%d, %d, %d, %d\n", i, i -#include -#include -#include -#include -#include -#include - - - - -class HashTableCPU { - -private: - google::dense_hash_map hashmap; - -public: - HashTableCPU(){ - //hashmap.set_empty_key(0); - } - ~HashTableCPU(){} - void insert_vals(const int64_t * const keys, const int64_t * const vals, const int n); - void lookup_vals(const int64_t * const keys, int64_t * const results, const int n); - -}; - -#endif - diff --git a/torchsparse/src/interpolation/devox_cpu.cpp b/torchsparse/src/interpolation/devox_cpu.cpp deleted file mode 100644 index 4bf18c1..0000000 --- a/torchsparse/src/interpolation/devox_cpu.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include -#include -#include "devox_cpu_header.h" - -//make sure indices is int type -//feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c) -at::Tensor cpu_devoxelize_forward( - const at::Tensor feat, - const at::Tensor indices, - const at::Tensor weight) -{ - int c = feat.size(1); - int N = indices.size(0); - - at::Tensor out = torch::zeros({N, c}, at::device(feat.device()).dtype(at::ScalarType::Float)); -#pragma omp parallel for - for (int i = 0; i < N; i++) - { - int *indices_ = indices.data_ptr() + i * 8; - float *weight_ = weight.data_ptr() + i * 8; - for (int j = 0; j < c; j++) - { - float *feat_ = feat.data_ptr() + j; - float cur_feat; - for (int k = 0; k < 8; k++) - { - cur_feat = (indices_[k] >= 0) ? feat_[indices_[k] * c] : 0; - *(out.data_ptr() + i * c + j) += weight_[k] * cur_feat; - } - } - } - return out; -} - -//top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad: (b,c,s), s=r^3 -at::Tensor cpu_devoxelize_backward( - const at::Tensor top_grad, - const at::Tensor indices, - const at::Tensor weight, - int n) -{ - int c = top_grad.size(1); - int N = top_grad.size(0); - at::Tensor bottom_grad = torch::zeros({n, c}, at::device(top_grad.device()).dtype(at::ScalarType::Float)); - - for (int i = 0; i < N; i++) - { - int *indices_ = indices.data_ptr() + i * 8; - float *weight_ = weight.data_ptr() + i * 8; -#pragma omp parallel for - for (int j = 0; j < c; j++) - { - float *top_grad_ = top_grad.data_ptr() + j; - float cur_top_grad; - for (int k = 0; k < 8; k++) - { - cur_top_grad = (indices_[k] >= 0) ? top_grad_[indices_[k] * c] : 0; - *(bottom_grad.data_ptr() + indices_[k] * c + j) += weight_[k] * cur_top_grad; - } - } - } - - return bottom_grad; -} diff --git a/torchsparse/src/interpolation/devox_cpu_header.h b/torchsparse/src/interpolation/devox_cpu_header.h deleted file mode 100644 index 99d889c..0000000 --- a/torchsparse/src/interpolation/devox_cpu_header.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _SPARSE_DEVOXELIZE_CPU -#define _SPARSE_DEVOXELIZE_CPU -#include -#include - -at::Tensor cpu_devoxelize_forward( - const at::Tensor feat, - const at::Tensor indices, - const at::Tensor weight); -at::Tensor cpu_devoxelize_backward( - const at::Tensor top_grad, - const at::Tensor indices, - const at::Tensor weight, - int n); - -#endif \ No newline at end of file diff --git a/torchsparse/src/interpolation/devox_deterministic.cpp b/torchsparse/src/interpolation/devox_deterministic.cpp deleted file mode 100644 index ec785bc..0000000 --- a/torchsparse/src/interpolation/devox_deterministic.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include -#include -#include "devox_gpu.h" - -//make sure indices is int type -//feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c) -at::Tensor deterministic_devoxelize_forward( - const at::Tensor feat, - const at::Tensor indices, - const at::Tensor weight) -{ - int c = feat.size(1); - int N = indices.size(0); - - at::Tensor out = torch::zeros({N, c}, at::device(feat.device()).dtype(at::ScalarType::Float)); - deterministic_devoxelize_wrapper(N, c, indices.data_ptr(), weight.data_ptr(), feat.data_ptr(), out.data_ptr()); - return out; -} - -//top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad: (b,c,s), s=r^3 -at::Tensor deterministic_devoxelize_backward( - const at::Tensor top_grad, - const at::Tensor indices, - const at::Tensor weight, - int n) -{ - int c = top_grad.size(1); - int N = top_grad.size(0); - at::Tensor bottom_grad_int = torch::zeros({n, c}, at::device(top_grad.device()).dtype(at::ScalarType::Int)); - deterministic_devoxelize_grad_wrapper(N, n, c, indices.data_ptr(), weight.data_ptr(), top_grad.data_ptr(), bottom_grad_int.data_ptr()); - - at::Tensor bottom_grad = bottom_grad_int.to(at::ScalarType::Double); - bottom_grad /= 1e10; - return bottom_grad.to(at::ScalarType::Float); -} diff --git a/torchsparse/src/interpolation/devox_deterministic_gpu.cu b/torchsparse/src/interpolation/devox_deterministic_gpu.cu deleted file mode 100644 index e533472..0000000 --- a/torchsparse/src/interpolation/devox_deterministic_gpu.cu +++ /dev/null @@ -1,61 +0,0 @@ -#include -#include -#include -#include "../common/gpu.cuh" - -//input features (n, c), indices (N, 8), weight (N, 8) -> output features (N, c) -__global__ void deterministic_devoxelize_kernel(int N, int c, const int *__restrict__ indices, const float *__restrict__ weight, const float *__restrict__ feat, float *__restrict__ out){ - int index = blockIdx.x * blockDim.x + threadIdx.x; - int i = index / c; - int j = index % c; - - if(i < N){ - const int* indices_ = indices + 8 * i; - const float *weight_ = weight + 8 * i; - const float *feat_ = feat + j; - - float cur_feat; - for(int k = 0; k < 8; k++){ - cur_feat = (indices_[k] >= 0) ? feat_[indices_[k] * c] : 0; - out[i * c + j] += weight_[k] * cur_feat; - } - - } - -} - -//input weight (N, 8), indices (N, 8), top_grad (N, c) -> bottom grad (n, c) -__global__ void deterministic_devoxelize_grad_kernel(int N, int n, int c, const int *__restrict__ indices, const float *__restrict__ weight, const float *__restrict__ top_grad, int *__restrict__ bottom_grad){ - - int index = blockIdx.x * blockDim.x + threadIdx.x; - int i = index / c; - int j = index % c; - - - if(i < N){ - const int* indices_ = indices + 8 * i; - const float *weight_ = weight + 8 * i; - - float cur_top_grad = top_grad[i * c + j]; - - - #pragma unroll - for(int k = 0; k < 8; k++){ - float grad_float = weight_[k]*cur_top_grad; - int64_t grad_int = (int64_t)round(grad_float * 1e10); - if(indices_[k] >= 0) atomicAdd(&bottom_grad[indices_[k]*c+j], (int)grad_int); - } - - - } -} - - - -void deterministic_devoxelize_wrapper(int N, int c, const int * indices, const float * weight, const float * feat, float * out){ - deterministic_devoxelize_kernel<<>>(N, c, indices, weight, feat, out); -} - -void deterministic_devoxelize_grad_wrapper(int N, int n, int c, const int *indices, const float * weight, const float * top_grad, int * bottom_grad){ - deterministic_devoxelize_grad_kernel<<>>(N, n, c, indices, weight, top_grad, bottom_grad); -} diff --git a/torchsparse/src/interpolation/devox_gpu.cu b/torchsparse/src/interpolation/devox_gpu.cu deleted file mode 100644 index 1f6286b..0000000 --- a/torchsparse/src/interpolation/devox_gpu.cu +++ /dev/null @@ -1,108 +0,0 @@ -#include -#include -#include -#include -#include - - -//input features (n, c), indices (N, 8), weight (N, 8) -> output features (N, c) -template -__global__ void devoxelize_kernel(int N, int c, const int *__restrict__ indices, const scalar_t *__restrict__ weight, const scalar_t *__restrict__ feat, scalar_t *__restrict__ out){ - int index = blockIdx.x * blockDim.x + threadIdx.x; - int i = index / c; - int j = index % c; - - if(i < N){ - const int* indices_ = indices + 8 * i; - const scalar_t *weight_ = weight + 8 * i; - const scalar_t *feat_ = feat + j; - - scalar_t cur_feat; - for(int k = 0; k < 8; k++){ - cur_feat = 0; - if (indices_[k] >= 0) - cur_feat = feat_[indices_[k] * c]; - - out[i * c + j] += weight_[k] * cur_feat; - } - - } - -} - -//input weight (N, 8), indices (N, 8), top_grad (N, c) -> bottom grad (n, c) -template -__global__ void devoxelize_grad_kernel(int N, int n, int c, const int *__restrict__ indices, const scalar_t *__restrict__ weight, const scalar_t *__restrict__ top_grad, scalar_t *__restrict__ bottom_grad){ - - int index = blockIdx.x * blockDim.x + threadIdx.x; - int i = index / c; - int j = index % c; - - - if(i < N){ - const int* indices_ = indices + 8 * i; - const scalar_t *weight_ = weight + 8 * i; - - scalar_t cur_top_grad = top_grad[i * c + j]; - - #pragma unroll - for(int k = 0; k < 8; k++){ - if(indices_[k] >= 0) atomicAdd(&bottom_grad[indices_[k]*c+j], weight_[k]*cur_top_grad); - } - - } -} - - -//make sure indices is int type -//feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c) -at::Tensor devoxelize_forward( - const at::Tensor feat, - const at::Tensor indices, - const at::Tensor weight -) -{ - int c = feat.size(1); - int N = indices.size(0); - - at::Tensor out = torch::zeros({N, c}, at::device(feat.device()).dtype(feat.dtype())); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(feat.type(), "devoxelize_forward", ([&] - { devoxelize_kernel<<>>( - N, - c, - indices.data_ptr(), - weight.data_ptr(), - feat.data_ptr(), - out.data_ptr()); - })); - - return out; -} - - -//top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad: (b,c,s), s=r^3 -at::Tensor devoxelize_backward( - const at::Tensor top_grad, - const at::Tensor indices, - const at::Tensor weight, - int n -) -{ - int c = top_grad.size(1); - int N = top_grad.size(0); - at::Tensor bottom_grad = torch::zeros({n, c}, at::device(top_grad.device()).dtype(top_grad.dtype())); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(top_grad.type(), "devoxelize_backward", ([&] - { devoxelize_grad_kernel<<>>( - N, - n, - c, - indices.data_ptr(), - weight.data_ptr(), - top_grad.data_ptr(), - bottom_grad.data_ptr()); - })); - - return bottom_grad; -} diff --git a/torchsparse/src/interpolation/devox_gpu.h b/torchsparse/src/interpolation/devox_gpu.h deleted file mode 100644 index a522052..0000000 --- a/torchsparse/src/interpolation/devox_gpu.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _SPARSE_DEVOXELIZE -#define _SPARSE_DEVOXELIZE -#include -#include - -//CUDA forward declarations -void deterministic_devoxelize_wrapper(int N, int c, const int *indices, const float *weight, const float *feat, float *out); -void deterministic_devoxelize_grad_wrapper(int N, int n, int c, const int *indices, const float *weight, const float *top_grad, int *bottom_grad); -at::Tensor devoxelize_forward( - const at::Tensor feat, - const at::Tensor indices, - const at::Tensor weight); -at::Tensor devoxelize_backward( - const at::Tensor top_grad, - const at::Tensor indices, - const at::Tensor weight, - int n); -at::Tensor deterministic_devoxelize_forward( - const at::Tensor feat, - const at::Tensor indices, - const at::Tensor weight); -at::Tensor deterministic_devoxelize_backward( - const at::Tensor top_grad, - const at::Tensor indices, - const at::Tensor weight, - int n); -#endif \ No newline at end of file diff --git a/torchsparse/src/others/count.cpp b/torchsparse/src/others/count.cpp deleted file mode 100644 index bf12279..0000000 --- a/torchsparse/src/others/count.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include -#include -#include "count_gpu.h" - - -//make sure indices is int type -//feat: (b,c,n) indices: (b,n) -> out: (b,c,s), out_indices: (b,n) (preprocessed indices) -at::Tensor count_forward( - const at::Tensor idx, - const int s -) -{ - //return group_point_forward_gpu(points, indices); - int N = idx.size(0); - at::Tensor out = torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int)); - count_wrapper(N, idx.data_ptr(), out.data_ptr()); - return out; -} - - -/* -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("count_forward", &count_forward, "Counting forward (CUDA)"); -} -*/ - diff --git a/torchsparse/src/others/count_cpu.cpp b/torchsparse/src/others/count_cpu.cpp deleted file mode 100644 index 93e20b6..0000000 --- a/torchsparse/src/others/count_cpu.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include -#include "count_cpu_header.h" - -at::Tensor cpu_count_forward( - const at::Tensor idx, - const int s) -{ - //return group_point_forward_gpu(points, indices); - int N = idx.size(0); - at::Tensor out = torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int)); - //count_wrapper(N, idx.data_ptr(), out.data_ptr()); - int *idx_ = idx.data_ptr(); - int *out_ = out.data_ptr(); -#pragma omp parallel for - for (int i = 0; i < N; i++) - { - int cur_idx = idx_[i]; - if (cur_idx < 0) - continue; -#pragma omp atomic - out_[cur_idx]++; - } - return out; -} diff --git a/torchsparse/src/others/count_cpu_header.h b/torchsparse/src/others/count_cpu_header.h deleted file mode 100644 index cf69f86..0000000 --- a/torchsparse/src/others/count_cpu_header.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef _SPARSE_COUNT_CPU -#define _SPARSE_COUNT_CPU -#include - -at::Tensor cpu_count_forward( - const at::Tensor idx, - const int s -); - -#endif \ No newline at end of file diff --git a/torchsparse/src/others/count_gpu.cu b/torchsparse/src/others/count_gpu.cu deleted file mode 100644 index d748277..0000000 --- a/torchsparse/src/others/count_gpu.cu +++ /dev/null @@ -1,17 +0,0 @@ -#include -#include -#include - -//counting -//input N*3 int32 tensor output N*1 int64 tensor -__global__ void count_kernel(int N, const int *__restrict__ data, int *__restrict__ out){ - int i = blockDim.x * blockIdx.x + threadIdx.x; - if(i < N){ - if(data[i] >= 0) atomicAdd(&out[data[i]], 1); - } -} - - -void count_wrapper(int N, const int * data, int * out){ - count_kernel<<>>(N, data, out); -} diff --git a/torchsparse/src/others/count_gpu.h b/torchsparse/src/others/count_gpu.h deleted file mode 100644 index a97cbd9..0000000 --- a/torchsparse/src/others/count_gpu.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _SPARSE_COUNT -#define _SPARSE_COUNT -#include -#include - -//CUDA forward declarations -void count_wrapper(int N, const int * data, int * out); -at::Tensor count_forward( - const at::Tensor idx, - const int s -); -#endif \ No newline at end of file diff --git a/torchsparse/src/others/insertion_cpu.cpp b/torchsparse/src/others/insertion_cpu.cpp deleted file mode 100644 index 25ffbd8..0000000 --- a/torchsparse/src/others/insertion_cpu.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include -#include "insertion_cpu_header.h" - -at::Tensor cpu_insertion_forward( - const at::Tensor inputs, - const at::Tensor idx, - const at::Tensor counts) -{ - //return group_point_forward_gpu(points, indices); - - int N = inputs.size(0); - int c = inputs.size(1); - int N1 = counts.size(0); - at::Tensor out = torch::zeros({N1, c}, at::device(idx.device()).dtype(at::ScalarType::Float)); - for (int i = 0; i < N; i++) - { - int pos = *(idx.data_ptr() + i); - if (*(counts.data_ptr() + pos) == 0) - continue; -#pragma omp parallel for - for (int j = 0; j < c; j++) - { - *(out.data_ptr() + pos * c + j) += *(inputs.data_ptr() + i * c + j) / (float)(*(counts.data_ptr() + pos)); - } - } - return out; -} - -at::Tensor cpu_insertion_backward( - const at::Tensor top_grad, - const at::Tensor idx, - const at::Tensor counts, - const int N) -{ - //return group_point_forward_gpu(points, indices); - - int c = top_grad.size(1); - //int N1 = counts.size(0); - at::Tensor bottom_grad = torch::zeros({N, c}, at::device(idx.device()).dtype(at::ScalarType::Float)); - for (int i = 0; i < N; i++) - { - if (*(counts.data_ptr() + *(idx.data_ptr() + i)) == 0) - continue; -#pragma omp parallel for - for (int j = 0; j < c; j++) - { - *(bottom_grad.data_ptr() + i * c + j) = *(top_grad.data_ptr() + *(idx.data_ptr() + i) * c + j) / (float)(*(counts.data_ptr() + *(idx.data_ptr() + i))); - } - } - return bottom_grad; -} diff --git a/torchsparse/src/others/insertion_cpu_header.h b/torchsparse/src/others/insertion_cpu_header.h deleted file mode 100644 index 6450550..0000000 --- a/torchsparse/src/others/insertion_cpu_header.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _SPARSE_INSERT_CPU -#define _SPARSE_INSERT_CPU -#include -#include - - -at::Tensor cpu_insertion_forward( - const at::Tensor inputs, - const at::Tensor idx, - const at::Tensor counts -); -at::Tensor cpu_insertion_backward( - const at::Tensor top_grad, - const at::Tensor idx, - const at::Tensor counts, - const int N -); -#endif \ No newline at end of file diff --git a/torchsparse/src/others/insertion_gpu.cu b/torchsparse/src/others/insertion_gpu.cu deleted file mode 100644 index 5e8dc97..0000000 --- a/torchsparse/src/others/insertion_gpu.cu +++ /dev/null @@ -1,85 +0,0 @@ -#include -#include -#include -#include -#include - -//hashing -//input N*F float tensor, pointer to output N'*F int64 tensor, N*1 count tensor, N*1 index tensor -template -__global__ void insertion_kernel(int N, int c, int s, const scalar_t *__restrict__ data, const int *__restrict__ idx, const int *__restrict__ counts, scalar_t *__restrict__ out){ - int index = blockDim.x * blockIdx.x + threadIdx.x; - int i = index / c; - int j = index % c; - if(i < N){ - int pos = idx[i]; - if(pos < 0 || pos >= s || counts[pos] == 0) return; - atomicAdd(&out[pos*c+j], data[i*c+j] / float(counts[pos])); - } -} - -template -__global__ void insertion_grad_kernel(int N, int c, int s, const scalar_t *__restrict__ top_grad, const int *__restrict__ idx, const int *__restrict__ counts, scalar_t *__restrict__ bottom_grad){ - int index = blockDim.x * blockIdx.x + threadIdx.x; - int i = index / c; - int j = index % c; - if(i < N){ - int pos = idx[i]; - if(pos < 0 || pos >= s || counts[pos]==0) return; - atomicAdd(&bottom_grad[i*c+j], top_grad[pos*c+j] / float(counts[pos])); - } -} - - -at::Tensor insertion_forward( - const at::Tensor inputs, - const at::Tensor idx, - const at::Tensor counts -) -{ - int N = inputs.size(0); - int c = inputs.size(1); - int N1 = counts.size(0); - - at::Tensor out = torch::zeros({N1, c}, at::device(idx.device()).dtype(inputs.dtype())); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs.type(), "insertion_forward", ([&] - { insertion_kernel<<>>( - N, - c, - N1, - inputs.data_ptr(), - idx.data_ptr(), - counts.data_ptr(), - out.data_ptr()); - })); - - return out; -} - - -at::Tensor insertion_backward( - const at::Tensor top_grad, - const at::Tensor idx, - const at::Tensor counts, - const int N -) -{ - int c = top_grad.size(1); - int N1 = counts.size(0); - - at::Tensor bottom_grad = torch::zeros({N, c}, at::device(idx.device()).dtype(top_grad.dtype())); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(top_grad.type(), "insertion_backward", ([&] - { insertion_grad_kernel<<>>( - N, - c, - N1, - top_grad.data_ptr(), - idx.data_ptr(), - counts.data_ptr(), - bottom_grad.data_ptr()); - })); - - return bottom_grad; -} diff --git a/torchsparse/src/others/insertion_gpu.h b/torchsparse/src/others/insertion_gpu.h deleted file mode 100644 index 1c5fd02..0000000 --- a/torchsparse/src/others/insertion_gpu.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _SPARSE_INSERT -#define _SPARSE_INSERT -#include -#include - -//make sure indices is int type -//feat: (b,c,n) indices: (b,n) -> out: (b,c,s), out_indices: (b,n) (preprocessed indices) - -at::Tensor insertion_forward( - const at::Tensor inputs, - const at::Tensor idx, - const at::Tensor counts -); -at::Tensor insertion_backward( - const at::Tensor top_grad, - const at::Tensor idx, - const at::Tensor counts, - const int N -); -#endif \ No newline at end of file diff --git a/torchsparse/src/others/query.cpp b/torchsparse/src/others/query.cpp deleted file mode 100644 index b305977..0000000 --- a/torchsparse/src/others/query.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include "../hashmap/hashmap.cuh" -#include -#include -#include -#include "query_gpu.h" - - -std::vector query_forward( - const at::Tensor hash_query, - const at::Tensor hash_target, - const at::Tensor idx_target -) -{ - //return group_point_forward_gpu(points, indices); - int n = hash_target.size(0); - int n1 = hash_query.size(0); - const int nextPow2 = pow(2,ceil(log2((double)n))); - // When n is large, the hash values tend to be more evenly distrubuted and choosing table_size to be - // 2 * nextPow2 typically suffices. For smaller n, the effect of uneven distribution of hash values is more - // pronounced and hence we choose table_size to be 4 * nextPow2 to reduce the chance of bucket overflow. - int table_size = (n < 2048) ? 4 * nextPow2 : 2 * nextPow2; - if(table_size < 512){ - table_size = 512; - } - int num_funcs = 3; - CuckooHashTableCuda_Multi in_hash_table(table_size, 8 * ceil(log2((double)n)), - num_funcs); - at::Tensor key_buf = torch::zeros({table_size}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); - at::Tensor val_buf = torch::zeros({table_size}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); - at::Tensor key = torch::zeros({num_funcs*table_size}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); - at::Tensor val = torch::zeros({num_funcs*table_size}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); - - in_hash_table.insert_vals((unsigned long long int*)(hash_target.data_ptr()), (unsigned long long int*)(idx_target.data_ptr()), (unsigned long long int*)(key_buf.data_ptr()), (unsigned long long int*)(val_buf.data_ptr()), (unsigned long long int*)(key.data_ptr()), (unsigned long long int*)(val.data_ptr()), n); - - at::Tensor out = torch::zeros({n1}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); - - in_hash_table.lookup_vals((unsigned long long int*)(hash_query.data_ptr()), (unsigned long long int*)(key.data_ptr()), (unsigned long long int*)(val.data_ptr()), (unsigned long long int*)(out.data_ptr()), n1); - return {out, key_buf, val_buf, key}; -} - - -/* -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("query_forward", &query_forward, "hash query forward (CUDA)"); -} -*/ - - diff --git a/torchsparse/src/others/query_cpu_header.h b/torchsparse/src/others/query_cpu_header.h deleted file mode 100644 index 9033f3d..0000000 --- a/torchsparse/src/others/query_cpu_header.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _SPARSE_QUERY_CPU -#define _SPARSE_QUERY_CPU -#include -#include -#include -#include - -at::Tensor cpu_query_forward( - const at::Tensor hash_query, - const at::Tensor hash_target, - const at::Tensor idx_target -); -#endif \ No newline at end of file diff --git a/torchsparse/src/others/query_gpu.h b/torchsparse/src/others/query_gpu.h deleted file mode 100644 index 8e4c884..0000000 --- a/torchsparse/src/others/query_gpu.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _SPARSE_QUERY -#define _SPARSE_QUERY -#include -#include -#include -#include - -std::vector query_forward( - const at::Tensor hash_query, - const at::Tensor hash_target, - const at::Tensor idx_target -); -#endif \ No newline at end of file diff --git a/torchsparse/src/torchsparse_bindings.cpp b/torchsparse/src/torchsparse_bindings.cpp deleted file mode 100644 index d12f5f3..0000000 --- a/torchsparse/src/torchsparse_bindings.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include -#include -#include "convolution/convolution_cpu_header.h" -#include "hash/hash_cpu_header.h" -#include "interpolation/devox_cpu_header.h" -#include "others/insertion_cpu_header.h" -#include "others/query_cpu_header.h" -#include "others/count_cpu_header.h" - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("sparseconv_cpu_forward", &ConvolutionForwardCPU, "point cloud convolution forward (CPU)"); - m.def("sparseconv_cpu_backward", &ConvolutionBackwardCPU, "point cloud convolution backward (CPU)"); - m.def("cpu_hash_forward", &cpu_hash_forward, "Hashing forward (CPU)"); - m.def("cpu_kernel_hash_forward", &cpu_kernel_hash_forward, "Kernel Hashing forward (CPU)"); - m.def("cpu_insertion_forward", &cpu_insertion_forward, "Insertion forward (CPU)"); - m.def("cpu_insertion_backward", &cpu_insertion_backward, "Insertion backward (CPU)"); - m.def("cpu_devoxelize_forward", &cpu_devoxelize_forward, "Devoxelization forward (CPU)"); - m.def("cpu_devoxelize_backward", &cpu_devoxelize_backward, "Devoxelization backward (CPU)"); - m.def("cpu_query_forward", &cpu_query_forward, "hash query forward (CPU)"); - m.def("cpu_count_forward", &cpu_count_forward, "count forward (CPU)"); -} - - diff --git a/torchsparse/src/torchsparse_bindings_gpu.cpp b/torchsparse/src/torchsparse_bindings_gpu.cpp deleted file mode 100644 index e27977f..0000000 --- a/torchsparse/src/torchsparse_bindings_gpu.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include -#include -#include -#include "convolution/convolution_cpu_header.h" -#include "hash/hash_cpu_header.h" -#include "others/insertion_cpu_header.h" -#include "others/query_cpu_header.h" -#include "convolution/convolution_gpu.h" -#include "hash/hash_gpu.h" -#include "interpolation/devox_gpu.h" -#include "interpolation/devox_cpu_header.h" -#include "others/count_gpu.h" -#include "others/insertion_gpu.h" -#include "others/insertion_cpu_header.h" -#include "others/query_gpu.h" -#include "others/count_cpu_header.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("sparseconv_cpu_forward", &ConvolutionForwardCPU, "point cloud convolution forward (CPU)"); - m.def("sparseconv_cpu_backward", &ConvolutionBackwardCPU, "point cloud convolution backward (CPU)"); - m.def("cpu_kernel_hash_forward", &cpu_kernel_hash_forward, "Kernel Hashing forward (CPU)"); - m.def("cpu_insertion_forward", &cpu_insertion_forward, "Insertion forward (CPU)"); - m.def("cpu_insertion_backward", &cpu_insertion_backward, "Insertion backward (CPU)"); - m.def("cpu_query_forward", &cpu_query_forward, "hash query forward (CPU)"); - m.def("sparseconv_forward", &ConvolutionForwardGPU, "point cloud convolution forward (CUDA)"); - m.def("sparseconv_backward", &ConvolutionBackwardGPU, "point cloud convolution backward (CUDA)"); - m.def("hash_forward", &hash_forward, "Hashing forward (CUDA)"); - m.def("kernel_hash_forward", &kernel_hash_forward, "Kernel Hashing forward (CUDA)"); - m.def("cpu_hash_forward", &cpu_hash_forward, "Hashing forward (CPU)"); - m.def("devoxelize_forward", &devoxelize_forward, "Devoxelization forward (CUDA)"); - m.def("devoxelize_backward", &devoxelize_backward, "Devoxelization backward (CUDA)"); - m.def("deterministic_devoxelize_forward", &deterministic_devoxelize_forward, "Devoxelization forward (CUDA)"); - m.def("deterministic_devoxelize_backward", &deterministic_devoxelize_backward, "Devoxelization backward (CUDA)"); - m.def("cpu_devoxelize_forward", &cpu_devoxelize_forward, "Devoxelization forward (CPU)"); - m.def("cpu_devoxelize_backward", &cpu_devoxelize_backward, "Devoxelization backward (CPU)"); - m.def("count_forward", &count_forward, "Counting forward (CUDA)"); - m.def("cpu_count_forward", &cpu_count_forward, "count forward (CPU)"); - m.def("insertion_forward", &insertion_forward, "Insertion forward (CUDA)"); - m.def("insertion_backward", &insertion_backward, "Insertion backward (CUDA)"); - m.def("cpu_insertion_forward", &cpu_insertion_forward, "Insertion forward (CPU)"); - m.def("cpu_insertion_backward", &cpu_insertion_backward, "Insertion backward (CPU)"); - m.def("query_forward", &query_forward, "hash query forward (CUDA)"); -} diff --git a/torchsparse/tensors.py b/torchsparse/tensors.py new file mode 100644 index 0000000..810f020 --- /dev/null +++ b/torchsparse/tensors.py @@ -0,0 +1,99 @@ +from typing import Any, Dict, Tuple, Union + +import torch + +from torchsparse.utils import make_ntuple + +__all__ = ['SparseTensor', 'PointTensor'] + + +class SparseTensor: + + def __init__(self, + feats: torch.Tensor, + coords: torch.Tensor, + stride: Union[int, Tuple[int, ...]] = 1) -> None: + self.feats = feats + self.coords = coords + self.stride = make_ntuple(stride, ndim=3) + self.cmaps: Dict[Tuple[int, ...], torch.Tensor] = {} + self.kmaps: Dict[Tuple[Any, ...], Any] = {} + + @property + def F(self) -> torch.Tensor: + return self.feats + + @F.setter + def F(self, feats: torch.Tensor) -> None: + self.feats = feats + + @property + def C(self) -> torch.Tensor: + return self.coords + + @C.setter + def C(self, coords: torch.Tensor) -> None: + self.coords = coords + + @property + def s(self) -> Tuple[int, ...]: + return self.stride + + @s.setter + def s(self, stride: Union[int, Tuple[int, ...]]) -> None: + self.stride = make_ntuple(stride, ndim=3) + + def cuda(self): + self.feats = self.feats.cuda() + self.coords = self.coords.cuda() + return self + + def detach(self): + self.feats = self.feats.detach() + self.coords = self.coords.detach() + return self + + def to(self, device, non_blocking=True): + self.feats = self.feats.to(device, non_blocking=non_blocking) + self.coords = self.coords.to(device, non_blocking=non_blocking) + return self + + def __add__(self, other): + tensor = SparseTensor(self.feats + other.feats, self.coords, + self.stride) + tensor.cmaps = self.cmaps + tensor.kmaps = self.kmaps + return tensor + + +class PointTensor: + + def __init__(self, feats, coords, idx_query=None, weights=None): + self.F = feats + self.C = coords + self.idx_query = idx_query if idx_query is not None else {} + self.weights = weights if weights is not None else {} + self.additional_features = {} + self.additional_features['idx_query'] = {} + self.additional_features['counts'] = {} + + def cuda(self): + self.F = self.F.cuda() + self.C = self.C.cuda() + return self + + def detach(self): + self.F = self.F.detach() + self.C = self.C.detach() + return self + + def to(self, device, non_blocking=True): + self.F = self.F.to(device, non_blocking=non_blocking) + self.C = self.C.to(device, non_blocking=non_blocking) + return self + + def __add__(self, other): + tensor = PointTensor(self.F + other.F, self.C, self.idx_query, + self.weights) + tensor.additional_features = self.additional_features + return tensor diff --git a/torchsparse/utils/__init__.py b/torchsparse/utils/__init__.py index 4215378..16281fe 100644 --- a/torchsparse/utils/__init__.py +++ b/torchsparse/utils/__init__.py @@ -1,2 +1 @@ -from .helpers import * -from .kernel import * \ No newline at end of file +from .utils import * diff --git a/torchsparse/utils/collate.py b/torchsparse/utils/collate.py new file mode 100644 index 0000000..f9ae5b4 --- /dev/null +++ b/torchsparse/utils/collate.py @@ -0,0 +1,59 @@ +from typing import Any, List + +import numpy as np +import torch + +from torchsparse import SparseTensor + +__all__ = ['sparse_collate', 'sparse_collate_fn'] + + +def sparse_collate(inputs: List[SparseTensor]) -> SparseTensor: + coords, feats = [], [] + stride = inputs[0].stride + + for k, x in enumerate(inputs): + if isinstance(x.coords, np.ndarray): + x.coords = torch.tensor(x.coords) + if isinstance(x.feats, np.ndarray): + x.feats = torch.tensor(x.feats) + + assert isinstance(x.coords, torch.Tensor), type(x.coords) + assert isinstance(x.feats, torch.Tensor), type(x.feats) + assert x.stride == stride, (x.stride, stride) + + input_size = x.coords.shape[0] + batch = torch.full((input_size, 1), + k, + device=x.coords.device, + dtype=torch.int) + + coords.append(torch.cat((x.coords, batch), dim=1)) + feats.append(x.feats) + + coords = torch.cat(coords, dim=0) + feats = torch.cat(feats, dim=0) + output = SparseTensor(coords=coords, feats=feats, stride=stride) + return output + + +def sparse_collate_fn(inputs: List[Any]) -> Any: + if isinstance(inputs[0], dict): + output = {} + for name in inputs[0].keys(): + if isinstance(inputs[0][name], dict): + output[name] = sparse_collate_fn( + [input[name] for input in inputs]) + elif isinstance(inputs[0][name], np.ndarray): + output[name] = torch.stack( + [torch.tensor(input[name]) for input in inputs], dim=0) + elif isinstance(inputs[0][name], torch.Tensor): + output[name] = torch.stack([input[name] for input in inputs], + dim=0) + elif isinstance(inputs[0][name], SparseTensor): + output[name] = sparse_collate([input[name] for input in inputs]) + else: + output[name] = [input[name] for input in inputs] + return output + else: + return inputs diff --git a/torchsparse/utils/helpers.py b/torchsparse/utils/helpers.py deleted file mode 100644 index 76b6552..0000000 --- a/torchsparse/utils/helpers.py +++ /dev/null @@ -1,265 +0,0 @@ -from collections import Sequence - -import numpy as np -import torch -from torchsparse import SparseTensor - -__all__ = [ - 'ravel_hash_vec', 'sparse_quantize', 'sparse_collate', 'sparse_collate_fn', - 'sparse_collate_tensors', 'make_tuple' -] - - -def ravel_hash_vec(arr): - assert arr.ndim == 2 - arr -= arr.min(0) - arr = arr.astype(np.uint64, copy=False) - arr_max = arr.max(0).astype(np.uint64) + 1 - - keys = np.zeros(arr.shape[0], dtype=np.uint64) - # Fortran style indexing - for j in range(arr.shape[1] - 1): - keys += arr[:, j] - keys *= arr_max[j + 1] - keys += arr[:, -1] - return keys - - -def sparse_quantize(coords, - feats=None, - labels=None, - ignore_label=255, - return_index=False, - return_invs=False, - hash_type='ravel', - quantization_size=1): - - use_label = labels is not None - use_feat = feats is not None - if not use_label and not use_feat: - return_index = True - - assert hash_type in [ - 'ravel' - ], "Invalid hash_type. Either ravel, or fnv allowed. You put hash_type=" + hash_type - assert coords.ndim == 2 - if use_feat: - assert feats.ndim == 2 - assert coords.shape[0] == feats.shape[0] - if use_label: - assert coords.shape[0] == len(labels) - - # Quantize the coordinates - dimension = coords.shape[1] - if isinstance(quantization_size, (Sequence, np.ndarray, torch.Tensor)): - assert len( - quantization_size - ) == dimension, "Quantization size and coordinates size mismatch." - quantization_size = [i for i in quantization_size] - elif np.isscalar(quantization_size): # Assume that it is a scalar - quantization_size = [int(quantization_size) for i in range(dimension)] - else: - raise ValueError('Not supported type for quantization_size.') - discrete_coords = np.floor(coords / np.array(quantization_size)) - - # Hash function type - key = ravel_hash_vec(discrete_coords) - if use_label: - _, inds, invs, counts = np.unique(key, - return_index=True, - return_inverse=True, - return_counts=True) - filtered_labels = labels[inds] - filtered_labels[counts > 1] = ignore_label - if return_invs: - if return_index: - return inds, filtered_labels, invs - else: - return discrete_coords[inds], feats[ - inds], filtered_labels, invs - else: - if return_index: - return inds, filtered_labels - else: - return discrete_coords[inds], feats[inds], filtered_labels - - else: - _, inds, invs = np.unique(key, return_index=True, return_inverse=True) - if return_invs: - if return_index: - return inds, invs - else: - if use_feat: - return discrete_coords[inds], feats[inds], invs - else: - return discrete_coords[inds], invs - else: - if return_index: - return inds - else: - if use_feat: - return discrete_coords[inds], feats[inds] - else: - return discrete_coords[inds] - - -def sparse_collate(coords, - feats, - labels=None, - is_double=False, - coord_float=False): - r"""Create a sparse tensor with batch indices C in `the documentation - `_. - - Convert a set of coordinates and features into the batch coordinates and - batch features. - - Args: - coords (set of `torch.Tensor` or `numpy.ndarray`): a set of coordinates. - - feats (set of `torch.Tensor` or `numpy.ndarray`): a set of features. - - labels (set of `torch.Tensor` or `numpy.ndarray`): a set of labels - associated to the inputs. - - is_double (`bool`): return double precision features if True. False by - default. - - """ - use_label = False if labels is None else True - coords_batch, feats_batch, labels_batch = [], [], [] - - batch_id = 0 - for coord, feat in zip(coords, feats): - if isinstance(coord, np.ndarray): - coord = torch.from_numpy(coord) - else: - assert isinstance( - coord, torch.Tensor - ), "Coords must be of type numpy.ndarray or torch.Tensor" - - if not coord_float: - coord = coord.int() - else: - coord = coord.float() - - if isinstance(feat, np.ndarray): - feat = torch.from_numpy(feat) - else: - assert isinstance( - feat, torch.Tensor - ), "Features must be of type numpy.ndarray or torch.Tensor" - feat = feat.double() if is_double else feat.float() - - # Batched coords - num_points = coord.shape[0] - - if not coord_float: - coords_batch.append( - torch.cat( - (coord, torch.ones( - (num_points, 1), device=coord.device).int() * - batch_id), 1)) - else: - coords_batch.append( - torch.cat( - (coord, torch.ones( - (num_points, 1), device=coord.device).float() * - batch_id), 1)) - - # Features - feats_batch.append(feat) - - # Labels - if use_label: - label = labels[batch_id] - if isinstance(label, np.ndarray): - label = torch.from_numpy(label) - else: - assert isinstance( - label, torch.Tensor - ), "labels must be of type numpy.ndarray or torch.Tensor" - labels_batch.append(label) - - batch_id += 1 - - # Concatenate all lists - if not coord_float: - coords_batch = torch.cat(coords_batch, 0).int() - else: - coords_batch = torch.cat(coords_batch, 0).float() - feats_batch = torch.cat(feats_batch, 0) - if use_label: - labels_batch = torch.cat(labels_batch, 0) - return coords_batch, feats_batch, labels_batch - else: - return coords_batch, feats_batch - - -def sparse_collate_tensors(sparse_tensors): - coords, feats = sparse_collate([x.C for x in sparse_tensors], - [x.F for x in sparse_tensors]) - return SparseTensor(feats, coords, sparse_tensors[0].s) - - -def sparse_collate_fn(batch): - if isinstance(batch[0], dict): - batch_size = batch.__len__() - ans_dict = {} - for key in batch[0].keys(): - if isinstance(batch[0][key], SparseTensor): - ans_dict[key] = sparse_collate_tensors( - [sample[key] for sample in batch]) - elif isinstance(batch[0][key], np.ndarray): - ans_dict[key] = torch.stack( - [torch.from_numpy(sample[key]) for sample in batch], - axis=0) - elif isinstance(batch[0][key], torch.Tensor): - ans_dict[key] = torch.stack([sample[key] for sample in batch], - axis=0) - elif isinstance(batch[0][key], dict): - ans_dict[key] = sparse_collate_fn( - [sample[key] for sample in batch]) - else: - ans_dict[key] = [sample[key] for sample in batch] - return ans_dict - else: - batch_size = batch.__len__() - ans_dict = tuple() - for i in range(len(batch[0])): - key = batch[0][i] - if isinstance(key, SparseTensor): - ans_dict += sparse_collate_tensors( - [sample[i] for sample in batch]), - elif isinstance(key, np.ndarray): - ans_dict += torch.stack( - [torch.from_numpy(sample[i]) for sample in batch], axis=0), - elif isinstance(key, torch.Tensor): - ans_dict += torch.stack([sample[i] for sample in batch], - axis=0), - elif isinstance(key, dict): - ans_dict += sparse_collate_fn([sample[i] for sample in batch]), - else: - ans_dict += [sample[i] for sample in batch], - return ans_dict - - -def make_tuple(inputs, dimension=3): - if isinstance(inputs, int): - outputs = tuple() - for d in range(dimension): - outputs += inputs, - return outputs - elif isinstance(inputs, list): - assert len(inputs) == dimension, 'Input length and dimension mismatch' - return tuple(inputs) - elif isinstance(inputs, tuple): - assert len(inputs) == dimension, 'Input length and dimension mismatch' - return inputs - elif isinstance(inputs, torch.Tensor): - inputs = inputs.squeeze() - shape = inputs.shape - assert len(shape) == 1 and shape[0] == dimension, 'Input length and dimension mismatch' - if inputs.is_cuda: - inputs = inputs.cpu() - return tuple((t.item() for t in inputs)) diff --git a/torchsparse/utils/kernel.py b/torchsparse/utils/kernel.py deleted file mode 100644 index 29c18d1..0000000 --- a/torchsparse/utils/kernel.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections import namedtuple -import numpy as np -import torch -from typing import Union, List, Tuple -from torchsparse.utils import make_tuple - -__all__ = ['KernelRegion', 'KernelMapKey'] - -KernelMapKey = namedtuple('KernelMapKey', - ['kernel_size', 'cur_stride', 'stride', 'dilation']) - - -class KernelRegion: - def __init__(self, - kernel_size: Union[int, List[int], Tuple[int, int, int]] = 3, - tensor_stride: Union[int, List[int], Tuple[int, int, int], - torch.Tensor] = 1, - dilation: Union[int, List[int], Tuple[int, int, int]] = 1, - dim: List[int] = [0, 1, 2]) -> None: - self.kernel_size = kernel_size - self.tensor_stride = make_tuple(tensor_stride) - self.dilation = make_tuple(dilation) - assert len(self.tensor_stride) == 3, 'Wrong tensor_stride' - assert len(self.dilation) == 3, 'Wrong dilation' - - ts = self.tensor_stride - d = self.dilation - - if not isinstance(kernel_size, (list, tuple)): - if kernel_size % 2 == 0: - # even - region_type = 0 - else: - # odd - region_type = 1 - - self.region_type = region_type - - x_offset = ( - np.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1) * - ts[0] * d[0]).tolist() if 0 in dim else [0] - y_offset = ( - np.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1) * - ts[1] * d[1]).tolist() if 1 in dim else [0] - z_offset = ( - np.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1) * - ts[2] * d[2]).tolist() if 2 in dim else [0] - - if self.region_type == 1: - kernel_offset = [[x, y, z] for z in z_offset for y in y_offset - for x in x_offset] - else: - kernel_offset = [[x, y, z] for x in x_offset for y in y_offset - for z in z_offset] - kernel_offset = np.array(kernel_offset) - self.kernel_offset = torch.from_numpy(kernel_offset).int() - else: - if dim == [0, 1, 2] and len(kernel_size) == 3: - kernel_x_size = kernel_size[0] - kernel_y_size = kernel_size[1] - kernel_z_size = kernel_size[2] - - x_offset = (np.arange(-kernel_x_size // 2 + 1, - kernel_x_size // 2 + 1) * ts[0] * - d[0]).tolist() - y_offset = (np.arange(-kernel_y_size // 2 + 1, - kernel_y_size // 2 + 1) * ts[1] * - d[1]).tolist() - z_offset = (np.arange(-kernel_z_size // 2 + 1, - kernel_z_size // 2 + 1) * ts[2] * - d[2]).tolist() - - kernel_offset = [[x, y, z] for x in x_offset for y in y_offset - for z in z_offset] - - kernel_offset = np.array(kernel_offset) - self.kernel_offset = torch.from_numpy(kernel_offset).int() - else: - raise NotImplementedError - - def get_kernel_offset(self): - return self.kernel_offset diff --git a/torchsparse/utils/quantize.py b/torchsparse/utils/quantize.py new file mode 100644 index 0000000..eaea1b5 --- /dev/null +++ b/torchsparse/utils/quantize.py @@ -0,0 +1,48 @@ +from itertools import repeat +from typing import Tuple, Union + +import numpy as np + +__all__ = ['sparse_quantize'] + + +def ravel_hash(x: np.ndarray) -> np.ndarray: + assert x.ndim == 2, x.shape + + x -= x.min(axis=0) + x = x.astype(np.uint64, copy=False) + xmax = x.max(axis=0).astype(np.uint64) + 1 + + h = np.zeros(x.shape[0], dtype=np.uint64) + for j in range(x.shape[1] - 1): + h += x[:, j] + h *= xmax[j + 1] + h += x[:, -1] + return h + + +def sparse_quantize( + coords, + voxel_size: Union[float, Tuple[float, ...]] = 1, + *, + return_index: bool = False, + return_inverse: bool = False +) -> Union[np.ndarray, Tuple[np.ndarray, ...]]: + if isinstance(voxel_size, (float, int)): + voxel_size = tuple(repeat(voxel_size, 3)) + assert isinstance(voxel_size, tuple) and len(voxel_size) == 3 + + voxel_size = np.array(voxel_size) + coords = np.floor(coords / voxel_size).astype(np.int32) + + _, indices, inverse_indices = np.unique(ravel_hash(coords), + return_index=True, + return_inverse=True) + coords = coords[indices] + + output = [coords] + if return_index: + output += [indices] + if return_inverse: + output += [inverse_indices] + return output[0] if len(output) == 1 else tuple(output) diff --git a/torchsparse/utils/utils.py b/torchsparse/utils/utils.py new file mode 100644 index 0000000..2b9ca65 --- /dev/null +++ b/torchsparse/utils/utils.py @@ -0,0 +1,19 @@ +from itertools import repeat +from typing import List, Tuple, Union + +import torch + +__all__ = ['make_ntuple'] + + +def make_ntuple(x: Union[int, List[int], Tuple[int, ...], torch.Tensor], + ndim: int) -> Tuple[int, ...]: + if isinstance(x, int): + x = tuple(repeat(x, ndim)) + elif isinstance(x, list): + x = tuple(x) + elif isinstance(x, torch.Tensor): + x = tuple(x.view(-1).cpu().numpy().tolist()) + + assert isinstance(x, tuple) and len(x) == ndim, x + return x diff --git a/torchsparse/version.py b/torchsparse/version.py new file mode 100644 index 0000000..19b4f1d --- /dev/null +++ b/torchsparse/version.py @@ -0,0 +1 @@ +__version__ = '1.3.0' From 01a84bd50a63181a45136d4ea2d1dee3ef33b321 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 17:10:58 -0400 Subject: [PATCH 02/12] Reformat `setup.py` --- setup.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index aff6ee4..299f005 100644 --- a/setup.py +++ b/setup.py @@ -17,9 +17,8 @@ sources = [os.path.join('torchsparse', 'backend', f'pybind_{device}.cpp')] for fpath in glob.glob(os.path.join('torchsparse', 'backend', '**', '*')): - if fpath.endswith('_cpu.cpp') and device in ['cpu', 'cuda']: - sources.append(fpath) - elif fpath.endswith('_cuda.cu') and device == 'cuda': + if ((fpath.endswith('_cpu.cpp') and device in ['cpu', 'cuda']) + or (fpath.endswith('_cuda.cu') and device == 'cuda')): sources.append(fpath) extension_type = CUDAExtension if device == 'cuda' else CppExtension From 0fb6314542c2b0017b18818c15c42707f6837fe2 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 17:15:54 -0400 Subject: [PATCH 03/12] Rename variable --- examples/example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/example.py b/examples/example.py index 6a4e475..be177bd 100644 --- a/examples/example.py +++ b/examples/example.py @@ -80,13 +80,13 @@ def __len__(self): for k, feed_dict in enumerate(dataflow): inputs = feed_dict['input'].cuda() - targets = feed_dict['label'].cuda() + labels = feed_dict['label'].cuda() with amp.autocast(enabled=args.amp_enabled): outputs = model(inputs) - loss = criterion(outputs.F, targets.F) + loss = criterion(outputs.feats, labels.feats) - print(f'[step {k + 1}] loss = {loss.item()}.') + print(f'[step {k + 1}] loss = {loss.item()}') optimizer.zero_grad() scaler.scale(loss).backward() From 4fb0c3e0cd636f4418abc2b01701ca873dadb557 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 17:18:13 -0400 Subject: [PATCH 04/12] Rename variables --- examples/example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/example.py b/examples/example.py index be177bd..6d19ebd 100644 --- a/examples/example.py +++ b/examples/example.py @@ -21,10 +21,10 @@ def __init__(self, input_size: int, voxel_size: float) -> None: self.voxel_size = voxel_size def __getitem__(self, _: int) -> Dict[str, Any]: - lidar = np.random.uniform(-100, 100, size=(self.input_size, 4)) + inputs = np.random.uniform(-100, 100, size=(self.input_size, 4)) labels = np.random.choice(10, size=self.input_size) - coords, feats = lidar[:, :3], lidar + coords, feats = inputs[:, :3], inputs coords -= np.min(coords, axis=0, keepdims=True) coords, indices = sparse_quantize(coords, self.voxel_size, From 2953ab9a0f0e55c505477a181062e6672d86679e Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 18:00:40 -0400 Subject: [PATCH 05/12] Add type annotations for `SparseCrop` --- torchsparse/nn/functional/crop.py | 26 ++++++++++++++++++++++---- torchsparse/nn/modules/crop.py | 15 +++++++++------ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/torchsparse/nn/functional/crop.py b/torchsparse/nn/functional/crop.py index 819e88f..d3c3c1d 100644 --- a/torchsparse/nn/functional/crop.py +++ b/torchsparse/nn/functional/crop.py @@ -1,10 +1,28 @@ +from typing import Optional, Tuple + +import torch + from torchsparse import SparseTensor __all__ = ['spcrop'] -def spcrop(inputs: SparseTensor, loc_min, loc_max) -> SparseTensor: - coords, feats, stride = inputs.C, inputs.F, inputs.s - mask = ((coords[:, :3] >= loc_min) & (coords[:, :3] < loc_max)).all(-1) +def spcrop(input: SparseTensor, + coords_min: Optional[Tuple[int, ...]] = None, + coords_max: Optional[Tuple[int, ...]] = None) -> SparseTensor: + coords, feats, stride = input.coords, input.feats, input.stride + + mask = torch.ones_like(coords) + if coords_min is not None: + coords_min = torch.tensor(coords_min, dtype=int, + device=coords.device).unsqueeze(dim=0) + mask &= (coords[:, :3] >= coords_min) + if coords_max is not None: + coords_max = torch.tensor(coords_max, dtype=int, + device=coords.device).unsqueeze(dim=0) + mask &= (coords[:, :3] <= coords_max) + + mask = torch.all(mask, dim=1) coords, feats = coords[mask], feats[mask] - return SparseTensor(coords=coords, feats=feats, stride=stride) + output = SparseTensor(coords=coords, feats=feats, stride=stride) + return output diff --git a/torchsparse/nn/modules/crop.py b/torchsparse/nn/modules/crop.py index 9b58421..a672f35 100644 --- a/torchsparse/nn/modules/crop.py +++ b/torchsparse/nn/modules/crop.py @@ -1,18 +1,21 @@ -import torch +from typing import Optional, Tuple + from torch import nn from torchsparse import SparseTensor -from torchsparse.nn.functional import spcrop +from torchsparse.nn import functional as F __all__ = ['SparseCrop'] class SparseCrop(nn.Module): - def __init__(self, loc_min, loc_max): + def __init__(self, + coords_min: Optional[Tuple[int, ...]] = None, + coords_max: Optional[Tuple[int, ...]] = None) -> None: super().__init__() - self.loc_min = torch.cuda.IntTensor([list(loc_min)]) - self.loc_max = torch.cuda.IntTensor([list(loc_max)]) + self.coords_min = coords_min + self.coords_max = coords_max def forward(self, input: SparseTensor) -> SparseTensor: - return spcrop(input, self.loc_min, self.loc_max) + return F.spcrop(input, self.coords_min, self.coords_max) From ce69a41f008d12210549cdf4ebc3a4df6499f1ae Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 18:04:20 -0400 Subject: [PATCH 06/12] Switch to `torch.sparse_coo_tensor` to support different devices --- torchsparse/nn/modules/bev.py | 39 ++++++++++++----------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/torchsparse/nn/modules/bev.py b/torchsparse/nn/modules/bev.py index 6e87aa7..c07b010 100644 --- a/torchsparse/nn/modules/bev.py +++ b/torchsparse/nn/modules/bev.py @@ -22,22 +22,19 @@ def extra_repr(self): return f'dim = {self.dim}' def forward(self, input: SparseTensor) -> SparseTensor: - coords, feats, stride = input.C, input.F, input.s + coords, feats, stride = input.coords, input.feats, input.stride coords = coords.clone() coords[:, self.dim] = 0 feats = torch.cat([torch.ones_like(feats[:, :1]), feats], axis=1) - tensor = torch.cuda.sparse.FloatTensor(coords.t().long(), - feats).coalesce() + tensor = torch.sparse_coo_tensor(coords.t().long(), feats).to_dense() coords = tensor.indices().t().int() feats = tensor.values()[:, 1:] / tensor.values()[:, :1] return SparseTensor(coords=coords, feats=feats, stride=stride) class ToDenseBEVConvolution(nn.Module): - """ - - Converts a torchsparse.SparseTensor to a BEV feature map. + """ Converts a torchsparse.SparseTensor to a BEV feature map. Group points with the same z value together and apply the same FC kernel. Aggregate the results by summing up all features within one BEV grid. @@ -48,8 +45,6 @@ class ToDenseBEVConvolution(nn.Module): bias: whether to use bias. Warning: usually larger memory consumption than ToBEVHeightCompression. - - """ def __init__(self, @@ -85,9 +80,8 @@ def reset_parameters(self): self.kernel.data.uniform_(-std, std) def forward(self, input: SparseTensor) -> torch.Tensor: - coords, feats, stride = input.C, input.F, input.s - if isinstance(stride, tuple): - stride = torch.Tensor(stride).unsqueeze(0).to(feats)[:, self.dim] + coords, feats, stride = input.coords, input.feats, input.stride + stride = torch.tensor(stride).unsqueeze(0).to(feats)[:, self.dim] kernel = torch.index_select(self.kernel, 0, (coords[:, self.dim] // stride).long()) @@ -140,10 +134,9 @@ def extra_repr(self): self.in_channels, self.out_channels, self.n_kernels, self.stride) def forward(self, input: SparseTensor) -> torch.Tensor: - coords, feats, stride = input.C, input.F, input.s + coords, feats, stride = input.coords, input.feats, input.stride ratio = stride * self.stride - if isinstance(stride, tuple): - stride = torch.Tensor(stride).unsqueeze(0).to(feats)[:, self.dim] + stride = torch.tensor(stride).unsqueeze(0).to(feats)[:, self.dim] kernels = torch.index_select(self.kernel, 0, coords[:, self.dim].long() // stride) @@ -153,15 +146,13 @@ def forward(self, input: SparseTensor) -> torch.Tensor: if self.stride > 1: coords[:3] /= ratio coords[:3] *= ratio - flatten = torch.cuda.sparse.FloatTensor(coords, feats).coalesce() + flatten = torch.sparse_coo_tensor(coords, feats).to_dense() return SparseTensor(flatten.values(), flatten.indices().t().int(), ratio) class ToBEVHeightCompression(nn.Module): - """ - - Converts a torchsparse.SparseTensor to a dense volumetric tensor, + """ Converts a torchsparse.SparseTensor to a dense volumetric tensor, then flatten the z dimension. E.g. input [N, C] (assume batch_size=1), spatial size [128,2,128] then output will be [1, 2C, 128, 128] @@ -171,16 +162,13 @@ class ToBEVHeightCompression(nn.Module): shape: shape of BEV map. dim: dimension index for z. (default: 1 for KITTI coords) bias: whether to use bias. - - """ def __init__(self, channels: int, shape: Union[List[int], Tuple[int, int, int], torch.Tensor], offset: Tuple[int, int, int] = (0, 0, 0), - dim: int = 1, - bias: bool = False) -> None: + dim: int = 1) -> None: super().__init__() self.channels = channels self.register_buffer('offset', torch.IntTensor([list(offset) + [0]])) @@ -196,10 +184,9 @@ def extra_repr(self) -> str: return f'channels={self.channels}' def forward(self, input: SparseTensor) -> torch.Tensor: - coords, feats, stride = input.C, input.F, input.s - if isinstance(stride, tuple): - stride = torch.Tensor(stride).unsqueeze(0).to(feats) - assert isinstance(stride, torch.Tensor) + coords, feats, stride = input.coords, input.feats, input.stride + stride = torch.tensor(stride).unsqueeze(0).to(coords.device) + assert isinstance(stride, torch.Tensor), type(stride) # [b, x, y, z] coords = (coords - self.offset).t()[[3] + self.bev_dims From 473594ecaecd8afeedc0fcbd6b9f128d600a11d5 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 18:30:55 -0400 Subject: [PATCH 07/12] Reformat code --- torchsparse/__init__.py | 2 +- torchsparse/backend/common/gpu.cuh | 119 ------------------------ torchsparse/backend/others/query_cpu.h | 1 + torchsparse/backend/others/query_cuda.h | 1 + torchsparse/nn/functional/conv.py | 2 +- torchsparse/nn/functional/crop.py | 6 +- torchsparse/nn/functional/devoxelize.py | 9 +- torchsparse/nn/functional/downsample.py | 12 ++- torchsparse/nn/functional/pooling.py | 28 +++--- torchsparse/nn/modules/bev.py | 14 +-- torchsparse/nn/modules/norm.py | 31 +++--- torchsparse/nn/modules/pooling.py | 6 +- torchsparse/operators.py | 15 +-- torchsparse/{tensors.py => tensor.py} | 13 +-- torchsparse/utils/quantize.py | 32 +++---- 15 files changed, 86 insertions(+), 205 deletions(-) delete mode 100644 torchsparse/backend/common/gpu.cuh rename torchsparse/{tensors.py => tensor.py} (89%) diff --git a/torchsparse/__init__.py b/torchsparse/__init__.py index e7ab36a..269c7ea 100644 --- a/torchsparse/__init__.py +++ b/torchsparse/__init__.py @@ -1,3 +1,3 @@ from .operators import * -from .tensors import * +from .tensor import * from .version import __version__ diff --git a/torchsparse/backend/common/gpu.cuh b/torchsparse/backend/common/gpu.cuh deleted file mode 100644 index e36e753..0000000 --- a/torchsparse/backend/common/gpu.cuh +++ /dev/null @@ -1,119 +0,0 @@ -#ifndef GPU_H_ -#define GPU_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -// -// CUDA macros -// - -// CUDA: various checks for different function calls. -#define CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - { \ - cudaError_t error = condition; \ - if (error != cudaSuccess) { \ - throw std::runtime_error(cudaGetErrorString(error) \ - << " at " << __FILE__ << ":" << __LINE__); \ - } \ - } - -#define CUBLAS_CHECK(condition) \ - { \ - cublasStatus_t status = condition; \ - if (status != CUBLAS_STATUS_SUCCESS) { \ - throw std::runtime_error(cublasGetErrorString(status) \ - << " at " << __FILE__ << ":" << __LINE__); \ - } \ - } - -#define CUSPARSE_CHECK(call) \ - { \ - cusparseStatus_t err; \ - if ((err = (call)) != CUSPARSE_STATUS_SUCCESS) { \ - throw std::runtime_error(cusparseGetErrorString(err) \ - << " at " << __FILE__ << ":" << __LINE__); \ - } \ - } - -#define CURAND_CHECK(condition) \ - { \ - curandStatus_t status = condition; \ - if (status != CURAND_STATUS_SUCCESS) { \ - throw std::runtime_error(curandGetErrorString(status) \ - << " at " << __FILE__ << ":" << __LINE__); \ - } \ - } - -// CUDA: grid stride looping -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - -// CUDA: check for error after kernel execution and exit loudly if there is one. -#define CUDA_POST_KERNEL_CHECK \ - { \ - cudaError_t status = cudaPeekAtLastError(); \ - if (status != cudaSuccess) { \ - throw std::runtime_error(cudaGetErrorString(status) \ - << " at " << __FILE__ << ":" << __LINE__); \ - } \ - } - -#define THRUST_CHECK(condition) \ - try { \ - condition; \ - } catch (thrust::system_error e) { \ - throw std::runtime_error("Thrust error: " << e.what() << " at " \ - << __FILE__ << ":" << __LINE__); \ - } - -// CUDA: library error reporting. -const char *cublasGetErrorString(cublasStatus_t error); - -// CUSparse error reporting. -const char *cusparseGetErrorString(cusparseStatus_t error); - -constexpr int CUDA_NUM_THREADS = 256; - -constexpr int SHARED_BLOCK_SIZE = 32; - -constexpr int BLOCK_SIZE = 32; - -inline int GET_BLOCKS(const int N) { - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - -template -void print(const thrust::device_vector &v); -template -void print(const thrust::device_vector &v1, - const thrust::device_vector &v2); - -// AtomicAddition for double with cuda arch <= 600 -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 -#else -__device__ double atomicAdd(double *address, double val) { - unsigned long long int *address_as_ull = (unsigned long long int *)address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); -} -#endif - -#endif // GPU_H_ diff --git a/torchsparse/backend/others/query_cpu.h b/torchsparse/backend/others/query_cpu.h index a63f219..b3c6970 100644 --- a/torchsparse/backend/others/query_cpu.h +++ b/torchsparse/backend/others/query_cpu.h @@ -6,4 +6,5 @@ at::Tensor hash_query_cpu(const at::Tensor hash_query, const at::Tensor hash_target, const at::Tensor idx_target); + #endif diff --git a/torchsparse/backend/others/query_cuda.h b/torchsparse/backend/others/query_cuda.h index 175c527..a46aedf 100644 --- a/torchsparse/backend/others/query_cuda.h +++ b/torchsparse/backend/others/query_cuda.h @@ -6,4 +6,5 @@ at::Tensor hash_query_cuda(const at::Tensor hash_query, const at::Tensor hash_target, const at::Tensor idx_target); + #endif diff --git a/torchsparse/nn/functional/conv.py b/torchsparse/nn/functional/conv.py index 2e9b73d..ad0a448 100644 --- a/torchsparse/nn/functional/conv.py +++ b/torchsparse/nn/functional/conv.py @@ -71,7 +71,7 @@ def backward(ctx, grad_output: torch.Tensor): grad_input = torch.zeros_like(input) grad_weight = torch.zeros_like(weight) - if input.device.type == 'cuda': + if grad_output.device.type == 'cuda': torchsparse.backend.convolution_backward_cuda( input, grad_input, grad_output.contiguous(), weight, grad_weight, nbmaps, nbsizes.cpu(), transposed) diff --git a/torchsparse/nn/functional/crop.py b/torchsparse/nn/functional/crop.py index d3c3c1d..0f0f076 100644 --- a/torchsparse/nn/functional/crop.py +++ b/torchsparse/nn/functional/crop.py @@ -14,11 +14,13 @@ def spcrop(input: SparseTensor, mask = torch.ones_like(coords) if coords_min is not None: - coords_min = torch.tensor(coords_min, dtype=int, + coords_min = torch.tensor(coords_min, + dtype=torch.int, device=coords.device).unsqueeze(dim=0) mask &= (coords[:, :3] >= coords_min) if coords_max is not None: - coords_max = torch.tensor(coords_max, dtype=int, + coords_max = torch.tensor(coords_max, + dtype=torch.int, device=coords.device).unsqueeze(dim=0) mask &= (coords[:, :3] <= coords_max) diff --git a/torchsparse/nn/functional/devoxelize.py b/torchsparse/nn/functional/devoxelize.py index 4f17d2c..73b3287 100644 --- a/torchsparse/nn/functional/devoxelize.py +++ b/torchsparse/nn/functional/devoxelize.py @@ -7,12 +7,11 @@ __all__ = ['spdevoxelize', 'calc_ti_weights'] -def calc_ti_weights(coords, idx_query, scale: float = 1): - # TODO(Haotian): normalize the weights to a probability distribution. +def calc_ti_weights(coords: torch.Tensor, + idx_query: torch.Tensor, + scale: float = 1) -> torch.Tensor: with torch.no_grad(): - # don't want points to lie exactly on grid p = coords - # don't use np.floor then convert to torch. numerical errors. if scale != 1: pf = torch.floor(coords / scale) * scale else: @@ -45,7 +44,7 @@ def calc_ti_weights(coords, idx_query, scale: float = 1): if scale != 1: w /= scale ** 3 w[idx_query == -1] = 0 - w /= w.sum(0) + 1e-8 + w /= torch.sum(w, dim=0) + 1e-8 return w diff --git a/torchsparse/nn/functional/downsample.py b/torchsparse/nn/functional/downsample.py index e1893ed..0fe699b 100644 --- a/torchsparse/nn/functional/downsample.py +++ b/torchsparse/nn/functional/downsample.py @@ -20,7 +20,7 @@ def spdownsample( sample_stride = [stride[k] * tensor_stride[k] for k in range(3)] sample_stride = torch.tensor(sample_stride, dtype=torch.int, - device=coords.device).unsqueeze(0) + device=coords.device).unsqueeze(dim=0) if all(stride[k] in [1, kernel_size[k]] for k in range(3)): coords = coords.clone() @@ -29,17 +29,19 @@ def spdownsample( offsets = get_kernel_offsets(kernel_size, tensor_stride, device=coords.device) + kernel_volume = offsets.size(0) coords_min = torch.min(coords[:, :3], dim=0, keepdim=True).values - xyz = coords[:, :3].unsqueeze(1).repeat(1, offsets.size(0), 1) + offsets - b = coords[:, 3:].repeat(1, offsets.size(0)) - coords = torch.cat([xyz.view(-1, 3), b.view(-1, 1)], dim=1) + x = coords[:, :3].unsqueeze(dim=1).repeat(1, kernel_volume, 1) + offsets + b = coords[:, 3:].repeat(1, kernel_volume) + coords = torch.cat([x.view(-1, 3), b.view(-1, 1)], dim=1) # TODO(Zhijian): We need to also filter `coords` based on `coords_max`. mask = (coords[:, :3] % sample_stride == 0) mask &= (coords[:, :3] >= coords_min) - coords = coords[torch.sum(mask, dim=1) == 3, :] + mask = torch.all(mask, dim=1) + coords = coords[mask] # This makes sure that the points will be ordered with respect to the batch # index, but this will not affect the correctness of the result. diff --git a/torchsparse/nn/functional/pooling.py b/torchsparse/nn/functional/pooling.py index 0d70d8b..a20ecc6 100644 --- a/torchsparse/nn/functional/pooling.py +++ b/torchsparse/nn/functional/pooling.py @@ -6,26 +6,22 @@ def global_avg_pool(inputs: SparseTensor) -> torch.Tensor: - batch_index = inputs.C[:, -1] - max_index = torch.max(batch_index).item() + batch_size = torch.max(inputs.coords[:, -1]).item() + 1 outputs = [] - for i in range(max_index + 1): - cur_inputs = torch.index_select(inputs.F, 0, - torch.where(batch_index == i)[0]) - cur_outputs = cur_inputs.mean(0).unsqueeze(0) - outputs.append(cur_outputs) - outputs = torch.cat(outputs, 0) + for k in range(batch_size): + input = inputs.feats[inputs.coords[:, -1] == k] + output = torch.mean(input, dim=0) + outputs.append(output) + outputs = torch.stack(outputs, dim=0) return outputs def global_max_pool(inputs: SparseTensor) -> torch.Tensor: - batch_index = inputs.C[:, -1] - max_index = torch.max(batch_index).item() + batch_size = torch.max(inputs.coords[:, -1]).item() + 1 outputs = [] - for i in range(max_index + 1): - cur_inputs = torch.index_select(inputs.F, 0, - torch.where(batch_index == i)[0]) - cur_outputs = cur_inputs.max(0)[0].unsqueeze(0) - outputs.append(cur_outputs) - outputs = torch.cat(outputs, 0) + for k in range(batch_size): + input = inputs.feats[inputs.coords[:, -1] == k] + output = torch.max(input, dim=0)[0] + outputs.append(output) + outputs = torch.stack(outputs, dim=0) return outputs diff --git a/torchsparse/nn/modules/bev.py b/torchsparse/nn/modules/bev.py index c07b010..f724a9f 100644 --- a/torchsparse/nn/modules/bev.py +++ b/torchsparse/nn/modules/bev.py @@ -81,18 +81,18 @@ def reset_parameters(self): def forward(self, input: SparseTensor) -> torch.Tensor: coords, feats, stride = input.coords, input.feats, input.stride - stride = torch.tensor(stride).unsqueeze(0).to(feats)[:, self.dim] + stride = torch.tensor(stride).unsqueeze(dim=0).to(feats)[:, self.dim] kernel = torch.index_select(self.kernel, 0, (coords[:, self.dim] // stride).long()) - feats = (feats.unsqueeze(-1) * kernel).sum(1) + self.bias + feats = (feats.unsqueeze(dim=-1) * kernel).sum(1) + self.bias coords = (coords - self.offset).t()[[3] + self.bev_dims].long() coords[1:] = (coords[1:] // stride).long() indices = coords[0] * int(self.bev_shape.prod()) + coords[1] * int( self.bev_shape[1]) + coords[2] batch_size = coords[0].max().item() + 1 output = torch.sparse_coo_tensor( - indices.unsqueeze(0), + indices.unsqueeze(dim=0), feats, torch.Size( [batch_size * int(self.bev_shape.prod()), @@ -136,11 +136,11 @@ def extra_repr(self): def forward(self, input: SparseTensor) -> torch.Tensor: coords, feats, stride = input.coords, input.feats, input.stride ratio = stride * self.stride - stride = torch.tensor(stride).unsqueeze(0).to(feats)[:, self.dim] + stride = torch.tensor(stride).unsqueeze(dim=0).to(feats)[:, self.dim] kernels = torch.index_select(self.kernel, 0, coords[:, self.dim].long() // stride) - feats = (feats.unsqueeze(-1) * kernels).sum(1) + self.bias + feats = (feats.unsqueeze(dim=-1) * kernels).sum(1) + self.bias coords = coords.t().long() coords[self.dim, :] = 0 if self.stride > 1: @@ -185,7 +185,7 @@ def extra_repr(self) -> str: def forward(self, input: SparseTensor) -> torch.Tensor: coords, feats, stride = input.coords, input.feats, input.stride - stride = torch.tensor(stride).unsqueeze(0).to(coords.device) + stride = torch.tensor(stride).unsqueeze(dim=0).to(coords.device) assert isinstance(stride, torch.Tensor), type(stride) # [b, x, y, z] @@ -202,7 +202,7 @@ def forward(self, input: SparseTensor) -> torch.Tensor: shape[1:].prod()) + coords[2] * int(shape[2]) + coords[3] batch_size = coords[0].max().item() + 1 output = torch.sparse_coo_tensor( - indices.unsqueeze(0), + indices.unsqueeze(dim=0), feats, torch.Size([batch_size * int(self.shape.prod()), feats.size(-1)]), diff --git a/torchsparse/nn/modules/norm.py b/torchsparse/nn/modules/norm.py index 1c22246..9dc65f7 100644 --- a/torchsparse/nn/modules/norm.py +++ b/torchsparse/nn/modules/norm.py @@ -16,27 +16,26 @@ def forward(self, input: SparseTensor) -> SparseTensor: class GroupNorm(nn.GroupNorm): def forward(self, input: SparseTensor) -> SparseTensor: - feats = input.F - coords = input.C - stride = input.s + coords, feats, stride = input.coords, input.feats, input.stride + + batch_size = torch.max(coords[:, -1]).item() + 1 + num_channels = feats.shape[1] + # PyTorch's GroupNorm function expects the input to be in (N, C, *) # format where N is batch size, and C is number of channels. "feats" # is not in that format. So, we extract the feats corresponding to # each sample, bring it to the format expected by PyTorch's GroupNorm # function, and invoke it. - batch_size = coords[-1][-1] + 1 - num_channels = feats.shape[1] - new_feats = torch.zeros_like(feats) - for sample_idx in range(batch_size): - indices = coords[:, -1] == sample_idx - sample_feats = feats[indices] - sample_feats = torch.transpose(sample_feats, 0, 1) - sample_feats = sample_feats.reshape(1, num_channels, -1) - normalized_feats = super().forward(sample_feats) - normalized_feats = normalized_feats.reshape(num_channels, -1) - normalized_feats = torch.transpose(normalized_feats, 0, 1) - new_feats[indices] = normalized_feats - output = SparseTensor(coords=coords, feats=new_feats, stride=stride) + nfeats = torch.zeros_like(feats) + for k in range(batch_size): + indices = coords[:, -1] == k + bfeats = feats[indices] + bfeats = bfeats.transpose(0, 1).reshape(1, num_channels, -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(num_channels, -1).transpose(0, 1) + nfeats[indices] = bfeats + + output = SparseTensor(coords=coords, feats=nfeats, stride=stride) output.cmaps = input.cmaps output.kmaps = input.kmaps return output diff --git a/torchsparse/nn/modules/pooling.py b/torchsparse/nn/modules/pooling.py index d58aa6b..faaa3cb 100644 --- a/torchsparse/nn/modules/pooling.py +++ b/torchsparse/nn/modules/pooling.py @@ -3,16 +3,16 @@ from torchsparse import SparseTensor from torchsparse.nn import functional as F -__all__ = ['GlobalAveragePooling', 'GlobalMaxPooling'] +__all__ = ['GlobalAvgPool', 'GlobalMaxPool'] -class GlobalAveragePooling(nn.Module): +class GlobalAvgPool(nn.Module): def forward(self, input: SparseTensor) -> SparseTensor: return F.global_avg_pool(input) -class GlobalMaxPooling(nn.Module): +class GlobalMaxPool(nn.Module): def forward(self, input: SparseTensor) -> SparseTensor: return F.global_max_pool(input) diff --git a/torchsparse/operators.py b/torchsparse/operators.py index d03a379..231dfe5 100644 --- a/torchsparse/operators.py +++ b/torchsparse/operators.py @@ -2,15 +2,16 @@ import torch -from torchsparse.tensors import SparseTensor +from torchsparse.tensor import SparseTensor __all__ = ['cat'] def cat(inputs: List[SparseTensor]) -> SparseTensor: - coords, stride = inputs[0].coords, inputs[0].stride - feats = torch.cat([inputs.feats for inputs in inputs], dim=1) - outputs = SparseTensor(feats, coords, stride) - outputs.cmaps = inputs[0].cmaps - outputs.kmaps = inputs[0].kmaps - return outputs + feats = torch.cat([input.feats for input in inputs], dim=1) + output = SparseTensor(coords=inputs[0].coords, + feats=feats, + stride=inputs[0].stride) + output.cmaps = inputs[0].cmaps + output.kmaps = inputs[0].kmaps + return output diff --git a/torchsparse/tensors.py b/torchsparse/tensor.py similarity index 89% rename from torchsparse/tensors.py rename to torchsparse/tensor.py index 810f020..47237da 100644 --- a/torchsparse/tensors.py +++ b/torchsparse/tensor.py @@ -53,17 +53,18 @@ def detach(self): self.coords = self.coords.detach() return self - def to(self, device, non_blocking=True): + def to(self, device: str, non_blocking: bool = True): self.feats = self.feats.to(device, non_blocking=non_blocking) self.coords = self.coords.to(device, non_blocking=non_blocking) return self def __add__(self, other): - tensor = SparseTensor(self.feats + other.feats, self.coords, - self.stride) - tensor.cmaps = self.cmaps - tensor.kmaps = self.kmaps - return tensor + output = SparseTensor(coords=self.coords, + feats=self.feats + other.feats, + stride=self.stride) + output.cmaps = self.cmaps + output.kmaps = self.kmaps + return output class PointTensor: diff --git a/torchsparse/utils/quantize.py b/torchsparse/utils/quantize.py index eaea1b5..10df84e 100644 --- a/torchsparse/utils/quantize.py +++ b/torchsparse/utils/quantize.py @@ -1,5 +1,5 @@ from itertools import repeat -from typing import Tuple, Union +from typing import List, Tuple, Union import numpy as np @@ -9,25 +9,23 @@ def ravel_hash(x: np.ndarray) -> np.ndarray: assert x.ndim == 2, x.shape - x -= x.min(axis=0) + x -= np.min(x, axis=0) x = x.astype(np.uint64, copy=False) - xmax = x.max(axis=0).astype(np.uint64) + 1 + xmax = np.max(x, axis=0).astype(np.uint64) + 1 h = np.zeros(x.shape[0], dtype=np.uint64) - for j in range(x.shape[1] - 1): - h += x[:, j] - h *= xmax[j + 1] + for k in range(x.shape[1] - 1): + h += x[:, k] + h *= xmax[k + 1] h += x[:, -1] return h -def sparse_quantize( - coords, - voxel_size: Union[float, Tuple[float, ...]] = 1, - *, - return_index: bool = False, - return_inverse: bool = False -) -> Union[np.ndarray, Tuple[np.ndarray, ...]]: +def sparse_quantize(coords, + voxel_size: Union[float, Tuple[float, ...]] = 1, + *, + return_index: bool = False, + return_inverse: bool = False) -> List[np.ndarray]: if isinstance(voxel_size, (float, int)): voxel_size = tuple(repeat(voxel_size, 3)) assert isinstance(voxel_size, tuple) and len(voxel_size) == 3 @@ -40,9 +38,9 @@ def sparse_quantize( return_inverse=True) coords = coords[indices] - output = [coords] + outputs = [coords] if return_index: - output += [indices] + outputs += [indices] if return_inverse: - output += [inverse_indices] - return output[0] if len(output) == 1 else tuple(output) + outputs += [inverse_indices] + return outputs[0] if len(outputs) == 1 else outputs From fcec738baffa0c0a0ee0566b9d26fad4b87df682 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 20:02:54 -0400 Subject: [PATCH 08/12] Update version to 1.4.0 --- torchsparse/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsparse/version.py b/torchsparse/version.py index 19b4f1d..96e3ce8 100644 --- a/torchsparse/version.py +++ b/torchsparse/version.py @@ -1 +1 @@ -__version__ = '1.3.0' +__version__ = '1.4.0' From 67b07b63bdf9ed6220853f8ce62ca019e3f4320d Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 20:04:30 -0400 Subject: [PATCH 09/12] Update version in `README.md` --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 929b49d..f25d894 100644 --- a/README.md +++ b/README.md @@ -20,10 +20,10 @@ TorchSparse depends on the [Google Sparse Hash](https://github.com/sparsehash/sp * You can also compile the library locally (if you do not have the sudo permission) and add the library path to the environment variable `CPLUS_INCLUDE_PATH`. -The latest released TorchSparse (v1.3.0) can then be installed by +The latest released TorchSparse (v1.4.0) can then be installed by ```bash -pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.3.0 +pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0 ``` If you use TorchSparse in your code, please remember to specify the exact version as your dependencies. From 823dada3df5100807fa1762df72597eca5d10d29 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 20:18:49 -0400 Subject: [PATCH 10/12] Fix the shape of `mask` --- torchsparse/nn/functional/crop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchsparse/nn/functional/crop.py b/torchsparse/nn/functional/crop.py index 0f0f076..75df98c 100644 --- a/torchsparse/nn/functional/crop.py +++ b/torchsparse/nn/functional/crop.py @@ -12,7 +12,9 @@ def spcrop(input: SparseTensor, coords_max: Optional[Tuple[int, ...]] = None) -> SparseTensor: coords, feats, stride = input.coords, input.feats, input.stride - mask = torch.ones_like(coords) + mask = torch.ones((coords.shape[0], 3), + dtype=torch.bool, + device=coords.device) if coords_min is not None: coords_min = torch.tensor(coords_min, dtype=torch.int, From eac36f29c54c0dff0fff981c12b4f8c8f8366965 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 20:34:06 -0400 Subject: [PATCH 11/12] Switch from `<` to `<=` for `spcrop` --- torchsparse/nn/functional/crop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchsparse/nn/functional/crop.py b/torchsparse/nn/functional/crop.py index 75df98c..69194c6 100644 --- a/torchsparse/nn/functional/crop.py +++ b/torchsparse/nn/functional/crop.py @@ -24,7 +24,10 @@ def spcrop(input: SparseTensor, coords_max = torch.tensor(coords_max, dtype=torch.int, device=coords.device).unsqueeze(dim=0) - mask &= (coords[:, :3] <= coords_max) + # Using "<" instead of "<=" is for the backward compatability (in + # some existing detection codebase). We might need to reflect this + # in the document or change it back to "<=" in the future. + mask &= (coords[:, :3] < coords_max) mask = torch.all(mask, dim=1) coords, feats = coords[mask], feats[mask] From 002beb585cfabd348c1602ed50bc64fb442e3dd0 Mon Sep 17 00:00:00 2001 From: Zhijian Liu Date: Thu, 24 Jun 2021 21:00:08 -0400 Subject: [PATCH 12/12] Update docstrings --- setup.cfg | 7 +++++-- torchsparse/nn/modules/bev.py | 37 +++++++++++++++++------------------ 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/setup.cfg b/setup.cfg index ecaad42..c662f4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,8 +8,11 @@ split_before_bitwise_operator = true [isort] known_first_party = torchsparse +[pydocstyle] +convention = google + [flake8] -select = B, C, E, F, P, T4, W, B9 -ignore = E501, E722, W503 +select = B, C, D, E, F, P, T4, W, B9 +ignore = D10, E501, E722, W503 per-file-ignores = __init__.py: F401, F403 diff --git a/torchsparse/nn/modules/bev.py b/torchsparse/nn/modules/bev.py index f724a9f..3b10330 100644 --- a/torchsparse/nn/modules/bev.py +++ b/torchsparse/nn/modules/bev.py @@ -34,17 +34,20 @@ def forward(self, input: SparseTensor) -> SparseTensor: class ToDenseBEVConvolution(nn.Module): - """ Converts a torchsparse.SparseTensor to a BEV feature map. + """Converts a SparseTensor into a dense BEV feature map. + Group points with the same z value together and apply the same FC kernel. Aggregate the results by summing up all features within one BEV grid. - in_channels: input channels - out_channels: output channels - shape: shape of BEV map. - dim: dimension index for z. (default: 1 for KITTI coords) - bias: whether to use bias. + Note: + This module consumes larger memory than `ToBEVHeightCompression`. - Warning: usually larger memory consumption than ToBEVHeightCompression. + Args: + in_channels: Number of input channels + out_channels: Number of output channels + shape: Shape of BEV map + dim: Dimension index for z (default: 1 for KITTI coords) + bias: Whether to use bias """ def __init__(self, @@ -104,8 +107,7 @@ def forward(self, input: SparseTensor) -> torch.Tensor: class ToBEVConvolution(nn.Module): - """ Sparse version of ToDenseBEVConvolution. - """ + """Converts a SparseTensor into a sparse BEV feature map.""" def __init__(self, in_channels: int, @@ -152,16 +154,13 @@ def forward(self, input: SparseTensor) -> torch.Tensor: class ToBEVHeightCompression(nn.Module): - """ Converts a torchsparse.SparseTensor to a dense volumetric tensor, - then flatten the z dimension. - E.g. input [N, C] (assume batch_size=1), spatial size [128,2,128] - then output will be [1, 2C, 128, 128] - - channels: input channels - (Note: output channels = channels x #unique z values) - shape: shape of BEV map. - dim: dimension index for z. (default: 1 for KITTI coords) - bias: whether to use bias. + """Converts a SparseTensor to a flattened volumetric tensor. + + Args: + channels: Number of input channels + (Note: output channels = channels x #unique z values) + shape: Shape of BEV map + dim: Dimension index for z (default: 1 for KITTI coords) """ def __init__(self,