In [3]:
import torch
from torch_geometric.datasets import ModelNet
from torch_geometric.transforms import SamplePoints, NormalizeScale
from torch_geometric.loader import DataLoader

import open3d as o3d
import plotly.graph_objects as go

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datetime

In [4]:
device = "mps"

In [58]:
num_points = 1024

pre_transform = NormalizeScale()
transform = SamplePoints(num_points)

batch_size = 64

root = 'data/ModelNet10'
dataset_train = ModelNet(root=root, name='10', train=True, pre_transform=pre_transform, transform=transform)
trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = ModelNet(root=root, name='10', train=False, pre_transform=pre_transform, transform=transform)
test_loader = DataLoader(dataset_test, batch_size=batch_size)

print(f'Number of training examples: {len(dataset_train)}')
print(f'Number of test examples: {len(dataset_test)}')

data = dataset_train[1001]
print(data)
print(f'Point cloud shape: {data.pos.shape}')
print(f'Label: {data.y}')

Number of training examples: 3991
Number of test examples: 908
Data(pos=[1024, 3], y=[1])
Point cloud shape: torch.Size([1024, 3])
Label: tensor([2])


In [6]:
fig = go.Figure(
  data=[
    go.Scatter3d(
      x=data.pos[:,0], y=data.pos[:,1], z=data.pos[:,2],
      mode='markers',
      marker=dict(size=1, color="white"))],
  layout=dict(
    scene=dict(
      xaxis=dict(visible=False),
      yaxis=dict(visible=False),
      zaxis=dict(visible=False))))

fig.update_layout(template='plotly_dark')

fig.show()

