In [1]:

import vtk
import os
import numpy as np
import itertools
import math, random
import data_process_ml
random.seed = 42
import copy

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torchvision import transforms, utils

import scipy.spatial.distance
# import plotly.graph_objects as go
# import plotly.express as px
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix



In [27]:
import open3d as o3d
def reduce_mesh(target_mesh, reference_mesh):
    # Find common vertices between the two meshes
    output_mesh = copy.deepcopy(target_mesh)
    target_vertices = set(tuple(vertex) for vertex in np.asarray(target_mesh.vertices))
    target_vertices_list = [tuple(vertex) for vertex in np.asarray(target_mesh.vertices)]
    index_dict = {item: index for index, item in enumerate(target_vertices_list)}
    reference_vertices = set(tuple(vertex) for vertex in np.asarray(reference_mesh.vertices))

    common_vertices = target_vertices.intersection(reference_vertices)
    #print(len(list(common_vertices)))
    list_mine = [index_dict[element_to_find] for element_to_find in list(common_vertices)]

    #print(list_mine)

    # Filter out vertices and triangles based on common vertices
    output_mesh.remove_vertices_by_index(list_mine)

    return output_mesh

def mesh_to_point_cloud(file_path,points = 10000):
    """
    Convert a mesh object to a point cloud object.
    """
    mesh = o3d.io.read_triangle_mesh(file_path)
    point_cloud = mesh.sample_points_uniformly(number_of_points=points) # Adjust number_of_points as needed
    return point_cloud

In [28]:
import vtk
import numpy as np

def vtp_to_mesh(file_path):
    reader = vtk.vtkXMLPolyDataReader()
    reader.SetFileName(file_path)
    reader.Update()
    polydata = reader.GetOutput()

    vertices = np.array(polydata.GetPoints().GetData())
    polygons = np.array(polydata.GetPolys().GetData())

    # Extract vertices
    vertices = vertices.reshape(-1, 3)

    # Extract faces
    faces = polygons.reshape(-1, 4)[:, 1:]  

    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    mesh.triangles = o3d.utility.Vector3iVector(faces)
    return mesh

def vtp_to_point_cloud(file_path,points = 10000):
    mesh = o3d.geometry.TriangleMesh()
    if ".vtp" not in file_path:
        mesh = o3d.io.read_triangle_mesh(file_path)
    else:
        mesh = vtp_to_mesh(file_path)

    point_cloud = mesh.sample_points_uniformly(number_of_points=points)

    return point_cloud

def vtp_to_point_cloud_cutvessel(vessel_file_path,cut_file_path, points = 10000):
    """
    Convert a mesh object to a point cloud object.
    """
    if ".vtp" not in vessel_file_path:
        vessel_mesh = o3d.io.read_triangle_mesh(vessel_file_path)
    else:
        vessel_mesh = vtp_to_mesh(vessel_file_path)
    if ".vtp" not in cut_file_path:
        cut_mesh = o3d.io.read_triangle_mesh(cut_file_path)
    else:
        cut_mesh = vtp_to_mesh(cut_file_path)
    
    box_mesh = cut_mesh.get_axis_aligned_bounding_box()
    reduced_mesh =  vessel_mesh.crop(box_mesh) 
    result_mesh = reduce_mesh(vessel_mesh,reduced_mesh)
    #print(cut1_mesh,result_mesh)
    if vessel_mesh.vertices == result_mesh.vertices:
        return False
                  
    #o3d.visualization.draw_geometries([result_mesh])
    
    point_cloud = result_mesh.sample_points_uniformly(number_of_points=points) # Adjust number_of_points as needed
    return point_cloud



test case

In [29]:
modelcov_vtp = vtp_to_point_cloud("./test_model/SNF00000036_02.vtp")
# o3d.visualization.draw_geometries([modelcov_vtp])

In [30]:
print(np.asarray(modelcov_vtp.points))

