Implementation of Point Cloud Transformer (PCT) in Tensorflow, based on its original implementation.

Original repo: https://github.com/qinglew/PointCloudTransformer/

Paper: https://link.springer.com/article/10.1007/s41095-021-0229-5

In [None]:
import tensorflow as tf
import math
import tensorflow.keras.backend as K

# Naive PCT Classification

In [None]:
class Embedding(tf.keras.layers.Layer):
  '''
  Input: (B, N, in_channels)
  Output: (B, N, out_channels)
  '''
  def __init__(self, in_channels=3, out_channels=128, **kwargs):
    self.input_dim = in_channels
    self.output_dim = out_channels
    super(Embedding, self).__init__(**kwargs)

    self.conv1 = tf.keras.layers.Conv1D(out_channels, kernel_size=1, use_bias=False)
    self.conv2 = tf.keras.layers.Conv1D(out_channels, kernel_size=1, use_bias=False)

    self.bn1 = tf.keras.layers.BatchNormalization()
    self.bn2 = tf.keras.layers.BatchNormalization()

    self.relu1 = tf.keras.layers.Activation("relu")
    self.relu2 = tf.keras.layers.Activation("relu")

  def call(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu1(x)

    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu2(x)
    return x

In [None]:
class SA(tf.keras.layers.Layer):
  '''
  Input: (B, N, out_channels)
  Output: (B, N, out_channels)
  '''
  def __init__(self, channels=128, **kwargs):
    super(SA, self).__init__(**kwargs)

    self.da = channels // 4

    self.q_conv = tf.keras.layers.Conv1D(channels // 4, kernel_size=1, use_bias=False)
    self.k_conv = tf.keras.layers.Conv1D(channels // 4, kernel_size=1, use_bias=False)
    self.v_conv = tf.keras.layers.Conv1D(channels, kernel_size=1)

    self.trans_conv = tf.keras.layers.Conv1D(channels, kernel_size=1)
    self.after_norm = tf.keras.layers.BatchNormalization()

    self.act = tf.keras.layers.Activation("relu")
    self.softmax = tf.keras.layers.Softmax(axis=-1)

  def call(self, x):
    x_q = self.q_conv(x)
    x_q = tf.transpose(x_q, perm=[0,2,1])
    x_k = self.k_conv(x)
    x_v = self.v_conv(x)

    energy = tf.matmul(x_k, x_q) / (math.sqrt(self.da))
    attention = self.softmax(energy)

    x_s = tf.matmul(x_v, attention, transpose_a=True)
    x_s = tf.transpose(x_s, perm=[0, 2, 1])
    x_s = self.act(self.after_norm(self.trans_conv(x_s)))

    x = x + x_s

    return x

In [None]:
class NaivePCT(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(NaivePCT, self).__init__(**kwargs)

    self.embedding = Embedding(3, 128)

    self.sa1 = SA(128)
    self.sa2 = SA(128)
    self.sa3 = SA(128)
    self.sa4 = SA(128)

    self.conv1d = tf.keras.layers.Dense(1024)
    self.bn = tf.keras.layers.BatchNormalization()
    self.act = tf.keras.layers.LeakyReLU(negative_slope=0.2)

  def call(self, x):
    x = self.embedding(x)

    x1 = self.sa1(x)
    x2 = self.sa2(x1)
    x3 = self.sa3(x2)
    x4 = self.sa4(x3)
    x = tf.concat([x1, x2, x3, x4], axis=1)

    x = self.conv1d(x)
    x = self.bn(x)
    x = self.act(x)

    x_max = tf.math.reduce_max(x, axis=-1)
    x_mean = tf.math.reduce_mean(x, axis=-1)

    #return x, x_max, x_mean
    return tf.concat([x_max, x_mean], axis=-1)

In [None]:
def PCT_Naive_Classification(out_classes, batch_size=128):
  inputs = tf.keras.layers.Input(shape=(1024,3), batch_size=batch_size)
  x = NaivePCT()(inputs)

  x = tf.keras.layers.Dense(512)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Activation("relu")(x)
  x = tf.keras.layers.Dropout(0.5)(x)

  x = tf.keras.layers.Dense(256)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Activation("relu")(x)
  x = tf.keras.layers.Dropout(0.5)(x)

  outputs = tf.keras.layers.Dense(out_classes)(x)

  return tf.keras.Model(inputs=inputs, outputs=outputs, name='PCT_Naive_Classification')

In [None]:
def sqrd_euclidean_distance_loss(y_true, y_pred):
    """
    Squared Euclidean distance loss
    https://en.wikipedia.org/wiki/Euclidean_distance
    :param y_true: TensorFlow/Theano tensor
    :param y_pred: TensorFlow/Theano tensor of the same shape as y_true
    :return: float
    """
    return K.sum(K.square(y_pred - y_true), axis=-1)

model = PCT_Naive_Classification(6)

# Compiling
model.compile(
  loss=sqrd_euclidean_distance_loss,
  optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
  metrics=[sqrd_euclidean_distance_loss],
)

model.summary()

# PCT Classification

## Utils

In [None]:
def farthest_point_sampling(npoint, xyz):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    B, N, C = xyz.shape
    centroids = tf.TensorArray(dtype=tf.int32, size=npoint, dynamic_size=False)
    distance = tf.ones((B, N), dtype=tf.float32) * 1e10
    farthest = tf.random.uniform((B,), minval=0, maxval=N, dtype=tf.int32)
    batch_indices = tf.range(B, dtype=tf.int32)

    def loop_body(i, centroids, distance, farthest):
        centroids = centroids.write(i, farthest)
        centroid = tf.gather(xyz, farthest, batch_dims=1)[:, tf.newaxis, :]  # [B, 1, 3]
        dist = tf.reduce_sum((xyz - centroid) ** 2, axis=-1)  # [B, N]
        mask = dist < distance
        distance = tf.where(mask, dist, distance)
        farthest = tf.argmax(distance, axis=-1, output_type=tf.int32)
        return i + 1, centroids, distance, farthest

    _, centroids, _, _ = tf.while_loop(
        lambda i, *args: i < npoint,
        loop_body,
        [0, centroids, distance, farthest]
    )

    return tf.transpose(centroids.stack(), perm=[1, 0])  # [B, npoint]

def pairwise_distance(point_cloud):
  """Compute pairwise distance of a point cloud.

  Args:
    point_cloud: tensor (batch_size, num_points, num_dims)

  Returns:
    pairwise distance: (batch_size, num_points, num_points)
  """
  og_batch_size = point_cloud.shape[0]
  point_cloud = tf.squeeze(point_cloud)
  if og_batch_size == 1:
    point_cloud = tf.expand_dims(point_cloud, 0)

  point_cloud_transpose = tf.transpose(point_cloud, perm=[0, 2, 1])
  point_cloud_inner = tf.matmul(point_cloud, point_cloud_transpose)
  point_cloud_inner = -2*point_cloud_inner
  point_cloud_square = tf.reduce_sum(tf.square(point_cloud), axis=-1, keepdims=True)
  point_cloud_square_tranpose = tf.transpose(point_cloud_square, perm=[0, 2, 1])
  return point_cloud_square + point_cloud_inner + point_cloud_square_tranpose

def knn(adj_matrix, k=16):
  """Get KNN based on the pairwise distance.
  Args:
    pairwise distance: (batch_size, num_points, num_points)
    k: int

  Returns:
    nearest neighbors: (batch_size, num_points, k)
  """
  neg_adj = -adj_matrix
  _, nn_idx = tf.nn.top_k(neg_adj, k=k)
  return nn_idx

def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    B, _, _ = points.shape
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = tf.range(B, dtype=tf.int32)
    batch_indices = tf.reshape(batch_indices, shape=view_shape)  # reshape to view_shape
    batch_indices = tf.tile(batch_indices, multiples=repeat_shape)  # repeat according to repeat_shape
    new_points = tf.gather(points, idx, axis=1, batch_dims=1)
    return new_points

def square_distance(src, dst):
    """
    Compute squared distances between two point sets.
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: squared distances, [B, N, M]
    """
    dist = tf.reduce_sum(src**2, axis=-1, keepdims=True) - \
           2 * tf.matmul(src, tf.transpose(dst, perm=[0, 2, 1])) + \
           tf.transpose(tf.reduce_sum(dst**2, axis=-1, keepdims=True), perm=[0, 2, 1])
    return tf.maximum(dist, 0)  # Avoid small negative values due to floating-point errors

def query_ball_point(radius, cardinality, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        cardinality: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, cardinality]
    """
    B, N, C = xyz.shape
    S = tf.shape(new_xyz)[1]

    # Generate group indices [1, 1, N] -> [B, S, N]
    group_idx = tf.tile(tf.range(N, dtype=tf.int32)[tf.newaxis, tf.newaxis, :], [B, S, 1])

    # Compute squared distances between new_xyz and xyz
    sqrdists = square_distance(new_xyz, xyz)

    # Mask out points outside the radius
    group_idx = tf.where(sqrdists > radius**2, N, group_idx)

    # Sort and select the top cardinality points
    group_idx = tf.sort(group_idx, axis=-1)[:, :, :cardinality]

    # Handle cases where there are fewer than cardinality points within the radius
    group_first = tf.tile(group_idx[:, :, 0:1], [1, 1, cardinality])  # Repeat the first valid index
    mask = tf.equal(group_idx, N)
    group_idx = tf.where(mask, group_first, group_idx)

    return group_idx

def sample_and_group(xyz, cardinality, ball_query=False, radius=0.2, use_knn=True, k=32):
    '''
    Input:
        k (nsample): int32
        radius: float32
        cardinality (npoint): int32
        xyz: (batch_size, ndataset, 3) TF tensor (original points)
        knn: bool, if True use kNN instead of radius search
    Output:
        new_xyz: (batch_size, cardinality, 3) TF tensor (sampled points)
        new_points: (batch_size, cardinality, k, 3) TF tensor (grouped points relative to sampled points)
        idx: (batch_size, cardinality, k) TF tensor, indices of local points as in ndataset points
        grouped_xyz: (batch_size, cardinality, k, 3) TF tensor, normalized point XYZs
            (subtracted by seed point XYZ) in local regions
    '''
    batch_size, num_points, num_dims = xyz.shape
    # farthest point sampling to get centroids
    fps_idx = farthest_point_sampling(cardinality, xyz) # [B, cardinality]

    # gather points corresponding to centroids
    new_xyz = tf.gather(xyz, fps_idx, batch_dims=1) # [B, cardinality, C]

    if ball_query:
      idx = query_ball_point(radius, k, xyz, new_xyz)

    elif use_knn:
      # compute pairwise distance between sampled points and original points
      sqrdists = square_distance(new_xyz, xyz) # [B, cardinality, num_points]

      # get k nearest neighbors from original points for each sampled point
      neg_sqrdists = -sqrdists
      _, idx = tf.nn.top_k(neg_sqrdists, k=k) # [B, cardinality, k]

    # gather the features (coordinates) of the k nearest neighbors from the original xyz
    batch_indices = tf.tile(tf.range(batch_size)[:, tf.newaxis, tf.newaxis], [1, cardinality, k])
    grouped_xyz = tf.gather_nd(xyz, tf.stack([batch_indices, idx], axis=-1)) # [B, cardinality, k, C]

    # Normalization (subtract centroid)
    grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1, 1, k, 1]) # [B, cardinality, k, C]

    new_points = grouped_xyz # For this layer, new_points are the grouped and normalized xyz

    return new_xyz, new_points, idx, grouped_xyz

## Network

In [None]:
class SG(tf.keras.layers.Layer):
  def __init__(self, s, in_channels, out_channels, **kwargs):
    super(SG, self).__init__(**kwargs)

    self.s = s
    self.in_channels = in_channels
    self.out_channels = out_channels

    self.conv1 = tf.keras.layers.Conv1D(out_channels, kernel_size=1, use_bias=False)
    self.conv2 = tf.keras.layers.Conv1D(out_channels, kernel_size=1, use_bias=False)
    self.bn1 = tf.keras.layers.BatchNormalization()
    self.bn2 = tf.keras.layers.BatchNormalization()
    self.relu1 = tf.keras.layers.Activation("relu")
    self.relu2 = tf.keras.layers.Activation("relu")

  def call(self, coords):
    new_xyz, new_feature, _, _ = sample_and_group(xyz=coords, cardinality=self.s, use_knn=True, k=32)
    b, s, k, d = new_feature.shape
    new_feature = tf.transpose(new_feature, perm=[0, 1, 3, 2])
    new_feature = tf.reshape(new_feature, [-1, d, k])
    batch_size = new_feature.shape[0]
    new_feature = self.relu1(self.bn1(self.conv1(new_feature)))
    new_feature = self.relu2(self.bn2(self.conv2(new_feature)))
    new_feature = tf.math.reduce_max(new_feature, axis=-1)
    new_feature = tf.reshape(new_feature, [b, s, -1])
    new_feature = tf.transpose(new_feature, perm=[0, 2, 1])

    return new_xyz, new_feature

In [None]:
class NeighborEmbedding(tf.keras.layers.Layer):
  def __init__(self, samples=[512, 256], **kwargs):
    super(NeighborEmbedding, self).__init__(**kwargs)

    self.conv1 = tf.keras.layers.Conv1D(64, kernel_size=1, use_bias=False)
    self.conv2 = tf.keras.layers.Conv1D(64, kernel_size=1, use_bias=False)
    self.bn1 = tf.keras.layers.BatchNormalization()
    self.bn2 = tf.keras.layers.BatchNormalization()
    self.relu1 = tf.keras.layers.Activation("relu")
    self.relu2 = tf.keras.layers.Activation("relu")

    self.sg1 = SG(s=samples[0], in_channels=128, out_channels=128)
    self.sg2 = SG(s=samples[1], in_channels=256, out_channels=256)

  def call(self, x):
    '''
    Input: [B, N, 3]
    Output: [B, 256, 256]
    '''
    #xyz = tf.transpose(x, perm=[0, 2, 1])
    features = self.relu1(self.bn1(self.conv1(x)))
    features = self.relu2(self.bn2(self.conv2(features)))

    new_xyz, new_features = self.sg1(features)
    _, new_features2 = self.sg2(new_features)

    return new_features2

In [None]:
class OA(tf.keras.layers.Layer):
  def __init__(self, channels, **kwargs):
    super(OA, self).__init__(**kwargs)

    self.q_conv = tf.keras.layers.Conv1D(channels // 4, kernel_size=1, use_bias=False)
    self.k_conv = tf.keras.layers.Conv1D(channels // 4, kernel_size=1, use_bias=False)
    self.v_conv = tf.keras.layers.Conv1D(channels, kernel_size=1)

    self.trans_conv = tf.keras.layers.Conv1D(channels, kernel_size=1)
    self.after_norm = tf.keras.layers.BatchNormalization()

    self.act = tf.keras.layers.Activation("relu")
    self.softmax = tf.keras.layers.Softmax(axis=-1)

  def call(self, x):
    x_q = self.q_conv(x)
    x_q = tf.transpose(x_q, perm=[0,2,1])
    x_k = self.k_conv(x)
    x_v = self.v_conv(x)

    energy = tf.matmul(x_k, x_q)
    attention = self.softmax(energy)
    attention = attention / (1e-9 + tf.reduce_sum(attention, axis=1, keepdims=True))

    x_r = tf.matmul(x_v, attention, transpose_a=True)
    x_r = tf.transpose(x_r, perm=[0,2,1])
    x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
    x = x + x_r

    return x

In [None]:
class PCT(tf.keras.layers.Layer):
  def __init__(self, samples=[512, 256], **kwargs):
    super(PCT, self).__init__(**kwargs)

    self.neighbor_embedding = NeighborEmbedding(samples)

    self.oa1 = OA(256)
    self.oa2 = OA(256)
    self.oa3 = OA(256)
    self.oa4 = OA(256)

    self.conv1d = tf.keras.layers.Dense(1024)
    self.bn = tf.keras.layers.BatchNormalization()
    self.act = tf.keras.layers.LeakyReLU(negative_slope=0.2)

  def call(self, x):
    x = self.neighbor_embedding(x)

    x1 = self.oa1(x)
    x2 = self.oa2(x1)
    x3 = self.oa3(x2)
    x4 = self.oa4(x3)

    x = tf.concat([x, x1, x2, x3, x4], axis=1)

    x = self.conv1d(x)
    x = self.bn(x)
    x = self.act(x)

    x_max = tf.math.reduce_max(x, axis=-1)
    x_mean = tf.math.reduce_mean(x, axis=-1)

    #return x, x_max, x_mean
    return tf.concat([x_max, x_mean], axis=-1)

In [None]:
def PCT_Classification(out_classes, batch_size=128):
  inputs = tf.keras.layers.Input(shape=(1024,3), batch_size=batch_size)
  x = PCT()(inputs)

  x = tf.keras.layers.Dense(512)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Activation("relu")(x)
  x = tf.keras.layers.Dropout(0.5)(x)

  x = tf.keras.layers.Dense(256)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Activation("relu")(x)
  x = tf.keras.layers.Dropout(0.5)(x)

  outputs = tf.keras.layers.Dense(out_classes)(x)

  return tf.keras.Model(inputs=inputs, outputs=outputs, name='PCT_Classification')

In [None]:
def sqrd_euclidean_distance_loss(y_true, y_pred):
    """
    Squared Euclidean distance loss
    https://en.wikipedia.org/wiki/Euclidean_distance
    :param y_true: TensorFlow/Theano tensor
    :param y_pred: TensorFlow/Theano tensor of the same shape as y_true
    :return: float
    """
    return K.sum(K.square(y_pred - y_true), axis=-1)

model = PCT_Classification(6)

# Compiling
model.compile(
  loss=sqrd_euclidean_distance_loss,
  optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
  metrics=[sqrd_euclidean_distance_loss],
)

model.summary()