In [None]:
# requires torch, torch_geometric, open3d, plotly
# open3d needs python 3.10, anything higher will not work

import torch
from torch_geometric.datasets import ModelNet
from torch_geometric.transforms import SamplePoints, NormalizeScale
from torch_geometric.loader import DataLoader

import open3d as o3d
import plotly.graph_objects as go

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datetime
import os
import random

In [None]:
# pointtransformer fully implemented, consolidated to another file for import
from pointtransformer import PointTransformerClassifier

In [None]:
# change to your device, also change device in pointnet.py file
device = "mps"

In [None]:
# modelnet10 dataset config

num_points = 1024

pre_transform = NormalizeScale()
transform = SamplePoints(num_points)

batch_size = 64

root = 'data/ModelNet10'
dataset_train = ModelNet(root=root, name='10', train=True, pre_transform=pre_transform, transform=transform)
trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = ModelNet(root=root, name='10', train=False, pre_transform=pre_transform, transform=transform)
testloader = DataLoader(dataset_test, batch_size=batch_size)

print(f'Number of training examples: {len(dataset_train)}')
print(f'Number of test examples: {len(dataset_test)}')

classes = dataset_test.raw_file_names
print(classes)

In [None]:
# modelnet40 dataset config

num_points = 1024

pre_transform = NormalizeScale()
transform = SamplePoints(num_points)

batch_size = 64

root = 'data/ModelNet40'
dataset_train = ModelNet(root=root, name='40', train=True, pre_transform=pre_transform, transform=transform)
trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = ModelNet(root=root, name='40', train=False, pre_transform=pre_transform, transform=transform)
testloader = DataLoader(dataset_test, batch_size=batch_size)

classes = ["airplane", "bathtub", "bed", "bench", "bookshelf", "bottle", "bowl", "car", "chair", "cone", "cup", "curtain", "desk", "door", "dresser", "flower_pot", "glass_box", "guitar", "keyboard", "lamp", "laptop", "mantel", "monitor", "night_stand", "person", "piano", "plant", "radio", "range_hood", "sink", "sofa", "stairs", "stool", "table", "tent", "toilet", "tv_stand", "vase", "wardrobe", "xbox"]

print(f'Number of training examples: {len(dataset_train)}')
print(f'Number of test examples: {len(dataset_test)}')
print(classes)

In [None]:
# sanity check, plot the first element of training data

data = dataset_test[0]

fig = go.Figure(
  data=[
    go.Scatter3d(
      x=data.pos[:,0], y=data.pos[:,1], z=data.pos[:,2],
      mode='markers',
      marker=dict(size=1, color="white"))],
  layout=dict(
    scene=dict(
      xaxis=dict(visible=False),
      yaxis=dict(visible=False),
      zaxis=dict(visible=False))))

fig.update_layout(template='plotly_dark')

fig.show()

In [None]:
# create a new pointtransformer
pointtransformer = PointTransformerClassifier(num_classes=10)
pointtransformer.to(device)

In [None]:
# quick sanity check
test_data = torch.rand(batch_size, num_points, 3).to(device)

output = pointtransformer(test_data)
print(output.shape)

In [None]:
# training hyperparameters
num_epochs = 50
learning_rate = 0.01
momentum = 0.9
weight_decay = 0.00001
reg_weight = 0.0001

optimizer = optim.SGD(pointtransformer.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
# load a pointtransformer from a saved state
checkpoint_pointnet = torch.load("pointtransformer_modelnet10/002.pth", map_location=torch.device(device))
pointtransformer.load_state_dict(checkpoint_pointnet['model_state_dict'])

In [None]:
# training loop

directory = "./pointtransformer_modelnet10"
os.makedirs(directory, exist_ok=True)

for epoch in range(num_epochs):

    accuracy = 0
    loss_avg = 0
    count = 0

    pointtransformer.train()
    for data in trainloader:

        clouds = data.pos.view(data.batch[-1]+1, num_points, 3).to(device)
        labels = data.y.to(device)

        optimizer.zero_grad()

        outputs = pointtransformer(clouds)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        loss_avg += loss.item()
        count += 1
    
    loss_avg = loss_avg/count
    
    pointtransformer.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            
            clouds = data.pos.view(data.batch[-1]+1, num_points, 3).to(device)
            labels = data.y.to(device)
            
            outputs = pointtransformer(clouds)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct/total
    
    print("{}   [Epoch {:3}]  Loss: {:8.4}  Accuracy:   {:8.4}%".format(datetime.datetime.now(), epoch, loss_avg, 100*accuracy))

    torch.save(
        {'model_state_dict': pointtransformer.state_dict()},
        directory + "/{:03d}".format(epoch) + ".pth")

In [None]:
# test evaluation

pointtransformer.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for data in testloader:
        
        clouds = data.pos.view(data.batch[-1]+1, num_points, 3).to(device)
        labels = data.y.to(device)
            
        outputs = pointtransformer(clouds)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    accuracy = correct/total

    print("{}   Accuracy:   {:8.4}%".format(datetime.datetime.now(), 100*accuracy))

In [None]:
# test on random point clouds

idx = random.randint(0, len(dataset_test))
data = dataset_test[idx]
cloud = data.pos.view(1, num_points, 3).to(device)

output = pointtransformer(cloud)
probabilities = 100*F.softmax(output.transpose(1,0), dim=0)

_, predicted = torch.max(output.data, 1)
label = data.y

print('Predicted Class: {}    Certainty: {:8.4}   Actual Class:   {}'.format(classes[predicted.item()], probabilities[predicted.item()].item(), classes[label.item()]))

fig = go.Figure(
  data=[
    go.Scatter3d(
      x=data.pos[:,0], y=data.pos[:,1], z=data.pos[:,2],
      mode='markers',
      marker=dict(size=1, color="white"))],
  layout=dict(
    scene=dict(
      xaxis=dict(visible=False),
      yaxis=dict(visible=False),
      zaxis=dict(visible=False))))

fig.update_layout(template='plotly_dark')

fig.show()