# Reading note "Sparse Voxel-Graph Attention Network (SVGA-Net)"


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Fake n points according to the paper [n, 4]
torch.manual_seed(0)
# B = 64
n = 100
dim_v = 3
dim_s = 1
dim_D = dim_v+dim_s
PCs = torch.rand(n, dim_D)
PCs.shape

torch.Size([100, 4])

In [3]:

def farthest_point_sample(xyz, N):
    """
    Input:
        xyz: point cloud data, [n, 3]
        N: number of samples
    Return:
        P_idx: sampled point cloud index, [N]
    """
    
    # 采样点矩阵（N）
    P_idx = torch.zeros(N) 
    # print(P.shape)

    # 采样点到所有点距离（n）
    distance = torch.ones(n) * 1e10
    # print(distance.shape)

    # 最远点，初试时随机选择一点点
    farthest = torch.randint(0,n,(1,))
    # print(farthest.shape)
 
    for i in range(N):

        # 更新第i个最远点
        P_idx[i] = farthest

        # 取出这个最远点的xyz坐标
        centroid = xyz[farthest, :]

        # 计算点集中的所有点到这个最远点的欧式距离
        dist = torch.sum((xyz - centroid) ** 2, -1)

        mask = dist < distance

        # 更新distances，记录样本中每个点距离所有已出现的采样点的最小距离
        distance[mask] = dist[mask]
        
        # 返回最远点索引
        farthest = torch.max(distance, -1)[1]
 
    return P_idx

In [4]:
xyz = PCs[:,:3]
N = 50
P_idx = farthest_point_sample(xyz, N).long()
# P_idx.sort()[0]
# P_idx.shape

In [6]:
P = PCs[P_idx, :]
print(P.shape)

torch.Size([50, 4])


In [7]:
def sqrt_dist(src, dst):
    '''
    Calculate Squared Euclidean distance between each two points
    
    Input: 
        src: source points [M, 3]
        dst: target points [N, 3]
    Output:
        sqrt_dist: per-point Squared Euclidean distance [M, N]
    '''
    M, _ = src.shape
    N, _ = dst.shape
    # print(N, M)
    
    # Squared Euclidean distance between every two vertices
    sqrt_dist = -2*torch.matmul(src, dst.permute(-1,-2))
    # print(sqrt_dist.shape)
    sqrt_dist += torch.sum(src**2, dim=-1).view(M, -1)
    # print(torch.sum(src**2, dim=-1).view(M, -1).shape)
    sqrt_dist += torch.sum(dst**2, dim=-1).view(-1, N)
    # print(torch.sum(dst**2, dim=-1).view(-1,N).shape)
    return sqrt_dist

In [8]:
def spherical_voxels(P_xyz, PC_xyz, radius):
    '''
    Subdivide the PC into N 3D spherical voxels B_idx:[N, t]

    Input:
        P_xyz: [N, 3]
        PC_xyz: [n, 3]
        radius: 
    Output:
        B_idx: spherical voxel indices [N, t]
    '''
    N, _ = P_xyz.shape
    n, _ = PC_xyz.shape

    # Squared Euclidean distance
    dist = sqrt_dist(P_xyz, PC_xyz)
    # print(dist)
    
    dist = dist.sort(dim=-1)
    # print(dist)
    # B_idx = torch.ones(N, n, dtype=torch.long)*n*2
    # print(B_idx.shape)
    mask = dist[0] < radius**2
    # print(mask.shape)
    B_idx = dist[1]
    # print(B_idx)

    return B_idx, mask

In [9]:
B_idx, mask = spherical_voxels(P[:,:3], PCs[:,:3], radius=0.5)
# print(B_idx.shape)
# print(mask.shape)
# subdivide the 3D space into N 3D spherical voxels B:[N, t, 4]
# B = torch.zeros(N, n, dim_D)
# print(B.shape)

B = PCs[B_idx, :]
B = B[:, :25, :]
print(B.shape)
# print(B[mask,:].shape)

torch.Size([50, 25, 4])


In [10]:
class LocalPointWiseFeat(nn.Module):
    def __init__(self, in_channel):
        super(LocalPointWiseFeat, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(in_channel, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )

    def forward(self, x):
        local_feat = self.linear(x)

        return local_feat

In [13]:
net = LocalPointWiseFeat(dim_D)

# local_feat: [50, 25, 128]
local_feat = net(B)
# print(local_feat.shape)

# att_score: [50, 25, 25]
att_score = torch.exp(torch.matmul(local_feat, local_feat.permute(0, 2, 1)))
# print(att_score.shape)

# [50, 25, 25]
att_score = att_score.softmax(-1)
# print(att_score.shape)

In [14]:
# B[:, :1, :]: cetners of each voxel sphere [50, 1, 4]

# [50, 1, 128]
F_g = net(B[:, :1, :])
# print(F_g.shape)

# construct KNN graph for each center, K = 3
K = 3
# [50, 3, 128]
KNN_graph = net(B[:, 1:K+1, :])
# print(KNN_graph.shape)

# [50, 1, 3]
beta_m = F_g.matmul(KNN_graph.permute(0,2,1))
# print(beta_m.shape)

# [50, 1, 3]
beta_m = torch.softmax(beta_m, dim=-1)
# print(beta_m.shape)
# beta_m = torch.sum(beta_m, dim=-1)
# print(beta_m.shape)
# print(beta_m.view(50, 1, 3, 1).repeat(1,1,1,128).shape)
# print(local_feat.unsqueeze(-2).repeat(1,1,3,1).shape)

# [50, 25, 128]
new_local_feat = torch.sum(beta_m.view(50, 1, 3, 1).repeat(1,1,1,128)*local_feat.unsqueeze(-2).repeat(1,1,3,1), dim=-2) +  att_score.matmul(local_feat)
print(new_local_feat.shape)

torch.Size([50, 25, 128])


In [15]:
voxel_feat = F.max_pool1d(new_local_feat.permute(0, 2, 1), kernel_size=25)
print(voxel_feat.shape)

torch.Size([50, 128, 1])
