In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data

from module.dataset import ModelNet40
from module.utils import *

import os, sys
from collections import OrderedDict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [22]:
class BatchNorm(nn.Module):
    '''
        Perform batch normalization.
        Input: A tensor of size (N, M, feature_dim), or (N, feature_dim, M) (available when feature_dim != M), 
                or (N, feature_dim)
        Output: A tensor of the same size as input.
    '''
    def __init__(self, feature_dim):
        super(BatchNorm, self).__init__()
        self.feature_dim = feature_dim
        self.batchnorm = nn.BatchNorm1d(feature_dim)
        self.permute = Permute((0, 2, 1))

    def forward(self, x):
        if (len(x.shape) == 3) and (x.shape[-1] == self.feature_dim):
            return self.permute(self.batchnorm(self.permute(x)))
        else:
            return self.batchnorm(x)

In [23]:
class Permute(nn.Module):
    def __init__(self, param):
        super(Permute, self).__init__()
        self.param = param

    def forward(self, x):
        return x.permute(self.param)

In [24]:
class MLP(nn.Module):
    def __init__(self, hidden_size, batchnorm = True, last_activation = True):
        super(MLP, self).__init__()
        q = []
        for i in range(len(hidden_size)-1):
            in_dim = hidden_size[i]
            out_dim = hidden_size[i+1]
            q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim)))
            if (i < len(hidden_size) - 2) or ((i == len(hidden_size) - 2) and (last_activation)):
                if (batchnorm):
                    q.append(("Batchnorm_%d" % i, BatchNorm(out_dim)))
                q.append(("ReLU_%d" % i, nn.ReLU(inplace=True)))
        self.mlp = nn.Sequential(OrderedDict(q))

    def forward(self, x):
        return self.mlp(x)

In [25]:
class MaxPooling(nn.Module):
    def __init__(self):
        super(MaxPooling, self).__init__()

    def forward(self, x, dim=1, keepdim = False):
        res, _ = torch.max(x, dim=dim, keepdim = keepdim)
        return res

In [26]:
class TNet(nn.Module):
    def __init__(self, nfeat):
        super(TNet, self).__init__()
        self.nfeat = nfeat
        self.tnet = nn.Sequential(MLP((nfeat, 64, 128, 1024)), MaxPooling(), 
                                  BatchNorm(1024), MLP((1024, 512, 256, nfeat*nfeat)))
        
    def forward(self, x):
        batch_size = x.shape[0]
        return self.tnet(x).view(batch_size, self.nfeat, self.nfeat)

In [27]:
class PointNet(nn.Module):
    def __init__(self, nfeat, nclass, dropout = 0):
        super(PointNet, self).__init__()

        self.input_transform = TNet(nfeat)
        self.mlp1 = nn.Sequential(BatchNorm(3), MLP((nfeat, 64, 64)))
        self.feature_transform = TNet(64)
        self.mlp2 = nn.Sequential(BatchNorm(64), MLP((64, 64, 128, 1024)))
        self.maxpooling = MaxPooling()
        self.mlp3 = nn.Sequential(BatchNorm(1024), MLP((1024, 512, 256)), nn.Dropout(dropout), nn.Linear(256, nclass))
        
        self.eye64 = torch.eye(64).to(device)

    def forward(self, xs):
        batch_size = xs.shape[0]
        
        transform = self.input_transform(xs)
        xs = torch.stack([torch.mm(xs[i],transform[i]) for i in range(batch_size)])
        xs = self.mlp1(xs)
        
        transform = self.feature_transform(xs)
        xs = torch.stack([torch.mm(xs[i],transform[i]) for i in range(batch_size)])
        xs = self.mlp2(xs)
        
        xs = self.mlp3(self.maxpooling(xs))
        
        if (self.training):
            transform_transpose = transform.transpose(1, 2)
            tmp = torch.stack([torch.mm(transform[i], transform_transpose[i]) for i in range(batch_size)])
            L_reg = ((tmp - self.eye64) ** 2).sum() / batch_size
            
        return (F.log_softmax(xs, dim=1), L_reg) if self.training else F.log_softmax(xs, dim=1)

In [33]:
lr = 0.001
num_points = 128
save_name = "PointNet.pt"
#batch_size = 512
batch_size = 16

########### loading data ###########

train_data = ModelNet40(num_points)
test_data = ModelNet40(num_points, 'test')

