In [701]:
import torch
from torch_geometric.datasets import ModelNet
from torch_geometric.transforms import SamplePoints, NormalizeScale
from torch_geometric.loader import DataLoader

import open3d as o3d
import plotly.graph_objects as go

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datetime

In [710]:
device = "mps"

In [703]:
num_points = 1024

pre_transform = NormalizeScale()
transform = SamplePoints(num_points)

batch_size = 64

root = 'data/ModelNet10'
dataset_train = ModelNet(root=root, name='10', train=True, pre_transform=pre_transform, transform=transform)
trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = ModelNet(root=root, name='10', train=False, pre_transform=pre_transform, transform=transform)
test_loader = DataLoader(dataset_test, batch_size=batch_size)

print(f'Number of training examples: {len(dataset_train)}')
print(f'Number of test examples: {len(dataset_test)}')

data = dataset_train[1001]
print(data)
print(f'Point cloud shape: {data.pos.shape}')
print(f'Label: {data.y}')

Number of training examples: 3991
Number of test examples: 908
Data(pos=[1024, 3], y=[1])
Point cloud shape: torch.Size([1024, 3])
Label: tensor([2])


In [664]:
fig = go.Figure(
  data=[
    go.Scatter3d(
      x=data.pos[:,0], y=data.pos[:,1], z=data.pos[:,2],
      mode='markers',
      marker=dict(size=1, color="white"))],
  layout=dict(
    scene=dict(
      xaxis=dict(visible=False),
      yaxis=dict(visible=False),
      zaxis=dict(visible=False))))

fig.update_layout(template='plotly_dark')

fig.show()

In [791]:
class Transformer(nn.Module):
    def __init__(self, num_points, features):
        super(Transformer, self).__init__()

        self.features = features

        self.mlp1 = nn.Sequential(
            nn.Conv1d(in_channels=features, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64))

        self.mlp2 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128))

        self.mlp3 = nn.Sequential(
            nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(1024))

        self.max_pool = nn.MaxPool1d(kernel_size=num_points)

        self.ll1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512))

        self.ll2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256))
        
        self.ll3 = nn.Linear(256, features*features)
    
    def forward(self, x):

        bs = x.shape[0]

        x = self.mlp1(x)
        x = self.mlp2(x)
        x = self.mlp3(x)

        x = self.max_pool(x).view(bs, -1)

        x = self.ll1(x)
        x = self.ll2(x)

        x = self.ll3(x)

        eye = torch.eye(self.features, requires_grad=True).repeat(bs, 1, 1).to(device)

        x = x.view(-1, self.features, self.features) + eye

        return x

class PointNet(nn.Module):
    def __init__(self, num_points, num_classes):
        super(PointNet, self).__init__()

        self.tnet1 = Transformer(num_points=num_points, features=3).to(device)

        self.mlp1 = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64))
    
        self.tnet2 = Transformer(num_points=num_points, features=64).to(device)

        self.mlp2 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(1024))

        self.max_pool = nn.MaxPool1d(kernel_size=num_points)

        self.classification = nn.Sequential(
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=0.3),
            nn.Linear(in_features=256, out_features=num_classes)
        )
    
    def forward(self, x):

        bs = x.shape[0]

        input_transform = self.tnet1(x)
        
        x = torch.bmm(x.transpose(2, 1), input_transform).transpose(2, 1)
        x = self.mlp1(x)

        feature_transform = self.tnet2(x)
        x = torch.bmm(x.transpose(2, 1), feature_transform).transpose(2, 1)
        x = self.mlp2(x)

        x = self.max_pool(x).view(bs, -1)

        x = self.classification(x)

        return x


In [792]:
pointnet = PointNet(num_points=num_points, num_classes=10)
pointnet.to(device)

PointNet(
  (tnet1): Transformer(
    (mlp1): Sequential(
      (0): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (mlp2): Sequential(
      (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (mlp3): Sequential(
      (0): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (max_pool): MaxPool1d(kernel_size=1024, stride=1024, padding=0, dilation=1, ceil_mode=False)
    (ll1): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (ll2): Sequential(
      (0): Linear(in_features=512, out_features=

In [793]:
learning_rate = 0.001
weight_decay = 0.001
num_epochs = 10

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(pointnet.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [751]:
pointnet.train()

test_data = torch.rand(32, 3, num_points).to(device)
output = pointnet(test_data)
print(output.shape)

torch.Size([32, 10])


In [796]:
pointnet.train()

for epoch in range(num_epochs):
    for data in test_loader:

        clouds = data.pos.resize(data.batch[-1]+1, num_points, 3)
        clouds = clouds.view(-1, clouds.shape[2], clouds.shape[1]).to(device)

        labels = data.y.to(device)

        optimizer.zero_grad()

        logits = pointnet(clouds)
        loss = criterion(logits, labels)
        optimizer.step()
    
    print("{}   [Epoch {:3}]  Loss: {:8.4}".format(datetime.datetime.now(), epoch+1, loss.item()))

2025-03-03 17:24:56.265981   [Epoch   1]  Loss:    2.598
2025-03-03 17:25:23.284272   [Epoch   2]  Loss:    2.591
2025-03-03 17:25:39.982744   [Epoch   3]  Loss:    2.477
2025-03-03 17:25:53.081311   [Epoch   4]  Loss:    2.578
2025-03-03 17:26:07.392929   [Epoch   5]  Loss:    2.529
2025-03-03 17:26:37.725503   [Epoch   6]  Loss:    2.606
2025-03-03 17:28:46.702383   [Epoch   7]  Loss:    2.743
2025-03-03 17:29:13.808445   [Epoch   8]  Loss:    2.828
2025-03-03 17:30:08.473483   [Epoch   9]  Loss:    2.547
2025-03-03 17:30:32.106480   [Epoch  10]  Loss:    2.592


In [797]:
pointnet.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for data in test_loader:
        
        clouds = data.pos.resize(data.batch[-1]+1, num_points, 3)
        clouds = clouds.view(-1, clouds.shape[2], clouds.shape[1]).to(device)

        labels = data.y.to(device)
        
        logits = pointnet(clouds)

        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy:    {}%'.format(100 * correct / total))

Accuracy:    8.039647577092511%
