# Import

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
cd /content/drive/MyDrive/PointNet

/content/drive/MyDrive/PointNet


In [4]:
!pip install plyfile
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting plyfile
  Downloading plyfile-0.7.4-py3-none-any.whl (39 kB)
Installing collected packages: plyfile
Successfully installed plyfile-0.7.4
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [5]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from plyfile import PlyData
import numpy as np
import einops

BATCHSIZE = 32
NPOINTS = 2500
EPOCH = 5
WORKERS = 2
DATAPATH = "./data/ModelNet40"
device = torch.device('cuda') if torch.cuda.is_available() is True else torch.device('cpu')

# Convert

In [None]:
import os
import glob
import multiprocessing
from multiprocessing import Pool
from pathlib import Path
from plyfile import PlyData
from tqdm import tqdm, trange
import time
import pdb
import gc

def preprocess_off(in_file):

    with open(in_file,"rt") as f:
        # some OFF files in original dataset had OFF345 345 344 where 
        # OFF collided with the number. Needs \n
        lines = f.readlines()
    if lines[0] != 'OFF\n':
        with open(in_file,"wt") as f:
            lines[0] = lines[0][0:3] + '\n' + lines[0][3:]
            lines = "".join(lines)
            f.write(lines)
            f.close()
    else:
        f.close()

def offFormat_to_plyFormat(C_file):

    with open(C_file,"rt") as Cf:
        lines = Cf.readlines()
    with open(C_file,"wt") as Cf:
        num_points = lines[1].split()[0]
        num_faces = lines[1].split()[1]
        lines[0] = 'ply\n'
        lines[1] = 'format ascii 1.0\n'+\
                    'element vertex %s'%num_points+'\n'+\
                    'property float x\n'+\
                    'property float y\n'+\
                    'property float z\n'+\
                    'element face %s'%num_faces+'\n'+\
                    'property list uchar int vertex_index\n'+\
                    'end_header\n'
        lines = "".join(lines)
        Cf.write(lines)
        Cf.close()


def full_off_to_ply(p, l, chunksize, num_cpu=None):

    print('searching path', p, 'to convert .off files to .ply')

    if (num_cpu != None):
        start = time.time()
        with Pool(processes = num_cpu) as po0:
            for i in po0.imap_unordered(preprocess_off, l, chunksize=chunksize):
                pass
        with Pool(processes = num_cpu) as po1:
            for i in po1.imap_unordered(offFormat_to_plyFormat, l, chunksize=chunksize):
                pass
        end = time.time()
        print("Muti core full_off_to_ply use %3f sec!" %(end - start))

    else:
        start = time.time()
        for f in l:
            preprocess_off(f)
        for f in l:
            offFormat_to_plyFormat(f)
        end = time.time()
        print("Single core full_off_to_ply use %3f sec!" %(end - start))




def process_txt():
    # Convert test.txt, train.txt, trainval.txt, val.txt data .off to .ply.
    txtNames = glob.glob(r"*.txt")

    for txtName in txtNames:
        with open(txtName,"rt") as file:
            x = file.read()
        with open(txtName, "wt") as file:
            x = x.replace(".off",".ply")
            file.write(x)

def suffix(l):
    # Convert all .off suffix to .ply.
    for i in l:
        tmp = i.with_suffix('.ply')
        os.rename(i, tmp)


def sub_check(f):

    plydata = PlyData.read(f)
    del plydata
    gc.collect()
    

def check(p, chunksize, num_cpu=None):
    # Use PlyData.read(f) check .ply, if ply file broken, sub_check will crush.
    ply = list(p.glob('**/*.ply'))
    progress = tqdm(total=len(ply))

    if (num_cpu != None):
        start = time.time()
        with Pool(processes = num_cpu) as po:
            for i in po.imap_unordered(sub_check, ply, chunksize=chunksize):
                progress.update(1)
        end = time.time()
        print("Muti core check use %3f sec!" %(end - start))
    else:
        start = time.time()
        for f in ply:
            sub_check(f)
            progress.update(1)
        end = time.time()
        print("Single core check use %3f sec!" %(end - start))
num_cpu = multiprocessing.cpu_count()
chunksize = 100
currPATH = os.getcwd().replace('\\','/')
p = Path(currPATH)
l = list(p.glob('**/*.off'))
if isinstance(num_cpu, int):
    print("Now use %d Threads!" %num_cpu)
else:
    num_cpu = None
    print("Now use single Thread!")
full_off_to_ply(p, l, chunksize, num_cpu)
process_txt()
suffix(l)
check(p, chunksize, num_cpu)
print('ur great!')
print('All .off format has converted!')

In [None]:
!python train_classification.py --dataset ./data/ModelNet40 --nepoch=1 --dataset_type modelnet40

# Dataset