train_size = int(0.9 * len(train_data))
valid_size = len(train_data) - train_size
train_data, valid_data = Data.random_split(train_data, [train_size, valid_size])
valid_data.partition = 'valid'
train_data.partition = 'train'

print("train data size: ", len(train_data))
print("valid data size: ", len(valid_data))
print("test data size: ", len(test_data))

def collate_fn(batch):
    Xs = torch.stack([X for X, _ in batch])
    Ys = torch.tensor([Y for _, Y in batch], dtype = torch.long)
    return Xs, Ys

train_iter  = Data.DataLoader(train_data, shuffle = True, batch_size = batch_size, collate_fn = collate_fn)
valid_iter = Data.DataLoader(valid_data, batch_size = batch_size, collate_fn = collate_fn)
test_iter = Data.DataLoader(test_data, batch_size = batch_size, collate_fn = collate_fn)

train data size:  8856
valid data size:  984
test data size:  2468


In [34]:
############### loading model ####################

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = PointNet(nfeat=3, nclass=40, dropout=0.3)
net.to(device)
print(net)

PointNet(
  (input_transform): TNet(
    (tnet): Sequential(
      (0): MLP(
        (mlp): Sequential(
          (Linear_0): Linear(in_features=3, out_features=64, bias=True)
          (Batchnorm_0): BatchNorm(
            (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (permute): Permute()
          )
          (ReLU_0): ReLU(inplace=True)
          (Linear_1): Linear(in_features=64, out_features=128, bias=True)
          (Batchnorm_1): BatchNorm(
            (batchnorm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (permute): Permute()
          )
          (ReLU_1): ReLU(inplace=True)
          (Linear_2): Linear(in_features=128, out_features=1024, bias=True)
          (Batchnorm_2): BatchNorm(
            (batchnorm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (permute): Permute()
          )
          (ReLU_2): ReLU(inplace=Tr

In [35]:
############### training #########################

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay = 0.0001)
loss = nn.NLLLoss()

def adjust_lr(optimizer, decay_rate=0.95):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= decay_rate

retrain = True
if os.path.exists(save_name):
    print("Model parameters have already been trained before. Retrain ? (y/n)")
    ans = input()
    if (ans == 'y'):
        checkpoint = torch.load(save_name, map_location = device)
        net.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        for g in optimizer.param_groups:
            g['lr'] = lr

train_model(train_iter, valid_iter, net, loss, optimizer, device = device, max_epochs = int(1000/(batch_size/64)), 
            adjust_lr = adjust_lr, early_stop = EarlyStop(patience = 20, save_name = save_name))
    

############### testing ##########################

loss, acc = evaluate_model(test_iter, net, loss)
print('test acc = %.6f' % (acc))


training on cuda


100%|██████████| 554/554 [00:15<00:00, 35.30it/s]


epoch 1 / 4000, train loss 3.3343 (acc 0.119015), valid loss 3.2928 (acc 0.166667)


100%|██████████| 554/554 [00:15<00:00, 35.47it/s]


epoch 2 / 4000, train loss 2.9125 (acc 0.219851), valid loss 2.9936 (acc 0.217480)


100%|██████████| 554/554 [00:15<00:00, 35.84it/s]


epoch 3 / 4000, train loss 2.6170 (acc 0.302620), valid loss 2.5957 (acc 0.307927)


100%|██████████| 554/554 [00:15<00:00, 35.92it/s]


epoch 4 / 4000, train loss 2.3930 (acc 0.353546), valid loss 2.7020 (acc 0.365854)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.87it/s]


epoch 5 / 4000, train loss 2.2952 (acc 0.376242), valid loss 2.1055 (acc 0.426829)


100%|██████████| 554/554 [00:15<00:00, 35.45it/s]


epoch 6 / 4000, train loss 2.0506 (acc 0.436427), valid loss 1.8373 (acc 0.484756)


100%|██████████| 554/554 [00:15<00:00, 35.93it/s]


epoch 7 / 4000, train loss 1.8717 (acc 0.479110), valid loss 1.9249 (acc 0.462398)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.88it/s]


epoch 8 / 4000, train loss 1.7038 (acc 0.518970), valid loss 1.6760 (acc 0.528455)


100%|██████████| 554/554 [00:15<00:00, 35.87it/s]


epoch 9 / 4000, train loss 1.5752 (acc 0.550587), valid loss 1.4773 (acc 0.559959)


100%|██████████| 554/554 [00:15<00:00, 35.90it/s]


epoch 10 / 4000, train loss 1.4712 (acc 0.572267), valid loss 1.5827 (acc 0.573171)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.86it/s]


epoch 11 / 4000, train loss 1.3626 (acc 0.602755), valid loss 1.3541 (acc 0.611789)


100%|██████████| 554/554 [00:15<00:00, 35.94it/s]


epoch 12 / 4000, train loss 1.2861 (acc 0.625000), valid loss 1.6044 (acc 0.549797)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.85it/s]


epoch 13 / 4000, train loss 1.2332 (acc 0.639453), valid loss 1.2955 (acc 0.636179)


100%|██████████| 554/554 [00:15<00:00, 35.95it/s]


epoch 14 / 4000, train loss 1.1598 (acc 0.660117), valid loss 1.2192 (acc 0.646341)


100%|██████████| 554/554 [00:15<00:00, 35.84it/s]


epoch 15 / 4000, train loss 1.0794 (acc 0.681911), valid loss 0.9748 (acc 0.706301)


100%|██████████| 554/554 [00:15<00:00, 35.81it/s]


epoch 16 / 4000, train loss 1.0676 (acc 0.682136), valid loss 1.1875 (acc 0.654472)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.56it/s]


epoch 17 / 4000, train loss 0.9909 (acc 0.703252), valid loss 1.0210 (acc 0.711382)
EarlyStopping counter: 2 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.76it/s]


epoch 18 / 4000, train loss 0.9606 (acc 0.713640), valid loss 1.0199 (acc 0.686992)
EarlyStopping counter: 3 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.45it/s]


epoch 19 / 4000, train loss 0.9212 (acc 0.719399), valid loss 0.8778 (acc 0.748984)


100%|██████████| 554/554 [00:15<00:00, 35.88it/s]


epoch 20 / 4000, train loss 0.8549 (acc 0.740176), valid loss 0.8138 (acc 0.758130)


100%|██████████| 554/554 [00:15<00:00, 35.82it/s]


epoch 21 / 4000, train loss 0.8362 (acc 0.744580), valid loss 0.8120 (acc 0.747967)


100%|██████████| 554/554 [00:15<00:00, 35.86it/s]


epoch 22 / 4000, train loss 0.8056 (acc 0.753839), valid loss 0.8527 (acc 0.747967)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.81it/s]