In [59]:
class Transformer(nn.Module):
    def __init__(self, num_points, features, identity):
        super(Transformer, self).__init__()

        self.features = features
        self.identity = identity

        self.mlp = nn.Sequential(
            nn.Conv1d(in_channels=features, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(1024))

        self.max_pool = nn.MaxPool1d(kernel_size=num_points)

        self.fcl = nn.Sequential(
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(in_features=256, out_features=features*features))
    
    def forward(self, x):

        bs = x.shape[0]

        x = self.mlp(x)
        x = self.max_pool(x).view(bs, -1)
        x = self.fcl(x)
        
        x = x.view(-1, self.features, self.features)

        if self.identity:
            x += torch.eye(self.features, requires_grad=True).repeat(bs, 1, 1).to(device)

        return x


class PointNet(nn.Module):
    def __init__(self, num_points, num_classes):
        super(PointNet, self).__init__()

        self.tnet1 = Transformer(num_points=num_points, features=3, identity=False).to(device)

        self.mlp1 = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64))
    
        self.tnet2 = Transformer(num_points=num_points, features=64, identity=True).to(device)

        self.mlp2 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(1024))

        self.max_pool = nn.MaxPool1d(kernel_size=num_points)

        self.classification_head = nn.Sequential(
            nn.Linear(in_features=1024, out_features=512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(in_features=256, out_features=num_classes)
        )
    
    def forward(self, x):

        bs = x.shape[0]

        input_transform = self.tnet1(x)
        
        x = torch.bmm(x.transpose(2, 1), input_transform).transpose(2, 1)
        x = self.mlp1(x)

        feature_transform = self.tnet2(x)

        x = torch.bmm(x.transpose(2, 1), feature_transform).transpose(2, 1)
        x = self.mlp2(x)

        x = self.max_pool(x).view(bs, -1)

        x = self.classification_head(x)

        return x, feature_transform


In [85]:
pointnet = PointNet(num_points=num_points, num_classes=10)
pointnet.to(device)

# rotate object around z-axis
# apply guassian noise to each point, mean 0 std deviation 0.002

PointNet(
  (tnet1): Transformer(
    (mlp): Sequential(
      (0): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (4): ReLU()
      (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
      (7): ReLU()
      (8): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (max_pool): MaxPool1d(kernel_size=1024, stride=1024, padding=0, dilation=1, ceil_mode=False)
    (fcl): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=512, out_features=256, bias=True)
      (4): ReLU()
      (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, aff

In [61]:
pointnet.train()

test_data = torch.rand(10, 3, num_points).to(device)
A, output = pointnet(test_data)
print(A.shape)
print(output.shape)

torch.Size([10, 10])
torch.Size([10, 64, 64])


In [86]:
class PointNetClassificationLoss(nn.Module):
    def __init__(self, reg_weight):
        super(PointNetClassificationLoss, self).__init__()

        self.reg_weight = reg_weight
        self.cross_entropy = nn.CrossEntropyLoss()
    
    def forward(self, outputs, labels, A):
        bs = A.shape[0]

        loss = self.cross_entropy(outputs, labels)

        I = torch.eye(64).repeat(bs, 1, 1).view(bs, 64, 64).to(device)

        loss += self.reg_weight*torch.linalg.norm(I - torch.bmm(A, A.transpose(2, 1)))/bs

        return loss

In [87]:
learning_rate = 0.001
num_epochs = 10
momentum = 0.9
reg_weight = 0.001
weight_decay = 0.0

criterion = PointNetClassificationLoss(reg_weight=reg_weight)
optimizer = optim.AdamW(pointnet.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [88]:
for epoch in range(num_epochs):

    accuracy = 0
    loss_avg = 0
    count = 0

    pointnet.train()
    for data in test_loader:

        clouds = data.pos.view(data.batch[-1]+1, num_points, 3)
        clouds = clouds.view(-1, clouds.shape[2], clouds.shape[1]).to(device)

        labels = data.y.to(device)

        optimizer.zero_grad()

        outputs, A = pointnet(clouds)
        loss = criterion(outputs, labels, A)
        loss.backward()
        optimizer.step()

        loss_avg += loss.item()
        count += 1
    
    loss_avg = loss_avg/count
    
    pointnet.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in test_loader:
            
            clouds = data.pos.view(data.batch[-1]+1, num_points, 3)
            clouds = clouds.view(-1, clouds.shape[2], clouds.shape[1]).to(device)

            labels = data.y.to(device)
            
            outputs, _ = pointnet(clouds)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct/total
    
    print("{}   [Epoch {:3}]  Loss: {:8.4}  Accuracy:   {:8.4}%".format(datetime.datetime.now(), epoch+1, loss_avg, 100*accuracy))

2025-03-04 01:00:31.283992   [Epoch   1]  Loss:     3.07  Accuracy:      11.01%
2025-03-04 01:00:45.834007   [Epoch   2]  Loss:    2.585  Accuracy:      10.24%
2025-03-04 01:01:00.031465   [Epoch   3]  Loss:    2.478  Accuracy:       10.9%
2025-03-04 01:01:14.233417   [Epoch   4]  Loss:    2.455  Accuracy:      13.99%
2025-03-04 01:01:28.139676   [Epoch   5]  Loss:     2.42  Accuracy:      11.23%
2025-03-04 01:01:42.794101   [Epoch   6]  Loss:    2.392  Accuracy:      11.12%
2025-03-04 01:01:57.300865   [Epoch   7]  Loss:    2.366  Accuracy:      13.88%
2025-03-04 01:02:11.754239   [Epoch   8]  Loss:    2.372  Accuracy:      13.99%
2025-03-04 01:02:25.672786   [Epoch   9]  Loss:    2.385  Accuracy:      12.22%
2025-03-04 01:02:40.473759   [Epoch  10]  Loss:    2.394  Accuracy:      14.32%


In [19]:
pointnet.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for data in test_loader:
        
        clouds = data.pos.resize(data.batch[-1]+1, num_points, 3)
        clouds = clouds.view(-1, clouds.shape[2], clouds.shape[1]).to(device)

        labels = data.y.to(device)
        
        logits = pointnet(clouds)

        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy:    {}%'.format(100 * correct / total))

Accuracy:    11.45374449339207%
