#PyTorch

##utils

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [447]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist


def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points


def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].unsqueeze(1)#.view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids


def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx


def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
    new_xyz = index_points(xyz, fps_idx)
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)

    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points

def sample_and_group_all(xyz, points):
    """
    Input:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, 3]
        new_points: sampled points data, [B, 1, N, 3+D]
    """
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points

In [448]:
class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points

##Net

In [449]:
class get_model(nn.Module):
    def __init__(self,num_class,normal_channel=False):
        super(get_model, self).__init__()
        in_channel = 6 if normal_channel else 3
        self.normal_channel = normal_channel
        self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
        self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.4)
        self.fc3 = nn.Linear(256, num_class)

    def forward(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x = F.linear(x)


        return x, l3_points

In [450]:
net = get_model(num_class=6)

#Tensorflow

In [None]:
!pip install tensorflow==2.10

In [26]:
import tensorflow as tf
import keras.backend as K
import numpy as np
tf.__version__

'2.10.0'

In [4]:
def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    B, _, _ = points.shape
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = tf.range(B, dtype=tf.int32)
    batch_indices = tf.reshape(batch_indices, shape=view_shape)  # reshape to view_shape
    batch_indices = tf.tile(batch_indices, multiples=repeat_shape)  # repeat according to repeat_shape
    new_points = tf.gather(points, idx, axis=1, batch_dims=1)
    return new_points

def farthest_point_sample(npoint, xyz):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    B, N, C = xyz.shape
    centroids = tf.TensorArray(dtype=tf.int32, size=npoint, dynamic_size=False)
    distance = tf.ones((B, N), dtype=tf.float32) * 1e10
    farthest = tf.random.uniform((B,), minval=0, maxval=N, dtype=tf.int32)
    batch_indices = tf.range(B, dtype=tf.int32)

    def loop_body(i, centroids, distance, farthest):
        centroids = centroids.write(i, farthest)
        centroid = tf.gather(xyz, farthest, batch_dims=1)[:, tf.newaxis, :]  # [B, 1, 3]
        dist = tf.reduce_sum((xyz - centroid) ** 2, axis=-1)  # [B, N]
        mask = dist < distance
        distance = tf.where(mask, dist, distance)
        farthest = tf.argmax(distance, axis=-1, output_type=tf.int32)
        return i + 1, centroids, distance, farthest

    _, centroids, _, _ = tf.while_loop(
        lambda i, *args: i < npoint,
        loop_body,
        [0, centroids, distance, farthest]
    )

    return tf.transpose(centroids.stack(), perm=[1, 0])  # [B, npoint]

def square_distance(src, dst):
    """
    Compute squared distances between two point sets.
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: squared distances, [B, N, M]
    """
    dist = tf.reduce_sum(src**2, axis=-1, keepdims=True) - \
           2 * tf.matmul(src, tf.transpose(dst, perm=[0, 2, 1])) + \
           tf.transpose(tf.reduce_sum(dst**2, axis=-1, keepdims=True), perm=[0, 2, 1])
    return tf.maximum(dist, 0)  # Avoid small negative values due to floating-point errors


def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    B, N, C = xyz.shape
    S = tf.shape(new_xyz)[1]

    # Generate group indices [1, 1, N] -> [B, S, N]
    group_idx = tf.tile(tf.range(N, dtype=tf.int32)[tf.newaxis, tf.newaxis, :], [B, S, 1])

    # Compute squared distances between new_xyz and xyz
    sqrdists = square_distance(new_xyz, xyz)

    # Mask out points outside the radius
    group_idx = tf.where(sqrdists > radius**2, N, group_idx)

    # Sort and select the top nsample points
    group_idx = tf.sort(group_idx, axis=-1)[:, :, :nsample]

    # Handle cases where there are fewer than nsample points within the radius
    group_first = tf.tile(group_idx[:, :, 0:1], [1, 1, nsample])  # Repeat the first valid index
    mask = tf.equal(group_idx, N)
    group_idx = tf.where(mask, group_first, group_idx)

    return group_idx

def sample_and_group(npoint, radius, nsample, xyz, points, use_xyz=True):
    '''
    Input:
        npoint: int32
        radius: float32
        nsample: int32
        xyz: (batch_size, ndataset, 3) TF tensor
        points: (batch_size, ndataset, channel) TF tensor, if None will just use xyz as points
        knn: bool, if True use kNN instead of radius search
        use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
    Output:
        new_xyz: (batch_size, npoint, 3) TF tensor
        new_points: (batch_size, npoint, nsample, 3+channel) TF tensor
        idx: (batch_size, npoint, nsample) TF tensor, indices of local points as in ndataset points
        grouped_xyz: (batch_size, npoint, nsample, 3) TF tensor, normalized point XYZs
            (subtracted by seed point XYZ) in local regions
    '''

    # Farthest point sampling
    fps_idx = farthest_point_sample(npoint, xyz)
    new_xyz = index_points(xyz, fps_idx) # (batch_size, npoint, 3)

    # Ball query
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx) # (batch_size, npoint, nsample, 3)
    grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1]) # translation normalization

    # Concatenate point features if provided
    if points is not None:
        grouped_points = index_points(points, idx) # (batch_size, npoint, nsample, channel)
        if use_xyz:
            new_points = tf.concat([grouped_xyz, grouped_points], axis=-1) # (batch_size, npoint, nample, 3+channel)
        else:
            new_points = grouped_points
    else:
      new_points = grouped_xyz

    return new_xyz, new_points, idx, grouped_xyz

