In [1]:
import torch

In [2]:
def square_distance(src, dst):
    '''
    Compute the distances between each point from src to dst
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: [B, N, C]
        dst: [B, M, C]
    Return:
        dist: [B, N, M]
    '''
    B, N, C = src.shape
    _, M, _ = dst.shape
    # [B, N, M]
    # -2*(xn * xm + yn * ym + zn * zm)
    dist = -2*torch.matmul(src, dst.permute(0,2,1))
    # sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn
    dist += torch.sum(src**2, dim=-1).view(B,N,1)
    # sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm
    dist += torch.sum(dst**2, dim=-1).view(B,1,M)
    
    return dist

In [3]:
src = torch.arange(20).view(4,5,1).repeat(1,1,3).long()
dst = torch.arange(24).view(4,6,1).repeat(1,1,3).long()
dist = torch.zeros(4,5,6)
# dist = -2*torch.matmul(src, dst.permute(0,2,1))
# dist += torch.sum(src**2, dim=-1).view(4,5,1)
# dist += torch.sum(dst**2, dim=-1).view(4,1,6)
dist = square_distance(src, dst)
dist.shape

torch.Size([4, 5, 6])

In [4]:
def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, C]
        npoint: number of centroids
    Return:
        centered_idx: sampled pointcloud index, [B, npoint]
    """
    B, N, C = xyz.shape
    # print(f'B: {B}\nN: {N}\nC: {C}')
    centered_idx = torch.zeros(B, npoint, dtype=torch.long)
    distance = torch.ones(B, N)*1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long)
    batch_indices = torch.arange(0, B, dtype=torch.long)
    for i in range(npoint):
        # [B, npoint]
        centered_idx[:,i] = farthest
        # [B, 1, C]
        centroid = xyz[batch_indices, farthest, :].view(B,1,C)
        dist = square_distance(xyz, centroid).squeeze()
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = distance.max(dim=-1)[1]

    return centered_idx

In [5]:
xyz = torch.randn(4,20,3)
npoint = 5

centered_idx = farthest_point_sample(xyz, npoint)
centered_idx.shape

torch.Size([4, 5])

In [6]:
def index_points(points, idx):
    '''
    Input:
        points: [B, N, C or D]
        idx: [B, npoint] or [B, npoint, nsample]...
    Return:
        new_points: [B, npoint, C or D] or [B, npoint, nsample, C or D]...
    '''
    B, N, C = points.shape
    view_list = list(idx.shape)
    view_list[1:] = [1]*(len(view_list)-1)
    repeat_list = list(idx.shape)
    repeat_list[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).view(view_list).repeat(repeat_list)
    new_points = points[batch_indices, idx, :]
    return new_points

In [22]:
points = torch.randn(4,20,9)
npoint = 5
idx = farthest_point_sample(points, 5)

new_points = index_points(points, idx)
new_points.shape

torch.Size([4, 5, 9])

In [8]:
def query_ball_point(radius, nsample, xyz, centered_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, C]
        centered_xyz: query points, [B, npoint, C]
    Return:
        grouped_idx: grouped points index, [B, npoint, nsample]
    """
    B, N, C = xyz.shape
    _, npoint, _ = centered_xyz.shape
    # [B, npoint, N]
    # initialize the grouped_idx containing all N points for each sampled point
    grouped_idx = torch.arange(N, dtype=torch.long).view(1,1,N).repeat(B, npoint, 1)
    # [B, npoint, N]
    # record the square distance from the sampled points to all N points
    distance = square_distance(centered_xyz, xyz)
    # [B, npoint, N]
    # a mask finding all points outside the local region (ball area radius^2)
    mask_1 = distance > radius**2
    # [B, npoint, N]
    # assign a large number(1e10) to the outside points
    distance[mask_1] = 1e10
    # [B, npoint, nsample]
    # select nsample sorted grouped_idx by distances, it may contain outside points
    grouped_idx = distance.sort(dim=-1)[1][:,:,:nsample]
    # [B, npoint, nsample]
    # duplicate the first point in each local region
    grouped_first = grouped_idx[:,:,0].view(B, npoint, 1).repeat(1,1,nsample)
    # [B, npoint, nsample]
    # a mask finding all outside points in the sampled distance matrix
    mask_2 = distance.sort(dim=-1)[0][:,:,:nsample]==1e10
    # [B, npoint, nsample]
    # assign the first point to all outside points in grouped_idx
    grouped_idx[mask_2] = grouped_first[mask_2]

    return grouped_idx

