<a href="https://colab.research.google.com/github/nikitakaraevv/pointnet/blob/master/PointNetSeg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive/PointNet3D"


Mounted at /content/gdrive


In [0]:
!pip install path.py;
from path import Path
import sys
sys.path.append(root_dir)



In [0]:
import plotly.graph_objects as go
import numpy as np
import scipy.spatial.distance
import math
import random
import utils


class10_dir = "/datasets/ModelNet10txt/ModelNet10/ModelNet10/"

In [0]:
import random

def read_pts(file):
    verts = np.genfromtxt(file)
    return utils.cent_norm(verts)
    #return verts

def read_seg(file):
    verts = np.genfromtxt(file, dtype= (int))
    return verts

def sample_2000(pts, pts_cat):    
    res1 = np.concatenate((pts,np.reshape(pts_cat, (pts_cat.shape[0], 1))), axis= 1)
    res = np.asarray(random.choices(res1, weights=None, cum_weights=None, k=2000))
    images = res[:, 0:3]
    categories = res[:, 3]
    categories-=np.ones(categories.shape)
    return images, categories

## Model

In [0]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class Tnet(nn.Module):
   def __init__(self, k=3):
      super().__init__()
      self.k=k
      self.conv1 = nn.Conv1d(k,64,1)
      self.conv2 = nn.Conv1d(64,128,1)
      self.conv3 = nn.Conv1d(128,1024,1)
      self.fc1 = nn.Linear(1024,512)
      self.fc2 = nn.Linear(512,256)
      self.fc3 = nn.Linear(256,k*k)

      self.bn1 = nn.BatchNorm1d(64)
      self.bn2 = nn.BatchNorm1d(128)
      self.bn3 = nn.BatchNorm1d(1024)
      self.bn4 = nn.BatchNorm1d(512)
      self.bn5 = nn.BatchNorm1d(256)
       

   def forward(self, input):
      # input.shape == (bs,n,3)
      bs = input.size(0)
      xb = F.relu(self.bn1(self.conv1(input)))
      xb = F.relu(self.bn2(self.conv2(xb)))
      xb = F.relu(self.bn3(self.conv3(xb)))
      pool = nn.MaxPool1d(xb.size(-1))(xb)
      flat = nn.Flatten(1)(pool)
      xb = F.relu(self.bn4(self.fc1(flat)))
      xb = F.relu(self.bn5(self.fc2(xb)))
      
      #initialize as identity
      init = torch.eye(self.k, requires_grad=True).repeat(bs,1,1)
      if xb.is_cuda:
        init=init.cuda()
      matrix = self.fc3(xb).view(-1,self.k,self.k) + init
      return matrix


class Transform(nn.Module):
   def __init__(self):
        super().__init__()
        self.input_transform = Tnet(k=3)
        self.feature_transform = Tnet(k=128)
        self.fc1 = nn.Conv1d(3,64,1)
        self.fc2 = nn.Conv1d(64,128,1) 
        self.fc3 = nn.Conv1d(128,128,1)
        self.fc4 = nn.Conv1d(128,512,1)
        self.fc5 = nn.Conv1d(512,2048,1)

        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(2048)

   def forward(self, input):
        n_pts = input.size()[2]
        matrix3x3 = self.input_transform(input)
        xb = torch.bmm(torch.transpose(input,1,2), matrix3x3).transpose(1,2)
        outs = []
        
        out1 = F.relu(self.bn1(self.fc1(xb)))
        outs.append(out1)
        out2 = F.relu(self.bn2(self.fc2(out1)))
        outs.append(out2)
        out3 = F.relu(self.bn3(self.fc3(out2)))
        outs.append(out3)
        matrix128x128 = self.feature_transform(out3)
        
        out4 = torch.bmm(torch.transpose(out3,1,2), matrix128x128).transpose(1,2) 
        outs.append(out4)
        out5 = F.relu(self.bn4(self.fc4(out4)))
        outs.append(out5)
       
        xb = self.bn5(self.fc5(out5))
        
        xb = nn.MaxPool1d(xb.size(-1))(xb)
        out6 = nn.Flatten(1)(xb).repeat(n_pts,1,1).transpose(0,2).transpose(0,1)#.repeat(1, 1, n_pts)
        outs.append(out6)
        
        
        return outs, matrix3x3, matrix128x128