def sample_and_group_all(xyz, points, use_xyz=True):
    '''
    Inputs:
        xyz: (batch_size, ndataset, 3) TF tensor
        points: (batch_size, ndataset, channel) TF tensor, if None will just use xyz as points
        use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
    Outputs:
        new_xyz: (batch_size, 1, 3) as (0,0,0)
        new_points: (batch_size, 1, ndataset, 3+channel) TF tensor
    Note:
        Equivalent to sample_and_group with npoint=1, radius=inf, use (0,0,0) as the centroid
    '''
    batch_size, nsample, _ = xyz.shape
    new_xyz = tf.constant(np.tile(np.array([0,0,0]).reshape((1,1,3)), (batch_size,1,1)),dtype=tf.float32) # (batch_size, 1, 3)
    idx = tf.constant(np.tile(np.array(range(nsample)).reshape((1,1,nsample)), (batch_size,1,1)))
    grouped_xyz = tf.reshape(xyz, (batch_size, 1, nsample, 3)) # (batch_size, npoint=1, nsample, 3)
    if points is not None:
        if use_xyz:
            new_points = tf.concat([xyz, points], axis=2) # (batch_size, 16, 259)
        else:
            new_points = points
        new_points = tf.expand_dims(new_points, 1) # (batch_size, 1, 16, 259)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points, idx, grouped_xyz

In [5]:
def pointnet_sa_module(xyz, points, npoint, radius, nsample, mlp, group_all, pooling='max', use_xyz=True, use_nchw=False):
    ''' PointNet Set Abstraction (SA) Module
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- #points sampled in farthest point sampling
            radius: float32 -- search radius in local region
            nsample: int32 -- how many points in each local region
            mlp: list of int32 -- output size for MLP on each point
            mlp2: list of int32 -- output size for MLP on each region
            group_all: bool -- group all points into one PC if set true, OVERRIDE
                npoint, radius and nsample settings
            use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
            use_nchw: bool, if True, use NCHW data format for conv2d, which is usually faster than NHWC format
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor
            idx: (batch_size, npoint, nsample) int32 -- indices for local regions
    '''

    # Sample and Grouping
    if group_all:
        _, nsample, _ = tf.shape(xyz)
        new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, use_xyz)
    else:
        new_xyz, new_points, idx, grouped_xyz = sample_and_group(npoint, radius, nsample, xyz, points, use_xyz)

    # Point Feature Embedding
    if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])
    for i, num_out_channel in enumerate(mlp):
       new_points = tf.keras.layers.Conv2D(num_out_channel, [1,1], padding='valid', activation='relu')(new_points)

    if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])

    # Pooling in Local Regions
    if pooling=='max':
        new_points = tf.reduce_max(new_points, axis=[2], keepdims=True, name='maxpool')
    '''
    elif pooling=='avg':
        new_points = tf.reduce_mean(new_points, axis=[2], keepdims=True, name='avgpool')
    elif pooling=='weighted_avg':
        #with tf.variable_scope('weighted_avg'):
        dists = tf.norm(grouped_xyz,axis=-1,ord=2,keepdims=True)
        exp_dists = tf.exp(-dists * 5)
        weights = exp_dists/tf.reduce_sum(exp_dists,axis=2,keepdims=True) # (batch_size, npoint, nsample, 1)
        new_points *= weights # (batch_size, npoint, nsample, mlp[-1])
        new_points = tf.reduce_sum(new_points, axis=2, keepdims=True)
    elif pooling=='max_and_avg':
        max_points = tf.reduce_max(new_points, axis=[2], keepdims=True, name='maxpool')
        avg_points = tf.reduce_mean(new_points, axis=[2], keepdims=True, name='avgpool')
        new_points = tf.concat([avg_points, max_points], axis=-1)
    '''
    new_points = tf.squeeze(new_points, [2]) # (batch_size, npoints, mlp2[-1])
    return new_xyz, new_points, idx

