# Reading note for "RandLA-Net"

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch_points_kernels import knn

In [7]:
rand_idx = np.random.choice(100, 10)
print(f"rand_idx: {rand_idx}")

rand_idx: [81 47 54 54 91 49 95  3 23 73]


In [10]:
# fake data
torch.manual_seed(0)

BATCH = 8
NUM_POINT = 2**4
D_IN = 3+4
NUM_NEIGHBOUR = 4

pc = torch.rand(BATCH, NUM_POINT, D_IN)
print(f"pc: {pc.shape}")
pc_xyz = pc[:, :, :3]
print(f"pc_xyz: {pc_xyz.shape}")

pc: torch.Size([8, 16, 7])
pc_xyz: torch.Size([8, 16, 3])
idx: torch.Size([8, 16, 4])
extended_idx: torch.Size([8, 3, 16, 4])


In [11]:
idx, dist = knn(pc_xyz.cpu().contiguous(), pc_xyz.cpu().contiguous(), NUM_NEIGHBOUR)
print(f"idx: {idx.shape}")

idx: torch.Size([8, 16, 4])
extended_idx: torch.Size([8, 3, 16, 4])


In [19]:
extended_idx = idx.unsqueeze(1).repeat(1, 3, 1, 1)
print(f"extended_idx: {extended_idx.shape}")
extended_dist = dist.unsqueeze(1).repeat(1, 3, 1, 1)
print(f"extended_dist: {extended_dist.shape}")
extended_xyz = pc_xyz.transpose(-2, -1).unsqueeze(-1).repeat(1, 1, 1, NUM_NEIGHBOUR)
print(f"extended_xyz: {extended_xyz.shape}")
neighbour = extended_xyz.gather(dim=2, index=extended_idx)
print(f"neighbour: {neighbour.shape}")

extended_idx: torch.Size([8, 3, 16, 4])
extended_dist: torch.Size([8, 3, 16, 4])
extended_xyz: torch.Size([8, 3, 16, 4])
neighbour: torch.Size([8, 3, 16, 4])


In [23]:
class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, bn=False, activation_fn=None):
        super(MLP, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels) if bn else None
        self.activation_fn = activation_fn

    def forward(self, x):
        """
            Forward pass of the network

            Parameters
            ----------
            input: torch.Tensor, shape (B, d_in, N, K)

            Returns
            -------
            torch.Tensor, shape (B, d_out, N, K)
        """

        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.activation_fn is not None:
            x = self.activation_fn(x)

        return x

In [31]:
class LocalSpatialEncoding(nn.Module):
    def __init__(self, out_channels):
        super(LocalSpatialEncoding, self).__init__()

        self.mlp = MLP(in_channels=3+3+3+1, out_channels=out_channels, bn=True, activation_fn=nn.ReLU())

    def forward(self, point, knn_output):
        '''
        Input:
            point: [B, N, 3+d]
            knn_output: tuple
        Output:
            neighbouring_feat: [B, 2*d, N, K]

        '''

        idx, dist = knn_output  # [B, N, K]
        B, N, K = idx.size()

        xyz = point[:, :, :3]  # [B, N, 3]
        feat = point[:, :, 3:]  # [B, N, d]

        extended_idx = idx.unsqueeze(1).repeat(1, 3, 1, 1)  # [B, 3, N, K]
        extended_xyz = xyz.transpose(-2, -1).unsqueeze(-1).repeat(1, 1, 1, K)  # [B, 3, N, K]
        neighbour = extended_xyz.gather(dim=2, index=extended_idx)  # [B, 3, N, K]
        concat_xyz = torch.cat((extended_xyz, neighbour, extended_xyz - neighbour, dist.unsqueeze(1)), dim=1) # [B, 10, N, K]
        relative_pnt_pos_enc = self.mlp(concat_xyz)  # [B, out_channels, N, K]
        output = torch.cat((relative_pnt_pos_enc, feat.transpose(-2, -1).unsqueeze(-1).repeat(1, 1, 1, K)), dim=1)

        return output

In [33]:
LoSE = LocalSpatialEncoding(out_channels=4)
knn_output = knn(pc_xyz.cpu().contiguous(), pc_xyz.cpu().contiguous(), NUM_NEIGHBOUR)
lose_feat = LoSE(pc, knn_output)
print(f"LoSE features: {lose_feat.shape}")

LoSE features: torch.Size([8, 8, 16, 4])


In [34]:
class AttentivePooling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentivePooling, self).__init__()

        self.score_fn = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.Softmax(dim=-2)
        )

        self.mlp = MLP(in_channels=in_channels, out_channels=out_channels)

    def forward(self, x):
        '''
        Input:
            [B, in_channels, N, K]
        Output:
            [B, out_channels, N, 1]
        '''
        scores = self.score_fn(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # [B, N, K, in_channels] -> [B, in_channels, N, K]
        feat = torch.sum(scores*x, dim=-1, keepdim=True)  # [B, in_channels, N, 1]
        feat = self.mlp(feat)  # [B, out_channels, N, 1]

        return feat

In [36]:
AttPooling = AttentivePooling(in_channels=8, out_channels=32)
agg_feat = AttPooling(lose_feat)
print(f"aggregated features: {agg_feat.shape}")

aggregated features: torch.Size([8, 32, 16, 1])
