## HKPS

In [1]:
import numpy as np
import math
import random
import os
import sys
import torch
import scipy.spatial.distance
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
sys.path.append('./utils')
import Hybrid_kmeans
import Pointnet
import Merge
import Visualize
import Normalize

import warnings
warnings.filterwarnings(action='ignore') 

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## Training

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
pointnet = Pointnet.PointNet()
pointnet.to(device);
optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.001)

# load dataset
# The dataset is an unorganized point cloud in the form of Nx3 (xyz)
# preprocessed by voxel downsamling.
dataset = np.load("./dataset/with_noise.npy",allow_pickle=True)
print(dataset.shape)

cuda:0
(67, 1600, 3)


In [3]:
def pointnetloss(outputs, labels):
    criterion = torch.nn.CrossEntropyLoss()
    return criterion(outputs, labels)

def make_labels(dataset,max_k=15,iteration=10):
    labels = np.zeros(len(dataset))
    for size in range(len(dataset)):
      print("make labels %d / %d: " % (int(size)+1,len(dataset)))
      loss = np.zeros(max_k)
      for k in range(1,max_k+1):
        k = int(k)
        #Kmeans
        #labels_,outputs,cos_loss,centroids = Hybrid_kmeans.Kmeans(dataset[size],k,10)
        #Hybrid_Kmeans
        labels_,outputs,cos_loss,centroids = Hybrid_kmeans.Kmeans_normal(dataset[size],k,iter=iteration)
        loss[k-1] = cos_loss 
      labels[size] = max_k-1
      for i in range(max_k-1,0,-1):
        if ((loss[i-1]-loss[i]) > 0.01):
          labels[size] = i
          break
      print("label : ",labels[size]+1)
      np.save("HKPS_labels",labels)
    
def train(model,dataset_,train_loss, epochs=5,make_label=True, save=True):
    
    #you can change 'max_k' and 'iteration'
    #'max_k' is the maximum number that PointNet will estimate
    #'iteration' is the number of Hybrid-Kmeans iterations
    #'iteration' can be reduced to reduce time consumption
    #but it might cause unstable results
    if (make_label == True):
        make_labels(dataset,max_k=15,iteration=10)
    
    #normalize dataset
    dataset = Normalize.normalize(dataset_)
    
    labels_ = np.load("HKPS_labels.npy",allow_pickle=True)

    print("Labels: ",labels_ +1 )
    print("train pointnet")
    inputs_, labels_ = torch.tensor(dataset).to(device).float(), torch.tensor(labels_).to(device).long()
    for epoch in range(epochs): 
        shuf_idx = np.random.permutation(len(labels_))
        inputs_ = inputs_[shuf_idx]
        labels_ = labels_[shuf_idx]

        pointnet.train()
        batch_size = int(len(inputs_)/20)
        for i in range(batch_size):
            inputs = inputs_[20*i:20*(i+1)]
            labels = labels_[20*i:20*(i+1)]
            optimizer.zero_grad()
            outputs, m3x3, m64x64 = pointnet(inputs.transpose(1,2))
            outputs_ = F.softmax(outputs)
            outputs_ = torch.argmax(outputs_,dim =1)

            loss = pointnetloss(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

            # print statistics
            print('[Epoch: %d, Batch: %4d / %4d], loss: %.3f' %
                        (epoch + 1, i + 1, batch_size, loss.item()), end='\r')

        pointnet.eval()
        inputs_val = inputs_[:20]
        labels_val = labels_[:20]      
        outputs, m3x3, m64x64 = pointnet(inputs_val.transpose(1,2))
        outputs_ = F.softmax(outputs)
        outputs_ = torch.argmax(outputs_,dim =1)

        loss = pointnetloss(outputs, labels_val)
        print("                                                   ", end='\r')
        print("Epoch: %d, valid loss : %.3f" %(epoch + 1, loss.item()), end='\r')

        # save the model
        if save:
          torch.save(pointnet.state_dict(), "./model_save/save_"+str(epoch)+".pth")

In [4]:
#if you want to make labels for PointNet training run make_labels or modify make_label as True
#'iteration' can be reduced to reduce time consumption
#but it might cause unstable results

train_loss = []
#make_labels(dataset,max_k=15,iteration=10)
train(pointnet, dataset,train_loss, epochs=50,make_label=False, save=True)

[5.479565740705611, -90.89614486694336, -3.7586827278137207]
[149.84942626953125, 91.87315368652344, 22.7500057220459]
x_range : 0.0 1.0
y_range : 0.0 1.0
z_range : 0.0 1.0
Labels:  [10. 14. 14. 11. 15. 12. 11. 14. 13. 14. 14. 15. 15. 14.  8. 15. 14. 15.
 15. 10.  5. 14. 14. 14.  7.  6.  6.  5.  5. 13.  5.  4.  9.  4.  5.  4.
  4.  4.  4.  4.  3.  4.  4.  4. 14. 15. 15. 15. 15. 15. 15. 15.  4. 15.
 15. 15. 15.  9.  8. 10.  2. 15.  2.  2.  2.  2.  2.]
train pointnet
Epoch: 50, valid loss : 0.915                      

## Validation

In [12]:
def valid(model, dataset,save_num=99,iteration=10):
  if(len(dataset.shape) <3):
    dataset = dataset.reshape(1,1600,-1)
  #normalize dataset
  dataset_normal = Normalize.normalize(dataset)

  inputs = torch.tensor(dataset_normal).to(device).float()
  file = "./model_save/save_" + str(save_num) + ".pth"
  pointnet.load_state_dict(torch.load(file))
  print(inputs.shape)
  pointnet.eval()
  with torch.no_grad():
    outputs, m3x3, m64x64 = pointnet(inputs.transpose(1,2))
  outputs_ = F.softmax(outputs)
  outputs_ = torch.argmax(outputs_,dim =1)

  print("PointNet result: ", outputs_ + 1)
  K = np.asarray(outputs_.cpu()+1)
  for n in range(len(K)):
      print("%d / %d" %(n+1, len(K)))
      labels_k_,outputs_k,cos_loss,centroids = Hybrid_kmeans.Kmeans_normal(dataset[n],K[n],iter=iteration)
      labels_k = Merge.PlaneMerge(outputs_k,np.copy(labels_k_),K[n])
      result_file = "HKPS_" + str(n)
      Visualize.visualize(labels_k,outputs_k,result_file)
  print("Result Saved")

In [None]:
#Results are saved as result_HKPS_'index'.txt
#The result is shape as xyzRGB
#'iteration' can be reduced to reduce time consumption
#but it might cause unstable results
valid(pointnet, dataset[:10], save_num=99,iteration=5)

In [14]:
print(torch.__version__)

1.7.1+cu101