class PointNetSeg(nn.Module):
    def __init__(self, classes = 10):
        super().__init__()
        self.transform = Transform()

        self.fc1 = nn.Conv1d(3008,256,1) 
        self.fc2 = nn.Conv1d(256,256,1) 
        self.fc3 = nn.Conv1d(256,128,1) 
        self.fc4 = nn.Conv1d(128,4,1) 
        

        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(4)
        
        self.logsoftmax = nn.LogSoftmax(dim=1)
        

    def forward(self, input):
        inputs, matrix3x3, matrix128x128 = self.transform(input)
        stack = torch.cat(inputs,1)
        
        xb = F.relu(self.bn1(self.fc1(stack)))
       
        xb = F.relu(self.bn2(self.fc2(xb)))
    
        xb = F.relu(self.bn3(self.fc3(xb)))
        
        output = F.relu(self.bn4(self.fc4(xb)))
        
        return self.logsoftmax(output), matrix3x3, matrix128x128



## Dataset

In [0]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data.dataset import random_split
import utils

class Data(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, root_dir, valid=False, transform=None):
        
        self.root_dir = root_dir
        self.files = []
        self.valid=valid

        newdir = root_dir + '/datasets/airplane_part_seg/02691156/expert_verified/points_label/'

        for file in os.listdir(newdir):
            o = {}
            o['category'] = newdir + file
            o['img_path'] = root_dir + '/datasets/airplane_part_seg/02691156/points/'+ file.replace('.seg', '.pts')
            self.files.append(o)
       

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]['img_path']
        category = self.files[idx]['category']
        with open(img_path, 'r') as f:
            image1 = read_pts(f)
        with open(category, 'r') as f:  
            category1 = read_seg(f)
        image2, category2 = sample_2000(image1, category1)
        if not self.valid:
            theta = random.random()*360
            image2 = utils.rotation_z(utils.add_noise(image2), theta)
        
        return {'image': np.array(image2, dtype="float32"), 'category': category2.astype(int)}


In [0]:

dset = Data(root_dir , transform=None)
train_num = int(len(dset) * 0.95)
val_num = int(len(dset) *0.05)
if int(len(dset)) - train_num -  val_num >0 :
    train_num = train_num + 1
elif int(len(dset)) - train_num -  val_num < 0:
    train_num = train_num -1
#train_dataset, val_dataset = random_split(dset, [3000, 118])
train_dataset, val_dataset = random_split(dset, [train_num, val_num])
val_dataset.valid=True

print('######### Dataset class created #########')
print('Number of images: ', len(dset))
print('Sample image shape: ', dset[0]['image'].shape)
#print('Sample image points categories', dset[0]['category'], end='\n\n')

train_loader = DataLoader(dataset=train_dataset, batch_size=64)
val_loader = DataLoader(dataset=val_dataset, batch_size=64)

#dataloader = torch.utils.data.DataLoader(dset, batch_size=4, shuffle=True, num_workers=4)

######### Dataset class created #########
Number of images:  2690
Sample image shape:  (2000, 3)


## Training loop

In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [0]:
pointnet = PointNetSeg()

In [0]:
pointnet.to(device);

In [0]:
optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.001)

In [0]:
def pointnetloss(outputs, labels, m3x3, m128x128, alpha = 0.0001):
    criterion = torch.nn.NLLLoss()
    bs=outputs.size(0)
    id3x3 = torch.eye(3, requires_grad=True).repeat(bs,1,1)
    id128x128 = torch.eye(128, requires_grad=True).repeat(bs,1,1)
    if outputs.is_cuda:
        id3x3=id3x3.cuda()
        id128x128=id128x128.cuda()
    diff3x3 = id3x3-torch.bmm(m3x3,m3x3.transpose(1,2))
    diff128x128 = id128x128-torch.bmm(m128x128,m128x128.transpose(1,2))
    return criterion(outputs, labels) + alpha * (torch.norm(diff3x3)+torch.norm(diff128x128)) / float(bs)
        

In [0]:
def train(model, train_loader, val_loader=None,  epochs=15, save=True):
    for epoch in range(epochs): 
        pointnet.train()
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data['image'].to(device), data['category'].to(device)
            optimizer.zero_grad()
            outputs, m3x3, m64x64 = pointnet(inputs.transpose(1,2))

            loss = pointnetloss(outputs, labels, m3x3, m64x64)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 10 == 9:    # print every 10 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                        (epoch + 1, i + 1, running_loss / 10))
                    running_loss = 0.0

        pointnet.eval()
        correct = total = 0

        # validation
        if val_loader:
            with torch.no_grad():
                for data in val_loader:
                    inputs, labels = data['image'].to(device), data['category'].to(device)
                    outputs, __, __ = pointnet(inputs.transpose(1,2))
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0) * labels.size(1) ##
                    correct += (predicted == labels).sum().item()
            val_acc = 100 * correct / total
            print('Valid accuracy: %d %%' % val_acc)

        # save the model
        if save:
            torch.save(pointnet.state_dict(), root_dir+"/modelsSeg/"+str(epoch)+"_"+str(val_acc))


In [0]:
train(pointnet, train_loader, val_loader,  save=True)


