In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np

In [2]:
torch.manual_seed(1234)

<torch._C.Generator at 0x7ff8901928b0>

In [3]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    计算每两个点之间的欧几里得距离。
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C]
    """
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)

In [4]:
d_model=512
d_points = 32
fc1 = nn.Linear(d_points, d_model)
fc2 = nn.Linear(d_model, d_points)
fc_delta = nn.Sequential(
    nn.Linear(3, d_model),
    nn.ReLU(),
    nn.Linear(d_model, d_model)
)
fc_gamma = nn.Sequential(
    nn.Linear(d_model, d_model),
    nn.ReLU(),
    nn.Linear(d_model, d_model)
)
w_qs = nn.Linear(d_model, d_model, bias=False)
w_ks = nn.Linear(d_model, d_model, bias=False)
w_vs = nn.Linear(d_model, d_model, bias=False)
k = 16

# start

In [5]:
random_seed = 1234
torch.manual_seed(random_seed)
point = torch.randn(8,1024,6)
xyz = point[..., :3]
print(xyz.shape)

torch.Size([8, 1024, 3])


In [6]:
fcc1 = nn.Sequential(
    nn.Linear(6, 32), 
    nn.ReLU(),
    nn.Linear(32, 32) # point [16,1024,6] to [16,1024,32]
    )
features=fcc1(point)
print(features.shape)

torch.Size([8, 1024, 32])


In [7]:
dists = square_distance(xyz, xyz)# 计算点距离 逐个相减
    #dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
    #     = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
print(dists.shape)

torch.Size([8, 1024, 1024])


In [8]:
knn_idx = dists.argsort()[:, :, :k]  # b x n x k 排序取前k个
print(knn_idx.shape)

torch.Size([8, 1024, 16])


In [9]:
knn_xyz = index_points(xyz, knn_idx)
print(knn_xyz.shape)

torch.Size([8, 1024, 16, 3])


In [10]:
pre = features
x = fc1(features)   #features [16,1024,32] to [16,1024,512]
print(x.shape)
q, k, v = w_qs(x), index_points(w_ks(x), knn_idx), index_points(w_vs(x), knn_idx)
print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([8, 1024, 512])
torch.Size([8, 1024, 512])
torch.Size([8, 1024, 16, 512])
torch.Size([8, 1024, 16, 512])


In [11]:
k.size(-1)

512

In [12]:
np.sqrt(k.size(-1))

22.627416997969522

In [13]:
pos_enc = fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f  邻居向量
print(pos_enc.shape)
print((xyz[:, :, None] - knn_xyz).shape)

torch.Size([8, 1024, 16, 512])
torch.Size([8, 1024, 16, 3])


In [14]:
attn = fc_gamma(q[:, :, None] - k + pos_enc)
print(attn.shape)
attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
print(attn.shape)

torch.Size([8, 1024, 16, 512])
torch.Size([8, 1024, 16, 512])


In [15]:
res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
print(res.shape)
res = fc2(res) + pre
print(res.shape)

torch.Size([8, 1024, 512])
torch.Size([8, 1024, 32])


In [None]:
torch.manual_seed(1234)
point1 = torch.randn(2,4,3)
print(point1)
print(point1[:, :, None].shape)
# print(point[:, :, None])
print(point1[:, None].shape)
# print(point[:, None])
print((point1[:, :, None] - point1[:, None]).shape)
# print(point[:, :, None] - point[:, None])
print(torch.sum((point1[:, :, None] - point1[:, None]) ** 2, dim=-1).shape)

In [2]:
torch.manual_seed(1234)
x = torch.randn(2,4,3,2)
x

tensor([[[[-0.1117, -0.4966],
          [ 0.1631, -0.8817],
          [ 0.0539,  0.6684]],

         [[-0.0597, -0.4675],
          [-0.2153,  0.8840],
          [-0.7584, -0.3689]],

         [[-0.3424, -1.4020],
          [ 0.3206, -1.0219],
          [ 0.7988, -0.0923]],

         [[-0.7049, -1.6024],
          [ 0.2891,  0.4899],
          [-0.3853, -0.7120]]],


        [[[-0.1706, -1.4594],
          [ 0.2207,  0.2463],
          [-1.3248,  0.6970]],

         [[-0.6631,  1.2158],
          [-1.4949,  0.8810],
          [-1.1786, -0.9340]],

         [[-0.5675, -0.2772],
          [-2.1834,  0.3668],
          [ 0.9380,  0.0078]],

         [[-0.3139, -1.1567],
          [ 1.8409, -1.0174],
          [ 1.2192,  0.1601]]]])

In [3]:
x = torch.max(x, 2)[0]

In [4]:
x

tensor([[[ 0.1631,  0.6684],
         [-0.0597,  0.8840],
         [ 0.7988, -0.0923],
         [ 0.2891,  0.4899]],

        [[ 0.2207,  0.6970],
         [-0.6631,  1.2158],
         [ 0.9380,  0.3668],
         [ 1.8409,  0.1601]]])

# BN

In [6]:
# With Learnable Parameters
m = nn.BatchNorm2d(1024)
# Without Learnable Parameters
# m = nn.BatchNorm1d(4, affine=False)

In [4]:
input = torch.randn(8, 1024, 16, 512)
input.shape

torch.Size([8, 1024, 16, 512])

In [7]:
output = m(input)
output.shape

torch.Size([8, 1024, 16, 512])