In [70]:
# requires torch, torch_geometric, open3d, plotly
# open3d needs python 3.10, anything higher will not work

import torch
from torch_geometric.datasets import ModelNet, FAUST
from torch_geometric.transforms import SamplePoints, NormalizeScale
from torch_geometric.loader import DataLoader

import open3d as o3d
import plotly.graph_objects as go

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datetime
import os
import random

In [71]:
# pointnet fully implemented, consolidated to another file for import
from pointnet import PointNetClassifier, PointNetClassificationLoss

In [72]:
# change to your device, also change device in pointnet.py file
device = "mps"

In [79]:
# modelnet10 dataset config

num_points = 1024

pre_transform = NormalizeScale()
transform = SamplePoints(num_points)

batch_size = 64

root = 'data/ModelNet10'
dataset_train = ModelNet(root=root, name='10', train=True, pre_transform=pre_transform, transform=transform)
trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = ModelNet(root=root, name='10', train=False, pre_transform=pre_transform, transform=transform)
testloader = DataLoader(dataset_test, batch_size=batch_size)

print(f'Number of training examples: {len(dataset_train)}')
print(f'Number of test examples: {len(dataset_test)}')

classes = dataset_test.raw_file_names
print(classes)

data = dataset_train[0]
print(data)
print(f'Point cloud shape: {data.pos.shape}')
print(f'Label: {data.y}')

Number of training examples: 3991
Number of test examples: 908
['bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet']
Data(pos=[1024, 3], y=[1])
Point cloud shape: torch.Size([1024, 3])
Label: tensor([0])


In [78]:
# modelnet40 dataset config

num_points = 1024

pre_transform = NormalizeScale()
transform = SamplePoints(num_points)

batch_size = 64

root = 'data/ModelNet40'
dataset_train = ModelNet(root=root, name='40', train=True, pre_transform=pre_transform, transform=transform)
trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = ModelNet(root=root, name='40', train=False, pre_transform=pre_transform, transform=transform)
testloader = DataLoader(dataset_test, batch_size=batch_size)

classes = ["airplane", "bathtub", "bed", "bench", "bookshelf", "bottle", "bowl", "car", "chair", "cone", "cup", "curtain", "desk", "door", "dresser", "flower_pot", "glass_box", "guitar", "keyboard", "lamp", "laptop", "mantel", "monitor", "night_stand", "person", "piano", "plant", "radio", "range_hood", "sink", "sofa", "stairs", "stool", "table", "tent", "toilet", "tv_stand", "vase", "wardrobe", "xbox"]

print(f'Number of training examples: {len(dataset_train)}')
print(f'Number of test examples: {len(dataset_test)}')

data = dataset_train[0]
print(data)
print(f'Point cloud shape: {data.pos.shape}')
print(f'Label: {data.y}')

Number of training examples: 9843
Number of test examples: 2468
Data(pos=[1024, 3], y=[1])
Point cloud shape: torch.Size([1024, 3])
Label: tensor([0])


In [80]:
# plot the first element of the training data
fig = go.Figure(
  data=[
    go.Scatter3d(
      x=data.pos[:,0], y=data.pos[:,1], z=data.pos[:,2],
      mode='markers',
      marker=dict(size=1, color="white"))],
  layout=dict(
    scene=dict(
      xaxis=dict(visible=False),
      yaxis=dict(visible=False),
      zaxis=dict(visible=False))))

fig.update_layout(template='plotly_dark')

fig.show()

In [81]:
# dodgy but somehow working implementation of point transfomer architecture
# https://arxiv.org/pdf/2012.09164

class TransformerLayer(nn.Module):
    def __init__(self, features):
        super(TransformerLayer, self).__init__()

        self.phi = nn.Linear(in_features=features, out_features=features)
        self.psi = nn.Linear(in_features=features, out_features=features)
        self.alpha = nn.Linear(in_features=features, out_features=features)

        self.theta = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=features, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(in_channels=features, out_channels=features, kernel_size=1))
    
    def forward(self, x, p):

        # require the calculation of positional encoding still
        #x += self.theta(p.transpose(2, 1)).transpose(2, 1)

        phi_x = self.phi(x)
        psi_x = self.psi(x)
        alpha_x = self.alpha(x)

        # originally was going to implement vector attention, switched to scalar
        # not sure if this is working properly though
        x = torch.bmm(F.softmax(torch.bmm(phi_x, psi_x.transpose(2,1)), dim=1), alpha_x)
        
        # also, normalization isn't working well? not sure why...
        #x = F.layer_norm(x, normalized_shape=[x.shape[1], x.shape[2]])

        return x