epoch 23 / 4000, train loss 0.7677 (acc 0.760276), valid loss 0.7835 (acc 0.762195)


100%|██████████| 554/554 [00:15<00:00, 35.79it/s]


epoch 24 / 4000, train loss 0.7493 (acc 0.770890), valid loss 1.0197 (acc 0.739837)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.86it/s]


epoch 25 / 4000, train loss 0.7429 (acc 0.769196), valid loss 0.7408 (acc 0.777439)


100%|██████████| 554/554 [00:15<00:00, 35.72it/s]


epoch 26 / 4000, train loss 0.7673 (acc 0.762308), valid loss 0.8452 (acc 0.764228)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.93it/s]


epoch 27 / 4000, train loss 0.7150 (acc 0.782182), valid loss 0.6940 (acc 0.796748)


100%|██████████| 554/554 [00:15<00:00, 35.83it/s]


epoch 28 / 4000, train loss 0.6799 (acc 0.787602), valid loss 0.9908 (acc 0.757114)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.88it/s]


epoch 29 / 4000, train loss 0.6588 (acc 0.792683), valid loss 0.7288 (acc 0.782520)
EarlyStopping counter: 2 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.87it/s]


epoch 30 / 4000, train loss 0.6421 (acc 0.797990), valid loss 1.2145 (acc 0.782520)
EarlyStopping counter: 3 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.81it/s]


epoch 31 / 4000, train loss 0.6751 (acc 0.792683), valid loss 0.6787 (acc 0.802846)


100%|██████████| 554/554 [00:15<00:00, 35.91it/s]


epoch 32 / 4000, train loss 0.6356 (acc 0.804313), valid loss 0.6677 (acc 0.808943)


100%|██████████| 554/554 [00:15<00:00, 35.62it/s]


epoch 33 / 4000, train loss 0.6163 (acc 0.803297), valid loss 0.7982 (acc 0.779472)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.85it/s]


