In [1]:
from pathlib import Path
from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T

  return torch._C._show_config()


## dataset loading

In [2]:
dataset_dir = '../dataset/modelnet/'
pre_transform = T.Compose([
    T.SamplePoints(1024, remove_faces=True, include_normals=True),
    T.NormalizeScale()])

In [3]:
train_dataset = ModelNet(dataset_dir, train=True, pre_transform=pre_transform, pre_filter=None)
test_dataset  = ModelNet(dataset_dir, train=False, pre_transform=pre_transform, pre_filter=None)

In [4]:
from torch_geometric.loader import DataLoader
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)

## pointnet

In [5]:
from torch_geometric.nn import global_max_pool
import torch.nn as nn

In [6]:
class SymmFunction(nn.Module):
    def __init__(self):
        # batch(32) * sampled point(1024), 3] --> [batch(32), 512]
        super().__init__()
        self.shared_mlp = nn.Sequential(
        nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
        nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
        nn.Linear(128, 512))

    def forward(self, batch):
        x = self.shared_mlp(batch.pos) # [batch * sampled points, 3] --> [batch * sampled points, 512]        
        x = global_max_pool(x, batch.batch) # [batch * sampled points, 512] --> [batch, 512]
        return x

In [7]:
f = SymmFunction()

In [8]:
batch = next(iter(dataloader))

In [9]:
print(batch)
y = f(batch)
print(y.shape)

DataBatch(pos=[32768, 3], y=[32], normal=[32768, 3], batch=[32768], ptr=[33])
torch.Size([32, 512])


In [10]:
import torch

In [11]:
class InputTNet(nn.Module):
    """
    Estinamates rotation matrix in orider to normalize rotation.
    """
    def __init__(self):
        super().__init__()
        self.input_mlp = nn.Sequential(
            nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU()
        )

        self.output_mlp = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 9)
        )

    def forward(self, x, batch):
        x = self.input_mlp(x) # [batch * sampled point, sampled point]
        x = global_max_pool(x, batch) # [batch, sampled point]
        x = self.output_mlp(x) # [batch, 9]
        x = x.view(-1, 3, 3)
        id_matrix = torch.eye(3).to(x.device).view(1, 3, 3).repeat(x.shape[0], 1, 1) # [batch, 3, 3]
        x = id_matrix + x
        return x

In [12]:
torch.eye(3).shape

torch.Size([3, 3])

In [13]:
f = InputTNet()
y = f(batch.pos, batch.batch)

In [14]:
class FeatureTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_mlp = nn.Sequential(
        nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
        nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
        nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU()
        )
        self.output_mlp = nn.Sequential(
        nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(),
        nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
        nn.Linear(256, 64*64)
        )

    def forward(self, x, batch):
        x = self.input_mlp(x)
        x = global_max_pool(x, batch)
        x = self.output_mlp(x)
        x = x.view(-1, 64, 64)
        id_matrix = torch.eye(64).to(x.device).view(1, 64, 64).repeat(x.shape[0], 1, 1)
        x = id_matrix + x
        return x

In [15]:
class PointNetClassification(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_tnet = InputTNet()
        self.mlp1 = nn.Sequential(
        nn.Linear(3, 64), nn.BatchNorm1d(64), nn.ReLU(),
        nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU()
        )

        self.feature_tnet = FeatureTNet()
        self.mlp2 = nn.Sequential(
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Linear(64, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 1024), nn.BatchNorm1d(1024), nn.ReLU()
        )
        self.mlp3 = nn.Sequential(
            nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 10)
        )

    def forward(self, batch_data):
        x = batch_data.pos

        input_transform = self.input_tnet(x, batch_data.batch_data)
        transform = input_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 3, 1)).view(-1, 3)

        x = self.mlp1(x)

        feature_transform = self.feature_tnet(x, batch_data.batch)
        transform = feature_transform[batch_data.batch, :, :]
        x = torch.bmm(transform, x.view(-1, 64, 1)).view(-1, 64)

        x = self.mlp2(x)
        x = global_max_pool(x, batch_data.batch)
        x = self.mlp3(x)

        return x, input_transform, feature_transform

## tensorboard

In [16]:
from tensorboardX import SummaryWriter

In [17]:
num_epoch = 400
batch_size = 32

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = PointNetClassification()
model = model.to(device)

In [18]:
optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())
scheculer = torch.optim.lr_scheduler.StepLR(optimizer, step_size=num_epoch//4, gamma=0.5)

In [20]:
import os
log_dir = './log_modelnet10_classification'
os.makedirs(log_dir, exist_ok=True)

In [22]:
writer = SummaryWriter(log_dir=log_dir)

In [23]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [24]:
criteria = torch.nn.CrossEntropyLoss()

## learning loop

In [25]:
from tqdm import tqdm

In [None]:
for epoch in range(num_epoch):
    model = model.train()

    losses = []
    for batch_data in tqdm(train_dataloader, total=len(train_dataloader)):
        batch_data = batch_data.to(device)
        this_batch_size = batch_data.bach_detach().max()+1

        pred_y, _ = feature_transform = model(batch_data)
        true_y = batch_data.y.detach()

        class_loss = criteria(pred_y, true_y)
        accuracy = float((pred_y.argmax(dim=1) == true_y).sum()) / float(this_batch_size)

        id_matrix = torch.eye(feature_transform.shape[1]).to(feature_transform.device).view(1, 64, 64).repeat(feature_transform.shape[0], 1, 1)
        transform_norm = torch.norm(torch.bmm(feature_transform, feature_transform.transpose(1, 2)) - id_matrix, dim=(1, 2))

        # Start from L.21