In [None]:
class ModelNetDataset(Dataset):
  def __init__(self, dataPath, n_points=2500, mode='train', augmentation=True):
    with open('./misc/modelnet_id.txt') as f:
      ids = [n.strip().split() for n in f.readlines()]
    category_to_id = {}
    for cat, i in ids:
      category_to_id[cat] = int(i)
    with open(os.path.join(dataPath, f'{mode}.txt'), 'r') as f:
      fileNames = [n.strip() for n in f.readlines()]
    self.files = [(os.path.join(dataPath, n), category_to_id[n.split('/')[0]]) for n in fileNames] # (path, category)
    self.n_points = n_points
    self.augmentation = augmentation
    self.category = len(category_to_id.keys())
    self.device = device

  def __len__(self):
    return len(self.files)

  def __getitem__(self, index):
    path, category = self.files[index] # eg. glass_box/train/glass_box_0090.ply
    points = self.read_points(path) # ex. (11634, 3)
    points = self.sampling(points, self.n_points) # (n_points, 3)
    points = self.standardization(points)
    if self.augmentation is True:
      points = self.augment_points(points)
    return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(np.array([category]).astype(np.int64))

  def read_points(self, path):
    with open(path, 'rb') as f:
        plydata = PlyData.read(f)
    p_x, p_y, p_z = plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']
    points = np.vstack((p_x, p_y, p_z)).T # ex. (11634, 3)
    return points
  
  def sampling(self, points, n_points):
    choice = np.random.choice(len(points), n_points, replace=True)
    return points[choice, :]
  
  def standardization(self, points):
    epsilon = 1e-7
    return (points - np.mean(points, axis=0)) / (np.std(points, axis=0) + epsilon)
  
  def augment_points(self, points):
    self.rotation(points[:, [0, 2]])
    self.jitter(points)
    return points
  
  def rotation(self, matrix):
    theta = np.random.uniform(0, np.pi * 2)
    rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    matrix[:] = matrix.dot(rotation_matrix)
  
  def jitter(self, points):
    points[:] = points + np.random.normal(0, 0.02, points.shape)

In [None]:
dataset = ModelNetDataset(DATAPATH, mode='train')
dataloader = DataLoader(
  dataset,
  batch_size=BATCHSIZE,
  shuffle=True,
  num_workers=WORKERS
)

# PointNet Model

In [10]:
class STN(nn.Module):
  """
  Spatial Transformer Network
  """
  def __init__(self, dim):
    super(STN, self).__init__()
    self.dim = dim
    self.conv1 = nn.Conv1d(dim, 64, 1)
    self.conv2 = nn.Conv1d(64, 128, 1)
    self.conv3 = nn.Conv1d(128, 1024, 1)
    self.fc1 = nn.Linear(1024, 512)
    self.fc2 = nn.Linear(512, 256)
    self.fc3 = nn.Linear(256, dim*dim)
    self.relu = nn.ReLU()
    self.bn1 = nn.BatchNorm1d(64)
    self.bn2 = nn.BatchNorm1d(128)
    self.bn3 = nn.BatchNorm1d(1024)
    self.bn4 = nn.BatchNorm1d(512)
    self.bn5 = nn.BatchNorm1d(256)

  def forward(self, x): # (batch, dim, n_points)
    batch, dim, n_points = x.shape # ex. batch=32, n_points=2500
    x = self.relu(self.bn1(self.conv1(x))) # (32, 64, 2500)
    x = self.relu(self.bn2(self.conv2(x))) # (32, 128, 2500)
    x = self.relu(self.bn3(self.conv3(x))) # (32, 1024, 2500)
    x = einops.reduce(x, 'B D N -> B D', 'max') # pointwise pooling (32, 1024)
    x = self.relu(self.bn4(self.fc1(x))) # (32, 512)
    x = self.relu(self.bn5(self.fc2(x))) # (32, 256)
    x = self.fc3(x) # (32, dim*dim)
    iden = torch.eye(dim).view(1, -1)
    iden = iden.cuda() if x.is_cuda else iden
    x = x + iden
    x = x.view(-1, dim, dim) # (32, dim, dim)
    return x


class PointNetFeature(nn.Module):
  def __init__(self, in_channel=3, channels=[64, 128, 1024], stn=True): 
    super(PointNetFeature, self).__init__()
    self.stn = stn
    if stn is True:
      self.stn = STN(in_channel)
      self.fstn = STN(channels[0])
    self.convs = nn.ModuleList()
    self.bns = nn.ModuleList()
    last_channel = in_channel
    for c in channels:
      self.convs.append(nn.Conv1d(last_channel, c, 1))
      self.bns.append(nn.BatchNorm1d(c))
      last_channel = c
    self.relu = nn.ReLU()
  
  def forward(self, x): # (batch, C, n_points)
    """
    Input:
        x : (B, C, N) if keypoint_wise is False else (B, C, N, S)
    Return:
        global_feature : (B, D3)
        local_feature : (B, D1, N)
        last_feature : (B, D3, N)
        trans_mat : (B, D1, D1)
    """
    if self.stn is True:
      x = einops.einsum(x, self.stn(x), 'B C N, B C C2 -> B C2 N') # (batch, n_points, 3)
      x = self.relu(self.bns[0](self.convs[0](x))) # (batch, 64, n_points)
      trans_mat = self.fstn(x) # (32, 64, 64)
      x = einops.einsum(x, trans_mat, 'B C N, B C C2 -> B C2 N') 
      local_feature = x
      for conv, bn in zip(self.convs[1:], self.bns[1:]):
        x = self.relu(self.bn(self.conv(x)))
    else:
      for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
        x = self.relu(bn(conv(x)))
        if i == 0:
          local_feature = x
      trans_mat = None
    last_feature = x
    global_feature = einops.reduce(x, 'B C N -> B C', 'max')
    return global_feature, local_feature, last_feature, trans_mat # (b, 1024), (b, 64, n), (b, 64, 64), trans_mat is returned together for regularization

