In [4]:
import os
import sys

# Add project root to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

# Import standard libraries
import json
import numpy as np
import matplotlib.pyplot as plt

# Import PyTorch libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Import Open3D for visualization (optional)
import open3d as o3d



#from src.models.pointnetplusplus import PointNetPlusPlus
from src.models.pointnet2_utils import PointNetSetAbstraction, PointNetSetAbstractionMsg, PointNetFeaturePropagation

In [5]:
# shapenet_dataset.py

class ShapeNetPartDataset(Dataset):
    def __init__(self, root_dir, split='train', num_points=2048, class_choice=None):
        self.root_dir = root_dir
        self.split = split
        self.num_points = num_points
        self.class_choice = class_choice

        # Load data
        self.datapath = []
        self.classes = {}
        self.class_to_seg_map = {}
        self.num_seg_classes = 0

        # Load the class names
        with open(os.path.join(self.root_dir, 'synsetoffset2category.txt'), 'r') as f:
            for line in f:
                cls, idx = line.strip().split()
                self.classes[cls] = idx

        if self.class_choice:
            self.classes = {k: v for k, v in self.classes.items() if k in self.class_choice}

        # Load segmentation label mappings
        with open(os.path.join(self.root_dir, 'misc', 'num_seg_classes.txt'), 'r') as f:
            for line in f:
                cls, num = line.strip().split()
                self.class_to_seg_map[cls] = int(num)
                self.num_seg_classes += int(num)

        # Load file paths
        for cls in self.classes:
            cls_root = os.path.join(self.root_dir, self.classes[cls], 'points')
            seg_root = os.path.join(self.root_dir, self.classes[cls], 'points_label')

            files = sorted(os.listdir(cls_root))
            if split == 'train':
                files = files[:int(len(files) * 0.9)]
            else:
                files = files[int(len(files) * 0.9):]

            for file in files:
                point_file = os.path.join(cls_root, file)
                seg_file = os.path.join(seg_root, file.replace('.pts', '.seg'))
                self.datapath.append((point_file, seg_file, cls))

    def __len__(self):
        return len(self.datapath)

    def __getitem__(self, idx):
        point_file, seg_file, cls = self.datapath[idx]
        point_set = np.loadtxt(point_file).astype(np.float32)
        seg = np.loadtxt(seg_file).astype(np.int64) - 1  # Labels start from 1

        # Sample points
        choice = np.random.choice(len(seg), self.num_points, replace=True)
        point_set = point_set[choice, :]
        seg = seg[choice]

        # Normalize
        point_set = point_set - np.mean(point_set, axis=0)
        norm = np.max(np.linalg.norm(point_set, axis=1))
        point_set = point_set / norm

        return point_set, seg


In [None]:
# Parameters
root_dir = './data/shapenet_part/'
num_points = 2048
batch_size = 16

# Create datasets
train_dataset = ShapeNetPartDataset(root_dir=root_dir, split='train', num_points=num_points)
val_dataset = ShapeNetPartDataset(root_dir=root_dir, split='test', num_points=num_points)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")

In [None]:
class PointNetPlusPlus(nn.Module):
    def __init__(self, num_classes: int):
        super(PointNetPlusPlus, self).__init__()

        # Set Abstraction layers with MSG
        self.sa1 = PointNetSetAbstractionMsg(
            npoint=512,
            radii=[0.1, 0.2],
            nsamples=[32, 64],
            mlps=[[3, 32, 32, 64], [3, 64, 64, 128]]
        )
        self.sa2 = PointNetSetAbstractionMsg(
            npoint=128,
            radii=[0.2, 0.4],
            nsamples=[64, 128],
            mlps=[[195, 128, 128, 256], [195, 256, 256, 512]]
        )
        self.sa3 = PointNetSetAbstraction(
            npoint=None,
            radius=None,
            nsample=None,
            mlp=[771, 512, 1024],
            group_all=True
        )

        # Feature Propagation layers
        self.fp3 = PointNetFeaturePropagation(in_channel=1792, mlp=[512, 512])
        self.fp2 = PointNetFeaturePropagation(in_channel=704, mlp=[512, 256])
        self.fp1 = PointNetFeaturePropagation(in_channel=259, mlp=[256, 128])  # Adjusted to 259

        # Fully connected layers
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.dropout = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the network.

        Args:
            xyz (torch.Tensor): Input point cloud data of shape (B, N, 3).

        Returns:
            torch.Tensor: Segmentation scores for each point.
        """
        B, N, _ = xyz.shape

        # Input Transformation
        l0_xyz = xyz.transpose(2, 1).contiguous()  # Shape: (B, 3, N)
        l0_points = None  # No additional features at input

        # Set Abstraction layers
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)     # l1_points: 192 channels
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)     # l2_points: 768 channels
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)     # l3_points: 1024 channels

        # Feature Propagation layers
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)  # l2_points: 512 channels
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)  # l1_points: 256 channels

        # Concatenate l0_xyz with l0_points if l0_points is not None
        if l0_points is not None:
            in_channel = l0_points.shape[1] + 256  # Adjusted
        else:
            in_channel = 3 + 256  # XYZ coordinates + features from l1_points

        self.fp1 = PointNetFeaturePropagation(in_channel=in_channel, mlp=[256, 128])

        l0_points = self.fp1(l0_xyz, l1_xyz, l0_points, l1_points)  # l0_points: 128 channels

        # Fully connected layers
        x = F.relu(self.bn1(self.conv1(l0_points)))
        x = self.dropout(x)
        x = self.conv2(x)

        x = x.transpose(2, 1).contiguous()  # Shape: (B, N, num_classes)
        x = F.log_softmax(x, dim=-1)

        return x