In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np

In [4]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    计算每两个点之间的欧几里得距离。
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C]
    """
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)

In [5]:
d_model=512
d_points = 32
fc1 = nn.Linear(d_points, d_model)
fc2 = nn.Linear(d_model, d_points)
fc_delta = nn.Sequential(
    nn.Linear(3, d_model),
    nn.ReLU(),
    nn.Linear(d_model, d_model)
)
fc_gamma = nn.Sequential(
    nn.Linear(d_model, d_model),
    nn.ReLU(),
    nn.Linear(d_model, d_model)
)
w_qs = nn.Linear(d_model, d_model, bias=False)
w_ks = nn.Linear(d_model, d_model, bias=False)
w_vs = nn.Linear(d_model, d_model, bias=False)
k = 16

# start

In [6]:
random_seed = 1234
torch.manual_seed(random_seed)
point = torch.randn(8,1024,6)
xyz = point[..., :3]
print(xyz.shape)

torch.Size([8, 1024, 3])


In [7]:
fcc1 = nn.Sequential(
    nn.Linear(6, 32), 
    nn.ReLU(),
    nn.Linear(32, 32) # point [16,1024,6] to [16,1024,32]
    )
features=fcc1(point)
print(features.shape)

torch.Size([8, 1024, 32])


In [8]:
dists = square_distance(xyz, xyz)# 计算点距离 逐个相减
    #dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
    #     = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
print(dists.shape)

torch.Size([8, 1024, 1024])


In [9]:
knn_idx = dists.argsort()[:, :, :k]  # b x n x k 排序取前k个
print(knn_idx.shape)

torch.Size([8, 1024, 16])


In [10]:
knn_xyz = index_points(xyz, knn_idx)
print(knn_xyz.shape)

torch.Size([8, 1024, 16, 3])


In [11]:
pre = features
x = fc1(features)   #features [16,1024,32] to [16,1024,512]
print(x.shape)
q, k, v = w_qs(x), index_points(w_ks(x), knn_idx), index_points(w_vs(x), knn_idx)
print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([8, 1024, 512])
torch.Size([8, 1024, 512])
torch.Size([8, 1024, 16, 512])
torch.Size([8, 1024, 16, 512])


In [12]:
k.size(-1)

512

In [13]:
np.sqrt(k.size(-1))

22.627416997969522

In [14]:
pos_enc = fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f  邻居向量
print(pos_enc.shape)
print((xyz[:, :, None] - knn_xyz).shape)

torch.Size([8, 1024, 16, 512])
torch.Size([8, 1024, 16, 3])


In [15]:
attn = fc_gamma(q[:, :, None] - k + pos_enc)
print(attn.shape)
attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
print(attn.shape)

torch.Size([8, 1024, 16, 512])
torch.Size([8, 1024, 16, 512])


In [16]:
res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
print(res.shape)
res = fc2(res) + pre
print(res.shape)

torch.Size([8, 1024, 512])
torch.Size([8, 1024, 32])


In [None]:
torch.manual_seed(1234)
point1 = torch.randn(2,4,3)
print(point1)
print(point1[:, :, None].shape)
# print(point[:, :, None])
print(point1[:, None].shape)
# print(point[:, None])
print((point1[:, :, None] - point1[:, None]).shape)
# print(point[:, :, None] - point[:, None])
print(torch.sum((point1[:, :, None] - point1[:, None]) ** 2, dim=-1).shape)