In [1]:
import torch
from torch_geometric.datasets import ModelNet
from torch_geometric.transforms import SamplePoints, NormalizeScale
from torch_geometric.loader import DataLoader

import open3d as o3d
import plotly.graph_objects as go

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datetime

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = "cuda"

In [3]:
num_points = 2048

pre_transform = NormalizeScale()
transform = SamplePoints(num_points)

batch_size = 64

root = 'data/ModelNet10'
dataset_train = ModelNet(root=root, name='10', train=True, pre_transform=pre_transform, transform=transform)
trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = ModelNet(root=root, name='10', train=False, pre_transform=pre_transform, transform=transform)
test_loader = DataLoader(dataset_test, batch_size=batch_size)

print(f'Number of training examples: {len(dataset_train)}')
print(f'Number of test examples: {len(dataset_test)}')

data = dataset_train[1001]
print(data)
print(f'Point cloud shape: {data.pos.shape}')
print(f'Label: {data.y}')

Number of training examples: 3991
Number of test examples: 908
Data(pos=[2048, 3], y=[1])
Point cloud shape: torch.Size([2048, 3])
Label: tensor([2])


In [4]:
fig = go.Figure(
  data=[
    go.Scatter3d(
      x=data.pos[:,0], y=data.pos[:,1], z=data.pos[:,2],
      mode='markers',
      marker=dict(size=1, color="white"))],
  layout=dict(
    scene=dict(
      xaxis=dict(visible=False),
      yaxis=dict(visible=False),
      zaxis=dict(visible=False))))

fig.update_layout(template='plotly_dark')

fig.show()

In [15]:
class Transformer(nn.Module):
    def __init__(self, num_points, features):
        super(Transformer, self).__init__()

        self.features = features

        self.mlp1 = nn.Sequential(
            nn.Conv1d(in_channels=features, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64))

        self.mlp2 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128))

        self.mlp3 = nn.Sequential(
            nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(1024))

        self.max_pool = nn.MaxPool1d(kernel_size=num_points)

        self.ll1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512))

        self.ll2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256))
        
        self.ll3 = nn.Linear(256, features*features)
    
    def forward(self, x):

        bs = x.shape[0]

        x = self.mlp1(x)
        x = self.mlp2(x)
        x = self.mlp3(x)

        x = self.max_pool(x).view(bs, -1)

        x = self.ll1(x)
        x = self.ll2(x)

        x = self.ll3(x)

        eye = torch.eye(self.features, requires_grad=True).repeat(bs, 1, 1).to(device)

        x = x.view(-1, self.features, self.features) + eye

        print(x)

        return x