class PointTransformer(nn.Module):
    def __init__(self, channels):
        super(PointTransformer, self).__init__()

        self.fcl1 = nn.Linear(in_features=channels, out_features=channels)
        self.transformer = TransformerLayer(channels).to(device)
        self.fcl2 = nn.Linear(in_features=channels, out_features=channels)

    def forward(self, x, p):

        residual = x
        
        x = self.fcl1(x)
        x = self.transformer(x, p)
        x = self.fcl2(x)

        return x + residual, p

class TransitionDown(nn.Module):
    def __init__(self, features):
        super(TransitionDown, self).__init__()

        self.mlp = nn.Sequential(
            nn.Conv1d(in_channels=features, out_channels=features*2, kernel_size=1),
            nn.BatchNorm1d(features*2),
            nn.ReLU())
    
    def forward(self, x, p):

        # random sampling, not the best but gets the job done
        idx = torch.randint(0, x.shape[1], (int(x.shape[1]/4),))
        
        # not sure if this is correct either haha
        x = self.mlp(x[:,idx,:].transpose(2, 1)).transpose(2, 1)
        p = p[:,idx,:]

        return x, p

class PointTransformerBackbone(nn.Module):
    def __init__(self):
        super(PointTransformerBackbone, self).__init__()

        self.init_linear = nn.Linear(in_features=3, out_features=32)
        
        self.transformer1 = PointTransformer(channels=32).to(device)
        self.down1 = TransitionDown(features=32).to(device)
        self.transformer2 = PointTransformer(channels=64).to(device)
        self.down2 = TransitionDown(features=64).to(device)
        self.transformer3 = PointTransformer(channels=128).to(device)
        self.down3 = TransitionDown(features=128).to(device)
        self.transformer4 = PointTransformer(channels=256).to(device)
        self.down4 = TransitionDown(features=256).to(device)
        self.transformer5 = PointTransformer(channels=512).to(device)
    
    def forward(self, x):

        bs = x.shape[0]
        num_points = x.shape[1]

        p = x
        x = self.init_linear(x)

        x, p = self.transformer1(x, p)
        x, p = self.down1(x, p)
        x, p = self.transformer2(x, p)
        x, p = self.down2(x, p)
        x, p = self.transformer3(x, p)
        x, p = self.down3(x, p)
        x, p = self.transformer4(x, p)
        x, p = self.down4(x, p)
        x, p = self.transformer5(x, p)

        x = x.transpose(2, 1)

        x = F.max_pool1d(x, kernel_size=int(num_points/256)).view(bs, -1)

        return x

class PointTransformerClassifier(nn.Module):
    def __init__(self, num_classes):
        super(PointTransformerClassifier, self).__init__()
        
        self.backbone = PointTransformerBackbone().to(device)

        self.classification_head = nn.Sequential(
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=num_classes))

    def forward(self, x):

        x = self.backbone(x)
        x = self.classification_head(x)

        return x

In [85]:
# create a new pointtransformer
pointtransformer = PointTransformerClassifier(num_classes=10)
pointtransformer.to(device)