In [6]:
class SampleAndGroup(tf.keras.layers.Layer):
    def __init__(self, npoint, radius, nsample, use_xyz=True, **kwargs):
      super(SampleAndGroup, self).__init__(**kwargs)
      self.npoint = npoint
      self.radius = radius
      self.nsample = nsample
      self.use_xyz = use_xyz

    def call(self, xyz, points):
      return sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, self.use_xyz)

class SampleAndGroupAll(tf.keras.layers.Layer):
    def __init__(self, use_xyz=True, **kwargs):
      super(SampleAndGroupAll, self).__init__(**kwargs)
      self.use_xyz = use_xyz

    def call(self, xyz, points):
      return sample_and_group_all(xyz, points, self.use_xyz)

In [19]:
class PointNetSAModule(tf.keras.layers.Layer):
    def __init__(self, npoint, radius, nsample, mlp, group_all, pooling='max', use_xyz=True, use_nchw=False, **kwargs):
        super(PointNetSAModule, self).__init__(**kwargs)
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp = mlp
        self.group_all = group_all
        self.pooling = pooling
        self.use_xyz = use_xyz
        self.use_nchw = use_nchw

        # Create Conv2D layers for the MLP in advance
        self.conv_layers = [
            tf.keras.layers.Conv2D(num_out_channel, [1, 1], padding='valid', activation='relu')
            for num_out_channel in self.mlp
        ]

    def call(self, xyz, points):
        # Sample and Grouping
        _, nsample, _ = xyz.shape
        if self.group_all:
            new_xyz, new_points, idx, grouped_xyz = SampleAndGroupAll(self.use_xyz)(xyz, points)
        else:
            new_xyz, new_points, idx, grouped_xyz = SampleAndGroup(self.npoint, self.radius, nsample, self.use_xyz)(xyz, points)

        # Point Feature Embedding
        if self.use_nchw:
            new_points = tf.transpose(new_points, [0, 3, 1, 2])

        for conv in self.conv_layers:  # Apply the pre-created Conv2D layers
            new_points = conv(new_points)

        if self.use_nchw:
            new_points = tf.transpose(new_points, [0, 2, 3, 1])

        # Pooling in Local Regions
        if self.pooling == 'max':
            new_points = tf.reduce_max(new_points, axis=[2], keepdims=True, name='maxpool')

        new_points = tf.squeeze(new_points, [2])  # (batch_size, npoints, mlp[-1])
        return new_xyz, new_points, idx

    def get_config(self):
        config = super(PointNetSAModule, self).get_config()
        config.update({
            'npoint': self.npoint,  # Assuming these are attributes of your PointNetSAModule
            'radius': self.radius,
            'nsample': self.nsample,
            'mlp': self.mlp,
            'group_all': self.group_all,
            'pooling': self.pooling,
            'use_xyz': self.use_xyz,
            'use_nchw': self.use_nchw,
        })
        return config

