# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for SemanticKITTI in PyTorch 2.4</span>  

## <span style="font-weight:bold">1. DataLoader</span>

### 1.1 Dataset Structure
The dataset should be structured as follows:
```
sequences/
├── 03/
│   ├── velodyne/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
│   ├── labels/
│   │   ├── 000000.label
│   │   ├── 000001.label
```



In [None]:
### Projection

import numpy as np

class ScanProjection(object):
    '''
    Project the 3D point cloud to 2D data with range projection

    Adapted from A. Milioto et al. https://github.com/PRBonn/lidar-bonnetal
    '''

    def __init__(self, proj_w, proj_h):
        # params of proj img size
        self.proj_w = proj_w
        self.proj_h = proj_h


    def doProjection(self, pointcloud: np.ndarray, label: np.array):
        # make sure pointcloud and label are the same length
        assert pointcloud.shape[0] == label.shape[0], "Pointcloud and label must have the same number of points"
        # get depth of all points
        depth = np.linalg.norm(pointcloud[:, :3], 2, axis=1)
        # get point cloud components
        x = pointcloud[:, 0]
        y = pointcloud[:, 1]
        z = pointcloud[:, 2]

        # get angles of all points
        yaw = -np.arctan2(y, -x)
        proj_x = 0.5 * (yaw / np.pi + 1.0)  # in [0.0, 1.0]
        #breakpoint()
        new_raw = np.nonzero((proj_x[1:] < 0.2) * (proj_x[:-1] > 0.8))[0] + 1
        proj_y = np.zeros_like(proj_x)
        proj_y[new_raw] = 1
        proj_y = np.cumsum(proj_y)
        # scale to image size using angular resolution
        proj_x = proj_x * self.proj_w - 0.001

        # round and clamp for use as index
        proj_x = np.maximum(np.minimum(
            self.proj_w - 1, np.floor(proj_x)), 0).astype(np.int32)

        proj_y = np.maximum(np.minimum(
            self.proj_h - 1, np.floor(proj_y)), 0).astype(np.int32)

        # order in decreasing depth
        indices = np.arange(depth.shape[0])
        order = np.argsort(depth)[::-1]
        depth = depth[order]
        indices = indices[order]
        pointcloud = pointcloud[order]
        proj_y = proj_y[order]
        proj_x = proj_x[order]
        label = label[order]

        # get projection result
        proj_range = np.full((self.proj_h, self.proj_w), -1, dtype=np.float32)
        proj_range[proj_y, proj_x] = depth

        proj_pointcloud = np.full((self.proj_h, self.proj_w, pointcloud.shape[1]), -1, dtype=np.float32)
        proj_pointcloud[proj_y, proj_x] = pointcloud

        proj_idx = np.full((self.proj_h, self.proj_w), -1, dtype=np.int32)
        proj_idx[proj_y, proj_x] = indices

        proj_label = np.full((self.proj_h, self.proj_w), 0, dtype=np.int32)
        proj_label[proj_y, proj_x] = label

        # create proj_tensor with cascade proj_pointcloud and proj_range
        # proj_pointcloud has size (64, 2048, 4)
        # proj_range has size (64, 2048)
        proj_tensor = np.concatenate((proj_pointcloud, proj_range[..., np.newaxis]), axis=-1)
        return proj_tensor, proj_label
    

In [None]:
### DataLoader

import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class KITTISegmentationDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.pc_dir = os.path.join(root_dir, 'velodyne/colorized')
        self.label_dir = os.path.join(root_dir, 'labels')
        # Get the list of files (no extension) in the point cloud directory
        self.file_list = [f[:-4] for f in os.listdir(self.pc_dir) if f.endswith('.bin')]
        # Setup the projection parameters
        self.projection = ScanProjection(proj_w=2048, proj_h=64)
        # Define the learning map for semantic labels
        # This map is used to convert the original labels to a smaller set of classes
        self.learning_map = {0: 0, 1: 0, 10: 1, 11: 2, 13: 5, 15: 3, 16: 5, 18: 4, 20: 5,
            30: 6, 31: 7, 32: 8, 40: 9, 44: 10, 48: 11, 49: 12, 50: 13,
            51: 14, 52: 0, 60: 9, 70: 15, 71: 16, 72: 17, 80: 18, 81: 19,
            99: 0, 252: 1, 253: 7, 254: 6, 255: 8, 256: 5, 257: 5, 258: 4, 259: 5}
        # Create a mapping array with size large enough to cover the largest key
        self.max_key = max(self.learning_map.keys())
        self.map_array = np.zeros((self.max_key + 1,), dtype=np.int32)
        # Fill the mapping array with the learning map values
        for key, value in self.learning_map.items():
            self.map_array[key] = value
            
    # Read the point cloud data from binary files
    @staticmethod
    def readPCD(path):
        pcd = np.fromfile(path, dtype=np.float32).reshape(-1, 8) # 4 channels: x, y, z, intensity, flag, R, G, B
        return pcd
    
    # Read the label data from files
    @staticmethod
    def readLabel(path):
        label = np.fromfile(path, dtype=np.int32)
        sem_label = label & 0xFFFF  # semantic label in lower half
        inst_label = label >> 16  # instance id in upper half
        return sem_label, inst_label
    
    def __len__(self):
        return len(self.file_list)

    @staticmethod
    def proj(pc): # perform projection here
        return pc

    def __getitem__(self, idx):
        fname = self.file_list[idx]
        pc_path = os.path.join(self.pc_dir, f"{fname}.bin")
        label_path = os.path.join(self.label_dir, f"{fname}.label")

        # Load binary data
        pc = self.readPCD(pc_path)  # x, y, z, intensity
        label,_ = self.readLabel(label_path)  # shape [H, W]
        # Map the labels using the learning map
        label = self.map_array[label]  # map to smaller set of classes
        img, label = self.projection.doProjection(pc, label) # shape [H, W, C]
        img = torch.tensor(img).permute(2, 0, 1).float()  # to [C, H, W]
        label = torch.tensor(label).long()                # [H, W]

        return img, label

