# Custom PyTorch Dataset

In [1]:
# Prerequisites:
import sys, os, platform 
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from glob import glob

import open3d as o3d

import torch
from torch.utils.data import DataLoader
import torch.nn as nn

from my_dataset import CustomDataset

print("OS: ", platform.platform())
print("Python Version: ", sys.version)


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
OS:  Windows-11-10.0.26100-SP0
Python Version:  3.12.7 (tags/v3.12.7:0b05ead, Oct  1 2024, 03:06:41) [MSC v.1941 64 bit (AMD64)]


Check for Cuda availability

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using NVIDIA GeForce RTX 3060 Laptop GPU


### Create Training, Validation, and Testing point cloud file lists

In [3]:
# Get list of training and testing files
project_dir = 'data/aerial_01'
pc_train_files = glob(os.path.join(project_dir,"train/*.txt"))
pc_test_files = glob(os.path.join(project_dir,"test/*.txt"))

In [4]:
# Get Validation file set from training set
val_index = np.random.choice(len(pc_train_files), int(len(pc_train_files)/5), replace=False)
val_list = [pc_train_files[i] for i in val_index]
train_list = [pc_train_files[i] for i in np.setdiff1d(list(range(len(pc_train_files))), val_index)]
test_list = pc_test_files

In [5]:
# Check sizes of datasets
print(f"# of training files: {len(train_list)}")
print(f"# of validation files: {len(val_list)}")
print(f"# of test files: {len(test_list)}")

# of training files: 101
# of validation files: 25
# of test files: 17


### Use CustomDataset class to load them

In [7]:
nr_points = 4096
train_dataset = CustomDataset(train_list, "xyz", num_points=nr_points, compute_normals=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = CustomDataset(val_list, "xyz", is_training=False, num_points=nr_points, compute_normals=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataset = CustomDataset(test_list, "xyz", is_training=False, num_points=nr_points, compute_normals=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [8]:
# Check number of classes
num_classes = train_dataset.num_classes
print("Number of classes: ", num_classes)

Number of classes:  5


### Visualize Dataset

In [None]:
def visualize_dataset(dataset):
    """
    Visualize the point clouds in the dataset using Open3D.
    Args:
        dataset (CustomDataset): The dataset containing point clouds.
    """
    pcds = []
    nb_pcd = len(dataset)
    for i in range(nb_pcd):
        point_cloud = dataset.get_data(i)
        if isinstance(point_cloud,torch.Tensor):
            point_cloud = point_cloud.cpu().numpy()
        point_cloud = point_cloud[:3].T
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(point_cloud)
        pcd.estimate_normals()

        random_color = np.random.random((3,))
        n = len(pcd.points)
        colors = np.tile(random_color,(n,1))
        pcd.colors = o3d.utility.Vector3dVector(colors)
        pcds.append(pcd)
        
    o3d.visualization.draw_geometries(pcds)

visualize_dataset(test_dataset)

### Define Training Loop

In [9]:
def train_and_save_model(train_loader, val_loader, num_classes, model, save_path="model.pth", num_epochs=100, learning_rate=0.001):
    """
    Train the model and save it to a file.
    
    Args:
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        num_classes (int): Number of classes in the dataset.
        model (torch.nn.Module): The model to be trained.
        save_path (str): Path to save the trained model.
        num_epochs (int): Number of epochs to train the model.
        learning_rate (float): Learning rate for the optimizer.
    """
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=-1)  
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            # Unpack the batch
            inputs, labels, index = batch
            p = inputs
            # Convert to float and long tensors
            p, labels = p.float(), labels.long()
            # Move to device
            p, labels = p.to(device), labels.to(device)
            # Forward pass
            seg_pred = model(p)
            seg_pred=seg_pred.contiguous().view(-1,num_classes)
            labels=labels.view(-1,1)[:,0]
            # Compute loss
            loss = criterion(seg_pred, labels)
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        # Evaluate the model on the training set
        model.eval()
        total_correct = 0
        total_points = 0
            
        with torch.no_grad():
            for batch in val_loader:
                inputs,labels,index = batch
                p = inputs
                p,labels = p.float(), labels.long()
                p,labels = p.to(device), labels.to(device)
                seg_pred = model(p)
                seg_pred = seg_pred.contiguous().view(-1,num_classes)
                labels = labels.view(-1,1)[:,0]
                _,predicted = seg_pred.max(1)
                total_correct += (predicted==labels).sum().item()
                total_points += labels.size(0)
        
        accuracy = 100*total_correct / total_points
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss {loss:.4f}, Accuracy: {accuracy:.2f}%")

    # Save the trained model
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

### Define the ANN

In [10]:
class SimpleNN(nn.Module):
    """
    A simple feedforward neural network for point cloud classification.
    Args:
        input_channels (int): Number of input channels (features).
        num_classes (int): Number of output classes.
    """

    def __init__(self, input_channels,num_classes) -> None:
        super(SimpleNN,self).__init__()
        self.layer1 = nn.Linear(input_channels,64)
        self.layer2 = nn.Linear(64,128)
        self.layer3 = nn.Linear(128,num_classes)
        
    def forward(self,x):
        batch_size,channels,num_points = x.shape
        x = x.permute(0,2,1).reshape(batch_size*num_points,channels)
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = self.layer3(x)
        return x

### Train the model

In [11]:
num_classes = train_loader.dataset.num_classes
print(f"Shape: {train_loader.dataset[0][0].shape[0]}, # of classes: {num_classes}")

# Instantiate the model
model = SimpleNN(train_loader.dataset[0][0].shape[0], num_classes)
model.to(device)

# Train and save the model
train_and_save_model(train_loader, val_loader, num_classes, model, save_path="hello_3D_sem_seg.pth")

Shape: 6, # of classes: 5
Epoch [1/100], Loss 0.7046, Accuracy: 69.81%
Epoch [2/100], Loss 0.6737, Accuracy: 69.81%
Epoch [3/100], Loss 0.6281, Accuracy: 69.81%
Epoch [4/100], Loss 0.6381, Accuracy: 69.81%
Epoch [5/100], Loss 0.8205, Accuracy: 69.81%
Epoch [6/100], Loss 0.8155, Accuracy: 69.81%
Epoch [7/100], Loss 0.7679, Accuracy: 69.81%
Epoch [8/100], Loss 0.7527, Accuracy: 69.81%
Epoch [9/100], Loss 0.7143, Accuracy: 69.81%
Epoch [10/100], Loss 0.7161, Accuracy: 69.81%
Epoch [11/100], Loss 0.6096, Accuracy: 69.81%
Epoch [12/100], Loss 0.6424, Accuracy: 69.81%
Epoch [13/100], Loss 0.6721, Accuracy: 69.81%
Epoch [14/100], Loss 0.6669, Accuracy: 69.81%
Epoch [15/100], Loss 0.7095, Accuracy: 69.81%
Epoch [16/100], Loss 0.6991, Accuracy: 69.81%
Epoch [17/100], Loss 0.6485, Accuracy: 69.81%
Epoch [18/100], Loss 0.7793, Accuracy: 69.81%
Epoch [19/100], Loss 0.5810, Accuracy: 69.81%
Epoch [20/100], Loss 0.6581, Accuracy: 69.81%
Epoch [21/100], Loss 0.5911, Accuracy: 69.81%
Epoch [22/100], L