In [20]:
def PointNet2(n_outs, batch_size):
    """ Classification PointNet, input is BxNx3, output BxN_classes """
    inputs = tf.keras.Input(shape=(1024, 3), batch_size=batch_size)

    l0_xyz = inputs
    l0_points = None

    # Set abstraction layers
    l1_xyz, l1_points, l1_indices = PointNetSAModule(npoint=512, radius=0.2, nsample=32, mlp=[64,64,128], group_all=False, use_nchw=True)(l0_xyz, l0_points)
    l2_xyz, l2_points, l2_indices = PointNetSAModule(npoint=128, radius=0.4, nsample=64, mlp=[128,128,256], group_all=False)(l1_xyz, l1_points)
    l3_xyz, l3_points, l3_indices = PointNetSAModule(npoint=None, radius=None, nsample=None, mlp=[256,512,1024], group_all=True)(l2_xyz, l2_points)

    # Fully connected layers
    net = tf.reshape(l3_points, [batch_size, -1])

    dense1 = tf.keras.layers.Dense(512)(net)
    dense1 = tf.keras.layers.BatchNormalization(momentum=0.0)(dense1)
    dense1 = tf.keras.layers.Activation("relu")(dense1)

    drop = tf.keras.layers.Dropout(0.3)(dense1)

    dense2 = tf.keras.layers.Dense(256)(drop)
    dense2 = tf.keras.layers.BatchNormalization(momentum=0.0)(dense2)
    dense2 = tf.keras.layers.Activation("relu")(dense2)

    drop = tf.keras.layers.Dropout(0.3)(dense2)

    outputs = tf.keras.layers.Dense(n_outs, activation=None)(drop)

    return tf.keras.Model(inputs=inputs, outputs=outputs, name="PointNet2")

In [21]:
def sqrd_euclidean_distance_loss(y_true, y_pred):
    """
    Squared Euclidean distance loss
    https://en.wikipedia.org/wiki/Euclidean_distance
    :param y_true: TensorFlow/Theano tensor
    :param y_pred: TensorFlow/Theano tensor of the same shape as y_true
    :return: float
    """
    return K.sum(K.square(y_pred - y_true), axis=-1)

model = PointNet2(6, 64)

# Compiling
model.compile(
  loss=sqrd_euclidean_distance_loss,
  optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
  metrics=[sqrd_euclidean_distance_loss],
)

model.summary()

Model: "PointNet2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(8, 1024, 3)]       0           []                               
                                                                                                  
 point_net_sa_module_3 (PointNe  ((8, 512, 3),       78080       ['input_2[0][0]']                
 tSAModule)                      (8, 512, 3),                                                     
                                 (8, 512, 1024))                                                  
                                                                                                  
 point_net_sa_module_4 (PointNe  ((8, 128, 3),       50432       ['point_net_sa_module_3[0][0]',  
 tSAModule)                      (8, 128, 256),                   'point_net_sa_module_3[0

#Training example

'points.npy' is a np.array with size (B, 1024, 3)
  containing the point clouds

'targets.npy' is a np.array with size (B, 1, 6)
  containing the target rotations

In [22]:
### Loading example files ###
array_file = open('points.npy', 'rb')
points = np.load(array_file, allow_pickle=True)
points = points.astype(np.float32)

array_file = open('targets.npy', 'rb')
targets = np.load(array_file, allow_pickle=True)
targets = targets.astype(np.float32)
targets = [np.concatenate((np.array(T[:3,:3][:,0]), np.array(T[:3,:3][:,1]))) for T in targets]

In [23]:
from sklearn.model_selection import train_test_split

# BATCH SIZE
batch_size = 64


### Splitting data ###
train_points, valid_points, train_targets, valid_targets = train_test_split(points, targets, train_size=0.9, shuffle=True)

def augment(points, label):
    # jitter points
    points += tf.random.uniform(points.shape, -0.00005, 0.00005, dtype=tf.float32)
    # shuffle points
    points = tf.random.shuffle(points)
    return points, label


train_dataset = tf.data.Dataset.from_tensor_slices((train_points, train_targets))
valid_dataset = tf.data.Dataset.from_tensor_slices((valid_points, valid_targets))

train_dataset = train_dataset.shuffle(len(train_points)).map(augment).batch(batch_size, drop_remainder=True)
valid_dataset = valid_dataset.shuffle(len(valid_points)).map(augment).batch(batch_size, drop_remainder=True)

print('Dataset splitted!')
print(f'train: {len(train_dataset)}\nvalidation: {len(valid_dataset)}')

points = 0
targets = 0
train_points = 0
valid_points = 0
train_targets = 0
valid_targets = 0

checkpoint = tf.keras.callbacks.ModelCheckpoint(f'best_model.hdf5', monitor='val_loss', verbose=1,
    save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')

Dataset splitted!
train: 11
validation: 1


In [None]:
model.fit(train_dataset, epochs=1500, validation_data=valid_dataset, callbacks=[checkpoint], verbose=0)