In [1]:
"""PointNet
Reference:
https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet_utils.py

"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class STN3d(nn.Module):
    def __init__(self, channel=3):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.iden = torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32)).reshape(1, 9)

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = self.iden.repeat(batchsize, 1).to(x.device)
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k
        self.iden = torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)).reshape(1, self.k * self.k)

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = self.iden.repeat(batchsize, 1).to(x.device)

        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x


class PointNetEncoder(nn.Module):
    """Encoder for PointNet

    Args:
        nn (_type_): _description_
    """

    def __init__(self,
                 in_channels: int,
                 input_transform: bool=True,
                 feature_transform: bool=True,
                 is_seg: bool=False,  
                 **kwargs
                 ):
        """_summary_

        Args:
            in_channels (int): feature size of input 
            input_transform (bool, optional): whether to use transformation for coordinates. Defaults to True.
            feature_transform (bool, optional): whether to use transformation for features. Defaults to True.
            is_seg (bool, optional): for segmentation or classification. Defaults to False.
        """
        super().__init__()
        self.stn = STN3d(in_channels) if input_transform else None
        self.conv0_1 = torch.nn.Conv1d(in_channels, 64, 1)
        self.conv0_2 = torch.nn.Conv1d(64, 64, 1)

        self.conv1 = torch.nn.Conv1d(64, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn0_1 = nn.BatchNorm1d(64)
        self.bn0_2 = nn.BatchNorm1d(64)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fstn = STNkd(k=64) if feature_transform else None
        self.out_channels = 1024 + 64 if is_seg else 1024 
         
    def forward_cls_feat(self, pos, x=None):
        if hasattr(pos, 'keys'):
            x = pos['x']
        if x is None:
            x = pos.transpose(1, 2).contiguous()
        
        B, D, N = x.size()
        if self.stn is not None:
            trans = self.stn(x)
            x = x.transpose(2, 1)
            if D > 3:
                feature = x[:, :, 3:]
                x = x[:, :, :3]
            x = torch.bmm(x, trans)
            if D > 3:
                x = torch.cat([x, feature], dim=2)
            x = x.transpose(2, 1)
        x = F.relu(self.bn0_1(self.conv0_1(x)))
        x = F.relu(self.bn0_2(self.conv0_2(x)))

        if self.fstn is not None:
            trans_feat = self.fstn(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2, 1)
        else:
            trans_feat = None
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        return x

    def forward_seg_feat(self, pos, x=None):
        if hasattr(pos, 'keys'):
            x = pos.get('x', None)
        if x is None:
            x = pos.transpose(1, 2).contiguous()

        B, D, N = x.size()
        if self.stn is not None:
            trans = self.stn(x)
            x = x.transpose(2, 1)
            if D > 3:
                feature = x[:, :, 3:]
                x = x[:, :, :3]
            x = torch.bmm(x, trans)
            if D > 3:
                x = torch.cat([x, feature], dim=2)
            x = x.transpose(2, 1)
        x = F.relu(self.bn0_1(self.conv0_1(x)))
        x = F.relu(self.bn0_2(self.conv0_2(x)))

        if self.fstn is not None:
            trans_feat = self.fstn(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2, 1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024, 1).repeat(1, 1, N)
        return pos, torch.cat([pointfeat, x], 1)
    
    def forward(self, x, features=None):
        return self.forward_cls_features(x)

In [4]:
import torch
import os


# Load checkpoint
checkpoint = torch.load("./pre-trained/pointnet_pre_trainder.pth", map_location="cuda")
state_dict = checkpoint['model']

# Strip "encoder." prefix and ignore classifier weights
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("encoder."):
        new_key = k[len("encoder."):]  # remove prefix
        new_state_dict[new_key] = v

# Initialize encoder
encoder = PointNetEncoder(in_channels=4)

# Load filtered state_dict
encoder.load_state_dict(new_state_dict)
encoder.eval()
encoder.to("cuda")


PointNetEncoder(
  (stn): STN3d(
    (conv1): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
    (fc1): Linear(in_features=1024, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (fc3): Linear(in_features=256, out_features=9, bias=True)
    (relu): ReLU()
    (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv0_1): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
  (conv0_2): Conv1d(64, 64, kernel_size=(1,), stride=(1

In [5]:
import torch
import torch.nn as nn

# Make sure your encoder is loaded
encoder = PointNetEncoder(in_channels=4)
encoder.eval()
encoder.to("cuda")

# Simple classifier head
num_classes = 15
classifier = nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512, num_classes)
)

# Wrapper to combine encoder + classifier
# Wrapper to combine encoder + classifier
class PointNetWithHead(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier
        
    def forward(self, x):
        # Pass x as B,C,N and avoid automatic transpose
        features = self.encoder.forward_cls_feat(x, x=x)
        out = self.classifier(features)
        return out


model = PointNetWithHead(encoder, classifier)
model.eval()
model.to("cuda")

# Generate some artificial point cloud data
B, C, N = 8, 4, 1024
fake_data = torch.randn(B, C, N).cuda()

# Forward pass
with torch.no_grad():
    outputs = model(fake_data)

print("Output shape:", outputs.shape)  # should be (B, num_classes)
print("Output:", outputs)


Output shape: torch.Size([8, 15])
Output: tensor([[-0.0148,  0.0100,  0.0501,  0.0233, -0.0017,  0.0111, -0.0213,  0.0035,
         -0.0051,  0.0198, -0.0114,  0.0101, -0.0501, -0.0183, -0.0177],
        [-0.0154,  0.0077,  0.0513,  0.0231, -0.0014,  0.0109, -0.0223,  0.0045,
         -0.0048,  0.0187, -0.0084,  0.0122, -0.0494, -0.0183, -0.0180],
        [-0.0156,  0.0080,  0.0508,  0.0236, -0.0035,  0.0103, -0.0226,  0.0065,
         -0.0035,  0.0197, -0.0093,  0.0097, -0.0519, -0.0178, -0.0179],
        [-0.0180,  0.0073,  0.0485,  0.0217, -0.0018,  0.0135, -0.0235,  0.0046,
         -0.0047,  0.0175, -0.0099,  0.0080, -0.0491, -0.0173, -0.0177],
        [-0.0160,  0.0076,  0.0483,  0.0227, -0.0017,  0.0117, -0.0213,  0.0040,
         -0.0036,  0.0179, -0.0108,  0.0086, -0.0477, -0.0190, -0.0193],
        [-0.0178,  0.0092,  0.0512,  0.0226, -0.0030,  0.0123, -0.0247,  0.0071,
         -0.0035,  0.0180, -0.0103,  0.0115, -0.0489, -0.0193, -0.0166],
        [-0.0153,  0.0074,  0.0503

In [6]:
import json
import torch
from torch.utils.data import Dataset
import numpy as np

class ConeClusterDataset(Dataset):
    def __init__(self, json_files, num_points=128):
        """
        json_files: list of paths to your JSON cluster files
        num_points: fixed number of points per cluster
        """
        self.samples = []
        self.num_points = num_points

        for file_path in json_files:
            with open(file_path, 'r') as f:
                data = json.load(f)
            for frame in data.values():
                for cluster in frame['clusters']:
                    points = cluster['points']
                    label = cluster['label']
                    if label == 255:
                        label=0
                    self.samples.append((points, label))

        

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        points, label = self.samples[idx]

        # Convert to numpy array (N, 4) [x, y, z, intensity]
        pts = np.array([[p['x'], p['y'], p['z'], p['i']] for p in points], dtype=np.float32)
        
        # Normalize coordinates relative to centroid
        centroid = pts[:, :3].mean(axis=0, keepdims=True)  # (1,3)
        pts[:, :3] -= centroid

        # Normalize intensity
        pts[:, 3] /= 255.0

        # Downsample or pad
        N = pts.shape[0]
        if N >= self.num_points:
            choice = np.random.choice(N, self.num_points, replace=False)
            pts = pts[choice]
        else:
            pad = np.zeros((self.num_points - N, pts.shape[1]), dtype=np.float32)
            pts = np.vstack([pts, pad])

        # Transpose to (features, points) for PointNet
        return torch.tensor(pts, dtype=torch.float32), torch.tensor(label, dtype=torch.long)



In [8]:
from torch.utils.data import DataLoader

dataset = ConeClusterDataset(["../data/fsg_accel_2024_08_16-11_06_03_recovered_filtered_0_labels.json", "../data/skidpad_2025-08-06-17_26_12_merged_filtered_0_labels.json"], num_points=128)
loader = DataLoader(dataset, batch_size=16, shuffle=True)


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np

# Assume PointNetEncoder and ConeClusterDataset are defined as before

# ----------------------------
# 1. Load Pretrained Encoder
# ----------------------------
encoder = PointNetEncoder(in_channels=4)  # 4 channels: x, y, z, intensity
checkpoint = torch.load("../models/pre-trained/pointnet_pre_trainder.pth", map_location="cpu")
state_dict = checkpoint['model']

# Remove the 'encoder.' prefix from keys if present
new_state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items() if "prediction" not in k}
encoder.load_state_dict(new_state_dict, strict=False)
encoder.eval()

# ----------------------------
# 2. Add a classifier head
# ----------------------------
num_classes = 4  # adjust to your labels
classifier = nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512, num_classes)
)

class PointNetWithHead(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier
    
    def forward(self, x):
        features = self.encoder.forward_cls_feat(x)  # get global features
        out = self.classifier(features)
        return out

model = PointNetWithHead(encoder, classifier)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# ----------------------------
# 3. Prepare Dataset & Dataloader
# ----------------------------
train_files = ["../data/fsg_accel_2024_08_16-11_06_03_recovered_filtered_0_labels.json", "../data/skidpad_2025-08-06-17_26_12_merged_filtered_0_labels.json"]  # replace with your paths
train_dataset = ConeClusterDataset(train_files, num_points=128)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)

# ----------------------------
# 4. Training Setup
# ----------------------------
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# ----------------------------
# 5. Training Loop
# ----------------------------
num_epochs = 5
model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    for data, labels in train_loader:
        data, labels = data.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * data.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")


Epoch 1/5, Loss: 0.7318, Accuracy: 0.7347
Epoch 2/5, Loss: 0.5335, Accuracy: 0.8109
Epoch 3/5, Loss: 0.4876, Accuracy: 0.8262
Epoch 4/5, Loss: 0.4569, Accuracy: 0.8383
Epoch 5/5, Loss: 0.4487, Accuracy: 0.8401
