In [None]:
def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx


def householder(x, y):
    if x.shape != y.shape:
        print('The Feature of X and Y Must Have The Same Size!')
    rho = torch.norm(x, dim=-1)/torch.norm(y, dim=-1)
    y = rho.reshape(x.shape[0], 1) * y
    norm = torch.norm(x - y, dim=-1).reshape(x.shape[0], 1)
    w = (x - y) / norm
    ww = w.reshape(x.shape[0], x.shape[-1], 1)
    
    wt=w.reshape(x.shape[0],1,x.shape[-1])
    device = torch.device('cuda')
    I=torch.eye(x.shape[-1]).expand(x.shape[0],x.shape[-1],x.shape[-1])
    I=I.to(device)
    H=I-(2*torch.matmul(ww,wt))
    
    rho1 = rho.reshape(x.shape[0], 1, 1)
    return H,rho1
    #return (ww * rho1).reshape(x.shape[0], x.shape[-1])


def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        idx = knn(x, k=k)  # (batch_size, num_points, k)
    device = torch.device('cuda')
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.reshape(batch_size*num_points, k)[:, 1:k]  #delete the 1st column, which is the coord of the center itself
    idx = idx.reshape(batch_size*num_points*(k-1))
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()  # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims)
    neighbor = x.view(batch_size * num_points, -1)[idx, :]  # batch_size * num_points * k + range(0, batch_size*num_points)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k-1, 1)
    x = x.view(batch_size * num_points * (k-1), -1)  # reshape x for computing householder easily

    w = householder(x, neighbor)
    w = w.view(batch_size, num_points, k - 1, num_dims)
    x = x.view(batch_size, num_points, k - 1, num_dims)
    feature = torch.cat((x, w), dim=3).permute(0, 3, 1, 2)
    return feature