class PointNetCls(nn.Module):
  def __init__(self, category=2):
    super(PointNetCls, self).__init__()
    self.feat = PointNetFeature()
    self.linear1 = nn.Linear(1024, 512)
    self.linear2 = nn.Linear(512, 256)
    self.linear3 = nn.Linear(256, category)
    self.bn1 = nn.BatchNorm1d(512)
    self.bn2 = nn.BatchNorm1d(256)
    self.dropout = nn.Dropout(p=0.3)
    self.relu = nn.ReLU()
  
  def forward(self, x): # (batch, n_points, 3)
    batch, n_points, _ = x.shape
    x = einops.rearrange(x, 'B N C -> B C N')
    global_feature, local_feature, _, trans_mat = self.feat(x) # (b, 1024), (b, 64, n), (b, 64, 64)
    x = self.bn1(self.linear1(global_feature))
    x = self.bn2(self.dropout(self.linear2(x)))
    x = self.linear3(x)
    return x, trans_mat # (b, category), (b, 64, 64)

class PointNetSeg(nn.Module):
  def __init__(self, category=2):
    super(PointNetSeg, self).__init__()
    self.feat = PointNetFeature()
    self.conv1 = torch.nn.Conv1d(1024+64, 512, 1)
    self.conv2 = torch.nn.Conv1d(512, 256, 1)
    self.conv3 = torch.nn.Conv1d(256, 128, 1)
    self.conv4 = torch.nn.Conv1d(128, category, 1)
    self.bn1 = nn.BatchNorm1d(512)
    self.bn2 = nn.BatchNorm1d(256)
    self.bn3 = nn.BatchNorm1d(128)
    self.relu = nn.ReLU()
  
  def forward(self, x): # (batch, n_points, 3)
    batch, n_points, _ = x.shape
    x = einops.rearrange(x, 'B N C -> B C N')
    global_feature, local_feature, _, trans_mat = self.feat(x) # (b, 1024), (b, 64, n), (b, 64, 64)
    concat_feature = torch.cat((local_feature, global_feature.unsqueeze(-1).repeat(1, 1, n_points)), dim=1) # (b, 1088, n)
    x = self.bn1(self.conv1(concat_feature))
    x = self.bn2(self.conv2(x))
    x = self.bn3(self.conv3(x))
    x = self.conv4(x) # (b, category, n)
    return x, trans_mat

class OrthogonalRegLoss(nn.Module):
  def __init__(self, alpha=1e-4):
    super(OrthogonalRegLoss, self).__init__()
    self.alpha = alpha
  
  def forward(self, mat):
    batch, dim, _ = mat.shape
    iden = torch.eye(dim).unsqueeze(0)
    iden = iden.cuda() if mat.is_cuda else iden
    return einops.reduce(torch.norm(einops.einsum(mat, mat, 'B i j, B k j -> B i k') - iden, dim=(1,2)), 'B -> ', 'mean') * self.alpha

# PointNet Training

In [None]:
classifier = PointNetCls(category=dataset.category).to(device)
classifier.train()
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
loss_fn = nn.CrossEntropyLoss()
orthloss = OrthogonalRegLoss(alpha=1e-4)
batch_num = len(dataloader)
for epoch in range(EPOCH):
  losses, accuracy = [], []
  for i, (points, category) in enumerate(dataloader):
    points, category = points.to(device), category.to(device)
    target = category.flatten()
    optimizer.zero_grad()
    pred, trans_mat = classifier(points)
    loss = loss_fn(pred, target) + orthloss(trans_mat) 
    loss.backward()
    losses.append(loss.item())
    optimizer.step()
    scheduler.step()
    pred_choice = pred.argmax(dim=1)
    correct = (pred_choice == target).sum()
    accuracy.append(correct.item() / BATCHSIZE)
    if i > 0 and ((i+1) % 50 == 0 or i+1 == batch_num):
      print('[Avg Epoch %d: %d/%d] train avg loss: %f accuracy: %f' % (epoch, i+1, batch_num, sum(losses)/len(losses), sum(accuracy)/len(accuracy)))
  torch.save(classifier.state_dict(), "pointnet_modelnet.pt")

# Utils

