diff --git a/.github/workflows/formatter.yml b/.github/workflows/formatter.yml
deleted file mode 100644
index 225bb62..0000000
--- a/.github/workflows/formatter.yml
+++ /dev/null
@@ -1,34 +0,0 @@
-name: Formatter
-
-on:
- push:
- branches:
- -master
- pull_request:
-
-jobs:
- build:
- name: Formatter
- runs-on: ubuntu-latest
- steps:
- - name: Checkout repo
- uses: actions/checkout@v2.3.4
- with:
- repository: ${{ github.repository }}
- token: ${{ github.token }}
- ref: ${{ github.event.pull_request.head.ref }}
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- python -m pip install --upgrade yapf
- echo $(yapf --version)
- - name: Format with YAPF
- run: |
- yapf --verbose --recursive --in-place --parallel --style '{SPACES_AROUND_POWER_OPERATOR: True}' .
- - name: Push commit
- run: |
- git config user.name github-actions
- git config user.email github-actions@github.com
- git add .
- git commit -m "Automation: Formatter" --all | exit 0
- git push
diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml
new file mode 100644
index 0000000..14e7d69
--- /dev/null
+++ b/.github/workflows/pre-commit.yaml
@@ -0,0 +1,17 @@
+name: pre-commit
+
+on:
+ pull_request:
+ push:
+ branches: [master]
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - run: |
+ sudo apt-get update
+ sudo apt-get install -y --no-install-recommends clang-format
+ - uses: actions/checkout@v2
+ - uses: actions/setup-python@v2
+ - uses: pre-commit/action@v2.0.3
diff --git a/.gitignore b/.gitignore
index 51d4ba0..e01e346 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,3 @@
.vscode/
build/
-*.pyc
\ No newline at end of file
+*.pyc
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..8e45a40
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,63 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.0.1
+ hooks:
+ - id: trailing-whitespace
+ name: (Common) Remove trailing whitespaces
+ - id: mixed-line-ending
+ name: (Common) Fix mixed line ending
+ args: ['--fix=lf']
+ - id: end-of-file-fixer
+ name: (Common) Remove extra EOF newlines
+ - id: check-merge-conflict
+ name: (Common) Check for merge conflicts
+ - id: requirements-txt-fixer
+ name: (Common) Sort "requirements.txt"
+ - id: fix-encoding-pragma
+ name: (Python) Remove encoding pragmas
+ args: ['--remove']
+ - id: double-quote-string-fixer
+ name: (Python) Fix double-quoted strings
+ - id: debug-statements
+ name: (Python) Check for debugger imports
+ - id: check-json
+ name: (JSON) Check syntax
+ - id: check-yaml
+ name: (YAML) Check syntax
+ - id: check-toml
+ name: (TOML) Check syntax
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v2.19.4
+ hooks:
+ - id: pyupgrade
+ name: (Python) Update syntax for newer versions
+ args: ['--py36-plus']
+ - repo: https://github.com/google/yapf
+ rev: v0.31.0
+ hooks:
+ - id: yapf
+ name: (Python) Format with yapf
+ - repo: https://github.com/pycqa/isort
+ rev: 5.8.0
+ hooks:
+ - id: isort
+ name: (Python) Sort imports with isort
+ - repo: https://github.com/pycqa/flake8
+ rev: 3.9.2
+ hooks:
+ - id: flake8
+ name: (Python) Check with flake8
+ additional_dependencies: [flake8-bugbear, flake8-comprehensions, flake8-docstrings, flake8-executable, flake8-quotes]
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v0.902
+ hooks:
+ - id: mypy
+ name: (Python) Check with mypy
+ additional_dependencies: [tokenize-rt]
+ - repo: local
+ hooks:
+ - id: clang-format
+ name: (C/C++/CUDA) Format with clang-format
+ entry: clang-format -style=google -i
+ language: system
+ files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
diff --git a/LICENSE b/LICENSE
index 1f23cf0..b6edbf4 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2020 Haotian Tang, Zhijian Liu, Song Han
+Copyright (c) 2020-2021 TorchSparse Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
@@ -19,31 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-
---------------------------- LICENSE FOR MinkowskiEngine --------------------------------
-MIT License
-
-Copyright (c) 2020 NVIDIA CORPORATION.
-Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu)
-
-Permission is hereby granted, free of charge, to any person obtaining a copy of
-this software and associated documentation files (the "Software"), to deal in
-the Software without restriction, including without limitation the rights to
-use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
-of the Software, and to permit persons to whom the Software is furnished to do
-so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
-Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
-of the code.
\ No newline at end of file
diff --git a/README.md b/README.md
index a9c9789..f25d894 100644
--- a/README.md
+++ b/README.md
@@ -1,141 +1,92 @@
# TorchSparse
-## News
+TorchSparse is a high-performance neural network library for point cloud processing.
-2020/09/20: We released `torchsparse` v1.1, which is significantly faster than our `torchsparse` v1.0 and is also achieves **1.9x** speedup over [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) v0.5 alpha when running MinkUNet18C!
-
-2020/08/30: We released `torchsparse` v1.0.
-
-## Overview
+## Installation
-We release `torchsparse`, a high-performance computing library for efficient 3D sparse convolution. This library aims at accelerating sparse computation in 3D, in particular the Sparse Convolution operation.
+TorchSparse depends on the [Google Sparse Hash](https://github.com/sparsehash/sparsehash) library.
-
+* On Ubuntu, it can be installed by
-The major advantage of this library is that we support all computation on the GPU, especially the kernel map construction (which is done on the CPU in latest [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) V0.4.3).
+ ```bash
+ sudo apt-get install libsparsehash-dev
+ ```
-## Installation
+* On Mac OS, it can be installed by
-You may run the following command to install torchsparse.
+ ```bash
+ brew install google-sparsehash
+ ```
-```bash
-pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git
-```
+* You can also compile the library locally (if you do not have the sudo permission) and add the library path to the environment variable `CPLUS_INCLUDE_PATH`.
-Note that this library depends on Google's [sparse hash map project](https://github.com/sparsehash/sparsehash). In order to install this library, you may run
+The latest released TorchSparse (v1.4.0) can then be installed by
```bash
-sudo apt-get install libsparsehash-dev
+pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
```
-on Ubuntu servers. If you are not sudo, please clone Google's codebase, compile it and install locally. Finally, add the path to this library to your `CPLUS_INCLUDE_PATH` environmental variable.
-
-For GPU server users, we currently support PyTorch 1.6.0 + CUDA 10.2 + CUDNN 7.6.2. For CPU users, we support PyTorch 1.6.0 (CPU version), MKLDNN backend is optional.
-
-## Usage
-
-Our [SPVNAS](https://github.com/mit-han-lab/e3d) project (ECCV2020) is built with torchsparse. You may navigate to this project and follow the instructions in that codebase to play around.
-
-Here, we also provide a walk-through on some important concepts in torchsparse.
-
-### Sparse Tensor and Point Tensor
-
-In torchsparse, we have two data structures for point cloud storage, namely `torchsparse.SparseTensor` and `torchsparse.PointTensor`. Both structures has two data fields `C` (coordinates) and `F` (features). In `SparseTensor`, we assume that all coordinates are **integer** and **do not duplicate**. However, in `PointTensor`, all coordinates are **floating-point** and can duplicate.
-
-### Sparse Quantize and Sparse Collate
-
-The way to convert a point cloud to `SparseTensor` so that it can be consumed by networks built with Sparse Convolution or Sparse Point-Voxel Convolution is to use the function `torchsparse.utils.sparse_quantize`. An example is given here:
-
-```python
-inds, labels, inverse_map = sparse_quantize(pc, feat, labels, return_index=True, return_invs=True)
-```
+If you use TorchSparse in your code, please remember to specify the exact version as your dependencies.
-where `pc`, `feat`, `labels` corresponds to point cloud (coordinates, should be integer), feature and ground-truth. The `inds` denotes unique indices in the point cloud coordinates, and `inverse_map` denotes the unique index each point is corresponding to. The `inverse map` is used to restore full point cloud prediction from downsampled prediction.
+## Benchmark
-To combine a list of `SparseTensor`s to a batch, you may want to use the `torchsparse.utils.sparse_collate_fn` function.
+We compare TorchSparse with [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) (where the latency is measured on NVIDIA GTX 1080Ti):
-Detailed results are given in [SemanticKITTI dataset preprocessing code](https://github.com/mit-han-lab/e3d/blob/master/spvnas/core/datasets/semantic_kitti.py) in our [SPVNAS](https://github.com/mit-han-lab/e3d) project.
+| | MinkowskiEngine v0.4.3 | TorchSparse v1.0.0 |
+| :----------------------- | :--------------------: | :----------------: |
+| MinkUNet18C (MACs / 10) | 224.7 ms | 124.3 ms |
+| MinkUNet18C (MACs / 4) | 244.3 ms | 160.9 ms |
+| MinkUNet18C (MACs / 2.5) | 269.6 ms | 214.3 ms |
+| MinkUNet18C | 323.5 ms | 294.0 ms |
-### Computation API
+## Getting Started
-The computation interface in torchsparse is straightforward and very similar to original PyTorch. An example here defines a basic convolution block:
+### Sparse Tensor
-```python
-class BasicConvolutionBlock(nn.Module):
- def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
- super().__init__()
- self.net = nn.Sequential(
- spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride),
- spnn.BatchNorm(outc),
- spnn.ReLU(True)
- )
-
- def forward(self, x):
- out = self.net(x)
- return out
-```
+Sparse tensor (`SparseTensor`) is the main data structure for point cloud, which has two data fields:
+* Coordinates (`coords`): a 2D integer tensor with a shape of N x 4, where the first three dimensions correspond to quantized x, y, z coordinates, and the last dimension denotes the batch index.
+* Features (`feats`): a 2D tensor with a shape of N x C, where C is the number of feature channels.
-where `spnn`denotes `torchsparse.nn`, and `spnn.Conv3d` means 3D sparse convolution operation, `spnn.BatchNorm` and `spnn.ReLU` denotes 3D sparse tensor batchnorm and activations, respectively. We also support direct convolution kernel call via `torchsparse.nn.functional`, for example:
+Most existing datasets provide raw point cloud data with float coordinates. We can use `sparse_quantize` (provided in `torchsparse.utils.quantize`) to voxelize x, y, z coordinates and remove duplicates:
```python
-outputs = torchsparse.nn.functional.conv3d(inputs, kernel, stride=1, dilation=1, transpose=False)
+coords -= np.min(coords, axis=0, keepdims=True)
+coords, indices = sparse_quantize(coords, voxel_size, return_index=True)
+coords = torch.tensor(coords, dtype=torch.int)
+feats = torch.tensor(feats[indices], dtype=torch.float)
+tensor = SparseTensor(coords=coords, feats=feats)
```
-where we need to define `inputs`(SparseTensor), `kernel` (of shape k^3 x OC x IC when k > 1, or OC x IC when k = 1, where k denotes the kernel size and IC, OC means input / output channels). The `outputs` is still a SparseTensor.
+We can then use `sparse_collate_fn` (provided in `torchsparse.utils.collate`) to assemble a batch of `SparseTensor`'s (and add the batch dimension to `coords`). Please refer to [this example](https://github.com/mit-han-lab/torchsparse/blob/dev/pre-commit/examples/example.py) for more details.
-Detailed examples are given in [here](https://github.com/mit-han-lab/e3d/blob/master/spvnas/core/modules/dynamic_sparseop.py), where we use the `torchsparse.nn.functional` interfaces to implement weight-shared 3D-NAS modules.
+### Sparse Neural Network
-### Sparse Hashmap API
-
-Sparse hash map query is important in 3D sparse computation. It is mainly used to infer a point's memory location (*i.e.* index) given its coordinates. For example, we use this operation in kernel map construction part of 3D sparse convolution, and also sparse voxelization / devoxelization in [Sparse Point-Voxel Convolution](https://arxiv.org/abs/2007.16100). Here, we provide the following example for hash map API:
+The neural network interface in TorchSparse is very similar to PyTorch:
```python
-source_hash = torchsparse.nn.functional.sphash(torch.floor(source_coords).int())
-target_hash = torchsparse.nn.functional.sphash(torch.floor(target_coords).int())
-idx_query = torchsparse.nn.functional.sphashquery(source_hash, target_hash)
+from torch import nn
+from torchsparse import nn as spnn
+
+model = nn.Sequential(
+ spnn.Conv3d(in_channels, out_channels, kernel_size),
+ spnn.BatchNorm(out_channels),
+ spnn.ReLU(True),
+)
```
-In this example, `sphash` is the function converting integer coordinates to hashing. The `sphashquery(source_hash, target_hash)` performs the hash table lookup. Here, the hash map has key `target_hash` and value corresponding to point indices in the target point cloud tensor. For each point in the `source_coords`, we find the point index in `target_coords` which has the same coordinate as it.
-
-### Dummy Training Example
-
-We here provides an entire training example with dummy input [here](examples/example.py). In this example, we cover
-
-- How we start from point cloud data and convert it to SparseTensor format;
-- How we can implement SparseTensor batching;
-- How to train a semantic segmentation SparseConvNet.
-
-You are also welcomed to check out our [SPVNAS](https://github.com/mit-han-lab/e3d) project to implement training / inference with real data.
-
-### Mixed Precision (float16) Support
-
-Mixed precision training is supported via `torch.cuda.amp.autocast` and `torch.cuda.amp.GradScaler`. Enabling mixed precision training can speed up training and reduce GPU memory usage. By wrapping your training code in a `torch.cuda.amp.autocast` block, feature tensors will automatically be converted to float16 if possible. See [here](examples/example.py) for a complete example.
-
-## Speed Comparison Between torchsparse and MinkowskiEngine
-
-We benchmark the performance of our torchsparse and latest [MinkowskiEngine V0.4.3](https://github.com/NVIDIA/MinkowskiEngine) here, latency is measured on NVIDIA GTX 1080Ti GPU:
-
-| Network | Latency (ME V0.4.3) | Latency (torchsparse V1.0.0) |
-| :----------------------: | :-----------------: | :--------------------------: |
-| MinkUNet18C (MACs / 10) | 224.7 | 124.3 |
-| MinkUNet18C (MACs / 4) | 244.3 | 160.9 |
-| MinkUNet18C (MACs / 2.5) | 269.6 | 214.3 |
-| MinkUNet18C | 323.5 | 294.0 |
-
## Citation
-If you find this code useful, please consider citing:
+If you use TorchSparse in your research, please use the following BibTeX entry:
```bibtex
-@inproceedings{
- tang2020searching,
- title = {Searching Efficient 3D Architectures with Sparse Point-Voxel Convolution},
- author = {Tang, Haotian* and Liu, Zhijian* and Zhao, Shengyu and Lin, Yujun and Lin, Ji and Wang, Hanrui and Han, Song},
- booktitle = {European Conference on Computer Vision},
+@inproceedings{tang2020searching,
+ title = {{Searching Efficient 3D Architectures with Sparse Point-Voxel Convolution}},
+ author = {Tang, Haotian and Liu, Zhijian and Zhao, Shengyu and Lin, Yujun and Lin, Ji and Wang, Hanrui and Han, Song},
+ booktitle = {European Conference on Computer Vision (ECCV)},
year = {2020}
}
```
## Acknowledgements
-This library is inspired by [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine), [SECOND](https://github.com/traveller59/second.pytorch) and [SparseConvNet](https://github.com/facebookresearch/SparseConvNet).
+TorchSparse is inspired by many existing open-source libraries, including (but not limited to) [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine), [SECOND](https://github.com/traveller59/second.pytorch) and [SparseConvNet](https://github.com/facebookresearch/SparseConvNet).
diff --git a/examples/example.py b/examples/example.py
index 2bc0ce1..6d19ebd 100644
--- a/examples/example.py
+++ b/examples/example.py
@@ -1,82 +1,94 @@
-import numpy as np
-import torch
-import torch.nn as nn
-import torchsparse
-import torchsparse.nn as spnn
-from torchsparse import SparseTensor
-from torchsparse.utils import sparse_collate_fn, sparse_quantize
import argparse
+import random
+from typing import Any, Dict
+import numpy as np
+import torch
+import torch.utils.data
+from torch import nn
+from torch.cuda import amp
-def generate_random_point_cloud(size=100000, voxel_size=0.2):
- pc = np.random.randn(size, 4)
- pc[:, :3] = pc[:, :3] * 10
- rounded_pc = np.round(pc[:, :3] / voxel_size).astype(np.int32)
- labels = np.random.choice(10, size)
- inds, _, inverse_map = sparse_quantize(rounded_pc,
- pc,
- labels,
- return_index=True,
- return_invs=True)
+from torchsparse import SparseTensor
+from torchsparse import nn as spnn
+from torchsparse.utils.collate import sparse_collate_fn
+from torchsparse.utils.quantize import sparse_quantize
- voxel_pc = rounded_pc[inds]
- voxel_feat = pc[inds]
- voxel_labels = labels[inds]
- sparse_tensor = SparseTensor(voxel_feat, voxel_pc)
- label_tensor = SparseTensor(voxel_labels, voxel_pc)
+class RandomDataset:
- feed_dict = {'lidar': sparse_tensor, 'targets': label_tensor}
+ def __init__(self, input_size: int, voxel_size: float) -> None:
+ self.input_size = input_size
+ self.voxel_size = voxel_size
- return feed_dict
+ def __getitem__(self, _: int) -> Dict[str, Any]:
+ inputs = np.random.uniform(-100, 100, size=(self.input_size, 4))
+ labels = np.random.choice(10, size=self.input_size)
+ coords, feats = inputs[:, :3], inputs
+ coords -= np.min(coords, axis=0, keepdims=True)
+ coords, indices = sparse_quantize(coords,
+ self.voxel_size,
+ return_index=True)
-def generate_batched_random_point_clouds(size=100000,
- voxel_size=0.2,
- batch_size=2):
- batch = []
- for i in range(batch_size):
- batch.append(generate_random_point_cloud(size, voxel_size))
- return sparse_collate_fn(batch)
+ coords = torch.tensor(coords, dtype=torch.int)
+ feats = torch.tensor(feats[indices], dtype=torch.float)
+ labels = torch.tensor(labels[indices], dtype=torch.long)
+ input = SparseTensor(coords=coords, feats=feats)
+ label = SparseTensor(coords=coords, feats=labels)
+ return {'input': input, 'label': label}
-def dummy_train(device, mixed=False):
- model = nn.Sequential(
- spnn.Conv3d(4, 32, kernel_size=3, stride=1), spnn.BatchNorm(32),
- spnn.ReLU(True), spnn.Conv3d(32, 64, kernel_size=2, stride=2),
- spnn.BatchNorm(64), spnn.ReLU(True),
- spnn.Conv3d(64, 64, kernel_size=2, stride=2, transpose=True),
- spnn.BatchNorm(64), spnn.ReLU(True),
- spnn.Conv3d(64, 32, kernel_size=3, stride=1), spnn.BatchNorm(32),
- spnn.ReLU(True), spnn.Conv3d(32, 10, kernel_size=1)).to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- criterion = nn.CrossEntropyLoss().to(device)
- scaler = torch.cuda.amp.GradScaler(enabled=mixed)
-
- print('Starting dummy training...')
- for i in range(10):
- optimizer.zero_grad()
- feed_dict = generate_batched_random_point_clouds()
- inputs = feed_dict['lidar'].to(device)
- targets = feed_dict['targets'].F.to(device).long()
- with torch.cuda.amp.autocast(enabled=mixed):
- outputs = model(inputs)
- loss = criterion(outputs.F, targets)
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
- print('[step %d] loss = %f.' % (i, loss.item()))
- print('Finished dummy training!')
+ def __len__(self):
+ return 100
if __name__ == '__main__':
parser = argparse.ArgumentParser()
- parser.add_argument("--mixed", action="store_true")
+ parser.add_argument('--amp_enabled', action='store_true')
args = parser.parse_args()
- # set seeds for reproducibility
- np.random.seed(2021)
- torch.manual_seed(2021)
+ random.seed(0)
+ np.random.seed(0)
+ torch.manual_seed(0)
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
- dummy_train(device, args.mixed)
\ No newline at end of file
+ dataset = RandomDataset(input_size=10000, voxel_size=0.2)
+ dataflow = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=2,
+ collate_fn=sparse_collate_fn,
+ )
+
+ model = nn.Sequential(
+ spnn.Conv3d(4, 32, 3),
+ spnn.BatchNorm(32),
+ spnn.ReLU(True),
+ spnn.Conv3d(32, 64, 2, stride=2),
+ spnn.BatchNorm(64),
+ spnn.ReLU(True),
+ spnn.Conv3d(64, 64, 2, stride=2, transposed=True),
+ spnn.BatchNorm(64),
+ spnn.ReLU(True),
+ spnn.Conv3d(64, 32, 3),
+ spnn.BatchNorm(32),
+ spnn.ReLU(True),
+ spnn.Conv3d(32, 10, 1),
+ ).cuda()
+
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ scaler = amp.GradScaler(enabled=args.amp_enabled)
+
+ for k, feed_dict in enumerate(dataflow):
+ inputs = feed_dict['input'].cuda()
+ labels = feed_dict['label'].cuda()
+
+ with amp.autocast(enabled=args.amp_enabled):
+ outputs = model(inputs)
+ loss = criterion(outputs.feats, labels.feats)
+
+ print(f'[step {k + 1}] loss = {loss.item()}')
+
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
diff --git a/examples/performance.py b/examples/performance.py
index 250c569..1b52ccf 100644
--- a/examples/performance.py
+++ b/examples/performance.py
@@ -3,10 +3,14 @@
import numpy as np
import torch
import torch.autograd.profiler as profiler
+import torch.cuda
import torch.nn as nn
+import torch.optim
+
import torchsparse.nn as spnn
from torchsparse import SparseTensor
-from torchsparse.utils import sparse_collate_fn, sparse_quantize
+from torchsparse.utils.collate import sparse_collate_fn
+from torchsparse.utils.quantize import sparse_quantize
def generate_random_point_cloud(size=100000, voxel_size=0.2):
@@ -36,7 +40,7 @@ def generate_batched_random_point_clouds(size=100000,
voxel_size=0.2,
batch_size=2):
batch = []
- for i in range(batch_size):
+ for _ in range(batch_size):
batch.append(generate_random_point_cloud(size, voxel_size))
return sparse_collate_fn(batch)
@@ -47,10 +51,10 @@ def dummy_train_3x3(device):
spnn.Conv3d(32, 64, kernel_size=3, stride=1),
spnn.Conv3d(64, 128, kernel_size=3, stride=1),
spnn.Conv3d(128, 256, kernel_size=3, stride=1),
- spnn.Conv3d(256, 128, kernel_size=3, stride=1, transpose=True),
- spnn.Conv3d(128, 64, kernel_size=3, stride=1, transpose=True),
- spnn.Conv3d(64, 32, kernel_size=3, stride=1, transpose=True),
- spnn.Conv3d(32, 10, kernel_size=3, stride=1, transpose=True),
+ spnn.Conv3d(256, 128, kernel_size=3, stride=1, transposed=True),
+ spnn.Conv3d(128, 64, kernel_size=3, stride=1, transposed=True),
+ spnn.Conv3d(64, 32, kernel_size=3, stride=1, transposed=True),
+ spnn.Conv3d(32, 10, kernel_size=3, stride=1, transposed=True),
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(device)
@@ -58,8 +62,8 @@ def dummy_train_3x3(device):
print('Starting dummy_train_3x3...')
time = datetime.now()
with profiler.profile(profile_memory=True, use_cuda=True) as prof:
- with profiler.record_function("model_inference"):
- for i in range(10):
+ with profiler.record_function('model_inference'):
+ for _ in range(10):
feed_dict = generate_batched_random_point_clouds()
inputs = feed_dict['lidar'].to(device)
targets = feed_dict['targets'].F.to(device).long()
@@ -69,8 +73,8 @@ def dummy_train_3x3(device):
loss.backward()
optimizer.step()
# print('[step %d] loss = %f.'%(i, loss.item()))
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
- prof.export_chrome_trace("trace_dummy_3x3.json")
+ print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
+ prof.export_chrome_trace('trace_dummy_3x3.json')
time = datetime.now() - time
print('Finished dummy_train_3x3 in ', time)
@@ -82,10 +86,10 @@ def dummy_train_3x1(device):
spnn.Conv3d(32, 64, kernel_size=(1, 3, 3), stride=1),
spnn.Conv3d(64, 128, kernel_size=(3, 1, 3), stride=1),
spnn.Conv3d(128, 256, kernel_size=(1, 3, 3), stride=1),
- spnn.Conv3d(256, 128, kernel_size=(3, 1, 3), stride=1, transpose=True),
- spnn.Conv3d(128, 64, kernel_size=(1, 3, 3), stride=1, transpose=True),
- spnn.Conv3d(64, 32, kernel_size=(3, 1, 3), stride=1, transpose=True),
- spnn.Conv3d(32, 10, kernel_size=(1, 3, 3), stride=1, transpose=True),
+ spnn.Conv3d(256, 128, kernel_size=(3, 1, 3), stride=1, transposed=True),
+ spnn.Conv3d(128, 64, kernel_size=(1, 3, 3), stride=1, transposed=True),
+ spnn.Conv3d(64, 32, kernel_size=(3, 1, 3), stride=1, transposed=True),
+ spnn.Conv3d(32, 10, kernel_size=(1, 3, 3), stride=1, transposed=True),
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(device)
@@ -93,8 +97,8 @@ def dummy_train_3x1(device):
print('Starting dummy_train_3x1 ...')
time = datetime.now()
with profiler.profile(profile_memory=True, use_cuda=True) as prof:
- with profiler.record_function("model_inference"):
- for i in range(10):
+ with profiler.record_function('model_inference'):
+ for _ in range(10):
feed_dict = generate_batched_random_point_clouds()
inputs = feed_dict['lidar'].to(device)
targets = feed_dict['targets'].F.to(device).long()
@@ -104,8 +108,8 @@ def dummy_train_3x1(device):
loss.backward()
optimizer.step()
# print('[step %d] loss = %f.'%(i, loss.item()))
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
- prof.export_chrome_trace("trace_dummy_3x1.json")
+ print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
+ prof.export_chrome_trace('trace_dummy_3x1.json')
time = datetime.now() - time
print('Finished dummy_train_3x1 in ', time)
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..c662f4c
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,18 @@
+[yapf]
+based_on_style = google
+spaces_around_power_operator = true
+split_before_arithmetic_operator = true
+split_before_logical_operator = true
+split_before_bitwise_operator = true
+
+[isort]
+known_first_party = torchsparse
+
+[pydocstyle]
+convention = google
+
+[flake8]
+select = B, C, D, E, F, P, T4, W, B9
+ignore = D10, E501, E722, W503
+per-file-ignores =
+ __init__.py: F401, F403
diff --git a/setup.py b/setup.py
index 6177a69..299f005 100644
--- a/setup.py
+++ b/setup.py
@@ -1,66 +1,39 @@
+import glob
import os
import torch
+import torch.cuda
from setuptools import find_packages, setup
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension)
-has_cuda = (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv(
- 'FORCE_CUDA', '0') == '1'
-
from torchsparse import __version__
-# Notice that CUDA files, header files should not share names with CPP files.
-# Otherwise, there will be "ninja: warning: multiple rules generate xxx.o", which leads to
-# multiple definitions error!
+if ((torch.cuda.is_available() and CUDA_HOME is not None)
+ or (os.getenv('FORCE_CUDA', '0') == '1')):
+ device = 'cuda'
+else:
+ device = 'cpu'
-file_lis = [
- 'torchsparse/src/torchsparse_bindings_gpu.cpp',
- 'torchsparse/src/convolution/convolution_cpu.cpp',
- 'torchsparse/src/convolution/convolution.cu',
- 'torchsparse/src/convolution/convolution_gpu.cu',
- 'torchsparse/src/hash/hash_cpu.cpp',
- 'torchsparse/src/hash/hash.cpp',
- 'torchsparse/src/hash/hash_gpu.cu',
- 'torchsparse/src/hashmap/hashmap.cu',
- 'torchsparse/src/hashmap/hashmap_cpu.cpp',
- 'torchsparse/src/interpolation/devox_gpu.cu',
- 'torchsparse/src/interpolation/devox_deterministic.cpp',
- 'torchsparse/src/interpolation/devox_deterministic_gpu.cu',
- 'torchsparse/src/interpolation/devox_cpu.cpp',
- 'torchsparse/src/others/count.cpp',
- 'torchsparse/src/others/count_gpu.cu',
- 'torchsparse/src/others/count_cpu.cpp',
- 'torchsparse/src/others/insertion_gpu.cu',
- 'torchsparse/src/others/insertion_cpu.cpp',
- 'torchsparse/src/others/query.cpp',
- 'torchsparse/src/others/query_cpu.cpp',
-] if has_cuda else [
- 'torchsparse/src/torchsparse_bindings.cpp',
- 'torchsparse/src/convolution/convolution_cpu.cpp',
- 'torchsparse/src/hash/hash_cpu.cpp',
- 'torchsparse/src/hashmap/hashmap_cpu.cpp',
- 'torchsparse/src/interpolation/devox_cpu.cpp',
- 'torchsparse/src/others/insertion_cpu.cpp',
- 'torchsparse/src/others/query_cpu.cpp',
- 'torchsparse/src/others/count_cpu.cpp'
-]
+sources = [os.path.join('torchsparse', 'backend', f'pybind_{device}.cpp')]
+for fpath in glob.glob(os.path.join('torchsparse', 'backend', '**', '*')):
+ if ((fpath.endswith('_cpu.cpp') and device in ['cpu', 'cuda'])
+ or (fpath.endswith('_cuda.cu') and device == 'cuda')):
+ sources.append(fpath)
+extension_type = CUDAExtension if device == 'cuda' else CppExtension
extra_compile_args = {
'cxx': ['-g', '-O3', '-fopenmp', '-lgomp'],
'nvcc': ['-O3']
-} if has_cuda else {
- 'cxx': ['-g', '-O3', '-fopenmp', '-lgomp']
}
-extension_type = CUDAExtension if has_cuda else CppExtension
setup(
name='torchsparse',
version=__version__,
packages=find_packages(),
ext_modules=[
- extension_type('torchsparse_backend',
- file_lis,
+ extension_type('torchsparse.backend',
+ sources,
extra_compile_args=extra_compile_args)
],
cmdclass={'build_ext': BuildExtension},
diff --git a/torchsparse/__init__.py b/torchsparse/__init__.py
index c046f68..269c7ea 100644
--- a/torchsparse/__init__.py
+++ b/torchsparse/__init__.py
@@ -1,18 +1,3 @@
-import torch
-from .sparse_tensor import *
-from .point_tensor import *
-
-__version__ = '1.3.0'
-
-
-def cat(input_list, dim=1):
- assert len(input_list) > 0
- inputs = input_list[0]
- features = inputs.F
- coords = inputs.C
- cur_stride = inputs.s
- output_tensor = SparseTensor(
- torch.cat([inputs.F for inputs in input_list], 1), coords, cur_stride)
- output_tensor.coord_maps = inputs.coord_maps
- output_tensor.kernel_maps = inputs.kernel_maps
- return output_tensor
+from .operators import *
+from .tensor import *
+from .version import __version__
diff --git a/torchsparse/backend/convolution/convolution_cpu.cpp b/torchsparse/backend/convolution/convolution_cpu.cpp
new file mode 100644
index 0000000..b0b7925
--- /dev/null
+++ b/torchsparse/backend/convolution/convolution_cpu.cpp
@@ -0,0 +1,183 @@
+#include "convolution_cpu.h"
+
+#include
+
+#include
+#include
+
+void scatter_cpu(const int n_in, const int n_out, const int c,
+ const float *in_feat, float *out_feat, const int *kmap,
+ const bool transpose) {
+ for (int i = 0; i < n_in; i++) {
+ int out_pos = kmap[2 * i + 1 - transpose];
+ if (out_pos < 0) {
+ continue;
+ }
+#pragma omp parallel for
+ for (int j = 0; j < c; j++) {
+ out_feat[out_pos * c + j] += in_feat[i * c + j];
+ }
+ }
+}
+
+void gather_cpu(const int n_k, const int n_in, const int c,
+ const float *in_feat, float *out_feat, const int *kmap,
+ const bool transpose) {
+ for (int i = 0; i < n_k; i++) {
+ int in_pos = kmap[2 * i + transpose];
+ if (in_pos < 0) {
+ continue;
+ }
+#pragma omp parallel for
+ for (int j = 0; j < c; j++) {
+ out_feat[i * c + j] = in_feat[in_pos * c + j];
+ }
+ }
+}
+
+void convolution_forward_cpu(at::Tensor in_feat, at::Tensor out_feat,
+ at::Tensor kernel, at::Tensor neighbor_map,
+ at::Tensor neighbor_offset, const bool transpose) {
+ if (in_feat.size(1) != kernel.size(1)) {
+ throw std::invalid_argument("Input feature size and kernel size mismatch");
+ }
+
+ int out_nrows = out_feat.size(0);
+ out_feat.resize_({out_nrows, kernel.size(2)});
+ out_feat.zero_();
+
+ int kernel_volume = kernel.size(0);
+ int in_buffer_size = 1;
+ bool flag = false;
+ // memory optimization
+ if (kernel_volume % 2 && out_nrows == in_feat.size(0)) {
+ flag = true;
+ in_buffer_size =
+ *std::max_element(neighbor_offset.data_ptr(),
+ neighbor_offset.data_ptr() + kernel_volume / 2);
+ in_buffer_size =
+ std::max(in_buffer_size,
+ *std::max_element(
+ neighbor_offset.data_ptr() + kernel_volume / 2 + 1,
+ neighbor_offset.data_ptr() + kernel_volume));
+ in_buffer_size = std::max(in_buffer_size, 1);
+
+ torch::mm_out(out_feat, in_feat, kernel[kernel_volume / 2]);
+ } else {
+ in_buffer_size =
+ *std::max_element(neighbor_offset.data_ptr(),
+ neighbor_offset.data_ptr() + kernel_volume);
+ }
+
+ auto options =
+ torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device());
+ auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options);
+ auto out_buffer = torch::zeros({in_buffer_size, kernel.size(2)}, options);
+ int cur_offset = 0;
+ for (int i = 0; i < kernel_volume; i++) {
+ if (flag && (i == kernel_volume / 2)) {
+ cur_offset += 2 * neighbor_offset.data_ptr()[i];
+ continue;
+ }
+
+ if (neighbor_offset.data_ptr()[i] == 0) {
+ continue;
+ }
+
+ auto out_buffer_activated = torch::from_blob(
+ out_buffer.data_ptr(),
+ {neighbor_offset.data_ptr()[i], kernel.size(2)}, options);
+ auto in_buffer_activated = torch::from_blob(
+ in_buffer.data_ptr(),
+ {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options);
+
+ // gather
+ gather_cpu(in_buffer_activated.size(0), in_feat.size(0), kernel.size(1),
+ in_feat.data_ptr(), in_buffer_activated.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, transpose);
+
+ // matmul
+ torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]);
+
+ // scatter
+ scatter_cpu(neighbor_offset.data_ptr()[i], out_nrows, kernel.size(2),
+ out_buffer_activated.data_ptr(),
+ out_feat.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, transpose);
+ cur_offset += 2 * neighbor_offset.data_ptr()[i];
+ }
+}
+
+void convolution_backward_cpu(at::Tensor in_feat, at::Tensor grad_in_feat,
+ at::Tensor grad_out_feat, at::Tensor kernel,
+ at::Tensor grad_kernel, at::Tensor neighbor_map,
+ at::Tensor neighbor_offset,
+ const bool transpose) {
+ grad_in_feat.resize_as_(in_feat);
+ grad_in_feat.zero_();
+ grad_kernel.resize_as_(kernel);
+ grad_kernel.zero_();
+
+ int kernel_volume = kernel.size(0);
+ bool flag = false;
+ int in_buffer_size;
+ in_buffer_size =
+ *std::max_element(neighbor_offset.data_ptr(),
+ neighbor_offset.data_ptr() + kernel_volume);
+
+ auto options =
+ torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device());
+ auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options);
+ auto in_grad_buffer =
+ torch::zeros({in_buffer_size, in_feat.size(1)}, options);
+ auto out_grad_buffer =
+ torch::zeros({in_buffer_size, kernel.size(2)}, options);
+
+ int cur_offset = 0;
+ for (int i = 0; i < kernel_volume; i++) {
+ auto kernel_grad_buffer = grad_kernel[i];
+ if (flag && (i == kernel_volume / 2)) {
+ cur_offset += 2 * neighbor_offset.data_ptr()[i];
+ continue;
+ }
+
+ if (neighbor_offset.data_ptr()[i] == 0) {
+ continue;
+ }
+
+ auto out_grad_buffer_activated = torch::from_blob(
+ out_grad_buffer.data_ptr(),
+ {neighbor_offset.data_ptr()[i], kernel.size(2)}, options);
+ auto in_grad_buffer_activated = torch::from_blob(
+ in_grad_buffer.data_ptr(),
+ {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options);
+ auto in_buffer_activated = torch::from_blob(
+ in_buffer.data_ptr(),
+ {neighbor_offset.data_ptr()[i], in_feat.size(1)}, options);
+
+ // gather
+ gather_cpu(out_grad_buffer_activated.size(0), grad_out_feat.size(0),
+ kernel.size(2), grad_out_feat.data_ptr(),
+ out_grad_buffer_activated.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, !transpose);
+
+ gather_cpu(in_buffer_activated.size(0), in_feat.size(0), kernel.size(1),
+ in_feat.data_ptr(), in_buffer_activated.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, transpose);
+
+ // matmul
+ torch::mm_out(in_grad_buffer_activated, out_grad_buffer_activated,
+ torch::transpose(kernel[i], 0, 1));
+ torch::mm_out(kernel_grad_buffer,
+ torch::transpose(in_buffer_activated, 0, 1),
+ out_grad_buffer_activated);
+
+ // scatter
+ scatter_cpu(neighbor_offset.data_ptr()[i], in_feat.size(0),
+ kernel.size(1), in_grad_buffer_activated.data_ptr(),
+ grad_in_feat.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, !transpose);
+
+ cur_offset += 2 * neighbor_offset.data_ptr()[i];
+ }
+}
diff --git a/torchsparse/backend/convolution/convolution_cpu.h b/torchsparse/backend/convolution/convolution_cpu.h
new file mode 100644
index 0000000..4e37bcc
--- /dev/null
+++ b/torchsparse/backend/convolution/convolution_cpu.h
@@ -0,0 +1,15 @@
+#ifndef TORCHSPARSE_CONVOLUTION_CPU
+#define TORCHSPARSE_CONVOLUTION_CPU
+
+#include
+
+void convolution_forward_cpu(at::Tensor in_feat, at::Tensor out_feat,
+ at::Tensor kernel, at::Tensor neighbor_map,
+ at::Tensor neighbor_offset, const bool transpose);
+
+void convolution_backward_cpu(at::Tensor in_feat, at::Tensor grad_in_feat,
+ at::Tensor grad_out_feat, at::Tensor kernel,
+ at::Tensor grad_kernel, at::Tensor neighbor_map,
+ at::Tensor neighbor_offset, const bool transpose);
+
+#endif
diff --git a/torchsparse/backend/convolution/convolution_cuda.cu b/torchsparse/backend/convolution/convolution_cuda.cu
new file mode 100644
index 0000000..026f0e7
--- /dev/null
+++ b/torchsparse/backend/convolution/convolution_cuda.cu
@@ -0,0 +1,278 @@
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+#include "convolution_cuda.h"
+
+template
+__global__ void gather_kernel(const int n_k, const int n_in, const int c,
+ const scalar_t *in_feat, scalar_t *out_feat,
+ const int *kmap, const bool transpose) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = index / c;
+ int j = index % c;
+ if (i >= n_k) return;
+ int in_pos = kmap[2 * i + transpose];
+ if (in_pos < 0) return;
+ out_feat[i * c + j] = in_feat[in_pos * c + j];
+}
+
+template
+__global__ void scatter_kernel(const int n_in, const int n_out, const int c,
+ const scalar_t *in_feat, scalar_t *out_feat,
+ const int *kmap, const bool transpose) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = index / c;
+ int j = index % c;
+ if (i >= n_in) return;
+ int out_pos = kmap[2 * i + 1 - transpose];
+ if (out_pos < 0) return;
+ out_feat[out_pos * c + j] += in_feat[i * c + j];
+}
+
+// in_feat: (N, c) N=# of input points, c = input channels
+// out_feat: (M, o) M=# of output points, o = output channels
+// for stride=1, M=N. For stride>1, the N input coords
+// are requantized to M points with grid size (stride *
+// cur_stride)
+// kernel: (k^3, c, o) for a 3D convolution of length k
+// neighbor_map: (a, 2) the hash table query results from out_coords to
+// in_coords
+// where neighbor_map[:,0] is the index of the output
+// feature and neighbor_map[:,1] is the index of the input
+// feature
+// neighbor_offset: (k^3) count of active weights based on neighbor_map
+// with unused weights having 0 and neighbor_offset[k^3/2]
+// holding w[0,0].
+void convolution_forward_cuda(at::Tensor in_feat, at::Tensor out_feat,
+ at::Tensor kernel, at::Tensor neighbor_map,
+ at::Tensor neighbor_offset,
+ const bool transpose) {
+ if (in_feat.size(1) != kernel.size(1)) {
+ throw std::invalid_argument("Input feature size and kernel size mismatch");
+ }
+
+ bool is_half = in_feat.scalar_type() == at::ScalarType::Half;
+
+ int n_in_feats = in_feat.size(0);
+ int n_in_channels = in_feat.size(1);
+ int n_out_feats = out_feat.size(0);
+ int n_out_channels = out_feat.size(1);
+ ;
+
+ int kernel_volume = kernel.size(0);
+
+ // memory optimization
+ bool precompute_mid = false;
+ int mid_kernel = kernel_volume / 2;
+ int in_buffer_size = 1;
+ // we can precompute features for w[0,0] which avoids gather/scatter
+ if (kernel_volume % 2 == 1 && n_in_feats == n_out_feats) {
+ precompute_mid = true;
+ in_buffer_size =
+ *std::max_element(neighbor_offset.data_ptr(),
+ neighbor_offset.data_ptr() + mid_kernel);
+ in_buffer_size = std::max(
+ in_buffer_size,
+ *std::max_element(neighbor_offset.data_ptr() + mid_kernel + 1,
+ neighbor_offset.data_ptr() + kernel_volume));
+ in_buffer_size = std::max(in_buffer_size, 1);
+
+ // (N, c) X (c, o) = (N, o)
+ torch::mm_out(out_feat, in_feat, kernel[mid_kernel]);
+ } else {
+ in_buffer_size =
+ *std::max_element(neighbor_offset.data_ptr(),
+ neighbor_offset.data_ptr() + kernel_volume);
+ }
+
+ auto options =
+ torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device());
+ auto in_buffer = torch::zeros({in_buffer_size, n_in_channels}, options);
+ auto out_buffer = torch::zeros({in_buffer_size, n_out_channels}, options);
+ int cur_offset = 0;
+ // gather/gemm/scatter on each weight
+ for (int i = 0; i < kernel_volume; i++) {
+ int n_active_feats = neighbor_offset.data_ptr()[i];
+ // if there's no active features for this weight, skip it
+ if (n_active_feats == 0) {
+ continue;
+ }
+
+ // if w[0,0] was precomputed above, skip it
+ if ((i == mid_kernel) && precompute_mid) {
+ cur_offset += 2 * n_active_feats;
+ continue;
+ }
+
+ // in_buffer_activated (i, c) holds the dense input features from gather
+ // for i = n_active_feats (# of features in the activated kernel from
+ // neighbor_offset) out_buffer_activated (i, o) holds the dense output
+ // features to scatter
+ at::Tensor out_buffer_activated;
+ at::Tensor in_buffer_activated;
+ if (is_half) {
+ out_buffer_activated =
+ torch::from_blob(out_buffer.data_ptr(),
+ {n_active_feats, n_out_channels}, options);
+ in_buffer_activated =
+ torch::from_blob(in_buffer.data_ptr(),
+ {n_active_feats, n_in_channels}, options);
+ } else {
+ out_buffer_activated =
+ torch::from_blob(out_buffer.data_ptr(),
+ {n_active_feats, n_out_channels}, options);
+ in_buffer_activated =
+ torch::from_blob(in_buffer.data_ptr(),
+ {n_active_feats, n_in_channels}, options);
+ }
+
+ // gather n_active_feats dense features from N sparse input features with c
+ // feature dimensions
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ in_feat.type(), "convolution_forward_cuda", ([&] {
+ gather_kernel
+ <<>>(
+ n_active_feats, n_in_feats, n_in_channels,
+ in_feat.data_ptr(),
+ in_buffer_activated.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, transpose);
+ }));
+
+ // gemm: (i, c) X (c, o) = (i, o)
+ torch::mm_out(out_buffer_activated, in_buffer_activated, kernel[i]);
+
+ // scatter n_active_feats dense features into n_out_feats output features of
+ // dimension n_out_channels
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ in_feat.type(), "convolution_forward_cuda", ([&] {
+ scatter_kernel
+ <<>>(
+ n_active_feats, n_out_feats, n_out_channels,
+ out_buffer_activated.data_ptr(),
+ out_feat.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, transpose);
+ }));
+
+ cur_offset += 2 * n_active_feats;
+ }
+}
+
+void convolution_backward_cuda(at::Tensor in_feat, at::Tensor grad_in_feat,
+ at::Tensor grad_out_feat, at::Tensor kernel,
+ at::Tensor grad_kernel, at::Tensor neighbor_map,
+ at::Tensor neighbor_offset,
+ const bool transpose) {
+ grad_in_feat.resize_as_(in_feat);
+ grad_in_feat.zero_();
+ grad_kernel.resize_as_(kernel);
+ grad_kernel.zero_();
+
+ bool is_half = in_feat.scalar_type() == at::ScalarType::Half;
+ int n_in_feats = in_feat.size(0);
+ int n_in_channels = in_feat.size(1);
+ int n_out_feats = grad_out_feat.size(0);
+ int n_out_channels = kernel.size(-1);
+
+ int kernel_volume = kernel.size(0);
+ bool flag = false;
+ int in_buffer_size;
+ in_buffer_size =
+ *std::max_element(neighbor_offset.data_ptr(),
+ neighbor_offset.data_ptr() + kernel_volume);
+
+ auto options =
+ torch::TensorOptions().dtype(in_feat.dtype()).device(in_feat.device());
+ auto in_buffer = torch::zeros({in_buffer_size, in_feat.size(1)}, options);
+ auto in_grad_buffer =
+ torch::zeros({in_buffer_size, in_feat.size(1)}, options);
+ auto out_grad_buffer =
+ torch::zeros({in_buffer_size, kernel.size(2)}, options);
+
+ int cur_offset = 0;
+ for (int i = 0; i < kernel_volume; i++) {
+ auto kernel_grad_buffer = grad_kernel[i];
+ int n_active_feats = neighbor_offset.data_ptr()[i];
+ if (flag && (i == kernel_volume / 2)) {
+ cur_offset += 2 * n_active_feats;
+ continue;
+ }
+
+ if (n_active_feats == 0) {
+ continue;
+ }
+
+ // Can't figure out a cleaner way to do this
+ at::Tensor out_grad_buffer_activated;
+ at::Tensor in_grad_buffer_activated;
+ at::Tensor in_buffer_activated;
+ if (is_half) {
+ out_grad_buffer_activated =
+ torch::from_blob(out_grad_buffer.data_ptr(),
+ {n_active_feats, kernel.size(2)}, options);
+ in_grad_buffer_activated =
+ torch::from_blob(in_grad_buffer.data_ptr(),
+ {n_active_feats, in_feat.size(1)}, options);
+ in_buffer_activated =
+ torch::from_blob(in_buffer.data_ptr(),
+ {n_active_feats, in_feat.size(1)}, options);
+ } else {
+ out_grad_buffer_activated =
+ torch::from_blob(out_grad_buffer.data_ptr(),
+ {n_active_feats, kernel.size(2)}, options);
+ in_grad_buffer_activated =
+ torch::from_blob(in_grad_buffer.data_ptr(),
+ {n_active_feats, in_feat.size(1)}, options);
+ in_buffer_activated =
+ torch::from_blob(in_buffer.data_ptr(),
+ {n_active_feats, in_feat.size(1)}, options);
+ }
+
+ // gather
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ in_feat.type(), "convolution_forward_cuda", ([&] {
+ gather_kernel
+ <<>>(
+ n_active_feats, n_out_feats, n_out_channels,
+ grad_out_feat.data_ptr(),
+ out_grad_buffer_activated.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, !transpose);
+ }));
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ in_feat.type(), "convolution_forward_cuda", ([&] {
+ gather_kernel
+ <<>>(
+ n_active_feats, n_in_feats, n_in_channels,
+ in_feat.data_ptr(),
+ in_buffer_activated.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, transpose);
+ }));
+
+ // gemm
+ torch::mm_out(in_grad_buffer_activated, out_grad_buffer_activated,
+ torch::transpose(kernel[i], 0, 1));
+ torch::mm_out(kernel_grad_buffer,
+ torch::transpose(in_buffer_activated, 0, 1),
+ out_grad_buffer_activated);
+
+ // scatter
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ in_feat.type(), "convolution_forward_cuda", ([&] {
+ scatter_kernel
+ <<>>(
+ n_active_feats, n_in_feats, n_in_channels,
+ in_grad_buffer_activated.data_ptr(),
+ grad_in_feat.data_ptr(),
+ neighbor_map.data_ptr() + cur_offset, !transpose);
+ }));
+
+ cur_offset += 2 * n_active_feats;
+ }
+}
diff --git a/torchsparse/backend/convolution/convolution_cuda.h b/torchsparse/backend/convolution/convolution_cuda.h
new file mode 100644
index 0000000..52880e1
--- /dev/null
+++ b/torchsparse/backend/convolution/convolution_cuda.h
@@ -0,0 +1,16 @@
+#ifndef TORCHSPARSE_CONVOLUTION_CUDA
+#define TORCHSPARSE_CONVOLUTION_CUDA
+
+#include
+
+void convolution_forward_cuda(at::Tensor in_feat, at::Tensor out_feat,
+ at::Tensor kernel, at::Tensor neighbor_map,
+ at::Tensor neighbor_offset, const bool transpose);
+
+void convolution_backward_cuda(at::Tensor in_feat, at::Tensor grad_in_feat,
+ at::Tensor grad_out_feat, at::Tensor kernel,
+ at::Tensor grad_kernel, at::Tensor neighbor_map,
+ at::Tensor neighbor_offset,
+ const bool transpose);
+
+#endif
diff --git a/torchsparse/backend/devoxelize/devoxelize_cpu.cpp b/torchsparse/backend/devoxelize/devoxelize_cpu.cpp
new file mode 100644
index 0000000..dc2afd5
--- /dev/null
+++ b/torchsparse/backend/devoxelize/devoxelize_cpu.cpp
@@ -0,0 +1,59 @@
+#include "devoxelize_cpu.h"
+
+#include
+
+#include
+
+// make sure indices is int type
+// feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c)
+at::Tensor devoxelize_forward_cpu(const at::Tensor feat,
+ const at::Tensor indices,
+ const at::Tensor weight) {
+ int c = feat.size(1);
+ int N = indices.size(0);
+
+ at::Tensor out = torch::zeros(
+ {N, c}, at::device(feat.device()).dtype(at::ScalarType::Float));
+#pragma omp parallel for
+ for (int i = 0; i < N; i++) {
+ int *indices_ = indices.data_ptr() + i * 8;
+ float *weight_ = weight.data_ptr() + i * 8;
+ for (int j = 0; j < c; j++) {
+ float *feat_ = feat.data_ptr() + j;
+ float cur_feat;
+ for (int k = 0; k < 8; k++) {
+ cur_feat = (indices_[k] >= 0) ? feat_[indices_[k] * c] : 0;
+ *(out.data_ptr() + i * c + j) += weight_[k] * cur_feat;
+ }
+ }
+ }
+ return out;
+}
+
+// top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad:
+// (b,c,s), s=r^3
+at::Tensor devoxelize_backward_cpu(const at::Tensor top_grad,
+ const at::Tensor indices,
+ const at::Tensor weight, int n) {
+ int c = top_grad.size(1);
+ int N = top_grad.size(0);
+ at::Tensor bottom_grad = torch::zeros(
+ {n, c}, at::device(top_grad.device()).dtype(at::ScalarType::Float));
+
+ for (int i = 0; i < N; i++) {
+ int *indices_ = indices.data_ptr() + i * 8;
+ float *weight_ = weight.data_ptr() + i * 8;
+#pragma omp parallel for
+ for (int j = 0; j < c; j++) {
+ float *top_grad_ = top_grad.data_ptr() + j;
+ float cur_top_grad;
+ for (int k = 0; k < 8; k++) {
+ cur_top_grad = (indices_[k] >= 0) ? top_grad_[indices_[k] * c] : 0;
+ *(bottom_grad.data_ptr() + indices_[k] * c + j) +=
+ weight_[k] * cur_top_grad;
+ }
+ }
+ }
+
+ return bottom_grad;
+}
diff --git a/torchsparse/backend/devoxelize/devoxelize_cpu.h b/torchsparse/backend/devoxelize/devoxelize_cpu.h
new file mode 100644
index 0000000..38dce96
--- /dev/null
+++ b/torchsparse/backend/devoxelize/devoxelize_cpu.h
@@ -0,0 +1,14 @@
+#ifndef TORCHSPARSE_DEVOXELIZE_CPU
+#define TORCHSPARSE_DEVOXELIZE_CPU
+
+#include
+
+at::Tensor devoxelize_forward_cpu(const at::Tensor feat,
+ const at::Tensor indices,
+ const at::Tensor weight);
+
+at::Tensor devoxelize_backward_cpu(const at::Tensor top_grad,
+ const at::Tensor indices,
+ const at::Tensor weight, int n);
+
+#endif
diff --git a/torchsparse/backend/devoxelize/devoxelize_cuda.cu b/torchsparse/backend/devoxelize/devoxelize_cuda.cu
new file mode 100644
index 0000000..c2c0423
--- /dev/null
+++ b/torchsparse/backend/devoxelize/devoxelize_cuda.cu
@@ -0,0 +1,98 @@
+#include
+#include
+#include
+#include
+
+#include
+
+// input features (n, c), indices (N, 8), weight (N, 8) -> output features (N,
+// c)
+template
+__global__ void devoxelize_forward_kernel(int N, int c,
+ const int *__restrict__ indices,
+ const scalar_t *__restrict__ weight,
+ const scalar_t *__restrict__ feat,
+ scalar_t *__restrict__ out) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = index / c;
+ int j = index % c;
+
+ if (i < N) {
+ const int *indices_ = indices + 8 * i;
+ const scalar_t *weight_ = weight + 8 * i;
+ const scalar_t *feat_ = feat + j;
+
+ scalar_t cur_feat;
+ for (int k = 0; k < 8; k++) {
+ cur_feat = 0;
+ if (indices_[k] >= 0) cur_feat = feat_[indices_[k] * c];
+
+ out[i * c + j] += weight_[k] * cur_feat;
+ }
+ }
+}
+
+// input weight (N, 8), indices (N, 8), top_grad (N, c) -> bottom grad (n, c)
+template
+__global__ void devoxelize_backward_kernel(
+ int N, int n, int c, const int *__restrict__ indices,
+ const scalar_t *__restrict__ weight, const scalar_t *__restrict__ top_grad,
+ scalar_t *__restrict__ bottom_grad) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = index / c;
+ int j = index % c;
+
+ if (i < N) {
+ const int *indices_ = indices + 8 * i;
+ const scalar_t *weight_ = weight + 8 * i;
+
+ scalar_t cur_top_grad = top_grad[i * c + j];
+
+#pragma unroll
+ for (int k = 0; k < 8; k++) {
+ if (indices_[k] >= 0)
+ atomicAdd(&bottom_grad[indices_[k] * c + j], weight_[k] * cur_top_grad);
+ }
+ }
+}
+
+// make sure indices is int type
+// feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c)
+at::Tensor devoxelize_forward_cuda(const at::Tensor feat,
+ const at::Tensor indices,
+ const at::Tensor weight) {
+ int c = feat.size(1);
+ int N = indices.size(0);
+
+ at::Tensor out =
+ torch::zeros({N, c}, at::device(feat.device()).dtype(feat.dtype()));
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ feat.type(), "devoxelize_forward_cuda", ([&] {
+ devoxelize_forward_kernel<<>>(
+ N, c, indices.data_ptr(), weight.data_ptr(),
+ feat.data_ptr(), out.data_ptr());
+ }));
+
+ return out;
+}
+
+// top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad:
+// (b,c,s), s=r^3
+at::Tensor devoxelize_backward_cuda(const at::Tensor top_grad,
+ const at::Tensor indices,
+ const at::Tensor weight, int n) {
+ int c = top_grad.size(1);
+ int N = top_grad.size(0);
+ at::Tensor bottom_grad = torch::zeros(
+ {n, c}, at::device(top_grad.device()).dtype(top_grad.dtype()));
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ top_grad.type(), "devoxelize_backward_cuda", ([&] {
+ devoxelize_backward_kernel<<>>(
+ N, n, c, indices.data_ptr(), weight.data_ptr(),
+ top_grad.data_ptr(), bottom_grad.data_ptr());
+ }));
+
+ return bottom_grad;
+}
diff --git a/torchsparse/backend/devoxelize/devoxelize_cuda.h b/torchsparse/backend/devoxelize/devoxelize_cuda.h
new file mode 100644
index 0000000..ade14dc
--- /dev/null
+++ b/torchsparse/backend/devoxelize/devoxelize_cuda.h
@@ -0,0 +1,14 @@
+#ifndef TORCHSPARSE_DEVOXELIZE_CUDA
+#define TORCHSPARSE_DEVOXELIZE_CUDA
+
+#include
+
+at::Tensor devoxelize_forward_cuda(const at::Tensor feat,
+ const at::Tensor indices,
+ const at::Tensor weight);
+
+at::Tensor devoxelize_backward_cuda(const at::Tensor top_grad,
+ const at::Tensor indices,
+ const at::Tensor weight, int n);
+
+#endif
diff --git a/torchsparse/backend/hash/hash_cpu.cpp b/torchsparse/backend/hash/hash_cpu.cpp
new file mode 100644
index 0000000..a017214
--- /dev/null
+++ b/torchsparse/backend/hash/hash_cpu.cpp
@@ -0,0 +1,58 @@
+#include "hash_cpu.h"
+
+#include
+
+#include
+
+void cpu_hash_wrapper(int N, const int *data, long *out) {
+#pragma omp parallel for
+ for (int i = 0; i < N; i++) {
+ unsigned long long hash = 14695981039346656037UL;
+ for (int j = 0; j < 4; j++) {
+ hash ^= (unsigned int)data[4 * i + j];
+ hash *= 1099511628211UL;
+ }
+ hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF);
+ out[i] = hash;
+ }
+}
+
+void cpu_kernel_hash_wrapper(int N, int K, const int *data,
+ const int *kernel_offset, long int *out) {
+ for (int k = 0; k < K; k++) {
+#pragma omp parallel for
+ for (int i = 0; i < N; i++) {
+ int cur_coord[4];
+ for (int j = 0; j < 3; j++) {
+ cur_coord[j] = data[i * 4 + j] + kernel_offset[k * 3 + j];
+ }
+ cur_coord[3] = data[3];
+ unsigned long long hash = 14695981039346656037UL;
+ for (int j = 0; j < 4; j++) {
+ hash ^= (unsigned int)cur_coord[j];
+ hash *= 1099511628211UL;
+ }
+ hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF);
+ out[k * N + i] = hash;
+ }
+ }
+}
+
+at::Tensor hash_cpu(const at::Tensor idx) {
+ int N = idx.size(0);
+ at::Tensor out =
+ torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long));
+ cpu_hash_wrapper(N, idx.data_ptr(), out.data_ptr());
+ return out;
+}
+
+at::Tensor kernel_hash_cpu(const at::Tensor idx,
+ const at::Tensor kernel_offset) {
+ int N = idx.size(0);
+ int K = kernel_offset.size(0);
+ at::Tensor out = torch::zeros(
+ {K, N}, at::device(idx.device()).dtype(at::ScalarType::Long));
+ cpu_kernel_hash_wrapper(N, K, idx.data_ptr(),
+ kernel_offset.data_ptr(), out.data_ptr());
+ return out;
+}
diff --git a/torchsparse/backend/hash/hash_cpu.h b/torchsparse/backend/hash/hash_cpu.h
new file mode 100644
index 0000000..6367480
--- /dev/null
+++ b/torchsparse/backend/hash/hash_cpu.h
@@ -0,0 +1,11 @@
+#ifndef _SPARSE_HASH_CPU
+#define _SPARSE_HASH_CPU
+
+#include
+
+at::Tensor hash_cpu(const at::Tensor idx);
+
+at::Tensor kernel_hash_cpu(const at::Tensor idx,
+ const at::Tensor kernel_offset);
+
+#endif
diff --git a/torchsparse/backend/hash/hash_cuda.cu b/torchsparse/backend/hash/hash_cuda.cu
new file mode 100644
index 0000000..7193da7
--- /dev/null
+++ b/torchsparse/backend/hash/hash_cuda.cu
@@ -0,0 +1,84 @@
+#include
+#include
+#include
+
+#include
+#include
+
+// hashing
+// input N*4 int32 tensor output N*1 int64 tensor
+__global__ void hash_kernel(int N, const int *__restrict__ data,
+ long int *__restrict__ out) {
+ int i = blockDim.x * blockIdx.x + threadIdx.x;
+ if (i < N) {
+ data += i * 4;
+ unsigned long long hash = 14695981039346656037UL;
+ for (int j = 0; j < 4; j++) {
+ hash ^= (unsigned int)data[j];
+ hash *= 1099511628211UL;
+ }
+ hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF);
+ out[i] = hash;
+ }
+}
+
+// kernel hashing: given data D and offset map K, generate D x K
+// input N*4 int32 tensor, |K|*3 int32 tensor, output |K|*N int64 tensor
+__global__ void kernel_hash_kernel(int N, int K, const int *__restrict__ data,
+ const int *__restrict__ kernel_offset,
+ long int *__restrict__ out) {
+ extern __shared__ int kernel_offset_local[];
+
+ for (int i = 0; i < K * 3; i++) {
+ kernel_offset_local[i] = kernel_offset[i];
+ }
+ __syncthreads();
+
+ int idx = blockDim.x * blockIdx.x + threadIdx.x;
+ int k = idx % K;
+ int i = idx / K;
+ int cur_coord[4];
+ if (i < N) {
+ data += i * 4;
+ for (int j = 0; j < 3; j++) {
+ cur_coord[j] = data[j] + kernel_offset[k * 3 + j];
+ }
+ cur_coord[3] = data[3];
+ unsigned long long hash = 14695981039346656037UL;
+ for (int j = 0; j < 4; j++) {
+ hash ^= (unsigned int)cur_coord[j];
+ hash *= 1099511628211UL;
+ }
+ hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF);
+ out[k * N + i] = hash;
+ }
+}
+
+void kernel_hash_wrapper(int N, int K, const int *data,
+ const int *kernel_offset, long int *out) {
+ kernel_hash_kernel<<>>(
+ N, K, data, kernel_offset, out);
+}
+
+void hash_wrapper(int N, const int *data, long int *out) {
+ hash_kernel<<>>(N, data, out);
+}
+
+at::Tensor hash_cuda(const at::Tensor idx) {
+ int N = idx.size(0);
+ at::Tensor out =
+ torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long));
+ hash_wrapper(N, idx.data_ptr(), out.data_ptr());
+ return out;
+}
+
+at::Tensor kernel_hash_cuda(const at::Tensor idx,
+ const at::Tensor kernel_offset) {
+ int N = idx.size(0);
+ int K = kernel_offset.size(0);
+ at::Tensor out = torch::zeros(
+ {K, N}, at::device(idx.device()).dtype(at::ScalarType::Long));
+ kernel_hash_wrapper(N, K, idx.data_ptr(), kernel_offset.data_ptr(),
+ out.data_ptr());
+ return out;
+}
diff --git a/torchsparse/backend/hash/hash_cuda.h b/torchsparse/backend/hash/hash_cuda.h
new file mode 100644
index 0000000..83b5807
--- /dev/null
+++ b/torchsparse/backend/hash/hash_cuda.h
@@ -0,0 +1,11 @@
+#ifndef TORCHSPARSE_HASH_CUDA
+#define TORCHSPARSE_HASH_CUDA
+
+#include
+
+at::Tensor hash_cuda(const at::Tensor idx);
+
+at::Tensor kernel_hash_cuda(const at::Tensor idx,
+ const at::Tensor kernel_offset);
+
+#endif
diff --git a/torchsparse/backend/hashmap/hashmap_cpu.cpp b/torchsparse/backend/hashmap/hashmap_cpu.cpp
new file mode 100644
index 0000000..1cdce23
--- /dev/null
+++ b/torchsparse/backend/hashmap/hashmap_cpu.cpp
@@ -0,0 +1,28 @@
+#include "hashmap_cpu.hpp"
+
+#include
+#include
+#include
+#include
+
+void HashTableCPU::lookup_vals(const int64_t* const keys,
+ int64_t* const results, const int n) {
+#pragma omp parallel for
+ for (int idx = 0; idx < n; idx++) {
+ int64_t key = keys[idx];
+ google::dense_hash_map::iterator iter = hashmap.find(key);
+ if (iter != hashmap.end()) {
+ results[idx] = iter->second;
+ } else {
+ results[idx] = 0;
+ }
+ }
+}
+
+void HashTableCPU::insert_vals(const int64_t* const keys,
+ const int64_t* const vals, const int n) {
+ for (int i = 0; i < 10; i++) {
+ printf("%d, %d, %d, %d\n", i, i < n, n, i < 10);
+ // hashmap[(int)keys[idx]] = (int)vals[idx]+1;
+ }
+}
diff --git a/torchsparse/backend/hashmap/hashmap_cpu.hpp b/torchsparse/backend/hashmap/hashmap_cpu.hpp
new file mode 100644
index 0000000..d1edeff
--- /dev/null
+++ b/torchsparse/backend/hashmap/hashmap_cpu.hpp
@@ -0,0 +1,27 @@
+#ifndef _CUCKOO_MULTI_CPU_HPP_
+#define _CUCKOO_MULTI_CPU_HPP_
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+class HashTableCPU {
+ private:
+ google::dense_hash_map hashmap;
+
+ public:
+ HashTableCPU() {}
+
+ ~HashTableCPU() {}
+
+ void insert_vals(const int64_t* const keys, const int64_t* const vals,
+ const int n);
+
+ void lookup_vals(const int64_t* const keys, int64_t* const results,
+ const int n);
+};
+
+#endif
diff --git a/torchsparse/backend/hashmap/hashmap_cuda.cu b/torchsparse/backend/hashmap/hashmap_cuda.cu
new file mode 100644
index 0000000..5b6db6a
--- /dev/null
+++ b/torchsparse/backend/hashmap/hashmap_cuda.cu
@@ -0,0 +1,214 @@
+#include
+#include
+#include
+
+#include "hashmap_cuda.cuh"
+
+typedef unsigned long long int VTYPE;
+
+__global__ void cuckooBucketKernel_Multi(VTYPE *const key_buf,
+ VTYPE *const val_buf, const int size,
+ const VTYPE *const keys,
+ const VTYPE *const vals, const int n,
+ int *const counters,
+ const int num_buckets) {
+ // Get thread index.
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
+
+ // Only threads within range are active.
+ if (idx < n) {
+ // Do 1st-level hashing to get bucket id, then do atomic add to get index
+ // inside the bucket.
+ VTYPE key = keys[idx];
+ VTYPE val = vals[idx];
+
+ int bucket_num = do_1st_hash(key, num_buckets);
+ int bucket_ofs = atomicAdd(&counters[bucket_num], 1);
+
+ // Directly write the key into the table buffer.
+ if (bucket_ofs >= BUCKET_SIZE) {
+ printf("%d/%d ERROR: bucket overflow! (n=%d, bucket_num=%d/%d, key=%d)\n",
+ bucket_ofs, BUCKET_SIZE, n, bucket_num, num_buckets, key);
+ } else {
+ key_buf[bucket_num * BUCKET_SIZE + bucket_ofs] = key;
+ val_buf[bucket_num * BUCKET_SIZE + bucket_ofs] = val;
+ }
+ }
+}
+
+__global__ void cuckooInsertKernel_Multi(
+ VTYPE *const key, VTYPE *const val, const VTYPE *const key_buf,
+ const VTYPE *const val_buf, const int size,
+ const FuncConfig *const hash_func_configs, const int num_funcs,
+ const int *const counters, const int num_buckets, const int evict_bound,
+ const int pos_width, int *const rehash_requests) {
+ // Create local cuckoo table in shared memory. Size passed in as the third
+ // kernel parameter.
+ extern __shared__ VTYPE local_key[];
+ for (int i = 0; i < num_funcs; ++i) {
+ local_key[i * BUCKET_SIZE + threadIdx.x] = EMPTY_CELL;
+ }
+
+ // might be useful
+ __syncthreads();
+
+ // Get thread index.
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
+ VTYPE cur_idx = idx;
+
+ // Only threads within local bucket range are active.
+ if (threadIdx.x < counters[blockIdx.x]) {
+ // Set initial conditions.
+ VTYPE cur_key = key_buf[cur_idx];
+ int cur_func = 0;
+ int evict_count = 0;
+
+ // Start the test-kick-and-reinsert loops.
+ do {
+ int pos = do_2nd_hash(cur_key, hash_func_configs, cur_func, BUCKET_SIZE);
+
+ VTYPE new_data = make_data(cur_idx + 1, cur_func, pos_width);
+
+ VTYPE old_idx =
+ atomicExch(&local_key[cur_func * BUCKET_SIZE + pos], new_data);
+
+ if (old_idx != EMPTY_CELL) {
+ cur_idx = fetch_val(old_idx, pos_width) - 1;
+ // potential overflow here. It seems that cur_idx < 0 is possible!
+ cur_key = key_buf[cur_idx];
+ cur_func = (fetch_func(old_idx, pos_width) + 1) % num_funcs;
+ evict_count++;
+ } else {
+ break;
+ }
+
+ } while (evict_count < num_funcs * evict_bound);
+
+ // If exceeds eviction bound, then needs rehashing.
+ if (evict_count >= num_funcs * evict_bound) {
+ atomicAdd(rehash_requests, 1);
+ }
+ }
+
+ // Every thread write its responsible local slot into the global data table.
+ __syncthreads();
+ for (int i = 0; i < num_funcs; ++i) {
+ VTYPE cur_idx = local_key[i * BUCKET_SIZE + threadIdx.x];
+ if (cur_idx == EMPTY_CELL) {
+ continue;
+ }
+ int cur_func = fetch_func(cur_idx, pos_width);
+ cur_idx = fetch_val(cur_idx, pos_width) - 1;
+ key[i * size + idx] = key_buf[cur_idx];
+ val[i * size + idx] = val_buf[cur_idx];
+ }
+}
+
+__global__ void cuckooLookupKernel_Multi(
+ const VTYPE *const keys, VTYPE *const results, const int n,
+ const VTYPE *const all_keys, const VTYPE *const all_vals, const int size,
+ const FuncConfig *const hash_func_configs, const int num_funcs,
+ const int num_buckets, const int pos_width) {
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
+
+ // Only threads within range are active.
+ if (idx < n) {
+ VTYPE key = keys[idx];
+ int bucket_num = do_1st_hash(key, num_buckets);
+ for (int i = 0; i < num_funcs; ++i) {
+ int pos = bucket_num * BUCKET_SIZE +
+ do_2nd_hash(key, hash_func_configs, i, BUCKET_SIZE);
+ if (all_keys[i * size + pos] == key) {
+ results[idx] = all_vals[i * size + pos] + 1;
+ return;
+ }
+ }
+
+ // TODO(Haotian): should be a value that will not be encountered.
+ results[idx] = EMPTY_CELL;
+ }
+}
+
+void CuckooHashTableCuda_Multi::lookup_vals(const VTYPE *const keys,
+ VTYPE *d_key, VTYPE *d_val,
+ VTYPE *const results, const int n) {
+ // Launch the lookup kernel.
+ cuckooLookupKernel_Multi<<>>(
+ keys, results, n, d_key, d_val, _size, _d_hash_func_configs, _num_funcs,
+ _num_buckets, _pos_width);
+}
+
+int CuckooHashTableCuda_Multi::insert_vals(const VTYPE *const keys,
+ const VTYPE *const vals,
+ VTYPE *d_key_buf, VTYPE *d_val_buf,
+ VTYPE *d_key, VTYPE *d_val,
+ const int n) {
+ //
+ // Phase 1: Distribute keys into buckets.
+ //
+
+ // Allocate GPU memory.
+
+ int *d_counters = NULL;
+
+ cudaMalloc((void **)&d_counters, _num_buckets * sizeof(int));
+
+ cudaMemset(d_counters, 0, _num_buckets * sizeof(int));
+
+ // Invoke bucket kernel.
+ cuckooBucketKernel_Multi<<>>(
+ d_key_buf, d_val_buf, _size, keys, vals, n, d_counters, _num_buckets);
+
+ //
+ // Phase 2: Local cuckoo hashing.
+ //
+
+ // Allocate GPU memory.
+
+ cudaDeviceSynchronize();
+ int *d_rehash_requests = NULL;
+
+ cudaMalloc((void **)&d_rehash_requests, sizeof(int));
+
+ // Copy values onto GPU memory.
+ cudaMemcpy(_d_hash_func_configs, _hash_func_configs,
+ _num_funcs * sizeof(FuncConfig), cudaMemcpyHostToDevice);
+
+ // Invoke insert kernel. Passes shared memory table size by the third
+ // argument. Loops until no rehashing needed.
+
+ int rehash_count = 0;
+ do {
+ int rehash_requests = 0;
+ cudaMemset(d_rehash_requests, 0, sizeof(int));
+ cuckooInsertKernel_Multi<<>>(
+ d_key, d_val, d_key_buf, d_val_buf, _size, _d_hash_func_configs,
+ _num_funcs, d_counters, _num_buckets, _evict_bound, _pos_width,
+ d_rehash_requests);
+ cudaMemcpy(&rehash_requests, d_rehash_requests, sizeof(int),
+ cudaMemcpyDeviceToHost);
+
+ if (rehash_requests == 0) {
+ break;
+ } else {
+ rehash_count++;
+ gen_hash_funcs();
+ cudaMemcpy(_d_hash_func_configs, _hash_func_configs,
+ _num_funcs * sizeof(FuncConfig), cudaMemcpyHostToDevice);
+ }
+ } while (rehash_count < MAX_DEPTH);
+
+ cudaDeviceSynchronize();
+
+ // Free GPU resources.
+
+ if (d_counters != NULL) {
+ cudaFree(d_counters);
+ }
+ if (d_rehash_requests != NULL) {
+ cudaFree(d_rehash_requests);
+ }
+
+ return (rehash_count < MAX_DEPTH) ? rehash_count : ERR_DEPTH;
+}
diff --git a/torchsparse/backend/hashmap/hashmap_cuda.cuh b/torchsparse/backend/hashmap/hashmap_cuda.cuh
new file mode 100644
index 0000000..cdaca02
--- /dev/null
+++ b/torchsparse/backend/hashmap/hashmap_cuda.cuh
@@ -0,0 +1,146 @@
+#ifndef _CUCKOO_CUDA_MULTI_HPP_
+#define _CUCKOO_CUDA_MULTI_HPP_
+
+#include
+#include
+#include
+#include
+#include
+
+#include "cuda_runtime.h"
+
+/** Reserved value for indicating "empty". */
+#define EMPTY_CELL (0)
+/** Max rehashing depth, and error depth. */
+#define MAX_DEPTH (100)
+#define ERR_DEPTH (-1)
+/** CUDA naive thread block size. */
+#define BLOCK_SIZE (256)
+/** CUDA multi-level thread block size = bucket size. */
+#define BUCKET_SIZE (512)
+
+typedef unsigned long long int VTYPE;
+
+/** Struct of a hash function config. */
+typedef struct {
+ int rv; // Randomized XOR value.
+ int ss; // Randomized shift filter start position.
+} FuncConfig;
+
+/** Hard code hash functions and all inline helper functions for CUDA kernels'
+ * use. */
+inline __device__ int do_1st_hash(const VTYPE val, const int num_buckets) {
+ return val % num_buckets;
+}
+
+inline __device__ int do_2nd_hash(const VTYPE val,
+ const FuncConfig *const hash_func_configs,
+ const int func_idx, const int size) {
+ FuncConfig fc = hash_func_configs[func_idx];
+ return ((val ^ fc.rv) >> fc.ss) % size; // XOR function as 2nd-level hashing.
+}
+
+// trying to ignore EMPTY_CELL by adding 1 at make_data.
+inline __device__ VTYPE fetch_val(const VTYPE data, const int pos_width) {
+ return data >> pos_width;
+}
+
+inline __device__ int fetch_func(const VTYPE data, const int pos_width) {
+ return data & ((0x1 << pos_width) - 1);
+}
+
+inline __device__ VTYPE make_data(const VTYPE val, const int func,
+ const int pos_width) {
+ return (val << pos_width) ^ func;
+}
+
+class CuckooHashTableCuda_Multi {
+ private:
+ const int _size;
+ const int _evict_bound;
+ const int _num_funcs;
+ const int _pos_width;
+ const int _num_buckets;
+
+ FuncConfig *_d_hash_func_configs;
+
+ /** Cuckoo hash function set. */
+ FuncConfig *_hash_func_configs;
+
+ /** Private operations. */
+ void gen_hash_funcs() {
+ // Calculate bit width of value range and table size.
+ int val_width = 8 * sizeof(VTYPE) - ceil(log2((double)_num_funcs));
+ int bucket_width = ceil(log2((double)_num_buckets));
+ int size_width = ceil(log2((double)BUCKET_SIZE));
+ // Generate randomized configurations.
+ for (int i = 0; i < _num_funcs; ++i) { // At index 0 is a dummy function.
+ if (val_width - bucket_width <= size_width)
+ _hash_func_configs[i] = {rand(), 0};
+ else {
+ _hash_func_configs[i] = {
+ rand(), rand() % (val_width - bucket_width - size_width + 1) +
+ bucket_width};
+ }
+ }
+ };
+
+ inline VTYPE fetch_val(const VTYPE data) { return data >> _pos_width; }
+ inline int fetch_func(const VTYPE data) {
+ return data & ((0x1 << _pos_width) - 1);
+ }
+
+ public:
+ CuckooHashTableCuda_Multi(const int size, const int evict_bound,
+ const int num_funcs)
+ : _size(size),
+ _evict_bound(evict_bound),
+ _num_funcs(num_funcs),
+ _pos_width(ceil(log2((double)_num_funcs))),
+ _num_buckets(ceil((double)_size / BUCKET_SIZE)) {
+ srand(time(NULL));
+ _d_hash_func_configs = NULL;
+ _hash_func_configs = NULL;
+ _hash_func_configs = new FuncConfig[num_funcs];
+
+ gen_hash_funcs();
+
+ cudaMalloc((void **)&_d_hash_func_configs, _num_funcs * sizeof(FuncConfig));
+ cudaMemcpy(_d_hash_func_configs, _hash_func_configs,
+ _num_funcs * sizeof(FuncConfig), cudaMemcpyHostToDevice);
+ };
+ ~CuckooHashTableCuda_Multi() {
+ if (_hash_func_configs != NULL) delete[] _hash_func_configs;
+
+ if (_d_hash_func_configs != NULL) cudaFree(_d_hash_func_configs);
+ };
+
+ int insert_vals(const VTYPE *const keys, const VTYPE *const vals,
+ VTYPE *d_key_buf, VTYPE *d_val_buf, VTYPE *d_key,
+ VTYPE *d_val, const int n);
+
+ void lookup_vals(const VTYPE *const keys, VTYPE *const results, VTYPE *d_key,
+ VTYPE *d_val, const int n);
+};
+
+__global__ void cuckooBucketKernel_Multi(VTYPE *const key_buf,
+ VTYPE *const val_buf, const int size,
+ const VTYPE *const keys,
+ const VTYPE *const vals, const int n,
+ int *const counters,
+ const int num_buckets);
+
+__global__ void cuckooInsertKernel_Multi(
+ VTYPE *const key, VTYPE *const val, const VTYPE *const key_buf,
+ const VTYPE *const val_buf, const int size,
+ const FuncConfig *const hash_func_configs, const int num_funcs,
+ const int *const counters, const int num_buckets, const int evict_bound,
+ const int pos_width, int *const rehash_requests);
+
+__global__ void cuckooLookupKernel_Multi(
+ const VTYPE *const keys, VTYPE *const results, const int n,
+ const VTYPE *const all_keys, const VTYPE *const all_vals, const int size,
+ const FuncConfig *const hash_func_configs, const int num_funcs,
+ const int num_buckets, const int pos_width);
+
+#endif
diff --git a/torchsparse/backend/others/count_cpu.cpp b/torchsparse/backend/others/count_cpu.cpp
new file mode 100644
index 0000000..ba0611c
--- /dev/null
+++ b/torchsparse/backend/others/count_cpu.cpp
@@ -0,0 +1,23 @@
+#include "count_cpu.h"
+
+#include
+
+#include
+
+at::Tensor count_cpu(const at::Tensor idx, const int s) {
+ int N = idx.size(0);
+ at::Tensor out =
+ torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int));
+ int *idx_ = idx.data_ptr();
+ int *out_ = out.data_ptr();
+#pragma omp parallel for
+ for (int i = 0; i < N; i++) {
+ int cur_idx = idx_[i];
+ if (cur_idx < 0) {
+ continue;
+ }
+#pragma omp atomic
+ out_[cur_idx]++;
+ }
+ return out;
+}
diff --git a/torchsparse/backend/others/count_cpu.h b/torchsparse/backend/others/count_cpu.h
new file mode 100644
index 0000000..f2a0ab3
--- /dev/null
+++ b/torchsparse/backend/others/count_cpu.h
@@ -0,0 +1,8 @@
+#ifndef _SPARSE_COUNT_CPU
+#define _SPARSE_COUNT_CPU
+
+#include
+
+at::Tensor count_cpu(const at::Tensor idx, const int s);
+
+#endif
diff --git a/torchsparse/backend/others/count_cuda.cu b/torchsparse/backend/others/count_cuda.cu
new file mode 100644
index 0000000..4860422
--- /dev/null
+++ b/torchsparse/backend/others/count_cuda.cu
@@ -0,0 +1,31 @@
+#include
+#include
+#include
+
+#include
+#include
+
+// counting
+// input N*3 int32 tensor output N*1 int64 tensor
+__global__ void count_kernel(int N, const int *__restrict__ data,
+ int *__restrict__ out) {
+ int i = blockDim.x * blockIdx.x + threadIdx.x;
+ if (i < N && data[i] >= 0) {
+ atomicAdd(&out[data[i]], 1);
+ }
+}
+
+void count_wrapper(int N, const int *data, int *out) {
+ count_kernel<<>>(N, data, out);
+}
+
+// make sure indices is int type
+// feat: (b,c,n) indices: (b,n) -> out: (b,c,s), out_indices: (b,n)
+// (preprocessed indices)
+at::Tensor count_cuda(const at::Tensor idx, const int s) {
+ int N = idx.size(0);
+ at::Tensor out =
+ torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int));
+ count_wrapper(N, idx.data_ptr(), out.data_ptr());
+ return out;
+}
diff --git a/torchsparse/backend/others/count_cuda.h b/torchsparse/backend/others/count_cuda.h
new file mode 100644
index 0000000..2bb64f6
--- /dev/null
+++ b/torchsparse/backend/others/count_cuda.h
@@ -0,0 +1,8 @@
+#ifndef _SPARSE_COUNT
+#define _SPARSE_COUNT
+
+#include
+
+at::Tensor count_cuda(const at::Tensor idx, const int s);
+
+#endif
diff --git a/torchsparse/src/others/query_cpu.cpp b/torchsparse/backend/others/query_cpu.cpp
similarity index 50%
rename from torchsparse/src/others/query_cpu.cpp
rename to torchsparse/backend/others/query_cpu.cpp
index cfba240..915aa83 100644
--- a/torchsparse/src/others/query_cpu.cpp
+++ b/torchsparse/backend/others/query_cpu.cpp
@@ -1,41 +1,34 @@
+#include "query_cpu.h"
+
#include
-#include "../hashmap/hashmap_cpu_header.hpp"
-#include
+
#include
-#include
-#include "query_cpu_header.h"
#include
+#include
+#include
+
+#include "../hashmap/hashmap_cpu.hpp"
-at::Tensor cpu_query_forward(
- const at::Tensor hash_query,
- const at::Tensor hash_target,
- const at::Tensor idx_target)
-{
- //return group_point_forward_gpu(points, indices);
+at::Tensor hash_query_cpu(const at::Tensor hash_query,
+ const at::Tensor hash_target,
+ const at::Tensor idx_target) {
int n = hash_target.size(0);
int n1 = hash_query.size(0);
google::dense_hash_map hashmap;
hashmap.set_empty_key(0);
- /*
- HashTableCPU in_hash_table;
- printf("inserting %d %d...\n", n, n1);
- in_hash_table.insert_vals(hash_target.data_ptr(), idx_target.data_ptr(), n);
- */
- at::Tensor out = torch::zeros({n1}, at::device(hash_query.device()).dtype(at::ScalarType::Long));
- for (int idx = 0; idx < n; idx++)
- {
+ at::Tensor out = torch::zeros(
+ {n1}, at::device(hash_query.device()).dtype(at::ScalarType::Long));
+ for (int idx = 0; idx < n; idx++) {
int64_t key = *(hash_target.data_ptr() + idx);
int64_t val = *(idx_target.data_ptr() + idx) + 1;
hashmap.insert(std::make_pair(key, val));
}
#pragma omp parallel for
- for (int idx = 0; idx < n1; idx++)
- {
+ for (int idx = 0; idx < n1; idx++) {
int64_t key = *(hash_query.data_ptr() + idx);
google::dense_hash_map::iterator iter = hashmap.find(key);
- if (iter != hashmap.end())
- {
+ if (iter != hashmap.end()) {
*(out.data_ptr() + idx) = iter->second;
}
}
diff --git a/torchsparse/backend/others/query_cpu.h b/torchsparse/backend/others/query_cpu.h
new file mode 100644
index 0000000..b3c6970
--- /dev/null
+++ b/torchsparse/backend/others/query_cpu.h
@@ -0,0 +1,10 @@
+#ifndef _SPARSE_QUERY_CPU
+#define _SPARSE_QUERY_CPU
+
+#include
+
+at::Tensor hash_query_cpu(const at::Tensor hash_query,
+ const at::Tensor hash_target,
+ const at::Tensor idx_target);
+
+#endif
diff --git a/torchsparse/backend/others/query_cuda.cu b/torchsparse/backend/others/query_cuda.cu
new file mode 100644
index 0000000..0209c00
--- /dev/null
+++ b/torchsparse/backend/others/query_cuda.cu
@@ -0,0 +1,58 @@
+#include
+
+#include
+#include
+#include
+
+#include "../hashmap/hashmap_cuda.cuh"
+
+at::Tensor hash_query_cuda(const at::Tensor hash_query,
+ const at::Tensor hash_target,
+ const at::Tensor idx_target) {
+ // return group_point_forward_gpu(points, indices);
+ int n = hash_target.size(0);
+ int n1 = hash_query.size(0);
+ const int nextPow2 = pow(2, ceil(log2((double)n)));
+ // When n is large, the hash values tend to be more evenly distrubuted and
+ // choosing table_size to be 2 * nextPow2 typically suffices. For smaller n,
+ // the effect of uneven distribution of hash values is more pronounced and
+ // hence we choose table_size to be 4 * nextPow2 to reduce the chance of
+ // bucket overflow.
+ int table_size = (n < 2048) ? 4 * nextPow2 : 2 * nextPow2;
+ if (table_size < 512) {
+ table_size = 512;
+ }
+ int num_funcs = 3;
+ CuckooHashTableCuda_Multi in_hash_table(table_size, 8 * ceil(log2((double)n)),
+ num_funcs);
+ at::Tensor key_buf =
+ torch::zeros({table_size},
+ at::device(hash_query.device()).dtype(at::ScalarType::Long));
+ at::Tensor val_buf =
+ torch::zeros({table_size},
+ at::device(hash_query.device()).dtype(at::ScalarType::Long));
+ at::Tensor key =
+ torch::zeros({num_funcs * table_size},
+ at::device(hash_query.device()).dtype(at::ScalarType::Long));
+ at::Tensor val =
+ torch::zeros({num_funcs * table_size},
+ at::device(hash_query.device()).dtype(at::ScalarType::Long));
+
+ in_hash_table.insert_vals(
+ (unsigned long long int *)(hash_target.data_ptr()),
+ (unsigned long long int *)(idx_target.data_ptr()),
+ (unsigned long long int *)(key_buf.data_ptr()),
+ (unsigned long long int *)(val_buf.data_ptr()),
+ (unsigned long long int *)(key.data_ptr()),
+ (unsigned long long int *)(val.data_ptr()), n);
+
+ at::Tensor out = torch::zeros(
+ {n1}, at::device(hash_query.device()).dtype(at::ScalarType::Long));
+
+ in_hash_table.lookup_vals(
+ (unsigned long long int *)(hash_query.data_ptr()),
+ (unsigned long long int *)(key.data_ptr()),
+ (unsigned long long int *)(val.data_ptr()),
+ (unsigned long long int *)(out.data_ptr()), n1);
+ return out;
+}
diff --git a/torchsparse/backend/others/query_cuda.h b/torchsparse/backend/others/query_cuda.h
new file mode 100644
index 0000000..a46aedf
--- /dev/null
+++ b/torchsparse/backend/others/query_cuda.h
@@ -0,0 +1,10 @@
+#ifndef _SPARSE_QUERY
+#define _SPARSE_QUERY
+
+#include
+
+at::Tensor hash_query_cuda(const at::Tensor hash_query,
+ const at::Tensor hash_target,
+ const at::Tensor idx_target);
+
+#endif
diff --git a/torchsparse/backend/pybind_cpu.cpp b/torchsparse/backend/pybind_cpu.cpp
new file mode 100644
index 0000000..d7ab41c
--- /dev/null
+++ b/torchsparse/backend/pybind_cpu.cpp
@@ -0,0 +1,23 @@
+#include
+#include
+#include
+
+#include "convolution/convolution_cpu.h"
+#include "devoxelize/devoxelize_cpu.h"
+#include "hash/hash_cpu.h"
+#include "others/count_cpu.h"
+#include "others/query_cpu.h"
+#include "voxelize/voxelize_cpu.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("convolution_forward_cpu", &convolution_forward_cpu);
+ m.def("convolution_backward_cpu", &convolution_backward_cpu);
+ m.def("voxelize_forward_cpu", &voxelize_forward_cpu);
+ m.def("voxelize_backward_cpu", &voxelize_backward_cpu);
+ m.def("devoxelize_forward_cpu", &devoxelize_forward_cpu);
+ m.def("devoxelize_backward_cpu", &devoxelize_backward_cpu);
+ m.def("hash_cpu", &hash_cpu);
+ m.def("kernel_hash_cpu", &kernel_hash_cpu);
+ m.def("hash_query_cpu", &hash_query_cpu);
+ m.def("count_cpu", &count_cpu);
+}
diff --git a/torchsparse/backend/pybind_cuda.cpp b/torchsparse/backend/pybind_cuda.cpp
new file mode 100644
index 0000000..be0e78d
--- /dev/null
+++ b/torchsparse/backend/pybind_cuda.cpp
@@ -0,0 +1,39 @@
+#include
+#include
+#include
+
+#include "convolution/convolution_cpu.h"
+#include "convolution/convolution_cuda.h"
+#include "devoxelize/devoxelize_cpu.h"
+#include "devoxelize/devoxelize_cuda.h"
+#include "hash/hash_cpu.h"
+#include "hash/hash_cuda.h"
+#include "others/count_cpu.h"
+#include "others/count_cuda.h"
+#include "others/query_cpu.h"
+#include "others/query_cuda.h"
+#include "voxelize/voxelize_cpu.h"
+#include "voxelize/voxelize_cuda.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("convolution_forward_cpu", &convolution_forward_cpu);
+ m.def("convolution_forward_cuda", &convolution_forward_cuda);
+ m.def("convolution_backward_cpu", &convolution_backward_cpu);
+ m.def("convolution_backward_cuda", &convolution_backward_cuda);
+ m.def("voxelize_forward_cpu", &voxelize_forward_cpu);
+ m.def("voxelize_forward_cuda", &voxelize_forward_cuda);
+ m.def("voxelize_backward_cpu", &voxelize_backward_cpu);
+ m.def("voxelize_backward_cuda", &voxelize_backward_cuda);
+ m.def("devoxelize_forward_cpu", &devoxelize_forward_cpu);
+ m.def("devoxelize_forward_cuda", &devoxelize_forward_cuda);
+ m.def("devoxelize_backward_cpu", &devoxelize_backward_cpu);
+ m.def("devoxelize_backward_cuda", &devoxelize_backward_cuda);
+ m.def("hash_cpu", &hash_cpu);
+ m.def("hash_cuda", &hash_cuda);
+ m.def("kernel_hash_cpu", &kernel_hash_cpu);
+ m.def("kernel_hash_cuda", &kernel_hash_cuda);
+ m.def("hash_query_cpu", &hash_query_cpu);
+ m.def("hash_query_cuda", &hash_query_cuda);
+ m.def("count_cpu", &count_cpu);
+ m.def("count_cuda", &count_cuda);
+}
diff --git a/torchsparse/backend/voxelize/voxelize_cpu.cpp b/torchsparse/backend/voxelize/voxelize_cpu.cpp
new file mode 100644
index 0000000..938b7bc
--- /dev/null
+++ b/torchsparse/backend/voxelize/voxelize_cpu.cpp
@@ -0,0 +1,43 @@
+#include "voxelize_cpu.h"
+
+#include
+
+#include
+
+at::Tensor voxelize_forward_cpu(const at::Tensor inputs, const at::Tensor idx,
+ const at::Tensor counts) {
+ int N = inputs.size(0);
+ int c = inputs.size(1);
+ int N1 = counts.size(0);
+ at::Tensor out = torch::zeros(
+ {N1, c}, at::device(idx.device()).dtype(at::ScalarType::Float));
+ for (int i = 0; i < N; i++) {
+ int pos = *(idx.data_ptr() + i);
+ if (*(counts.data_ptr() + pos) == 0) continue;
+#pragma omp parallel for
+ for (int j = 0; j < c; j++) {
+ *(out.data_ptr() + pos * c + j) +=
+ *(inputs.data_ptr() + i * c + j) /
+ (float)(*(counts.data_ptr() + pos));
+ }
+ }
+ return out;
+}
+
+at::Tensor voxelize_backward_cpu(const at::Tensor top_grad,
+ const at::Tensor idx, const at::Tensor counts,
+ const int N) {
+ int c = top_grad.size(1);
+ at::Tensor bottom_grad = torch::zeros(
+ {N, c}, at::device(idx.device()).dtype(at::ScalarType::Float));
+ for (int i = 0; i < N; i++) {
+ if (*(counts.data_ptr() + *(idx.data_ptr() + i)) == 0) continue;
+#pragma omp parallel for
+ for (int j = 0; j < c; j++) {
+ *(bottom_grad.data_ptr() + i * c + j) =
+ *(top_grad.data_ptr() + *(idx.data_ptr() + i) * c + j) /
+ (float)(*(counts.data_ptr() + *(idx.data_ptr() + i)));
+ }
+ }
+ return bottom_grad;
+}
diff --git a/torchsparse/backend/voxelize/voxelize_cpu.h b/torchsparse/backend/voxelize/voxelize_cpu.h
new file mode 100644
index 0000000..bed480e
--- /dev/null
+++ b/torchsparse/backend/voxelize/voxelize_cpu.h
@@ -0,0 +1,13 @@
+#ifndef TORCHSPARSE_VOXELIZE_CPU
+#define TORCHSPARSE_VOXELIZE_CPU
+
+#include
+
+at::Tensor voxelize_forward_cpu(const at::Tensor inputs, const at::Tensor idx,
+ const at::Tensor counts);
+
+at::Tensor voxelize_backward_cpu(const at::Tensor top_grad,
+ const at::Tensor idx, const at::Tensor counts,
+ const int N);
+
+#endif
diff --git a/torchsparse/backend/voxelize/voxelize_cuda.cu b/torchsparse/backend/voxelize/voxelize_cuda.cu
new file mode 100644
index 0000000..a47f605
--- /dev/null
+++ b/torchsparse/backend/voxelize/voxelize_cuda.cu
@@ -0,0 +1,80 @@
+#include
+#include
+#include
+
+#include
+#include
+
+// hashing
+// input N*F float tensor, pointer to output N'*F int64 tensor, N*1 count
+// tensor, N*1 index tensor
+template
+__global__ void voxelize_forward_kernel(int N, int c, int s,
+ const scalar_t *__restrict__ data,
+ const int *__restrict__ idx,
+ const int *__restrict__ counts,
+ scalar_t *__restrict__ out) {
+ int index = blockDim.x * blockIdx.x + threadIdx.x;
+ int i = index / c;
+ int j = index % c;
+ if (i < N) {
+ int pos = idx[i];
+ if (pos < 0 || pos >= s || counts[pos] == 0) return;
+ atomicAdd(&out[pos * c + j], data[i * c + j] / float(counts[pos]));
+ }
+}
+
+template
+__global__ void voxelize_backward_kernel(int N, int c, int s,
+ const scalar_t *__restrict__ top_grad,
+ const int *__restrict__ idx,
+ const int *__restrict__ counts,
+ scalar_t *__restrict__ bottom_grad) {
+ int index = blockDim.x * blockIdx.x + threadIdx.x;
+ int i = index / c;
+ int j = index % c;
+ if (i < N) {
+ int pos = idx[i];
+ if (pos < 0 || pos >= s || counts[pos] == 0) return;
+ atomicAdd(&bottom_grad[i * c + j],
+ top_grad[pos * c + j] / float(counts[pos]));
+ }
+}
+
+at::Tensor voxelize_forward_cuda(const at::Tensor inputs, const at::Tensor idx,
+ const at::Tensor counts) {
+ int N = inputs.size(0);
+ int c = inputs.size(1);
+ int N1 = counts.size(0);
+
+ at::Tensor out =
+ torch::zeros({N1, c}, at::device(idx.device()).dtype(inputs.dtype()));
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ inputs.type(), "voxelize_forward_cuda", ([&] {
+ voxelize_forward_kernel<<>>(
+ N, c, N1, inputs.data_ptr(), idx.data_ptr(),
+ counts.data_ptr(), out.data_ptr