# PointNet Semantic Segmentation Inference

Model and code modified from https://www.kaggle.com/code/mahdiasdzd/pointnet/notebook

In [1]:
import os
import sys
import json
import numpy as np
# from tqdm import tqdm
# import pandas as pd
# import glob
# import random
# from sklearn.metrics import roc_auc_score, roc_curve, auc

# plotting 
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# DeepLearning 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# import torch.optim as optim

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(np.__version__)
print(torch.__version__)
!nvidia-smi

2.2.6
2.9.0+cu130
Thu Nov  6 11:09:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 581.15                 Driver Version: 581.15         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 2080 ...  WDDM  |   00000000:2D:00.0  On |                  N/A |
| 12%   55C    P0             41W /  250W |    5379MiB /   8192MiB |      7%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------

## Model Initialization

In [3]:
class STN3d(nn.Module):

    def __init__(self, num_points = 2500):
        super(STN3d, self).__init__()
        self.num_points = num_points
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        
        self.mp1 = torch.nn.MaxPool1d(num_points)
        
        # FC layers
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        
        # Expected input shape = (bs, 3, num_points)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.mp1(x)
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x

In [4]:
class OpenShape(nn.Module):
    def __init__(self, num_points = 2500):
        super(OpenShape, self).__init__()

        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.mp1 = torch.nn.MaxPool1d(num_points)
        

    def forward(self, x):
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = self.mp1(x)
        return x

In [5]:
class PointNetfeat(nn.Module):

    def __init__(self, num_points = 2500, global_feat = True):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d(num_points = num_points)
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.mp1 = torch.nn.MaxPool1d(num_points)
        self.num_points = num_points
        self.global_feat = global_feat
        self.OpenShape = OpenShape(num_points = num_points)

    def forward(self, x):
        batchsize = x.size()[0]
        trans = self.stn(x)
        x = x.transpose(2,1)
        x = torch.bmm(x, trans)
        x = x.transpose(2,1)
        x = F.elu(self.bn1(self.conv1(x)))
        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = self.mp1(x)
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
            return torch.cat([x, pointfeat], 1), trans

In [6]:
class PointNetDenseSeg(nn.Module):
    def __init__(self, num_points=2500, num_classes=16, num_instances=10):
        super(PointNetDenseSeg, self).__init__()
        self.num_points = num_points
        self.num_classes = num_classes
        self.num_instances = num_instances

        # Feature extraction using PointNetfeat
        self.feat = PointNetfeat(num_points, global_feat=False)

        # Semantic segmentation layers
        self.conv1_sem = torch.nn.Conv1d(1088, 512, 1)
        self.conv2_sem = torch.nn.Conv1d(512, 256, 1)
        self.conv3_sem = torch.nn.Conv1d(256, 128, 1)
        self.conv4_sem = torch.nn.Conv1d(128, num_classes, 1)
        self.bn1_sem = nn.BatchNorm1d(512)
        self.bn2_sem = nn.BatchNorm1d(256)
        self.bn3_sem = nn.BatchNorm1d(128)

        # Instance segmentation layers
        self.conv1_inst = torch.nn.Conv1d(1088, 512, 1)
        self.conv2_inst = torch.nn.Conv1d(512, 256, 1)
        self.conv3_inst = torch.nn.Conv1d(256, 128, 1)
        self.conv4_inst = torch.nn.Conv1d(128, num_instances, 1)
        self.bn1_inst = nn.BatchNorm1d(512)
        self.bn2_inst = nn.BatchNorm1d(256)
        self.bn3_inst = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        x, trans = self.feat(x)

        # Semantic segmentation branch
        x_sem = F.relu(self.bn1_sem(self.conv1_sem(x)))
        x_sem = F.relu(self.bn2_sem(self.conv2_sem(x_sem)))
        x_sem = F.relu(self.bn3_sem(self.conv3_sem(x_sem)))
        x_sem = self.conv4_sem(x_sem)

        # Instance segmentation branch
        x_inst = F.relu(self.bn1_inst(self.conv1_inst(x)))
        x_inst = F.relu(self.bn2_inst(self.conv2_inst(x_inst)))
        x_inst = F.relu(self.bn3_inst(self.conv3_inst(x_inst)))
        x_inst = self.conv4_inst(x_inst)

        return x_sem, x_inst, trans

In [7]:
num_points = 100000
classifier = PointNetDenseSeg(num_points = num_points)
classifier.load_state_dict(torch.load('pointnet_se_seg.pth'))
classifier = classifier.to(device)
classifier.eval()