epoch 34 / 4000, train loss 0.6137 (acc 0.806685), valid loss 0.8597 (acc 0.800813)
EarlyStopping counter: 2 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.86it/s]


epoch 35 / 4000, train loss 0.5828 (acc 0.817638), valid loss 0.7345 (acc 0.786585)
EarlyStopping counter: 3 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.79it/s]


epoch 36 / 4000, train loss 0.5632 (acc 0.821251), valid loss 0.6522 (acc 0.813008)


100%|██████████| 554/554 [00:15<00:00, 35.67it/s]


epoch 37 / 4000, train loss 0.5772 (acc 0.818993), valid loss 0.9810 (acc 0.807927)
EarlyStopping counter: 1 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.77it/s]


epoch 38 / 4000, train loss 0.5620 (acc 0.827575), valid loss 1.2125 (acc 0.802846)
EarlyStopping counter: 2 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.61it/s]


epoch 39 / 4000, train loss 0.5185 (acc 0.835818), valid loss 0.8656 (acc 0.807927)
EarlyStopping counter: 3 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.35it/s]


epoch 40 / 4000, train loss 0.5264 (acc 0.830849), valid loss 2.6010 (acc 0.779472)
EarlyStopping counter: 4 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.10it/s]


epoch 41 / 4000, train loss 0.5039 (acc 0.837173), valid loss 1.3872 (acc 0.818089)
EarlyStopping counter: 5 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.58it/s]


epoch 42 / 4000, train loss 0.5144 (acc 0.832769), valid loss 0.9311 (acc 0.831301)
EarlyStopping counter: 6 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.60it/s]


epoch 43 / 4000, train loss 0.5015 (acc 0.838866), valid loss 0.6936 (acc 0.832317)
EarlyStopping counter: 7 out of 20


100%|██████████| 554/554 [00:14<00:00, 39.17it/s]


epoch 44 / 4000, train loss 0.4905 (acc 0.842254), valid loss 1.1147 (acc 0.816057)
EarlyStopping counter: 8 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.03it/s]


epoch 45 / 4000, train loss 0.4749 (acc 0.847561), valid loss 8.1201 (acc 0.789634)
EarlyStopping counter: 9 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.65it/s]


epoch 46 / 4000, train loss 0.4801 (acc 0.848238), valid loss 1.1451 (acc 0.830285)
EarlyStopping counter: 10 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.06it/s]


epoch 47 / 4000, train loss 0.4652 (acc 0.849481), valid loss 2.6186 (acc 0.832317)
EarlyStopping counter: 11 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.52it/s]


epoch 48 / 4000, train loss 0.4596 (acc 0.851287), valid loss 4.9415 (acc 0.783537)
EarlyStopping counter: 12 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.33it/s]


epoch 49 / 4000, train loss 0.4590 (acc 0.852868), valid loss 7.7225 (acc 0.764228)
EarlyStopping counter: 13 out of 20


100%|██████████| 554/554 [00:13<00:00, 40.39it/s]


epoch 50 / 4000, train loss 0.4380 (acc 0.858175), valid loss 2.7887 (acc 0.814024)
EarlyStopping counter: 14 out of 20


100%|██████████| 554/554 [00:16<00:00, 32.97it/s]


epoch 51 / 4000, train loss 0.4373 (acc 0.857836), valid loss 2.3674 (acc 0.813008)
EarlyStopping counter: 15 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.56it/s]


epoch 52 / 4000, train loss 0.4292 (acc 0.865176), valid loss 7.3395 (acc 0.803862)
EarlyStopping counter: 16 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.39it/s]


epoch 53 / 4000, train loss 0.4353 (acc 0.856820), valid loss 3.5521 (acc 0.816057)
EarlyStopping counter: 17 out of 20


100%|██████████| 554/554 [00:15<00:00, 34.95it/s]


epoch 54 / 4000, train loss 0.4349 (acc 0.861337), valid loss 1.5773 (acc 0.816057)
EarlyStopping counter: 18 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.25it/s]


epoch 55 / 4000, train loss 0.4174 (acc 0.866983), valid loss 18.8756 (acc 0.785569)
EarlyStopping counter: 19 out of 20


100%|██████████| 554/554 [00:15<00:00, 35.66it/s]


epoch 56 / 4000, train loss 0.4092 (acc 0.871500), valid loss 1.1155 (acc 0.843496)
EarlyStopping counter: 20 out of 20
test acc = 0.825770