In [12]:
class Utils:
  @staticmethod
  def distance(src_points, tgt_points, channel_first=False):
    """
    src_points : (B, N, C)
    tgt_points : (B, M, C)
    return : (B, N, M)
    """
    if channel_first is True:
      src_points = einops.rearrange(src_points, 'B C N -> B N C')
      tgt_points = einops.rearrange(tgt_points, 'B C N -> B N C')
    B, N, _ = src_points.shape
    _, M, _ = tgt_points.shape
    return (((src_points.repeat_interleave(M, dim=1) - tgt_points.repeat(1, N, 1)) ** 2).sum(dim=-1)).reshape(B, N, M)

  @staticmethod
  def points_by_idx(points, idx, channel_first=False):
    """
    points : (B, N, C)
    idx : (B, S)
    return : (B, S, C)
    """
    if channel_first is True:
      dim = list(points.shape) # (B, C, ##)
      dim = [dim[0]] + dim[2:] + [dim[1]] # (B, ##, C)
      points = points.reshape(dim)
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :] # (B, ##, C)
    if channel_first is True:
      dim = list(new_points.shape)
      dim = [dim[0]] + [dim[-1]] + dim[1:-1]
      new_points = new_points.reshape(dim)
    return new_points


  @staticmethod
  def farthest_point_sampling(points, n_sample, channel_first=False):
    """
    points : (B, N, C)
    n_sample : S
    return : (B, S, C)
    """
    if channel_first is True:
      points = einops.rearrange(points, 'B C N -> B N C')
    B, N, _ = points.shape
    S = n_sample
    device = points.device
    batchidx = torch.arange(B).to(device)
    farthest_idx = torch.zeros(B, S, dtype=torch.long).to(device) # (B, S)
    for i in range(1, S):
      dist = ((points.unsqueeze(1).repeat(1, S, 1, 1) - points[batchidx.view(-1, 1).expand(-1, S), farthest_idx].unsqueeze(2).repeat_interleave(N, dim=2))**2).sum(dim=-1)
      idx = dist.min(dim=1)[0].argmax(dim=-1)
      farthest_idx[batchidx, torch.tensor([i]).repeat(B)] = idx
    sampled_points = Utils.points_by_idx(points, farthest_idx)
    if channel_first is True:
      sampled_points = einops.rearrange(sampled_points, 'B S C -> B C S')
    return sampled_points, farthest_idx

  @staticmethod
  def query_ball(radius, n_sample, points, centroids, channel_first=False):
    """
    n_sample : S
    points : (B, N, C)
    centroids : (B, M, C)
    return : (B, M, S, C)
    """
    if channel_first is True:
      points = einops.rearrange(points, 'B C N -> B N C')
      centroids = einops.rearrange(centroids, 'B C N -> B N C')
    device = points.device
    B, N, _ = points.shape
    _, M, _ = centroids.shape
    point_idx = torch.arange(N).to(device).view(1, 1, N).repeat(B, M, 1) # (B, M, N) : Indices of points of each group 
    distanceMat = Utils.distance(centroids, points) # (B, M, N)
    point_idx[distanceMat > radius ** 2] = N # Set indices of points outside the group to N(invalid)
    point_idx, _ = point_idx.sort(dim=-1)
    point_idx = point_idx[:, :, :n_sample] # collect N samples (B, M, n_sample)
    invalid_mask = point_idx == N
    pad_points = einops.repeat(point_idx[:, :, 0], 'B M -> B M repeat', repeat=n_sample) # use first point for each group to pad invalid points(N)
    point_idx[invalid_mask] = pad_points[invalid_mask]
    if channel_first is True:
      return Utils.points_by_idx(points, point_idx.reshape(B, -1)).reshape(B, -1, M, n_sample), point_idx # (B, C, M*S)
    else:
      return Utils.points_by_idx(points, point_idx.reshape(B, -1)).reshape(B, M, n_sample, -1), point_idx # (B, C, M*S)
  
  @staticmethod
  def sample_and_group(n_keypoint, radius, n_sample, point_xyz, point_features=None, concat=True, channel_first=False):
      """
      * channel dim is set to second dim if channel_first is True (e.g (B, N, 3) -> (B, 3, N))
      Input:
          point_xyz: input points position data, [B, N, 3]
          point_features: input points data, [B, N, C] 
      Return:
          keypoint_xyz : [B, n_keypoint, 3]
          keypoint_idx : [B, n_keypoint]
          grouped_idx : [B, n_keypoint, n_sample]
          grouped_xyz: sampled points position data, [B, n_keypoint, n_sample, 3]
          final_features: [B, n_keypoint, n_sample, D] if concat is False else [B, n_keypoint, n_sample, 3+C]
      """
      if channel_first is True:
        point_xyz = einops.rearrange(point_xyz, 'B C N -> B N C')
        if point_features is not None:
          point_features = einops.rearrange(point_features, 'B C N -> B N C')
      B, N, C = point_xyz.shape
      S = n_keypoint
      keypoints_idx = None
      keypoint_xyz, keypoints_idx = Utils.farthest_point_sampling(point_xyz, n_keypoint) # [B, n_keypoint, 3]
      grouped_xyz, grouped_idx = Utils.query_ball(radius, n_sample, point_xyz, keypoint_xyz) # [B, n_keypoint, n_sample, 3], [B, n_keypoint, n_sample]
      grouped_xyz_norm = grouped_xyz - keypoint_xyz.view(B, S, 1, C)
      if point_features is None:
        final_features = grouped_xyz_norm
      else:
        grouped_features = Utils.points_by_idx(point_features, grouped_idx.reshape(B, -1)).reshape(B, n_keypoint, n_sample, -1) # [B, npoint, n_sample, D]
        if concat is True:
          final_features = torch.cat([grouped_xyz_norm, grouped_features], dim=-1)
        else:
          final_features = grouped_features     
      if channel_first is True:
        keypoint_xyz = einops.rearrange(keypoint_xyz, 'B N C -> B C N')
        grouped_xyz_norm = einops.rearrange(grouped_xyz_norm, 'B N S C -> B C N S')
        final_features = einops.rearrange(final_features, 'B N S C -> B C N S')
      return keypoint_xyz, keypoints_idx, grouped_idx, grouped_xyz_norm, final_features

  @staticmethod
  def sample_and_group_all(point_xyz, point_features=None, concat=True, channel_first=False):
      """
      * channel dim is set to second dim if channel_first is True (e.g (B, N, 3) -> (B, 3, N))
      Input:
          point_xyz: input points position data, [B, N, 3] 
          point_features: input points data, [B, N, C] 
      Return:
          grouped_xyz: sampled points position data, [B, n_keypoint(1), n_sample(N), 3]
          final_features: [B, n_keypoint(1), n_sample(N), C] if concat is False else [B, n_keypoint(1), n_sample(N), 3+C]
      """
      if channel_first is True:
        point_xyz = einops.rearrange(point_xyz, 'B C N -> B N C')
        if point_features is not None:
          point_features = einops.rearrange(point_features, 'B C N -> B N C')
      B, N, C = point_xyz.shape
      grouped_xyz = point_xyz.unsqueeze(1)
      if point_features is None:
        final_features = grouped_xyz
      else:
        grouped_features = point_features.unsqueeze(1)
        if concat is True:
          final_features = torch.cat([grouped_xyz, grouped_features], dim=-1)
        else:
          final_features = grouped_xyz
      if channel_first is True:
        grouped_xyz = einops.rearrange(grouped_xyz, 'B N S C -> B C N S')
        final_features = einops.rearrange(final_features, 'B N S C -> B C N S')
        return None, None, None, grouped_xyz, final_features