[1,    10] loss: 1.203
[1,    20] loss: 0.924
[1,    30] loss: 0.844
[1,    40] loss: 0.800
Valid accuracy: 78 %
[2,    10] loss: 0.769
[2,    20] loss: 0.741
[2,    30] loss: 0.732
[2,    40] loss: 0.729
Valid accuracy: 82 %
[3,    10] loss: 0.709
[3,    20] loss: 0.685
[3,    30] loss: 0.680
[3,    40] loss: 0.676
Valid accuracy: 85 %
[4,    10] loss: 0.663
[4,    20] loss: 0.642
[4,    30] loss: 0.637
[4,    40] loss: 0.636
Valid accuracy: 87 %
[5,    10] loss: 0.626
[5,    20] loss: 0.617
[5,    30] loss: 0.610
[5,    40] loss: 0.613
Valid accuracy: 86 %
[6,    10] loss: 0.604
[6,    20] loss: 0.587
[6,    30] loss: 0.580
[6,    40] loss: 0.583
Valid accuracy: 86 %
[7,    10] loss: 0.574
[7,    20] loss: 0.563
[7,    30] loss: 0.558
[7,    40] loss: 0.565
Valid accuracy: 86 %
[8,    10] loss: 0.553
[8,    20] loss: 0.539
[8,    30] loss: 0.538
[8,    40] loss: 0.543
Valid accuracy: 87 %
[9,    10] loss: 0.533
[9,    20] loss: 0.525
[9,    30] loss: 0.516
[9,    40] loss: 0.521
Vali

## test

In [0]:
pointnet = PointNetSeg()
pointnet.load_state_dict(torch.load(root_dir+"/modelsSeg/"+"14_88.01940298507462"))
pointnet.eval()

PointNetSeg(
  (transform): Transform(
    (input_transform): Tnet(
      (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
      (fc1): Linear(in_features=1024, out_features=512, bias=True)
      (fc2): Linear(in_features=512, out_features=256, bias=True)
      (fc3): Linear(in_features=256, out_features=9, bias=True)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (feature_transform): Tnet(
      (conv1): Conv1d(128, 64, kernel_size

In [0]:
batch = next(iter(val_loader))
pred = pointnet(batch['image'].transpose(1,2))
pred_np = np.array(torch.argmax(pred[0],1));
pred_np



array([[1, 1, 0, ..., 0, 0, 1],
       [1, 2, 1, ..., 1, 2, 1],
       [0, 0, 2, ..., 0, 1, 3],
       ...,
       [1, 0, 1, ..., 3, 2, 1],
       [2, 3, 0, ..., 0, 1, 3],
       [1, 1, 1, ..., 0, 0, 1]])

In [0]:
batch['image'][0].shape

torch.Size([2000, 3])

In [0]:
pred_np==np.array(batch['category'])

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ..., False,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ..., False,  True,  True]])

In [0]:
acc = (pred_np==np.array(batch['category']))

In [0]:
resulting_acc = np.sum(acc, axis=1) / 2000

In [0]:
resulting_acc

array([0.863 , 0.906 , 0.91  , 0.867 , 0.9165, 0.882 , 0.9055, 0.8445,
       0.8405, 0.9235, 0.8865, 0.8855, 0.903 , 0.884 , 0.8945, 0.85  ,
       0.9125, 0.822 , 0.9345, 0.895 , 0.9135, 0.9395, 0.9385, 0.9165,
       0.8865, 0.848 , 0.8765, 0.9105, 0.8805, 0.83  , 0.852 , 0.9225,
       0.906 , 0.7705, 0.883 , 0.785 , 0.811 , 0.8565, 0.866 , 0.868 ,
       0.7855, 0.7305, 0.9155, 0.8915, 0.9065, 0.805 , 0.875 , 0.89  ,
       0.813 , 0.9005, 0.8325, 0.833 , 0.879 , 0.9215, 0.8185, 0.933 ,
       0.9325, 0.9   , 0.833 , 0.8535, 0.8545, 0.895 , 0.8325, 0.9295])

In [0]:
pred_np

array([[1, 1, 0, ..., 0, 0, 1],
       [1, 2, 1, ..., 1, 2, 1],
       [0, 0, 2, ..., 0, 1, 3],
       ...,
       [1, 0, 1, ..., 3, 2, 1],
       [2, 3, 0, ..., 0, 1, 3],
       [1, 1, 1, ..., 0, 0, 1]])

In [0]:
x,y,z=np.array(batch['image'][0]).T
c = np.array(batch['category'][0]).T

fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, 
                                   mode='markers',
                                   marker=dict(
        size=30,
        color=c,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=1.0
    ))])
fig.update_traces(marker=dict(size=2,
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))
fig.show()