In [None]:
from torch.utils.data import Dataset, DataLoader

dataset = KITTISegmentationDataset('../SemanticKITTI/dataset/sequences/03')
loader = DataLoader(dataset, batch_size=1, shuffle=True)


In [None]:
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

class RangeViTSegmentationModel(nn.Module):
    def __init__(self, in_channels, n_classes):
        super().__init__()

        # Create ViT model without features_only to see what we actually get
        self.backbone = timm.create_model(
            'vit_small_patch16_384',       
            pretrained=True,
            in_chans=in_channels,
            num_classes=0,  # Set num_classes to 0 to avoid classification head; this has no effect on number of classes in seg_head (still 20)
            global_pool='', # disables CLS token pooling
            features_only=False  # Don't use features_only
        )
        
        # Get the actual feature dimension from the model
        feat_dim = 384  # This should be 384 for vit_small
        hidden_dim = 256
        # Print for debugging
        print(f"ViT feature dimension: {feat_dim}")
        print(f"number of classes: {n_classes}")
        # Create segmentation head with the correct input dimension
        self.seg_head = nn.Sequential(
            nn.Conv2d(feat_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, n_classes, kernel_size=1)
        )
        
        self.original_size = None  # Store original size for resizing back

    def forward(self, x):
        # Store original size for later upsampling
        self.original_size = x.shape[2:]
        # Resize input to 384x384 (what ViT expects)
        x_resized = F.interpolate(x, size=(384, 384), mode='bilinear', align_corners=False)
        # Extract features from backbone
        feats = self.backbone(x_resized)
        # Reshape features for segmentation head 
        # ViT returns tokens, we need to reshape to 2D feature map
        B = x_resized.shape[0]
        h = w = int(384 / 16)  # 16 is patch size of vit_small_patch16_384
        C = feats.shape[-1]
        # Remove CLS token and reshape to [B, C, h, w]
        feats = feats[:, 1:, :].reshape(B, h, w, C).permute(0, 3, 1, 2)
        
        # Apply segmentation head
        logits = self.seg_head(feats)
        
        # Resize back to original dimensions
        return F.interpolate(logits, size=self.original_size, mode='bilinear', align_corners=False)

In [None]:
import torch

def compute_iou(preds, labels, num_classes, ignore_index=None):
    """
    preds: [B, H, W] — predicted class indices
    labels: [B, H, W] — ground-truth class indices
    """
    ious = []
    for cls in range(num_classes):
        if ignore_index is not None and cls == ignore_index:
            continue

        pred_cls = (preds == cls)
        label_cls = (labels == cls)

        intersection = (pred_cls & label_cls).sum().float()
        union = (pred_cls | label_cls).sum().float()

        if union == 0:
            ious.append(torch.tensor(float('nan')))  # skip class if no samples
        else:
            ious.append(intersection / union)

    # Compute mean ignoring NaNs
    ious = torch.stack(ious)
    mIoU = torch.nanmean(ious).item()
    return mIoU, ious

In [None]:
### Train the model
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 20
in_channels = 9 # range, x, y, z, intensity, flag, R, G, B
num_epochs = 10
model = RangeViTSegmentationModel(n_classes=num_classes, in_channels=in_channels).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Load the model if you have a pre-trained one
model.load_state_dict(torch.load('range_vit_segmentation.pth'))
# Training loop
for epoch in range(num_epochs):  # or more
    model.train() # a switch that tells the model to be in training mode. It doesn't actually perform any training computations itself
    for imgs, labels in loader:
        valid_mask = (imgs[:, 5, :, :] > 0)  # Assuming the fifth channel is range, and we want to ignore invalid points
        valid_mask = valid_mask.unsqueeze(1)  # [B, 1, H, W] to match channel dimension
        valid_mask = valid_mask.to(device)
        imgs = imgs.to(device)                # [B, C, H, W]
        labels = labels.to(device)             # [B, H, W]

        optimizer.zero_grad()
        # actually perform the training step
        outputs = model(imgs)                 # [B, num_classes, H, W]
        loss = criterion(outputs, labels)     # Compute loss
        preds = outputs.argmax(dim=1)         # [B, H, W]
        mIoU, ious = compute_iou(preds, labels, num_classes, ignore_index=0)  # ignore background class
        loss.backward()  # Calculates gradients of the loss with respect to all model parameters
        optimizer.step() # Updates Parameter 

    print(f"Epoch [{epoch+1}] Loss: {loss.item():.4f}")

In [None]:
# save model
torch.save(model.state_dict(), 'range_vit_segmentation.pth')