# Set Abstraction & Propagation

In [13]:
class SetAbstraction(nn.Module):
    def __init__(self, n_keypoint, radius, n_sample, in_channel, channels, group_all, no_feature=False):
        super(SetAbstraction, self).__init__()
        self.n_keypoint = n_keypoint
        self.radius = radius
        self.n_sample = n_sample
        self.in_channel = in_channel
        self.channels = channels
        self.group_all = group_all
        self.concat = no_feature is False
        self.relu = nn.ReLU()
        if no_feature is False:
          xyz_dim = 3
          self.feat = PointNetFeature(in_channel=in_channel+xyz_dim, channels=channels)
        else:
          self.feat = PointNetFeature(in_channel=in_channel, channels=channels)
      
    def forward(self, point_xyz, point_features):
        """
        Input:
            N' : Number of points
            point_xyz: input points position data, [B, C, N']
            point_features: input points data, [B, D, N']
        Return:
            N : Number of keypoints
            keypoint_xyz: keypoints position data, [B, C, N']
            keypoint_features: keypoints feature data, [B, D', N']
        """
        if self.group_all is True:
            # None,               [B, 3+D, n_keypoint(1), n_sample(N)]
            keypoint_xyz, keypoint_idx, _, _, grouped_features = Utils.sample_and_group_all(point_xyz, point_features=point_features, concat=self.concat, channel_first=True)
        else:
            # [B, 3, n_keypoint], [B, 3+D, n_keypoint, n_sample]
            keypoint_xyz, keypoint_idx, _, _, grouped_features = Utils.sample_and_group(self.n_keypoint, self.radius, self.n_sample, point_xyz, point_features=point_features, concat=self.concat, channel_first=True)
        B, D, N, S = grouped_features.shape
        grouped_features = einops.rearrange(grouped_features, "B D N S -> B D (N S)")
        _, _, grouped_features, _ = self.feat(grouped_features)
        grouped_features = einops.rearrange(grouped_features, "B D (N S) -> B D N S", N=N, S=S)
        keypoint_features = einops.reduce(grouped_features, "B D N S -> B D N", 'max')
        return keypoint_xyz, keypoint_idx, keypoint_features

class SetMSGAbstraction(nn.Module):
    def __init__(self, n_keypoint, radius_list, n_sample_list, in_channel, channels_list, no_feature=False):
      super(SetMSGAbstraction, self).__init__()
      self.n_keypoint = n_keypoint
      self.radius_list = radius_list
      self.n_sample_list = n_sample_list
      self.in_channel = in_channel
      self.channels_list = channels_list
      self.sa_list = nn.ModuleList()
      self.concat = no_feature is False
      for radius, n_sample, channels in zip(radius_list, n_sample_list, channels_list):
        self.sa_list.append(SetAbstraction(n_keypoint, radius, n_sample, in_channel, channels, False, no_feature=no_feature))

    def forward(self, point_xyz, point_features):
        """
        Input:
            N' : Number of points
            point_xyz: input points position data, [B, C, N']
            point_features: input points data, [B, D, N']
        Return:
            N : Number of keypoints
            M : Number of scales
            keypoint_xyz: keypoints position data, [B, C, N']
            keypoint_features: keypoints feature data, [B, M*D', N']
        """
        feature_list = []
        for i, sa in enumerate(self.sa_list):
          keypoint_xyz, keypoint_idx, keypoint_feature = sa(point_xyz, point_features)
          feature_list.append(keypoint_feature)
        return keypoint_xyz, keypoint_idx, torch.cat(feature_list, dim=1)

class FeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(FeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel
        self.relu = nn.ReLU()

    def forward(self, xyz1, xyz2, features1, features2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            features1: input points data, [B, D1, N]
            features2: input points data, [B, D2, S]
        Return:
            upsampled_points: upsampled points data, [B, D, N]
        """
        B, C, N = xyz1.shape

        if xyz2 is None or xyz2.shape[-1] == 1:
            interpolated_points = features2.repeat(1, 1, N)
        else:
            dists = Utils.distance(xyz1, xyz2, channel_first=True) # [B, N, S]
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]
            dist_recip = 1.0 / (dists + 1e-8)
            norm = einops.reduce(dist_recip, 'B N T -> B N 1', 'sum')
            weight = einops.rearrange(dist_recip / norm, 'B N T -> B 1 N T') # T : 3
            sampled_points = Utils.points_by_idx(features2, idx, channel_first=True) # [B, D, N, T]
            interpolated_points = einops.reduce(sampled_points * weight, 'B D N T -> B D N', 'sum')
        if features1 is not None:
            upsampled_points = torch.cat([features1, interpolated_points], dim=1)
        else:
            upsampled_points = interpolated_points
        for i, (bn, conv) in enumerate(zip(self.mlp_bns, self.mlp_convs)):
            upsampled_points = self.relu(bn(conv(upsampled_points)))
        return upsampled_points

# PointNet++(Classification + SSG)

In [None]:
class PointNet2Cls(nn.Module):
  def __init__(self, category=2, channel_first=False):
    super(PointNet2Cls, self).__init__()
    self.sa_layers = nn.ModuleList([
        SetAbstraction(512, 0.2, 32, 3, [64, 64, 128], False, no_feature=True),
        SetAbstraction(128, 0.4, 32, 128, [128, 128, 256], False),
        SetAbstraction(None, None, None, 256, [256, 512, 1024], True)
    ])
    self.ffn = nn.Sequential(
        nn.Linear(1024, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, category)
    )
    self.channel_first = channel_first

  def forward(self, xyz): # (batch, n_points, 3)
    if self.channel_first is False:
      xyz = einops.rearrange(xyz, 'B N C -> B C N')
    B, _, _ = xyz.shape
    features = None
    for sa in self.sa_layers:
      xyz, _, features = sa(xyz, features)
    x = einops.rearrange(features, 'B C 1 -> B C')
    x = self.ffn(x)
    return x

# PointNet++(Classification + MSG)

In [None]:
class PointNet2ClsMSG(nn.Module):
  def __init__(self, category=2, channel_first=False):
    super(PointNet2ClsMSG, self).__init__()
    self.sa_layers = nn.ModuleList([
        SetMSGAbstraction(512, [0.1, 0.2, 0.4], [16, 32, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]], no_feature=True),
        SetMSGAbstraction(128, [0.2, 0.4, 0.8], [32, 64, 128], 64+128+128,[[64, 64, 128], [128, 128, 256], [128, 128, 256]]),
        SetAbstraction(None, None, None, 128+256+256, [256, 512, 1024], True)
    ])
    self.ffn = nn.Sequential(
        nn.Linear(1024, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, category)
    )
    self.channel_first = channel_first

  def forward(self, xyz): # (batch, n_points, 3)
    if self.channel_first is False:
      xyz = einops.rearrange(xyz, 'B N C -> B C N')
    B, _, _ = xyz.shape
    features = None
    for sa in self.sa_layers:
      xyz, _, features = sa(xyz, features)
    x = einops.rearrange(features, 'B C 1 -> B C')
    x = self.ffn(x)
    return x

# PointNet++(Classification + MRG)

In [None]:
class PointNet2ClsMRG(nn.Module):
  def __init__(self, category=2, channel_first=False):
    super(PointNet2ClsMRG, self).__init__()
    self.branch1 = nn.ModuleList([
        SetAbstraction(512, 0.2, 32, 3, [64, 64, 128], False, no_feature=True),
        SetAbstraction(64, 0.4, 32, 128, [128, 128, 256], False)
    ])
    self.branch2 = nn.ModuleList([
        SetAbstraction(512, 0.4, 32, 3, [64, 128, 256], False, no_feature=True),
    ])
    self.branch3 = nn.ModuleList([
        SetAbstraction(None, None, None, 3, [64, 128, 256, 512], True, no_feature=True),
    ])
    self.branch4 = nn.ModuleList([
        SetAbstraction(None, None, None, 256+256, [256, 512, 1024], True, no_feature=False),
    ])
    self.ffn = nn.Sequential(
        nn.Linear(1024 + 512, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, category)
    )
    self.channel_first = channel_first

  def forward(self, xyz): # (batch, n_points, 3)
    if self.channel_first is False:
      xyz = einops.rearrange(xyz, 'B N C -> B C N')
    B, _, _ = xyz.shape

    features1 = None
    xyz1 = xyz
    for sa in self.branch1:
      xyz1, idx1, features1 = sa(xyz1, features1) # (B, 3, 64), (B, 64), (B, C, 64)

    features2 = None
    xyz2 = xyz
    for sa in self.branch2:
      xyz2, _, features2 = sa(xyz2, features2) # (B, 3, 512), (B, C, 512)
    features2 = einops.rearrange(features2, "B C N -> B N C")
    _, S = idx1.shape
    batchidx = einops.rearrange(torch.arange(B), "B -> B 1")
    features2 = einops.rearrange(features2[batchidx, idx1], "B N C -> B C N") # (B, C, 64)

    features3 = None
    xyz3 = xyz
    for sa in self.branch3:
      _, _, features3 = sa(xyz3, features3) # (B, C, 1)

    features4 = torch.cat([features1, features2], dim=1)
    xyz4 = xyz1
    for sa in self.branch4:
      _, _, features4 = sa(xyz4, features4) # (B, C, 1)
    
    final_features = torch.cat([features3, features4], dim=1)

    x = einops.rearrange(final_features, 'B C 1 -> B C')
    x = self.ffn(x)
    return x

# PointNet++ (Part segmentation + SSG)

In [None]:
class PointNet2PartSeg(nn.Module):
  def __init__(self, cls_cat=16, cls_seg=50, channel_first=False):
    super(PointNet2PartSeg, self).__init__()
    self.channel_first = channel_first
    self.sa1 = SetAbstraction(512, 0.2, 32, 3, [64, 64, 128], False, no_feature=True)
    self.sa2 = SetAbstraction(128, 0.4, 32, 128, [128, 128, 256], False)
    self.sa3 = SetAbstraction(None, None, None, 256, [256, 512, 1024], True)
    xyz_dim = 3
    self.fp1 = FeaturePropagation(1024+256, [256, 256])
    self.fp2 = FeaturePropagation(256+128, [256, 128])
    self.fp3 = FeaturePropagation(128+cls_cat+xyz_dim, [128, 128, 128])
    self.ffn = nn.Sequential(
        nn.Dropout(0.5),
        nn.Conv1d(128, 128, 1),
        nn.BatchNorm1d(128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Conv1d(128, cls_seg, 1)
    )
  
  def forward(self, xyz, cls_label): # cls_label is one-hot (B, cls_cat)
    if self.channel_first is False:
      xyz = einops.rearrange(xyz, 'B N C -> B C N')
    B, C, N = xyz.shape
    xyz0 = xyz
    features0 = None
    xyz1, _, features1 = self.sa1(xyz0, features0)
    xyz2, _, features2 = self.sa2(xyz1, features1)
    xyz3, _, features3 = self.sa3(xyz2, features2)
    features2_upsampled = self.fp1(xyz2, xyz3, features2, features3)
    features1_upsampled = self.fp2(xyz1, xyz2, features1, features2_upsampled)
    cls_label = einops.repeat(cls_label, 'B cat -> B cat N', N=N)
    features0_upsampled = self.fp3(xyz0, xyz1, torch.cat([cls_label, xyz0], dim=1), features1_upsampled)
    x = self.ffn(features0_upsampled) # (B, C, N)
    x = einops.rearrange(x, 'B C N -> B N C')
    return x

# PointNet++ (Part segmentation + MSG)

In [None]:
class PointNet2PartSegMSG(nn.Module):
  def __init__(self, cls_cat=16, cls_seg=50, channel_first=False):
    super(PointNet2PartSegMSG, self).__init__()
    self.channel_first = channel_first
    self.sa1 = SetMSGAbstraction(512, [0.1, 0.2, 0.4], [16, 32, 128], 3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]], no_feature=True)
    self.sa2 = SetMSGAbstraction(128, [0.2, 0.4, 0.8], [32, 64, 128], 64+128+128, [[64, 64, 128], [128, 128, 256], [128, 128, 256]])
    self.sa3 = SetAbstraction(None, None, None, 128+256+256, [256, 512, 1024], True)
    xyz_dim = 3
    self.fp1 = FeaturePropagation((128+256+256)+1024, [256, 256])
    self.fp2 = FeaturePropagation((64+128+128)+256, [256, 128])
    self.fp3 = FeaturePropagation((cls_cat+xyz_dim)+128, [128, 128, 128])
    self.ffn = nn.Sequential(
        nn.Dropout(0.5),
        nn.Conv1d(128, 128, 1),
        nn.BatchNorm1d(128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Conv1d(128, cls_seg, 1)
    )
  
  def forward(self, xyz, cls_label): # cls_label is one-hot (B, cls_cat)
    if self.channel_first is False:
      xyz = einops.rearrange(xyz, 'B N C -> B C N')
    B, C, N = xyz.shape
    xyz0 = xyz
    features0 = None
    xyz1, _, features1 = self.sa1(xyz0, features0)
    xyz2, _, features2 = self.sa2(xyz1, features1)
    xyz3, _, features3 = self.sa3(xyz2, features2)
    features2_upsampled = self.fp1(xyz2, xyz3, features2, features3)
    features1_upsampled = self.fp2(xyz1, xyz2, features1, features2_upsampled)
    cls_label = einops.repeat(cls_label, 'B cat -> B cat N', N=N)
    features0_upsampled = self.fp3(xyz0, xyz1, torch.cat([cls_label, xyz0], dim=1), features1_upsampled)
    x = self.ffn(features0_upsampled) # (B, C, N)
    x = einops.rearrange(x, 'B C N -> B N C')
    return x

# PointNet++ (Semantic Segmentation + SSG)

In [30]:
class PointNet2SemSeg(nn.Module):
  def __init__(self, cls_seg=50, channel_first=False):
    super(PointNet2SemSeg, self).__init__()
    self.channel_first = channel_first
    self.sa1 = SetAbstraction(1024, 0.1, 32, 3, [32, 32, 64], False, no_feature=True)
    self.sa2 = SetAbstraction(256, 0.2, 32, 64, [64, 64, 128], False)
    self.sa3 = SetAbstraction(64, 0.4, 32, 128, [128, 128, 256], False)
    self.sa4 = SetAbstraction(16, 0.8, 32, 256, [256, 256, 512], False)
    xyz_dim = 3
    self.fp1 = FeaturePropagation(512+256, [256, 256])
    self.fp2 = FeaturePropagation(256+128, [256, 256])
    self.fp3 = FeaturePropagation(256+64, [256, 128])
    self.fp4 = FeaturePropagation(128, [128, 128, 128])
    self.ffn = nn.Sequential(
        nn.Dropout(0.5),
        nn.Conv1d(128, 128, 1),
        nn.BatchNorm1d(128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Conv1d(128, cls_seg, 1)
    )
  
  def forward(self, xyz):
    if self.channel_first is False:
      xyz = einops.rearrange(xyz, 'B N C -> B C N')
    B, C, N = xyz.shape
    xyz0 = xyz
    features0 = None
    xyz1, _, features1 = self.sa1(xyz0, features0)
    xyz2, _, features2 = self.sa2(xyz1, features1)
    xyz3, _, features3 = self.sa3(xyz2, features2)
    xyz4, _, features4 = self.sa4(xyz3, features3)
    features3_upsampled = self.fp1(xyz3, xyz4, features3, features4)
    features2_upsampled = self.fp2(xyz2, xyz3, features2, features3_upsampled)
    features1_upsampled = self.fp3(xyz1, xyz2, features1, features2_upsampled)
    features0_upsampled = self.fp4(xyz0, xyz1, features0, features1_upsampled)
    x = self.ffn(features0_upsampled) # (B, C, N)
    x = einops.rearrange(x, 'B C N -> B N C')
    return x

# PointNet++ (Semantic Segmentation + MSG)

In [31]:
class PointNet2SemSegMSG(nn.Module):
  def __init__(self, cls_seg=50, channel_first=False):
    super(PointNet2SemSegMSG, self).__init__()
    self.channel_first = channel_first
    self.sa1 = SetMSGAbstraction(1024, [0.05, 0.1], [16, 32], 3, [[16, 16, 32], [32, 32, 64]], no_feature=True)
    self.sa2 = SetMSGAbstraction(256, [0.1, 0.2], [16, 32], 32+64, [[64, 64, 128], [64, 96, 128]])
    self.sa3 = SetMSGAbstraction(64, [0.2, 0.4], [16, 32], 128+128, [[128, 196, 256], [128, 196, 256]])
    self.sa4 = SetMSGAbstraction(16, [0.4, 0.8], [16, 32], 256+256, [[256, 256, 512], [256, 384, 512]])
    xyz_dim = 3
    self.fp1 = FeaturePropagation(512+512+256+256, [256, 256])
    self.fp2 = FeaturePropagation(128+128+256, [256, 256])
    self.fp3 = FeaturePropagation(32+64+256, [256, 128])
    self.fp4 = FeaturePropagation(128, [128, 128, 128])
    self.ffn = nn.Sequential(
        nn.Dropout(0.5),
        nn.Conv1d(128, 128, 1),
        nn.BatchNorm1d(128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Conv1d(128, cls_seg, 1)
    )
  
  def forward(self, xyz):
    if self.channel_first is False:
      xyz = einops.rearrange(xyz, 'B N C -> B C N')
    B, C, N = xyz.shape
    xyz0 = xyz
    features0 = None
    xyz1, _, features1 = self.sa1(xyz0, features0)
    xyz2, _, features2 = self.sa2(xyz1, features1)
    xyz3, _, features3 = self.sa3(xyz2, features2)
    xyz4, _, features4 = self.sa4(xyz3, features3)
    features3_upsampled = self.fp1(xyz3, xyz4, features3, features4)
    features2_upsampled = self.fp2(xyz2, xyz3, features2, features3_upsampled)
    features1_upsampled = self.fp3(xyz1, xyz2, features1, features2_upsampled)
    features0_upsampled = self.fp4(xyz0, xyz1, features0, features1_upsampled)
    x = self.ffn(features0_upsampled) # (B, C, N)
    x = einops.rearrange(x, 'B C N -> B N C')
    return x

# PointNet++ training

In [None]:
cls_cat = 16
cls_seg = 50
classifier = PointNet2SemSeg(cls_seg=cls_seg)
B = 2
N = 1500
xyz = torch.rand(B, N, 3)
cls_label = torch.rand(B, cls_cat)
result = classifier(xyz) # (batch, N, seg_cat)
print(result.shape)

In [None]:
cls_cat = 16
cls_seg = 50
classifier = PointNet2SemSegMSG(cls_seg=cls_seg)
B = 2
N = 2500
xyz = torch.rand(B, N, 3)
cls_label = torch.rand(B, cls_cat)
result = classifier(xyz) # (batch, N, seg_cat)
print(result.shape)

In [None]:
classifier = PointNet2Cls(category=dataset.category).to(device)
classifier.train()
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
loss_fn = nn.CrossEntropyLoss()
orthloss = OrthogonalRegLoss(alpha=1e-4)
batch_num = len(dataloader)
for epoch in range(EPOCH):
  losses, accuracy = [], []
  for i, (points, category) in enumerate(dataloader):
    points, category = points.to(device), category.to(device)
    target = category.flatten()
    optimizer.zero_grad()
    pred = classifier(points)
    loss = loss_fn(pred, target)
    loss.backward()
    losses.append(loss.item())
    optimizer.step()
    scheduler.step()
    pred_choice = pred.argmax(dim=1)
    correct = (pred_choice == target).sum()
    accuracy.append(correct.item() / BATCHSIZE)
    if i > 0 and ((i+1) % 50 == 0 or i+1 == batch_num):
      print('[Avg Epoch %d: %d/%d] train avg loss: %f accuracy: %f' % (epoch, i+1, batch_num, sum(losses)/len(losses), sum(accuracy)/len(accuracy)))
  torch.save(classifier.state_dict(), "pointnet_modelnet.pt")