PointNetDenseSeg(
  (feat): PointNetfeat(
    (stn): STN3d(
      (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
      (mp1): MaxPool1d(kernel_size=100000, stride=100000, padding=0, dilation=1, ceil_mode=False)
      (fc1): Linear(in_features=1024, out_features=512, bias=True)
      (fc2): Linear(in_features=512, out_features=256, bias=True)
      (fc3): Linear(in_features=256, out_features=9, bias=True)
      (relu): ReLU()
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=

In [8]:
for name, param in classifier.named_parameters():
    print(f"Layer: {name}, Weights/Biases: \n {param.data}")
    print(param.data.shape)
    break

Layer: feat.stn.conv1.weight, Weights/Biases: 
 tensor([[[-0.2062],
         [ 0.0188],
         [-0.2074]],

        [[ 0.3418],
         [-0.2682],
         [ 0.3939]],

        [[-0.1998],
         [-0.3774],
         [-0.2125]],

        [[ 0.4823],
         [ 0.0336],
         [-0.4573]],

        [[-0.4988],
         [ 0.1188],
         [ 0.5872]],

        [[-0.3966],
         [-0.0720],
         [-0.1775]],

        [[ 0.5281],
         [ 0.1959],
         [ 0.0195]],

        [[ 0.0723],
         [-0.1757],
         [-0.5828]],

        [[ 0.1966],
         [-0.2205],
         [-0.2059]],

        [[ 0.4263],
         [ 0.2155],
         [-0.2420]],

        [[ 0.0096],
         [-0.1004],
         [-0.3947]],

        [[-0.3425],
         [ 0.1092],
         [ 0.2810]],

        [[ 0.2617],
         [ 0.3500],
         [-0.2710]],

        [[ 0.5850],
         [-0.2367],
         [ 0.6039]],

        [[ 0.4725],
         [ 0.5132],
         [-0.1394]],

        [[ 0.5483],
  

## Model Inference

In [9]:
# class ShapeNetDataset(torch.utils.data.Dataset):
#     def __init__(self, root_dir, split_type, num_samples=2500):
#         self.root_dir = root_dir
#         self.split_type = split_type
#         self.num_samples = num_samples
#         with open(os.path.join(root_dir, f'{self.split_type}_split.json'), 'r') as f:
#             self.split_data = json.load(f)       
            
#     def __getitem__(self, index):
#         # read point cloud data
#         class_id, class_name, point_cloud_path, seg_label_path = self.split_data[index]
        
#         # point cloud data
#         point_cloud_path = os.path.join(self.root_dir, point_cloud_path)
#         pc_data = np.load(point_cloud_path)
        
#         # segmentation labels
#         # -1 is to change part values from [1-16] to [0-15]
#         # which helps when running segmentation
#         pc_seg_labels = np.loadtxt(os.path.join(self.root_dir, seg_label_path)).astype(np.int8) - 1
#         #pc_seg_labels = pc_seg_labels.reshape(pc_seg_labels.size,1)
        
#         # Sample fixed number of points
#         num_points = pc_data.shape[0]
#         if num_points < self.num_samples:
#             # Duplicate random points if the number of points is less than max_num_points
#             additional_indices = np.random.choice(num_points, self.num_samples - num_points, replace=True)
#             pc_data = np.concatenate((pc_data, pc_data[additional_indices]), axis=0)
#             pc_seg_labels = np.concatenate((pc_seg_labels, pc_seg_labels[additional_indices]), axis=0)
                
#         else:
#             # Randomly sample max_num_points from the available points
#             random_indices = np.random.choice(num_points, self.num_samples)
#             pc_data = pc_data[random_indices]
#             pc_seg_labels = pc_seg_labels[random_indices]
        
#         # return variable
#         data_dict= {}
#         data_dict['class_id'] = class_id
#         data_dict['class_name'] = class_name        
#         data_dict['points'] = pc_data 
#         data_dict['seg_labels'] = pc_seg_labels 
#         return data_dict        
                    
#     def __len__(self):
#         return len(self.split_data)

In [10]:
# def collate_fn(batch_list):
#     ret = {}
#     ret['class_id'] =  torch.from_numpy(np.array([x['class_id'] for x in batch_list])).long()
#     ret['class_name'] = np.array([x['class_name'] for x in batch_list])
#     ret['points'] = torch.from_numpy(np.stack([x['points'] for x in batch_list], axis=0)).float()
#     ret['seg_labels'] = torch.from_numpy(np.stack([x['seg_labels'] for x in batch_list], axis=0)).long()
#     return ret

In [11]:
# from torch.utils.data import Dataset, DataLoader
# import open3d as o3d

# class SinglePLYDataset(Dataset):
#     def __init__(self, ply_path, n_points=2048, with_normals=False):
#         self.ply_path = ply_path
#         self.n_points = n_points
#         self.with_normals = with_normals
#         self._prepare()

#     def _prepare(self):
#         pcd = o3d.io.read_point_cloud(self.ply_path)
#         pts = np.asarray(pcd.points, dtype=np.float32)
#         if self.with_normals:
#             if not pcd.has_normals():
#                 pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=30))
#             feats = np.asarray(pcd.normals, dtype=np.float32)
#         else:
#             feats = None

#         # normalize
#         c = pts.mean(0, keepdims=True)
#         pts_c = pts - c
#         s = np.max(np.linalg.norm(pts_c, axis=1))
#         pts_n = pts_c / (s + 1e-8)

#         # sample/pad
#         N = self.n_points
#         if pts_n.shape[0] >= N:
#             sel = np.random.choice(pts_n.shape[0], N, replace=False)
#         else:
#             pad = np.random.choice(pts_n.shape[0], N - pts_n.shape[0], replace=True)
#             sel = np.concatenate([np.arange(pts_n.shape[0]), pad])

#         self.sel = sel
#         self.pts_n = pts_n[sel]                    # (N,3)
#         self.feats = feats[sel] if feats is not None else None

#     def __len__(self): return 1

#     def __getitem__(self, idx):
#         if self.feats is None:
#             x = self.pts_n.T                        # (3,N)
#         else:
#             x = np.concatenate([self.pts_n, self.feats], axis=1).T  # (3+C, N)
#         return torch.from_numpy(x).float()

# # Usage
# ds = SinglePLYDataset("../point_cloud_data/shoe/shoe-poisson-sampled.ply", n_points=2500, with_normals=False)
# dl = DataLoader(ds, batch_size=1, shuffle=False)

# for batch in dl:                      # batch: (B, C, N)
#     batch = batch.to(device)
#     print(batch.shape)
#     with torch.no_grad():
#         out = classifier(batch)            # or model(batch, onehot) for part-seg

In [12]:
# model_preds, x, y = out

In [13]:
# print(model_preds.shape)
# print(x.shape)
# print(y.shape)

In [14]:
# from visual_utils import plot_pc_data3d, plot_bboxes_3d

# # Random test sample
# # test_sample = test_set[2047]
# # batch_dict = collate_fn([test_sample])

# # Get model predictions
# # x = batch_dict['points'].transpose(1, 2).to(device)
# # model_preds, _, _ = classifier(x)
# pred_part_labels = torch.argmax(model_preds, axis=1).detach().cpu().numpy()[0]

# # points = test_sample['points']
# points = ds[0]
# # part_labels = test_sample['seg_labels']

# PCD_SCENE=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), aspectmode='data')
# NUM_PARTS = 16
# PART_COLORS = np.random.choice(range(255),size=(NUM_PARTS,3))

# # plot results
# fig = make_subplots(rows=1, cols=2, specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]], column_widths=[0.5, 0.5],
#                     subplot_titles=('Part Labels', 'Part Predictions'))

# # ground truth part labels
# part_label_plots = plot_pc_data3d(x=points[:,0], y=points[:,1], z=points[:,2], apply_color_gradient=False, 
#                                   color=PART_COLORS[pred_part_labels - 1], marker_size=2)

# # ground truth part labels
# pred_part_label_plots = plot_pc_data3d(x=points[:,0], y=points[:,1], z=points[:,2], apply_color_gradient=False, 
#                                   color=PART_COLORS[pred_part_labels - 1], marker_size=2)

# fig.update_layout(template="plotly_dark", scene=PCD_SCENE, scene2=PCD_SCENE, height = 720, width = 1280,
#                 title='PointNet Semantic Segmentation', title_x=0.5, title_y=0.97, margin=dict(r=0, b=0, l=0, t=0))
# fig.add_trace(part_label_plots, row=1, col=1)
# fig.add_trace(pred_part_label_plots, row=1, col=2)
# #fig.add_trace(plot_pc_data3d(x=test_sample['points'][:,0], y=test_sample['points'][:,1], z=test_sample['points'][:,2]), row=1, col=1)
# #fig.add_trace(go.Bar(x=list(class_name_id_map.keys()), y=pred_class_probs, showlegend=False), row=1, col=2)
# fig.show()

In [15]:
from torch.utils.data import Dataset, DataLoader
import open3d as o3d
import numpy as np
import torch

class SinglePLYDataset(Dataset):
    def __init__(self, ply_path, n_points=2048, with_normals=False, seed=0):
        self.ply_path = ply_path
        self.n_points = n_points
        self.with_normals = with_normals
        self.rng = np.random.RandomState(seed)
        self._prepare()

    def _prepare(self):
        pcd = o3d.io.read_point_cloud(self.ply_path)
        pts = np.asarray(pcd.points, dtype=np.float32)              # (M,3)
        assert pts.shape[0] > 0, "No points in PLY."

        if self.with_normals:
            if not pcd.has_normals():
                pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=30))
            feats = np.asarray(pcd.normals, dtype=np.float32)       # (M,3)
        else:
            feats = None

        # normalize to unit sphere
        center = pts.mean(0, keepdims=True)
        pts_c = pts - center
        scale = np.max(np.linalg.norm(pts_c, axis=1))
        pts_n = pts_c / (scale + 1e-8)

        # sample or pad to N
        N = self.n_points
        M = pts_n.shape[0]
        if M >= N:
            sel = self.rng.choice(M, N, replace=False)
        else:
            pad = self.rng.choice(M, N - M, replace=True)
            sel = np.concatenate([np.arange(M), pad])

        self.sel = sel
        self.center = center
        self.scale = scale
        self.pts_norm = pts_n[sel]                                   # (N,3) normalized
        self.pts_denorm = self.pts_norm * scale + center             # (N,3) original coords (selected)
        self.feats = feats[sel] if feats is not None else None

    def __len__(self): return 1

    def __getitem__(self, idx):
        # (C, N) as most PointNet variants expect (B, C, N)
        if self.feats is None:
            x = self.pts_norm.T                                      # (3,N)
        else:
            x = np.concatenate([self.pts_norm, self.feats], axis=1).T  # (3+C, N)
        return torch.from_numpy(x).float()


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [16]:
import torch
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) Dataset / Loader
ds = SinglePLYDataset("../point_cloud_data/shoe/shoe-poisson-sampled.ply", n_points=num_points, with_normals=False, seed=0)
dl = DataLoader(ds, batch_size=1, shuffle=False)

# 2) Your model (replace this with your actual loaded model)
# Example: model = get_pn2_partseg_model(num_classes=50, normal_channel=False)
model = classifier.to(device).eval()

# 3) Optional: ShapeNetPart coarse-class one-hot
coarse = ["Airplane","Bag","Cap","Car","Chair","Earphone","Guitar","Knife",
          "Lamp","Laptop","Motorbike","Mug","Pistol","Rocket","Skateboard","Table"]
onehot = torch.zeros(1, len(coarse), device=device)
# Pick one *if your forward requires it*. If not needed, you can skip onehot entirely.
onehot[0, coarse.index("Chair")] = 1.0  # arbitrary pick to satisfy the interface

# 4) Inference (robust to common output shapes / returns)
batch = next(iter(dl)).to(device)             # (1, C, N)
with torch.no_grad():
    try:
        out = model(batch, onehot)
    except TypeError:
        out = model(batch)

# Unpack if model returns a tuple
logits = out[0] if isinstance(out, (tuple, list)) else out  # (B,*,*) tensor

# Make logits shape = (B, N, P)
if logits.dim() != 3:
    raise RuntimeError(f"Unexpected logits dim={logits.dim()}, expected 3.")
B, A, C = logits.shape
N = ds.pts_norm.shape[0]

if A == N:
    per_point = logits                                      # (B, N, P)
elif C == N:
    per_point = logits.transpose(1, 2)                      # (B, N, P) from (B, P, N)
else:
    # Some heads emit (B, N, P) already but with N not equal to requested â€” warn here
    raise RuntimeError(f"Cannot infer (N,P) axes from shape {tuple(logits.shape)} vs N={N}")

# Argmax to get labels (0..P-1)
pred = per_point.argmax(dim=2).squeeze(0).cpu().numpy()     # (N,)
num_parts = per_point.shape[-1]
print(f"Per-point labels: shape={pred.shape}, unique={np.unique(pred)}; P={num_parts}")


Per-point labels: shape=(100000,), unique=[0 2 3]; P=16


In [17]:
# Color palette
rng = np.random.RandomState(0)
palette = rng.rand(num_parts, 3)
colors = palette[pred]                                      # (N,3)

# Save a colored PLY (original coordinates, not normalized)
pcd_out = o3d.geometry.PointCloud()
pcd_out.points = o3d.utility.Vector3dVector(ds.pts_denorm)  # (N,3)
pcd_out.colors = o3d.utility.Vector3dVector(colors)
out_path = "../point_cloud_data/shoe/shoe-segmented.ply"
o3d.io.write_point_cloud(out_path, pcd_out)
print("Saved:", out_path)

# Static preview (works in headless notebooks)
# fig = plt.figure(figsize=(6,6))
# ax = fig.add_subplot(111, projection='3d')
# ax.scatter(ds.pts_norm[:,0], ds.pts_norm[:,1], ds.pts_norm[:,2], s=2, c=colors)
# ax.set_title("Predicted parts (normalized coords)")
# ax.set_axis_off()
# plt.show()

Saved: ../point_cloud_data/shoe/shoe-segmented.ply
