In [2]:
import torch
from torch_geometric.datasets import ModelNet
from torch_geometric.transforms import SamplePoints, NormalizeScale
from torch_geometric.loader import DataLoader

import open3d as o3d
import plotly.graph_objects as go

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datetime

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
device = "cuda"

In [6]:
num_points = 512

pre_transform = NormalizeScale()
transform = SamplePoints(num_points)

batch_size = 64

root = 'data/ModelNet10'
dataset_train = ModelNet(root=root, name='10', train=True, pre_transform=pre_transform, transform=transform)
trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = ModelNet(root=root, name='10', train=False, pre_transform=pre_transform, transform=transform)
testloader = DataLoader(dataset_test, batch_size=batch_size)

print(f'Number of training examples: {len(dataset_train)}')
print(f'Number of test examples: {len(dataset_test)}')

data = dataset_train[1001]
print(data)
print(f'Point cloud shape: {data.pos.shape}')
print(f'Label: {data.y}')

Number of training examples: 3991
Number of test examples: 908
Data(pos=[512, 3], y=[1])
Point cloud shape: torch.Size([512, 3])
Label: tensor([2])


In [7]:
fig = go.Figure(
  data=[
    go.Scatter3d(
      x=data.pos[:,0], y=data.pos[:,1], z=data.pos[:,2],
      mode='markers',
      marker=dict(size=1, color="white"))],
  layout=dict(
    scene=dict(
      xaxis=dict(visible=False),
      yaxis=dict(visible=False),
      zaxis=dict(visible=False))))

fig.update_layout(template='plotly_dark')

fig.show()

In [8]:
class Transformer(nn.Module):
    def __init__(self, features, identity):
        super(Transformer, self).__init__()

        self.features = features
        self.identity = identity

        self.mlp = nn.Sequential(
            nn.Conv1d(in_channels=features, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(1024))

        self.fcl = nn.Sequential(
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(in_features=256, out_features=features*features))
    
    def forward(self, x):

        bs = x.shape[0]
        num_points = x.shape[2]

        x = self.mlp(x)
        x = F.max_pool1d(x, kernel_size=num_points).view(bs, -1)
        x = self.fcl(x)
        
        x = x.view(-1, self.features, self.features)

        if self.identity:
            x += torch.eye(self.features, requires_grad=True).repeat(bs, 1, 1).to(device)

        return x


class PointNet(nn.Module):
    def __init__(self, num_classes):
        super(PointNet, self).__init__()

        self.tnet1 = Transformer(features=3, identity=False).to(device)

        self.mlp1 = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64))
    
        self.tnet2 = Transformer(features=64, identity=True).to(device)

        self.mlp2 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(1024))

        self.max_pool = nn.MaxPool1d(kernel_size=num_points)

        self.classification_head = nn.Sequential(
            nn.Linear(in_features=1024, out_features=512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(in_features=256, out_features=num_classes)
        )
    
    def forward(self, x):

        bs = x.shape[0]
        num_points = x.shape[2]

        input_transform = self.tnet1(x)
        
        x = torch.bmm(x.transpose(2, 1), input_transform).transpose(2, 1)
        x = self.mlp1(x)

        feature_transform = self.tnet2(x)

        x = torch.bmm(x.transpose(2, 1), feature_transform).transpose(2, 1)
        x = self.mlp2(x)

        x = F.max_pool1d(x, kernel_size=num_points).view(bs, -1)

        x = self.classification_head(x)

        return x, feature_transform


In [87]:
pointnet = PointNet(num_classes=10)
pointnet.to(device)

# rotate object around z-axis
# apply guassian noise to each point, mean 0 std deviation 0.002

PointNet(
  (tnet1): Transformer(
    (mlp): Sequential(
      (0): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (4): ReLU()
      (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
      (7): ReLU()
      (8): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fcl): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=512, out_features=256, bias=True)
      (4): ReLU()
      (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Linear(in_features=256, out_features=9, bias=True)

In [88]:
pointnet.train()

test_data = torch.rand(10, 3, num_points).to(device)
A, output = pointnet(test_data)
print(A.shape)
print(output.shape)

torch.Size([10, 10])
torch.Size([10, 64, 64])


In [89]:
class PointNetClassificationLoss(nn.Module):
    def __init__(self, reg_weight):
        super(PointNetClassificationLoss, self).__init__()

        self.reg_weight = reg_weight
        self.cross_entropy = nn.CrossEntropyLoss()
    
    def forward(self, outputs, labels, A):
        bs = A.shape[0]

        loss = self.cross_entropy(outputs, labels)

        I = torch.eye(64).repeat(bs, 1, 1).view(bs, 64, 64).to(device)

        loss += self.reg_weight*torch.linalg.norm(I - torch.bmm(A, A.transpose(2, 1)))/bs

        return loss

In [95]:
learning_rate = 0.0001
num_epochs = 10
momentum = 0.9
reg_weight = 0.001
weight_decay = 0.005

criterion = PointNetClassificationLoss(reg_weight=reg_weight)
optimizer = optim.AdamW(pointnet.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [96]:
for epoch in range(num_epochs):

    accuracy = 0
    loss_avg = 0
    count = 0

    pointnet.train()
    for data in trainloader:

        clouds = data.pos.view(data.batch[-1]+1, num_points, 3)
        clouds = clouds.transpose(2, 1).to(device)

        labels = data.y.to(device)

        optimizer.zero_grad()

        outputs, A = pointnet(clouds)
        loss = criterion(outputs, labels, A)
        loss.backward()
        optimizer.step()

        loss_avg += loss.item()
        count += 1
    
    loss_avg = loss_avg/count
    
    pointnet.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            
            clouds = data.pos.view(data.batch[-1]+1, num_points, 3)
            clouds = clouds.transpose(2,1).to(device)

            labels = data.y.to(device)
            
            outputs, _ = pointnet(clouds)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct/total
    
    print("{}   [Epoch {:3}]  Loss: {:8.4}  Accuracy:   {:8.4}%".format(datetime.datetime.now(), epoch+1, loss_avg, 100*accuracy))

2025-03-04 03:25:07.324285   [Epoch   1]  Loss:   0.2396  Accuracy:       84.8%


KeyboardInterrupt: 

In [97]:
pointnet.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for data in testloader:
        
        clouds = data.pos.view(data.batch[-1]+1, num_points, 3)
        clouds = clouds.transpose(2, 1).to(device)

        labels = data.y.to(device)
            
        outputs, _ = pointnet(clouds)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy:    {}%'.format(100 * correct / total))

Accuracy:    82.59911894273128%


In [5]:
pointnet(data.pos.view(1, 3, num_points))

NameError: name 'pointnet' is not defined