class PointNet(nn.Module):
    def __init__(self, num_points, num_classes):
        super(PointNet, self).__init__()

        self.tnet1 = Transformer(num_points=num_points, features=3).to(device)

        self.mlp1 = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64))
    
        self.tnet2 = Transformer(num_points=num_points, features=64).to(device)

        self.mlp2 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(1024))

        self.max_pool = nn.MaxPool1d(kernel_size=num_points)

        self.classification = nn.Sequential(
            nn.Linear(in_features=1024, out_features=512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            #nn.Dropout(p=0.3),
            nn.Linear(in_features=256, out_features=num_classes)
        )
    
    def forward(self, x):

        bs = x.shape[0]

        input_transform = self.tnet1(x)
        
        x = torch.bmm(x.transpose(2, 1), input_transform).transpose(2, 1)
        x = self.mlp1(x)

        feature_transform = self.tnet2(x)

        x = torch.bmm(x.transpose(2, 1), feature_transform).transpose(2, 1)
        x = self.mlp2(x)

        x = self.max_pool(x).view(bs, -1)

        x = self.classification(x)

        return x


In [16]:
pointnet = PointNet(num_points=num_points, num_classes=10)
pointnet.to(device)

PointNet(
  (tnet1): Transformer(
    (mlp1): Sequential(
      (0): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (mlp2): Sequential(
      (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (mlp3): Sequential(
      (0): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (max_pool): MaxPool1d(kernel_size=2048, stride=2048, padding=0, dilation=1, ceil_mode=False)
    (ll1): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (ll2): Sequential(
      (0): Linear(in_features=512, out_features=

In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class PointNetLoss(nn.Module):
    def __init__(self, alpha=None, gamma=0, reg_weight=0, size_average=True):
        super(PointNetLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reg_weight = reg_weight
        self.size_average = size_average

        # sanitize inputs
        if isinstance(alpha,(float, int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,(list, np.ndarray)): self.alpha = torch.Tensor(alpha)

        # get Balanced Cross Entropy Loss
        self.cross_entropy_loss = nn.CrossEntropyLoss(weight=self.alpha)

    def forward(self, predictions, targets, A):

        # get batch size
        bs = predictions.size(0)

        # get Balanced Cross Entropy Loss
        ce_loss = self.cross_entropy_loss(predictions, targets)

        # reformat predictions and targets (segmentation only)
        if len(predictions.shape) > 2:
            predictions = predictions.transpose(1, 2) # (b, c, n) -> (b, n, c)
            predictions = predictions.contiguous() \
                                     .view(-1, predictions.size(2)) # (b, n, c) -> (b*n, c)

        # get predicted class probabilities for the true class
        pn = F.softmax(predictions)
        pn = pn.gather(1, targets.view(-1, 1)).view(-1)

        # get regularization term
        if self.reg_weight > 0:
            I = torch.eye(64).unsqueeze(0).repeat(A.shape[0], 1, 1) # .to(device)
            if A.is_cuda: I = I.cuda()
            reg = torch.linalg.norm(I - torch.bmm(A, A.transpose(2, 1)))
            reg = self.reg_weight*reg/bs
        else:
            reg = 0

        # compute loss (negative sign is included in ce_loss)
        loss = ((1 - pn)**self.gamma * ce_loss)
        if self.size_average: return loss.mean() + reg
        else: return loss.sum() + reg

In [20]:
learning_rate = 0.01
weight_decay = 0.001
num_epochs = 10
momentum = 0.9

criterion = PointNetLoss()
optimizer = optim.SGD(pointnet.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)

In [21]:
pointnet.train()

test_data = torch.rand(32, 3, num_points).to(device)
output = pointnet(test_data)
print(output.shape)

tensor([[[-3.2453e-01,  2.1320e-01, -8.9657e-01],
         [ 8.5364e-01,  1.1810e+00, -4.8812e-01],
         [ 6.4396e-01, -8.6291e-01,  1.8614e+00]],

        [[ 7.0053e-01,  1.7361e-01,  3.1551e-01],
         [ 5.3966e-03, -2.0258e-01,  2.2178e-01],
         [-2.3244e-01,  4.9471e-01,  1.5581e+00]],

        [[ 1.5127e+00,  1.8806e-01,  3.1942e-03],
         [-3.3128e-01,  1.8441e+00,  6.2730e-01],
         [ 5.1252e-01, -2.1749e-02,  4.8961e-01]],

        [[ 1.0923e+00,  2.2026e-01, -1.2789e-01],
         [ 4.7465e-01,  1.7396e+00, -1.6729e-01],
         [-4.9962e-01, -6.7512e-02,  9.0775e-01]],

        [[ 1.2147e+00, -2.3655e-01,  1.8408e-01],
         [ 4.5251e-01,  7.6995e-01, -4.9652e-01],
         [-9.0555e-02, -4.0941e-01,  9.7235e-01]],

        [[ 1.2454e+00, -2.6460e-01, -1.5269e-01],
         [ 2.5864e-01,  1.5695e+00,  6.5494e-01],
         [ 3.6819e-01, -7.6436e-01,  7.8043e-01]],

        [[ 1.3150e-01,  2.8598e-01,  5.5346e-01],
         [-7.0264e-01,  2.3203e-01, -7

In [22]:

for epoch in range(2*num_epochs):

    accuracy = 0
    loss_avg = 0
    count = 0

    pointnet.train()
    for data in test_loader:

        clouds = data.pos.view(data.batch[-1]+1, num_points, 3)
        clouds = clouds.view(-1, clouds.shape[2], clouds.shape[1]).to(device)

        labels = data.y.to(device)

        optimizer.zero_grad()

        logits = pointnet(clouds)
        loss = criterion(logits, labels)
        optimizer.step()

        loss_avg += loss.item()
        count += 1
    
    loss_avg = loss_avg/count
    
    pointnet.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in test_loader:
            
            clouds = data.pos.view(data.batch[-1]+1, num_points, 3)
            clouds = clouds.view(-1, clouds.shape[2], clouds.shape[1]).to(device)

            labels = data.y.to(device)
            
            logits = pointnet(clouds)

            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct/total
    
    print("{}   [Epoch {:3}]  Loss: {:8.4}  Accuracy:   {:8.4}%".format(datetime.datetime.now(), epoch+1, loss_avg, 100*accuracy))

tensor([[[ 1.6270e+00,  4.0374e-01, -5.5226e-01],
         [ 5.9350e-01,  1.2622e+00,  2.0541e-01],
         [-7.9745e-02,  2.2768e-01,  7.0685e-01]],

        [[ 1.0002e+00, -4.9960e-01, -1.3954e-01],
         [-7.6615e-02,  8.8340e-01, -9.2814e-01],
         [ 6.8637e-02,  7.5411e-02,  3.7312e-01]],

        [[ 1.1933e+00, -1.7367e-01, -6.1312e-01],
         [-1.0280e-01,  1.2919e+00, -5.2838e-02],
         [-6.3746e-03, -1.0190e-01,  3.6018e-01]],

        [[ 1.6491e+00,  2.6571e-01, -1.2288e-01],
         [-2.3176e-02,  5.1634e-01,  8.8678e-01],
         [-1.0918e+00,  4.9649e-01,  1.4230e+00]],

        [[ 1.6021e+00, -1.5966e-01, -2.2475e-01],
         [ 3.3448e-01,  1.2885e+00,  5.0692e-01],
         [-5.1645e-01, -1.5417e+00,  7.9385e-01]],

        [[ 1.5159e+00,  7.5126e-01,  6.1435e-01],
         [ 2.0660e-01,  1.3062e+00, -3.8195e-01],
         [-3.2498e-01,  2.1877e-01,  1.4871e+00]],

        [[ 7.3406e-01, -7.1309e-01,  5.5707e-01],
         [-1.3237e+00,  1.5125e+00,  1

TypeError: PointNetLoss.forward() missing 1 required positional argument: 'A'

In [19]:
pointnet.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for data in test_loader:
        
        clouds = data.pos.resize(data.batch[-1]+1, num_points, 3)
        clouds = clouds.view(-1, clouds.shape[2], clouds.shape[1]).to(device)

        labels = data.y.to(device)
        
        logits = pointnet(clouds)

        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy:    {}%'.format(100 * correct / total))

Accuracy:    11.45374449339207%