[[38.39359112 25.73729217 11.00451665]
 [40.55299039 24.10619762 13.48027918]
 [36.71958714 28.03707383 14.44223189]
 ...
 [22.93412879 25.76327997 33.00578464]
 [23.1041807  26.44252294 31.81658701]
 [37.61510272 34.44956764 20.37695921]]


In [5]:
# modelcov_vtp = vtp_to_point_cloud_cutvessel("./test_model/ANSYS_UNIGE_09.vtp","./test_model/ANSYS_UNIGE_09_dome.vtp")
# o3d.visualization.draw_geometries([modelcov_vtp])

In [6]:
# modelcov_vtp = vtp_to_point_cloud_cutvessel("./test_model/C0040.vtp","./test_model/C0040_dome.vtp")
# o3d.visualization.draw_geometries([modelcov_vtp])

In [92]:
class Aneuxmodel_Dataset(Dataset):

    def __init__(self, df, root = "", transform = None,mesh = "area-001",cuttype = "dome",crop = False):

        self.root = root
        self.transform = transform
        self.mesh = mesh
        self.cuttype = cuttype
        self.crop = crop
        if self.cuttype!= "dome" and self.cuttype!= "cut1":
            return "type error"
        
        self.df = df
        self.label = []
        self.vessel_model_file = []
        self.cut1_model_file = []
        self.dome_model_file = []
        self.model_table = []
        
        self.cropdome_vessel_file = []
        self.cropcut1_vessel_file = []
        
        self.training_data_load()
        self.label_loader()
        self.my_device = "cuda:0"
            
    def label_loader(self):
        self.label = []
        for model in self.model_table:
            if  model in list(self.df["dataset"]):
                label_num = self.df[self.df["dataset"] == model]["status"]
                label_num = list(label_num)[0]
                if label_num == "rupture":
                    label_num = 1
                else:
                    label_num = 0
                self.label.append(label_num)
        self.label = torch.from_numpy(np.array(self.label))
        return True
    
    def find_dome_cut1(self,all_IA_model,model_name):
        IA_name_cut1 = ""
        IA_name_dome = ""
        for IA_model in all_IA_model:
                
            if model_name in IA_model and "cut1" in IA_model and model_name not in IA_name_cut1:
                #print("find:" + IA_model)
                IA_name_cut1 = IA_model
            elif model_name.split("_")[0] in IA_model and "cut1" in IA_model and IA_name_cut1 == "":
                IA_name_cut1 = IA_model
                
            if model_name in IA_model and "dome" in IA_model and model_name not in IA_name_dome:
                IA_name_dome = IA_model
            elif model_name.split("_")[0] in IA_model and "dome" in IA_model and IA_name_dome == "":
                IA_name_dome = IA_model
                
            if IA_name_cut1 != "" and IA_name_dome != "" and model_name in IA_name_cut1 and model_name in IA_name_dome:
                break
        return IA_name_cut1,IA_name_dome
                
    def training_data_load(self):
        self.label = []
        self.vessel_model_file = []
        self.cut1_model_file = []
        self.dome_model_file = []
        self.model_table = []
        self.cropdome_vessel_file = []
        self.cropcut1_vessel_file = []
        
        #path of the files
        IA = "aneurysms\\remeshed\\area-001"
        Vessel = "vessels\\remeshed\\area-001"
        if self.mesh == "area-001":
            IA = "aneurysms\\remeshed\\area-001"
            Vessel = "vessels\\remeshed\\area-001"
        elif self.mesh == "area-005":
            IA = "aneurysms\r\emeshed\\area-005"
            Vessel = "vessels\\remeshed\\area-005"
        elif self.mesh == "orginal":
            IA = "aneurysms\\orginal"
            Vessel = "vessels\\orginal"
            
        IA_root = os.path.join(self.root,IA)
        Vessel_root = os.path.join(self.root,Vessel)
        #list of the model files
        all_vessel_model = os.listdir(Vessel_root)
        all_IA_model = os.listdir(IA_root)

        for model in all_vessel_model[:50]:
            #file name for IA cut1 and dome
            model_name = model[:-4]
            # get cut1 file name and dome file name from ALL_IA_model list
            IA_name_cut1 = ""
            IA_name_dome = ""
            IA_name_cut1,IA_name_dome = self.find_dome_cut1(all_IA_model,model_name)

            #read the file path and add the model to the list
            if IA_name_cut1 in all_IA_model:
                IA_root_cut1 = os.path.join(IA_root,IA_name_cut1)
                cut1 = vtp_to_point_cloud(IA_root_cut1,points = 10000)
                self.cut1_model_file.append(cut1)
            else:
                print("missing a cut1 model: " + model_name)
                
            if IA_name_dome in all_IA_model:
                IA_root_dome = os.path.join(IA_root,IA_name_dome)
                dome = vtp_to_point_cloud(IA_root_dome,points = 10000)
                self.dome_model_file.append(dome)
            else:
                print("missing a dome model: " + model_name)
            
            
            # append vessel list             
            if model in all_vessel_model:
                Vessel_model_root = os.path.join(Vessel_root,model)
                vessel = vtp_to_point_cloud(Vessel_model_root,points = 10000)
                self.vessel_model_file.append(vessel)
            else:
                print("missing a vessel model: " + model_name)
            if self.crop:
                IA_root_cut1 = os.path.join(IA_root,IA_name_cut1)
                IA_root_dome = os.path.join(IA_root,IA_name_dome)
                Vessel_model_root = os.path.join(Vessel_root,model)
                Vessel_crop_cut1 = vtp_to_point_cloud_cutvessel(Vessel_model_root,IA_root_cut1,points = 10000)
                Vessel_crop_dome = vtp_to_point_cloud_cutvessel(Vessel_model_root,IA_root_dome,points = 10000)
                self.cropcut1_vessel_file.append(Vessel_crop_cut1)
                self.cropdome_vessel_file.append(Vessel_crop_dome)
                        
            self.model_table.append(IA_name_cut1.split("_")[0])
        
        return True
        # self.org_imgs = np.array(self.org_imgs)
        # self.total_imgs = torch.from_numpy(np.array(self.total_imgs)) 
        # self.label = torch.from_numpy(np.array(self.label))  
        # return True

    
    def __getitem__(self, index):
        
        """ Returns one data pair (image and target caption). """
        cut_model = self.cut1_model_file[index]
        vessel_model = self.vessel_model_file[index]
        if self.cuttype == "dome":
            cut_model = self.dome_model_file[index]
        if self.crop and self.cuttype == "cut1":
            vessel_model = self.cropcut1_vessel_file[index]
        if self.crop and self.cuttype == "dome":
            vessel_model = self.cropdome_vessel_file[index]
            
        vessel_model = np.asarray(vessel_model.points)
        cut_model = np.asarray(cut_model.points)
        
        if self.transform is not None:      
            cut_model= self.transform(cut_model)
            vessel_model = self.transform(vessel_model)
            
        label_return = self.label[index]
            
         
        return vessel_model,cut_model,label_return

    def __len__(self):
                                
        return len(self.model_table)