PointTransformerClassifier(
  (backbone): PointTransformerBackbone(
    (init_linear): Linear(in_features=3, out_features=32, bias=True)
    (transformer1): PointTransformer(
      (fcl1): Linear(in_features=32, out_features=32, bias=True)
      (transformer): TransformerLayer(
        (phi): Linear(in_features=32, out_features=32, bias=True)
        (psi): Linear(in_features=32, out_features=32, bias=True)
        (alpha): Linear(in_features=32, out_features=32, bias=True)
        (theta): Sequential(
          (0): Conv1d(3, 32, kernel_size=(1,), stride=(1,))
          (1): ReLU()
          (2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
        )
      )
      (fcl2): Linear(in_features=32, out_features=32, bias=True)
    )
    (down1): TransitionDown(
      (mlp): Sequential(
        (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (transformer2): 

In [86]:
# quick sanity check
test_data = torch.rand(batch_size, num_points, 3).to(device)

x = pointtransformer(test_data)
print(x.shape)

torch.Size([64, 10])


In [87]:
# create a new pointnet
pointnet = PointNetClassifier(num_classes=10)
pointnet.to(device)

PointNetClassifier(
  (backbone): PointNetBackbone(
    (tnet1): Transformer(
      (mlp): Sequential(
        (0): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
        (1): ReLU()
        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
        (4): ReLU()
        (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
        (7): ReLU()
        (8): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (fcl): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): ReLU()
        (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Linear(in_features=512, out_features=256, bias=True)
        (4): ReLU()
        (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_r

In [88]:
# quick sanity check
test_data = torch.rand(10, 3, num_points).to(device)

output, A = pointnet(test_data)
print(A.shape)
print(output.shape)

torch.Size([10, 64, 64])
torch.Size([10, 10])


In [89]:
# training hyperparameters
num_epochs = 50
learning_rate = 0.00001

# reg_weight used for training pointnet
#reg_weight = 0.0001

optimizer = optim.Adam(pointtransformer.parameters(), lr=learning_rate)

# scheduler sometimes works, maybe a better one can be used
#scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.00001, max_lr=0.0001, 
#                                              step_size_up=2000, cycle_momentum=False)

criterion = nn.CrossEntropyLoss().to(device)

# special crossentropy loss defined for pointnet that incorporates reg_weight
#criterion = PointNetClassificationLoss(reg_weight=reg_weight).to(device)

In [91]:
# load a pointnet from a saved state
checkpoint_pointnet = torch.load("pointnet_modelnet10.pth", map_location=torch.device(device))
pointnet.load_state_dict(checkpoint_pointnet['model_state_dict'])

checkpoint_pointtransformer = torch.load("pointtransformer_modelnet10.pth", map_location=torch.device(device))
pointtransformer.load_state_dict(checkpoint_pointtransformer['model_state_dict'])

<All keys matched successfully>

In [None]:
# training loop, currently set up for training pointtransformer
#directory = "./pointtransformer_modelnet10"
#os.makedirs(directory, exist_ok=True)

for epoch in range(num_epochs):

    accuracy = 0
    loss_avg = 0
    count = 0

    pointtransformer.train()
    for data in trainloader:

        clouds = data.pos.view(data.batch[-1]+1, num_points, 3).to(device)
        # pointnet requires transposed data, pointtransformer does not
        #clouds = clouds.transpose(2, 1).to(device)

        labels = data.y.to(device)

        optimizer.zero_grad()

        outputs = pointtransformer(clouds)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        #scheduler.step()

        loss_avg += loss.item()
        count += 1
    
    loss_avg = loss_avg/count
    
    pointtransformer.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            
            clouds = data.pos.view(data.batch[-1]+1, num_points, 3).to(device)
            # pointnet requires transposed data, pointtransformer does not
            #clouds = clouds.transpose(2,1).to(device)

            labels = data.y.to(device)
            
            outputs = pointtransformer(clouds)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct/total
    
    print("{}   [Epoch {:3}]  Loss: {:8.4}  Accuracy:   {:8.4}%".format(datetime.datetime.now(), epoch, loss_avg, 100*accuracy))

'''
    torch.save(
        {'model_state_dict': pointnet.state_dict()},
        directory + "/epoch_" + str(epoch) + ".pth")
'''

In [93]:
pointnet.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for data in testloader:
        
        clouds = data.pos.view(data.batch[-1]+1, num_points, 3).to(device)
        clouds = clouds.transpose(2, 1).to(device)

        labels = data.y.to(device)
            
        outputs, _ = pointnet(clouds)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    accuracy = correct/total

    print("{}   Accuracy:   {:8.4}%".format(datetime.datetime.now(), 100*accuracy))

2025-03-11 01:55:22.046050   Accuracy:      90.86%


In [96]:
idx = random.randint(0, len(dataset_test))

data = dataset_test[idx]

cloud = data.pos.view(1, num_points, 3).to(device)
# pointnet requires transposed data, pointtransformer does not
cloud = cloud.transpose(2, 1).to(device)

output, _ = pointnet(cloud)

probabilities = 100*F.softmax(output.transpose(1,0), dim=0)

_, predicted = torch.max(output.data, 1)
label = data.y

print('Predicted Class: {}    Certainty: {:8.4}   Actual Class:   {}'.format(classes[predicted.item()], probabilities[predicted.item()].item(), classes[label.item()]))

fig = go.Figure(
  data=[
    go.Scatter3d(
      x=data.pos[:,0], y=data.pos[:,1], z=data.pos[:,2],
      mode='markers',
      marker=dict(size=1, color="white"))],
  layout=dict(
    scene=dict(
      xaxis=dict(visible=False),
      yaxis=dict(visible=False),
      zaxis=dict(visible=False))))

fig.update_layout(template='plotly_dark')

fig.show()

Predicted Class: table    Certainty:    71.37   Actual Class:   desk