In [9]:
torch.manual_seed(4*20*3)
xyz = torch.randn(4,20,3)
npoint = 5
nsample = 10
radius = 2
centered_idx = farthest_point_sample(xyz, npoint)
centered_xyz = index_points(xyz, centered_idx)

grouped_idx = query_ball_point(radius, nsample, xyz, centered_xyz)
grouped_xyz = index_points(xyz, grouped_idx)
grouped_idx

tensor([[[11,  1, 17, 14, 11, 11, 11, 11, 11, 11],
         [19,  9, 12,  7, 13, 15,  4,  6, 16, 19],
         [10,  3, 10, 10, 10, 10, 10, 10, 10, 10],
         [ 8,  1,  8,  8,  8,  8,  8,  8,  8,  8],
         [18,  4, 16,  0,  3,  5,  6, 14, 15, 12]],

        [[ 8, 14, 13,  4, 15,  5,  0, 11, 18, 19],
         [12,  2, 16,  6, 12, 12, 12, 12, 12, 12],
         [ 6, 16,  7, 14, 18, 13, 12,  5,  2,  8],
         [ 3, 15,  9,  4,  8, 11,  0, 17,  5, 10],
         [17,  1, 10,  0, 19,  8, 13,  4,  3, 11]],

        [[11,  6, 13,  7,  1, 16,  5, 11, 11, 11],
         [17, 18,  8, 17, 17, 17, 17, 17, 17, 17],
         [14, 15, 19, 10,  2,  5, 14, 14, 14, 14],
         [12,  2,  5, 19, 15, 12, 12, 12, 12, 12],
         [ 3,  4,  7,  3,  3,  3,  3,  3,  3,  3]],

        [[13, 10, 12, 15, 13, 13, 13, 13, 13, 13],
         [ 2,  0, 16, 19,  8,  5,  2,  2,  2,  2],
         [ 1,  7,  1,  1,  1,  1,  1,  1,  1,  1],
         [14, 11,  4, 17,  9,  8, 18, 12,  0, 14],
         [ 3, 10,  3,  3,

In [10]:
def sample_and_group(npoint, radius, nsample, xyz, features):
    """
    Input:
        npoint: Number of point for FPS
        radius: Radius of ball query
        nsample: Number of point for each ball query
        xyz: input points position data, [B, N, C]
        features: input points data, [B, N, D]
    Return:
        centered_points: sampled points data, [B, npoint, C+D]
        grouped_points: sampled points data, [B, npoint, nsample, C+D]
    """
    B, N, C = xyz.shape
    # [B, npoint]
    centered_idx = farthest_point_sample(xyz, npoint)
    # [B, npoint, C]
    centered_xyz = index_points(xyz, centered_idx)
    # [B, npoint, nsample]
    grouped_idx = query_ball_point(radius, nsample, xyz, centered_xyz)
    # [B, npoint, nsample, C]
    grouped_xyz = index_points(xyz, grouped_idx)
    # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - centered_xyz.view(B, npoint, 1, C)
    if features is not None:
        centered_features = index_points(features, centered_idx)
        centered_points = torch.cat((centered_xyz, centered_features), dim=-1)
        grouped_features = index_points(features, grouped_idx)
        grouped_points = torch.cat((grouped_xyz_norm, grouped_features), dim=-1)
    else:
        centered_points = centered_xyz
        grouped_points = grouped_xyz
    return centered_points, grouped_points

In [25]:
torch.manual_seed(4*20*(3+6))
xyz = torch.randn(4,20,3)
features = torch.randn(4,20,6)
B = 4
npoint = 5
nsample = 10
radius = 2

centered_points, grouped_points = sample_and_group(npoint, radius, nsample, xyz, features)
grouped_points.shape


torch.Size([4, 5, 10, 9])

In [12]:
# take all points as a group
# npoint = 1
# nsample = N
def sample_and_group_all(xyz, features):
    """
    Input:
        xyz: input points position data, [B, N, C]
        features: input points data, [B, N, D]
    Return:
        centered_points: sampled points position data, [B, 1, C+D]
        grouped_points: sampled points data, [B, 1, N, C+D]
    """
    B, N, C = xyz.shape
    npoint = 1
    nsample = N
    centered_idx = farthest_point_sample(xyz, npoint)
    centered_xyz = index_points(xyz, centered_idx)
    grouped_xyz = xyz.view(B, npoint, nsample, C)
    if features is not None:
        centered_features = index_points(features, centered_idx)
        centered_points = torch.cat((centered_xyz, centered_features), dim=-1)
        grouped_features = features.view(B, npoint, nsample, -1)
        grouped_points = torch.cat((grouped_xyz, grouped_features), dim=-1)
    else:
        centered_points = centered_xyz
        grouped_points = grouped_xyz
    return centered_points, grouped_points

In [13]:
torch.manual_seed(4*20*(3+6))
xyz = torch.randn(4,20,3)
features = torch.randn(4,20,6)
radius = 2

centered_points, grouped_points = sample_and_group_all(xyz, features)
grouped_points.shape

torch.Size([4, 1, 20, 9])

In [14]:
import torch.nn as nn
import torch.nn.functional as F

In [15]:
class GraphAttention(nn.Module):
    def __init__(self, all_channel, feature_channel, dropout, alpha):
        super(GraphAttention, self).__init__()
        self.a = nn.Parameter(torch.zeros(all_channel, feature_channel))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.dropout = dropout
        self.alpha = alpha
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, centered_points, grouped_points):
        '''
        Input:
            centered_points: centered point position data and feature [B, npoint, C+D]
            grouped_points: sampled points position data and feature [B, npoint, nsample, C+D]
        Return:
            centered_xyz: centered point position data [B, npoint, C]
            centered_h: learned centered_features [B, npoint, D]
        '''
        # [B, npoint, C]
        centered_xyz = centered_points[...,:3]
        # [B, npoint, nsample, C]
        grouped_xyz = grouped_points[...,:3]
        # [B, npoint, D]
        centered_features = centered_points[...,3:]
        # [B, npoint, nsample, D]
        grouped_features = grouped_points[...,3:]

        B, npoint, C = centered_xyz.shape
        _, _, nsample, _ = grouped_xyz.shape
        _, _, D = centered_features.shape

        # # compute the position difference between the centered point and its neighbouring points, [B, npoint, nsample, C]
        # delta_p = grouped_xyz - centered_xyz.view(B, npoint, 1, C)
        # # compute the feature difference between the centered point and its neighbouring points, [B, npoint, nsample, D]
        # delta_h = grouped_features - centered_features.view(B, npoint, 1, D)
        # # concatenate the position data of delta_p and the feature of delta_h, [B, npoint, nsample, C+D]
        # delta_p_concat_delta_h = torch.cat((delta_p, delta_h), dim=-1)

        # compute the position and feature difference between the centered point and its neighbouring points, [B, npoint, nsample, C+D]
        delta_points = grouped_points - centered_points.view(B, npoint, 1, -1)
        # compute attention scores between each centered point and its neighbouring points by a MLP: [C+D, D]
        # [B, npoint, nsample, D]
        attention = self.leakyrelu(torch.matmul(delta_points, self.a))
        attention = F.dropout(attention, self.dropout, training=self.training)
        # normalize each centered point's attention scores with all its neighbouring points
        attention = attention.softmax(2)
        # return weighted sum of each centered points by element-wise product of the computed attention scores and its responding neighbouring points' features, the output works as the centered_features for the next layer
        # [B, npoint, D]
        centered_h = torch.sum(torch.mul(attention, grouped_features), dim=2)

        return centered_h

In [16]:
torch.manual_seed(4*20*(3+6))
xyz = torch.randn(4,20,3)
features = torch.randn(4,20,6)
npoint = 5
nsample = 10
radius = 2

all_channel = xyz.shape[-1] + features.shape[-1]
feature_channel = features.shape[-1]
print(f'all_channel: {all_channel}\nfeature_channel: {feature_channel}\n')

centered_points, grouped_points = sample_and_group(npoint, radius, nsample, xyz, features)
centered_xyz = centered_points[...,:3]
grouped_xyz = grouped_points[...,:3]
centered_features = centered_points[...,3:]
grouped_features = grouped_points[...,3:]

print(f"centered_xyz: {centered_xyz.shape}\ngrouped_xyz: {grouped_xyz.shape}\ncentered_features: {centered_features.shape}\ngrouped_features: {grouped_features.shape}")


all_channel: 9
feature_channel: 6

centered_xyz: torch.Size([4, 5, 3])
grouped_xyz: torch.Size([4, 5, 10, 3])
centered_features: torch.Size([4, 5, 6])
grouped_features: torch.Size([4, 5, 10, 6])


In [17]:
GAC = GraphAttention(all_channel, feature_channel, dropout=0.5, alpha=0.2)
centered_h = GAC(centered_points, grouped_points)
centered_h.shape

torch.Size([4, 5, 6])

In [18]:
class GraphAttentionConvLayer(nn.Module):
    def __init__(self,
                 npoint,
                 radius,
                 nsample,
                 feature_channel,
                 mlp,
                 group_all,
                 dropout=0.6,
                 alpha=0.2):
        '''
        Input:
                npoint: Number of point for FPS sampling
                radius: Radius for ball query
                nsample: Number of point for each ball query
                feature_channel: the dimention of the input feature channel
                mlp: A list for mlp input-output channel, such as [64, 64, 128]
                group_all: bool type for group_all or not
        '''

        super(GraphAttentionConvLayer, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.group_all = group_all
        self.dropout = dropout
        self.alpha = alpha
        self.conv2ds = nn.ModuleList()
        self.bn2ds = nn.ModuleList()
        last_channel = feature_channel
        for out_channel in mlp:
            self.conv2ds.append(nn.Conv2d(last_channel, out_channel, 1))
            self.bn2ds.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.GAC = GraphAttention(3 + last_channel, last_channel, dropout=self.dropout, alpha=self.alpha)

    def forward(self, points):
        '''
        Input:
            points: position data and features [B, N, C+D]
        Return:
            sampled_points: sampled position data and aggregated features [B, npoint, 3+D']
        '''
        xyz = points[...,:3]
        features = points[...,3:]
        centered_points, grouped_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, features)
        centered_xyz = centered_points[...,:3]
        grouped_xyz = grouped_points[...,:3]
        centered_features = centered_points[...,3:]
        grouped_features = grouped_points[...,3:]

        centered_features = centered_features.unsqueeze(2).permute(0,3,2,1)
        grouped_features = grouped_features.permute(0,3,2,1)
        for i, conv2d in enumerate(self.conv2ds):
            bn2d = self.bn2ds[i]
            centered_features = F.relu(bn2d(conv2d(centered_features)))
            grouped_features = F.relu(bn2d(conv2d(grouped_features)))
        
        centered_features = centered_features.permute(0,3,2,1).squeeze()
        grouped_features = grouped_features.permute(0,3,2,1)
        centered_points = torch.cat((centered_xyz, centered_features), dim=-1)
        grouped_points = torch.cat((grouped_xyz, grouped_features), dim=-1)
        sampled_features = self.GAC(centered_points, grouped_points)
        sampled_points = torch.cat((centered_xyz, sampled_features), dim=-1)
        return sampled_points

In [19]:
torch.manual_seed(4*20*(3+6))
xyz = torch.randn(4,20,3)
features = torch.randn(4,20,6)
points = torch.cat((xyz, features), dim=-1)
npoint = 5
nsample = 10
radius = 2
feature_channel = features.shape[-1]

GAC_Layer1 = GraphAttentionConvLayer(npoint, radius, nsample, feature_channel, [16, 32], False)
GAC_Layer2 = GraphAttentionConvLayer(4, 1, 5, 32, [32, 64], False)
layer1_points = GAC_Layer1(points)
layer2_points = GAC_Layer2(layer1_points)
layer2_points.shape

torch.Size([4, 4, 67])