In [1]:
import torch
from torch import nn

from typing import List

if False:
    from pointnet2.pointnet2_utils import (
        ball_query, gather_operation, furthest_point_sample)

In [2]:
"""
Note: To compute voxel_size, multiply base size (specified in SECOND config) by downsampling ratio.
"""

'\nNote: To compute voxel_size, multiply base size (specified in SECOND config) by downsampling ratio.\n'

In [3]:
class cfg:
    n_keypoints = 2048
    strides = [1, 2, 4, 8]

In [67]:
class VSA_MLP(nn.Module):
    """
    Represents G in equation 2.
    """

    def __init__(self, C_in: int, channels: List):
        """
        C_in: incoming channels.
        channels: length-3 list of channels in each layer.
        """
        super(self, VSA_MLP).__init__()
        self.layers = nn.Sequential(
            nn.Linear(C_in, channels[0], bias=True),
            nn.BatchNorm1d(channels[0]),
            nn.ReLU(inplace=True),
            nn.Linear(channels[0], channels[1], bias=True),
            nn.BatchNorm1d(channels[1]),
            nn.ReLU(inplace=True),
            nn.Linear(channels[1], channels[2], bias=True),
        )

    def forward(self, voxel_set: torch.Tensor):
        x = self.layers(voxel_set)
        x = x.max(2)
        return x


class VoxelSetAbstraction(nn.Module):
    """
    For each keypoint, convert its location to
    continuous voxel index coordinates. Then fetch
    voxels within ball query.
    """

    def __init__(self, radius: float, nsample: int, voxel_size: torch.Tensor, volume_offset: torch.Tensor):
        """
        radius: maximum distance for ball query, measured in raw point cloud coordinates.
        nsample: maximum number of neighbors to return in ball query.
        voxel_size: length-3 tensor describing size of atomic voxel, accounting for stride.
        volume_offset: length-3 tensor describing coordinate offset of voxel grid.
        """
        super(self, VoxelSetAbstraction).__init__()
        self.radius = radius
        self.nsample = nsample
        self.voxel_size = voxel_size
        self.volume_offset = volume_offset

    def to_raw_coordinates(self, voxel_index: torch.Tensor):
        """
        voxel_index: shape (B, Tk, 3) array of coordinates
        return: shape (B, Tk, 3) array of locations in raw coordinates.
        """
        location = (voxel_index * self.voxel_size) + self.volume_offset
        return location

    def get_neighbors(self, keypoint_location: torch.Tensor, voxel_feature: torch.Tensor, voxel_location: torch.Tensor):
        neighbor_index = ball_query(self.radius, self.nsample, voxel_location, keypoint_location)
        neighbor_feature = gather_operation(voxel_feature, neighbor_index)
        neighbor_location = gather_operation(voxel_location, neighbor_index)
        return neighbor_feature, neighbor_location

    def combine_features(self, neighbor_feature: torch.Tensor, neighbor_location: torch.Tensor, keypoint_location: torch.Tensor):
        """Form neighborhood feature set (equation 1)."""
        offset_location = neighbor_location - keypoint_location
        combined_feature = torch.cat((neighbor_feature, offset_location), dim=2)
        return combined_feature

    def forward(self, keypoint_location: torch.Tensor, voxel_feature: torch.Tensor, voxel_index: torch.Tensor):
        voxel_location = self.to_raw_coordinates(voxel_index)
        neighbor_feature, neighbor_location = self.ball_query(keypoint_location, voxel_feature, voxel_location)
        feature = self.combine_features(neighbor_feature, neighbor_location, keypoint_location)
        return feature


class PV_RCNN(nn.Module):
    """
    Carry out feature computation described in PV-RCNN paper.
    """
    
    def __init__(self, num_keypoint: int):
        """
        num_keypoint: number of keypoints
        """
        super(self, PV_RCNN).__init__()
        self.num_keypoint = num_keypoint
        pass
    
    def forward(self, raw_point):
        keypoint_index = furthest_point_sample(raw_point, self.num_keypoint)
        keypoint = gather_operation(raw_point, keypoint_index)

In [76]:
B, C = 2, 18000
x = torch.randn((B, C, 4), dtype=torch.float32)