In [None]:
import torch
import numpy as np
from torch import nn
from typing import List
import sys

import spconv
from pointnet2.pointnet2_utils import ball_query, gather_operation, furthest_point_sample

In [2]:
"""
Note: To compute voxel_size, multiply base size (specified in SECOND config) by downsampling ratio.
"""

'\nNote: To compute voxel_size, multiply base size (specified in SECOND config) by downsampling ratio.\n'

In [22]:
class PvrcnnConfig:
    n_keypoints = 2048
    strides = [1, 2, 4, 8]

In [54]:
class VSA_MLP(nn.Module):
    """
    Represents G in equation 2.
    """

    def __init__(self, C_in: int, channels: List):
        """
        C_in: incoming channels.
        channels: length-3 list of channels in each layer.
        """
        super(VSA_MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(C_in, channels[0], bias=True),
            nn.BatchNorm1d(channels[0]),
            nn.ReLU(inplace=True),
            nn.Linear(channels[0], channels[1], bias=True),
            nn.BatchNorm1d(channels[1]),
            nn.ReLU(inplace=True),
            nn.Linear(channels[1], channels[2], bias=True),
        )

    def forward(self, voxel_set: torch.Tensor):
        x = self.layers(voxel_set)
        x = x.max(2)
        return x


class VoxelSetAbstraction(nn.Module):
    """
    For each keypoint, convert its location to
    continuous voxel index coordinates. Then fetch
    voxels within ball query.
    """

    def __init__(self, radius: float, nsample: int, voxel_size: torch.Tensor, volume_offset: torch.Tensor):
        """
        radius: maximum distance for ball query, measured in raw point cloud coordinates.
        nsample: maximum number of neighbors to return in ball query.
        voxel_size: length-3 tensor describing size of atomic voxel, accounting for stride.
        volume_offset: length-3 tensor describing coordinate offset of voxel grid.
        """
        super(VoxelSetAbstraction, self).__init__()
        self.radius = radius
        self.nsample = nsample
        self.voxel_size = voxel_size
        self.volume_offset = volume_offset

    def to_raw_coordinates(self, voxel_index: torch.Tensor):
        """
        voxel_index: shape (B, Tk, 3) array of coordinates
        return: shape (B, Tk, 3) array of locations in raw coordinates.
        """
        location = (voxel_index * self.voxel_size) + self.volume_offset
        return location

    def get_neighbors(self, keypoint_location: torch.Tensor, voxel_feature: torch.Tensor, voxel_location: torch.Tensor):
        neighbor_index = ball_query(self.radius, self.nsample, voxel_location, keypoint_location)
        neighbor_feature = gather_operation(voxel_feature, neighbor_index)
        neighbor_location = gather_operation(voxel_location, neighbor_index)
        return neighbor_feature, neighbor_location

    def combine_features(self, neighbor_feature: torch.Tensor, neighbor_location: torch.Tensor, keypoint_location: torch.Tensor):
        """Form neighborhood feature set (equation 1)."""
        offset_location = neighbor_location - keypoint_location
        combined_feature = torch.cat((neighbor_feature, offset_location), dim=2)
        return combined_feature

    def forward(self, keypoint_location: torch.Tensor, voxel_feature: torch.Tensor, voxel_index: torch.Tensor):
        voxel_location = self.to_raw_coordinates(voxel_index)
        neighbor_feature, neighbor_location = self.ball_query(keypoint_location, voxel_feature, voxel_location)
        feature = self.combine_features(neighbor_feature, neighbor_location, keypoint_location)
        return feature


class PV_RCNN(nn.Module):
    """
    Carry out feature computation described in PV-RCNN paper.
    """
    
    def __init__(self, num_keypoint: int):
        """
        num_keypoint: number of keypoints
        """
        super(PV_RCNN, self).__init__()
        self.num_keypoint = num_keypoint
        pass
    
    def forward(self, raw_point):
        keypoint_index = furthest_point_sample(raw_point, self.num_keypoint)
        keypoint = gather_operation(raw_point, keypoint_index)

In [53]:
class CNN_3D(nn.Module):
    """
    Placeholder sparse 3D CNN with three blocks:
    block_0: [1600, 1200, 41] -> [800, 600, 21]
    block_1: [800, 600, 21]   -> [400, 300, 11]
    block_2: [400, 300, 11]   -> [200, 150, 5]
    """
    
    def __init__(self, C_in, shape, return_dense=False):
        super(CNN_3D, self).__init__()
        self.blocks = spconv.SparseSequential(
            spconv.SparseConv3d(C_in, 16, 3, 2, padding=1, bias=False),
            spconv.SparseConv3d(16, 32, 3, 2, padding=1, bias=False),
            spconv.SparseConv3d(32, 64, 3, 2, padding=1, bias=False),
        )
        self.shape = shape
        self.return_dense = return_dense

    def forward(self, features, coordinates, batch_size):
        coordinates = coors.int()
        x0 = spconv.SparseConvTensor(features, coordinates, self.shape, batch_size)
        x1 = self.blocks[0](x0)
        x2 = self.blocks[1](x1)
        x3 = self.blocks[2](x2)
        x = [x0, x1, x2, x3]
        x = [xi.dense() for xi in x] if self.return_dense else x
        return x

In [81]:
cfg = PvrcnnConfig()

voxel_size = np.r_[0.05, 0.05, 0.1]
grid_bounds = np.r_[0, -40, -3, 70.4, 40, 1]

voxel_generator = spconv.utils.VoxelGenerator(
    voxel_size=voxel_size, point_cloud_range=grid_bounds,
    max_num_points=5, max_voxels=40000,
)

points = np.fromfile('./sample.bin', dtype=np.float32).reshape(-1, 4)
voxels, coords, voxel_population = voxel_generator.generate(points)

from_numpy = lambda x: torch.from_numpy(x).unsqueeze(0).cuda()
points, voxels, coords, voxel_population = map(
    from_numpy, (points, voxels, coords, voxel_population))
#indices_keypoint = furthest_point_sample(points, cfg.n_keypoints)

In [56]:
cnn_3d = CNN_3D(C_in=4, shape=[1600, 1200, 41])

In [97]:
grid_shape = (grid_bounds[3:] - grid_bounds[:3]) / voxel_size
grid_shape = (grid_shape[::-1] + [1, 0, 0]).astype(np.int32) # [1408, 1600, 40] -> [41, 1600, 1408]
x_sparse = spconv.SparseConvTensor(voxels, coords, grid_shape, batch_size=1)

In [99]:
x_sparse.spatial_shape

#ret = ret.dense()
#N, C, D, H, W = ret.shape
#ret = ret.view(N, C * D, H, W)

array([  41, 1600, 1408], dtype=int32)