load the dataset

In [44]:
root = "..\..\msc_data\models-v1.0\models"
IA = "aneurysms\\remeshed\\area-001"
Vessel = "vessels\\remeshed\\area-001"
IA_root = os.path.join(root,IA)
Vessel_root = os.path.join(root,Vessel)
list1 = os.listdir(Vessel_root)
list2 = os.listdir(IA_root)

In [45]:
morpho_path = ".\AneuX\data-v1.0\data\morpho-per-cut.csv"
patient_path = ".\AneuX\data-v1.0\data\clinical.csv"
morpho_data_patient = data_process_ml.read_and_combine_data(morpho_path,patient_path)
merged_dataset = data_process_ml.encode_column(morpho_data_patient)
merged_dataset = data_process_ml.drop_columns(merged_dataset)
morpho_data_cut1,morpho_data_dome = data_process_ml.output_cut1anddome(merged_dataset)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  morpho_data_cut1.drop(morpho_data_cut1.columns[3:23], axis=1, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  morpho_data_cut1.drop(['source_x',"cuttype","dataset"], axis=1, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  morpho_data_dome.drop(['source_x',"cuttype","dataset"], axis=1, inplace=True)


In [46]:
morpho_data_patient

Unnamed: 0,source_x,dataset,cuttype,curvature-gauss--L2N,curvature.1-gauss-H,curvature.2-gauss-L2N,curvature.3-gauss-L2NCH,curvature.4-mean--L2N,curvature.5-mean-H,curvature.6-mean-L2N,...,zmi_surf.136-energy-n16,zmi_surf.137-energy-n17,zmi_surf.138-energy-n18,zmi_surf.139-energy-n19,zmi_surf.140-energy-n20,status,location,side,sex,age
0,hug2016,p043_HAARCREcDAAQDQcbHgANDRQM,cut1,,,,,,,,...,45.5701,47.4192,49.5589,52.4916,54.9455,unruptured,ICA oph,left,female,64.2
1,hug2016,p043_HAARCREcDAAQDQcbHgANDRQM,cut2,,,,,,,,...,60.8748,64.7821,68.2609,71.7461,74.6314,unruptured,ICA oph,left,female,64.2
2,hug2016,p043_HAARCREcDAAQDQcbHgANDRQM,dome,0.725242,-0.0894094,2.60987,1.18701,0.198244,0.086614,1.86744,...,30.1846,32.0237,33.6042,35.1063,37.0662,unruptured,ICA oph,left,female,64.2
3,hug2016,p043_HAARCREcDAAQDQcbHgANDRQM,ninja,2.10899,0.214448,3.36637,1.46921,0.613297,0.297034,1.9215,...,32.396,34.5638,36.0719,37.9747,39.807,unruptured,ICA oph,left,female,64.2
4,hug2016,p044_BBMdFxESDBMcEwcVBhMBExQC,cut1,,,,,,,,...,42.1663,44.2741,46.3791,47.4028,49.1238,unruptured,VA V4,left,female,72.7
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2771,aneurisk,C0098,ninja,4.02451,-1.3277,6.76148,1.58789,0.700816,-0.445498,2.39925,...,32.908,35.3017,36.9199,38.3504,40.132,ruptured,MCA bif,right,female,59.0
2772,aneurisk,C0099,cut1,,,,,,,,...,39.4508,42.9421,45,47.9252,50.0939,unruptured,MCA bif,right,female,42.0
2773,aneurisk,C0099,cut2,,,,,,,,...,48.4827,51.0062,53.8531,56.8737,60.1946,unruptured,MCA bif,right,female,42.0
2774,aneurisk,C0099,dome,0.606894,0.0592102,1.77604,1.05955,0.0435289,-0.0156052,1.5576,...,29.1724,30.676,32.1221,33.7636,35.7248,unruptured,MCA bif,right,female,42.0


In [96]:
import pandas as pd
df = pd.DataFrame()
Aneux_Dataset = Aneuxmodel_Dataset(root = root,
                                   df=morpho_data_patient,
                                   transform = transforms.ToTensor(),
                                   mesh = "area-001",
                                   cuttype = "dome",
                                   crop = False)

In [98]:
Aneux_Dataset
label_aneux = []
for model in Aneux_Dataset.model_table:
    if  model in list(Aneux_Dataset.df["dataset"]):
        label_num = Aneux_Dataset.df[Aneux_Dataset.df["dataset"] == model]["status"]
        label_num = list(label_num)[0]
        if label_num == "rupture":
            label_num = 1
        else:
            label_num = 0
        label_aneux.append(label_num)
    else:
        print(model)


C0007
C0025
C0039


In [93]:
# load dataset, transfer the open3d pointnet to the tensotflow object
train_size = int(len(Aneux_Dataset) * 0.8) # 80% training data
valid_size = len(Aneux_Dataset) - train_size
train_data, valid_data = random_split(Aneux_Dataset, [train_size, valid_size])

train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=16,
    shuffle=True,
    #num_workers=2, 
    pin_memory=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_data,
    batch_size=100, # Forward pass only so batch size can be larger
    shuffle=False,
    #num_workers=2, 
    pin_memory=True
)

In [36]:
for i in train_loader:
    a,b,c = i
    print(a,b,c)

IndexError: index 532 is out of bounds for dimension 0 with size 0

In [18]:
file = "\\C0006"
model_list = []
for model_name in list1[:]:
    file = model_name[:-4]
    Vessel_root_file = Vessel_root + "\\" + file + ".vtp"
    
    IA_name_cut1 = ""
    IA_name_dome = ""
    for IA_model in list2:
        
        if file in IA_model and "cut1" in IA_model and file not in IA_name_cut1:
            print("find:" + IA_model)
            IA_name_cut1 = IA_model
        elif file.split("_")[0] in IA_model and "cut1" in IA_model and IA_name_cut1 == "":
            IA_name_cut1 = IA_model
            
        if file in IA_model and "dome" in IA_model and file not in IA_name_dome:
            IA_name_dome = IA_model
        elif file.split("_")[0] in IA_model and "dome" in IA_model and IA_name_dome == "":
            IA_name_dome = IA_model
            
        if IA_name_cut1 != "" and IA_name_dome != "" and file in IA_name_cut1 and file in IA_name_dome:
            break
            
    print(IA_name_cut1)
    IA_root_file = os.path.join(IA_root,IA_name_cut1)
    modelcov_vtp = True
    if IA_name_cut1 != "" and IA_name_dome!= "":
        modelcov_vtp = vtp_to_point_cloud_cutvessel(Vessel_root_file,IA_root_file)
    if modelcov_vtp == False:
        print("Can not crop: " + file)
        print("Vessel root name: " + Vessel_root_file)
        print("IA root name: " + IA_root_file)
        break
         #o3d.visualization.draw_geometries([modelcov_vtp])
    else:
        model_list.append(modelcov_vtp)

find:ANSYS_UNIGE_09_cut1.vtp
ANSYS_UNIGE_09_cut1.vtp
find:ANSYS_UNIGE_16_cut1.vtp
ANSYS_UNIGE_16_cut1.vtp
find:ANSYS_UNIGE_17_10_cut1.vtp
ANSYS_UNIGE_17_10_cut1.vtp
find:ANSYS_UNIGE_27_cut1.vtp
ANSYS_UNIGE_27_cut1.vtp
find:ANSYS_UNIGE_28_269_cut1.vtp
ANSYS_UNIGE_28_269_cut1.vtp
find:ANSYS_UNIGE_30_612_cut1.vtp
ANSYS_UNIGE_30_612_cut1.vtp
find:ANSYS_UNIGE_30_614_cut1.vtp
ANSYS_UNIGE_30_614_cut1.vtp
find:ANSYS_UNIGE_33_628_cut1.vtp
ANSYS_UNIGE_33_628_cut1.vtp
find:ANSYS_UNIGE_33_631_cut1.vtp
ANSYS_UNIGE_33_631_cut1.vtp
find:ANSYS_UNIGE_34_cut1.vtp
ANSYS_UNIGE_34_cut1.vtp
find:ANSYS_UNIGE_35_cut1.vtp
ANSYS_UNIGE_35_cut1.vtp
find:C0001_cut1.vtp
C0001_cut1.vtp
find:C0002_cut1.vtp
C0002_cut1.vtp
find:C0003_cut1.vtp
C0003_cut1.vtp
find:C0005_cut1.vtp
C0005_cut1.vtp
find:C0006_cut1.vtp
C0006_cut1.vtp
find:C0007a_cut1.vtp
C0007a_cut1.vtp
find:C0008_cut1.vtp
C0008_cut1.vtp
find:C0009_cut1.vtp
C0009_cut1.vtp
find:C0010_cut1.vtp
C0010_cut1.vtp
find:C0011_cut1.vtp
C0011_cut1.vtp
find:C0012_cut1.vtp

In [19]:
print(len(list1),len(model_list))
print(model_list)

682 682
[PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points., PointCloud with 10000 points.,

# PointNet model

In [28]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
#https://www.kaggle.com/code/balraj98/pointnet-for-3d-object-classification-pytorch/notebook
class Tnet(nn.Module):
    def __init__(self, k=3):
        super().__init__()
        self.k=k
        self.conv1 = nn.Conv1d(k,64,1)
        self.conv2 = nn.Conv1d(64,128,1)
        self.conv3 = nn.Conv1d(128,1024,1)
        self.fc1 = nn.Linear(1024,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,k*k)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, input):
        # input.shape == (bs,n,3)
        bs = input.size(0)
        xb = F.relu(self.bn1(self.conv1(input)))
        xb = F.relu(self.bn2(self.conv2(xb)))
        xb = F.relu(self.bn3(self.conv3(xb)))
        pool = nn.MaxPool1d(xb.size(-1))(xb)
        flat = nn.Flatten(1)(pool)
        xb = F.relu(self.bn4(self.fc1(flat)))
        xb = F.relu(self.bn5(self.fc2(xb)))

        #initialize as identity
        init = torch.eye(self.k, requires_grad=True).repeat(bs,1,1)
        if xb.is_cuda:
            init=init.cuda()
        matrix = self.fc3(xb).view(-1,self.k,self.k) + init
        return matrix


class Transform(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_transform = Tnet(k=3)
        self.feature_transform = Tnet(k=64)
        self.conv1 = nn.Conv1d(3,64,1)

        self.conv2 = nn.Conv1d(64,128,1)
        self.conv3 = nn.Conv1d(128,1024,1)


        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

    def forward(self, input):
        matrix3x3 = self.input_transform(input)
        # batch matrix multiplication
        xb = torch.bmm(torch.transpose(input,1,2), matrix3x3).transpose(1,2)

        xb = F.relu(self.bn1(self.conv1(xb)))

        matrix64x64 = self.feature_transform(xb)
        xb = torch.bmm(torch.transpose(xb,1,2), matrix64x64).transpose(1,2)

        xb = F.relu(self.bn2(self.conv2(xb)))
        xb = self.bn3(self.conv3(xb))
        xb = nn.MaxPool1d(xb.size(-1))(xb)
        output = nn.Flatten(1)(xb)
        return output, matrix3x3, matrix64x64

class PointNet(nn.Module):
    def __init__(self, classes = 10):
        super().__init__()
        self.transform = Transform()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, classes)
        

        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(p=0.3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, input):
        xb, matrix3x3, matrix64x64 = self.transform(input)
        xb = F.relu(self.bn1(self.fc1(xb)))
        xb = F.relu(self.bn2(self.dropout(self.fc2(xb))))
        output = self.fc3(xb)
        return self.logsoftmax(output)



In [29]:
pointnet = PointNet()

In [30]:
#Train use 3d model use Pointnet

# PointNet ++

In [31]:
#https://github.com/charlesq34/pointnet2

In [32]:
import torch
class PointNetplus(nn.Module):
    def __init__(self, classes = 10):
        super().__init__()
        self.transform = Transform()
        
        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, classes)
        

        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(p=0.3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, input1, input2):
        xb, matrix3x3, matrix64x64 = self.transform(input)
        xa, matrix3x3, matrix64x64 = self.transform(input1)
        xb, matrix3x3, matrix64x64 = self.transform(input2)
        xc = xa + xb
        xc = F.relu(self.bn1(self.fc1(xc)))
        xc = F.relu(self.bn2(self.dropout(self.fc2(xc))))
        output = self.fc3(xb)
        return self.logsoftmax(output)

In [33]:
#Train use 3d model use Pointnet