# Reading Note "Graph Attention for PC Segment"



In [1]:
import torch
import torch.nn as nn

In [2]:
# input: [10, 3+6]
torch.manual_seed(1)
N = 10
dim_feat = 6
pc_wf = torch.rand(N, 3+dim_feat) # point cloud with feat.
# input with 3d info
pc_xyz = pc_wf[:,:3]
pc_feat = pc_wf[:,3:]
print(pc_wf, pc_wf.shape)
print(pc_xyz, pc_xyz.shape)
print(pc_feat, pc_feat.shape)

tensor([[0.7576, 0.2793, 0.4031, 0.7347, 0.0293, 0.7999, 0.3971, 0.7544, 0.5695],
        [0.4388, 0.6387, 0.5247, 0.6826, 0.3051, 0.4635, 0.4550, 0.5725, 0.4980],
        [0.9371, 0.6556, 0.3138, 0.1980, 0.4162, 0.2843, 0.3398, 0.5239, 0.7981],
        [0.7718, 0.0112, 0.8100, 0.6397, 0.9743, 0.8300, 0.0444, 0.0246, 0.2588],
        [0.9391, 0.4167, 0.7140, 0.2676, 0.9906, 0.2885, 0.8750, 0.5059, 0.2366],
        [0.7570, 0.2346, 0.6471, 0.3556, 0.4452, 0.0193, 0.2616, 0.7713, 0.3785],
        [0.9980, 0.9008, 0.4766, 0.1663, 0.8045, 0.6552, 0.1768, 0.8248, 0.8036],
        [0.9434, 0.2197, 0.4177, 0.4903, 0.5730, 0.1205, 0.1452, 0.7720, 0.3828],
        [0.7442, 0.5285, 0.6642, 0.6099, 0.6818, 0.7479, 0.0369, 0.7517, 0.1484],
        [0.1227, 0.5304, 0.4148, 0.7937, 0.2104, 0.0555, 0.8639, 0.4259, 0.7812]]) torch.Size([10, 9])
tensor([[0.7576, 0.2793, 0.4031],
        [0.4388, 0.6387, 0.5247],
        [0.9371, 0.6556, 0.3138],
        [0.7718, 0.0112, 0.8100],
        [0.9391, 0.4167

src = torch.ones(10,3)
# print(src)
# dst = torch.zeros(10,3)
dst = 2*torch.ones(10,3)
# print(dst.permute(-1,-2))
dist = -2*torch.matmul(src, dst.permute(-1,-2))
print(dist)
dist += torch.sum(src**2, dim=-1)
dist += torch.sum(dst**2, dim=-1).view(10,-1)
print(dist)
# print(torch.sum(src**2, dim=-1))
# torch.sum(dst**2, dim=-1).view(10,-1)

In [3]:
def sqrt_dist(src, dst):
    '''
    Calculate Squared Euclidean distance between each two points
    
    Input: 
        src: source points [M, 3]
        dst: target points [N, 3]
    Output:
        sqrt_dist: per-point Squared Euclidean distance [M, N]
    '''
    M, _ = src.shape
    N, _ = dst.shape
    # print(N, M)
    # Squared Euclidean distance between every two vertices
    sqrt_dist = -2*torch.matmul(src, dst.permute(-1,-2))
    # print(sqrt_dist.shape)
    sqrt_dist += torch.sum(src**2, dim=-1).view(M, -1)
    # print(torch.sum(src**2, dim=-1).view(M, -1).shape)
    sqrt_dist += torch.sum(dst**2, dim=-1).view(-1, N)
    # print(torch.sum(dst**2, dim=-1).view(-1,N).shape)
    return sqrt_dist

# sqrt_dist = sqrt_dist(pc_3d, pc_3d)
# print(sqrt_dist)
# vertices_id = torch.arange(10).view(1,-1).repeat(10, 1)
# vertices_id[sqrt_dist>0.5**2] = 10
# vertices_id = sqrt_dist.sort(dim=-1)[:5]


In [4]:
def creat_graph(pc_xyz, radius=128):
  '''
  Create a graph of each point and their five nearest points within the radius

  Input:
    pc_xyz: point cloud with 3d info [N, 3]
    radius: neighbourhood radius
  Output:
    graph: a matrix of each point index and their five neariset neighbours' within the radius
  '''
  # vertices_id = torch.arange(10).view(1,-1).repeat(10,1)
  # print(vertices_id)
  # Squared Euclidean distance
  dist = sqrt_dist(pc_xyz, pc_xyz)
  # print(dist, dist.shape)
  # create a graph of every vertix and its five nearest neighbours (vertice id)
  graph = dist.sort(dim=-1)[1][:,:5]
  # print(g, g.shape)
  # store the distance of five nearest neighbours
  dist_value = dist.sort()[0][:,:5]
  # print(dist_value, dist_value.shape)
  
  # create a mask for neighbours out of radius
  mask = dist_value > radius**2
  # print(mask, mask.shape)
  # store the index of the nearest neighbour, which is the vertix itself
  vertix_nearest = dist.sort()[1][:,0].view(-1,1).repeat(1, 5)
  # print(vertix_nearest, vertix_nearest.shape)
  # update graph by replacing the neighbours outside radius with the nearest vertix
  graph[mask] = vertix_nearest[mask]
  # print(g, g.shape)
  return graph

graph = creat_graph(pc_xyz)
graph


tensor([[0, 7, 5, 8, 4],
        [1, 9, 8, 0, 5],
        [2, 6, 8, 0, 7],
        [3, 5, 4, 7, 0],
        [4, 8, 5, 7, 0],
        [5, 0, 4, 3, 8],
        [6, 2, 8, 4, 1],
        [7, 0, 5, 4, 8],
        [8, 4, 5, 1, 0],
        [9, 1, 8, 0, 5]])

In [5]:
group_xyz = pc_xyz[graph, :]
print(group_xyz.shape)
group_feat = pc_feat[graph, :]
print(group_feat.shape)


torch.Size([10, 5, 3])
torch.Size([10, 5, 6])


In [6]:
def gac_forward(group_xyz, group_feat, K):
  '''
  Implement Graph Attention Convolution to get new features with xyz info
  Input:
    group_xyz: points neighbourhood with xyz info
    group_feat: points neighbourhood without xyz info (just features, can be original or learned)
    K: dimention of transformed features
  Output:
    new_feat: new features after GAC computation (H' in the paper)
  '''
  N, S, C = group_xyz.shape
  _, _, F = group_feat.shape
  query_xyz = group_xyz[:,0,:].view(N, 1, -1).repeat(1, S, 1)
  # print(query_xyz.shape)
  query_feat = group_feat[:,0,:].view(N, 1, -1).repeat(1, S, 1)
  # print(query_feat.shape)
  delta_p = group_xyz -query_xyz
  # print(delta_p.shape)
  M_g = nn.Linear(F, K)
  delta_h = M_g(query_feat) - M_g(query_feat)
  # print(delta_h.shape)
  delta_p_cat_h = torch.cat([delta_p, delta_h], dim=-1)
  # print(delta_p_cat_h.shape)
  M_a = nn.Linear(C+K, K)
  att_score = M_a(delta_p_cat_h)
  # print(att_score[0], att_score.shape)
  att_score = att_score.softmax(dim=1)
  # print(att_score[0], att_score.shape)
  new_feat = torch.sum(torch.mul(att_score, M_g(group_feat)), dim=1)
  # print(new_feat.shape)

  return new_feat

new_feat = gac_forward(group_xyz, group_feat, 10)
