# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for SemanticKITTI in PyTorch 2.4</span>  

## <span style="font-weight:bold">1. DataLoader</span>

### 1.1 Dataset Structure
The dataset should be structured as follows:
```
sequences/
├── 00/
│   ├── preprocess/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
├── 01/
│   ├── preprocess/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
```

Libraries required: timm, torch, tqdm



In [1]:
# install dependencies in the requirements.txt file
# !pip install -r requirements.txt

In [2]:
import torch.optim as optim
import torch
import numpy as np

import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm
from timm.models.vision_transformer import PatchEmbed


In [3]:
### Projection
class ScanProjection(object):
    '''
    Project the 3D point cloud to 2D data with range projection

    Adapted from A. Milioto et al. https://github.com/PRBonn/lidar-bonnetal
    '''

    def __init__(self, proj_w, proj_h):
        # params of proj img size
        self.proj_w = proj_w
        self.proj_h = proj_h


    def doProjection(self, pointcloud: np.ndarray):

        # get depth of all points
        depth = np.linalg.norm(pointcloud[:, :3], 2, axis=1)
        # get point cloud components
        x = pointcloud[:, 0]
        y = pointcloud[:, 1]
        z = pointcloud[:, 2]
        # label is the last column of pointcloud
        label = pointcloud[:,-1]
        # remove the last column from pointcloud
        pointcloud = pointcloud[:, :-1]
        # remove flag, R, G, and B from pointcloud
        pointcloud = pointcloud[:, :-4]  # now only has [x, y, z, intensity]
        # get angles of all points
        yaw = -np.arctan2(y, -x)
        proj_x = 0.5 * (yaw / np.pi + 1.0)  # in [0.0, 1.0]
        #breakpoint()
        new_raw = np.nonzero((proj_x[1:] < 0.2) * (proj_x[:-1] > 0.8))[0] + 1
        proj_y = np.zeros_like(proj_x)
        proj_y[new_raw] = 1
        proj_y = np.cumsum(proj_y)
        # scale to image size using angular resolution
        proj_x = proj_x * self.proj_w - 0.001

        # round and clamp for use as index
        proj_x = np.maximum(np.minimum(
            self.proj_w - 1, np.floor(proj_x)), 0).astype(np.int32)

        proj_y = np.maximum(np.minimum(
            self.proj_h - 1, np.floor(proj_y)), 0).astype(np.int32)

        # order in decreasing depth
        indices = np.arange(depth.shape[0])
        order = np.argsort(depth)[::-1]
        depth = depth[order]
        indices = indices[order]
        pointcloud = pointcloud[order]
        proj_y = proj_y[order]
        proj_x = proj_x[order]
        label = label[order]

        # get projection result
        proj_range = np.full((self.proj_h, self.proj_w), -1, dtype=np.float32)
        proj_range[proj_y, proj_x] = depth

        proj_pointcloud = np.full((self.proj_h, self.proj_w, pointcloud.shape[1]), -1, dtype=np.float32)
        proj_pointcloud[proj_y, proj_x] = pointcloud

        proj_idx = np.full((self.proj_h, self.proj_w), -1, dtype=np.int32)
        proj_idx[proj_y, proj_x] = indices

        proj_label = np.full((self.proj_h, self.proj_w), 0, dtype=np.int32)
        proj_label[proj_y, proj_x] = label

        # create proj_tensor with cascade proj_pointcloud and proj_range
        # proj_pointcloud has size (64, 2048, 4)
        # proj_range has size (64, 2048)
        proj_tensor = np.concatenate((proj_range[..., np.newaxis], proj_pointcloud), axis=-1) # [range, x, y, z, flag, R, G, B]
        return proj_tensor, proj_label
    

In [4]:
### DataLoader
class KITTISegmentationDataset(Dataset):
    def __init__(self, root_dir, sequences):
        self.root_dir = root_dir
        self.file_list = []
        for seq in sequences:
            seq_dir = os.path.join(root_dir, seq)
            assert os.path.exists(seq_dir), f"Sequence {seq} does not exist in {root_dir}"
            file_list = []
            pc_dir = os.path.join(seq_dir, 'preprocess')
            # Get the list of files (full path) in the point cloud directory
            file_list = [os.path.join(pc_dir, f) for f in os.listdir(pc_dir) if f.endswith('.bin')]
            self.file_list.extend(file_list)
        # Setup the projection parameters
        self.projection = ScanProjection(proj_w=2048, proj_h=64)
        # Define the learning map for semantic labels
        # This map is used to convert the original labels to a smaller set of classes
        self.learning_map = {0: 0, 1: 0, 10: 1, 11: 2, 13: 5, 15: 3, 16: 5, 18: 4, 20: 5,
            30: 6, 31: 7, 32: 8, 40: 9, 44: 10, 48: 11, 49: 12, 50: 13,
            51: 14, 52: 0, 60: 9, 70: 15, 71: 16, 72: 17, 80: 18, 81: 19,
            99: 0, 252: 1, 253: 7, 254: 6, 255: 8, 256: 5, 257: 5, 258: 4, 259: 5}
        # Create a mapping array with size large enough to cover the largest key
        self.max_key = max(self.learning_map.keys())
        self.map_array = np.zeros((self.max_key + 1,), dtype=np.int32)
        # Fill the mapping array with the learning map values
        for key, value in self.learning_map.items():
            self.map_array[key] = value
            
    # Read the point cloud data from binary files
    @staticmethod
    def readPCD(path):
        pcd = np.fromfile(path, dtype=np.float32).reshape(-1, 9) # 9 channels: x, y, z, intensity, flag, R, G, B, label
        return pcd
  
    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        pc_path = self.file_list[idx]

        # Load binary data
        pc = self.readPCD(pc_path)  # x, y, z, intensity
        img, label = self.projection.doProjection(pc) # shape [H, W, C]
        # Map the labels using the learning map
        label = self.map_array[label]  # map to smaller set of classes
        img = torch.tensor(img).permute(2, 0, 1).float()  # to [C, H, W]
        label = torch.tensor(label).long()                # [H, W]
        # Normalize the tensor
        mean = torch.tensor([12.12, 10.88, 0.23, -1.04, 0.21])
        std = torch.tensor([12.32, 11.47, 6.91, 0.86, 0.16])
        img = (img - mean[:, None, None]) / std[:, None, None]
        return img, label

In [5]:
# Hardware to run on; Uncomment appropriate lines
# runpod cloud RTX 4090: ~ 10 it/s: paralell might not needed since training takes about 3 hours.
# Powerful GPU so increase the batch size for faster training, the num_workers also increase so that the data loading is not a bottleneck
dataset = KITTISegmentationDataset('./dataset/sequences',['00','01','02','03','04','05','06','07','09','10'])
loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=16)
dataset_val = KITTISegmentationDataset('./dataset/sequences',['08'])
loader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, num_workers = 16)
# local Legion computer: ~ 2 it/s
# batch_size and num_workers are set to 1 due to limited resources
# dataset = KITTISegmentationDataset('../SemanticKITTI/dataset/sequences',['00','01','02','03','04','05','06','07','09','10'])
# loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
# dataset_val = KITTISegmentationDataset('../SemanticKITTI/dataset/sequences',['08'])
# loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1)

In [6]:

class RangeViTSegmentationModel(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.input_height = 64
        self.input_width = 2048
        self.patch_height = 2
        self.patch_width = 8

        self.backbone = timm.create_model(
            'vit_small_patch16_384',
            pretrained=True,
            in_chans=in_channels,
            num_classes=0,
            global_pool='',
            features_only=False
        )

        # Override patch embedding
        self.backbone.patch_embed = PatchEmbed(
            img_size=(self.input_height, self.input_width),
            patch_size=(self.patch_height, self.patch_width),
            in_chans=in_channels,
            embed_dim=self.backbone.embed_dim
        )

        self.grid_h, self.grid_w = self.backbone.patch_embed.grid_size  # (32, 256)
        self.num_patches = self.grid_h * self.grid_w
        print(f"Grid size: {self.grid_h} x {self.grid_w}, Patches: {self.num_patches}")

        expected_tokens = 1 + self.num_patches
        if self.backbone.pos_embed.shape[1] != expected_tokens:
            self.update_pos_embed()

        self.seg_head = nn.Sequential(
            nn.Conv2d(self.backbone.embed_dim, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

        self.original_size = None

    def update_pos_embed(self):
        old_pos_embed = self.backbone.pos_embed
        cls_token = old_pos_embed[:, :1, :]
        patch_pos = old_pos_embed[:, 1:, :]

        # Original pretrained ViT size was 24x24 : (384x384)/(16x16)
        patch_pos = patch_pos.reshape(1, 24, 24, -1).permute(0, 3, 1, 2)
        patch_pos = F.interpolate(patch_pos, size=(self.grid_h, self.grid_w), mode='bilinear', align_corners=False)
        patch_pos = patch_pos.permute(0, 2, 3, 1).reshape(1, self.num_patches, -1)
        new_pos_embed = torch.cat([cls_token, patch_pos], dim=1)
        self.backbone.pos_embed = nn.Parameter(new_pos_embed)

    def forward(self, x):
        B = x.shape[0]
        self.original_size = x.shape[2:]  # Expect (64, 2048)

        # DO NOT resize
        feats = self.backbone(x)  # [B, 8193, C]
        C = feats.shape[-1]
        feats = feats[:, 1:, :].reshape(B, self.grid_h, self.grid_w, C).permute(0, 3, 1, 2)

        logits = self.seg_head(feats)  # [B, num_classes, 32, 256]
        logits = F.interpolate(logits, size=self.original_size, mode='bilinear', align_corners=False)
        return logits

In [7]:
def compute_iou(preds, labels, num_classes):
    ious = []
    correct = (preds == labels)
    accuracy = correct.sum().float() / labels.numel()

    for cls in range(num_classes):
        # Get binary predictions and labels for this class
        pred_cls = (preds == cls)
        label_cls = (labels == cls)

        # Intersection and Union
        intersection = (pred_cls & label_cls).sum().float()
        union = (pred_cls | label_cls).sum().float()

        if union == 0:
            ious.append(torch.tensor(float('nan'), device=device))  # undefined for this class
        else:
            ious.append(intersection / union)

    # Mean IoU (excluding NaNs)
    ious_tensor = torch.stack(ious)
    mIoU = torch.nanmean(ious_tensor)

    return mIoU, ious_tensor, accuracy

In [8]:
### Train the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
num_classes = 20
in_channels = 5 # range, x, y, z, intensity, flag, R, G, B
num_epochs = 60
model = RangeViTSegmentationModel(num_classes=num_classes, in_channels=in_channels).to(device)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel!")
    model = nn.DataParallel(model)

model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.0004)
# Load the model if you have a pre-trained one
pretrain_path = 'range_vit_segmentation_noRGB_patch.pth'
if os.path.exists(pretrain_path):
    print(f"Loading pre-trained model from {pretrain_path}")
    model.load_state_dict(torch.load(pretrain_path, map_location=device))
# Training loop
best_val_mIoU = 0.0
model.train() # a switch that tells the model to be in training mode. It doesn't actually perform any training computations itself
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    batch_bar = tqdm(loader, desc=f"Training Epoch {epoch+1}", leave=False)
    average_loss = 0.0
    average_acc = 0.0
    average_mIoU = 0.0
    for imgs, labels in batch_bar:
        imgs = imgs.to(device)                # [B, C, H, W]
        labels = labels.to(device)             # [B, H, W]
        optimizer.zero_grad()
        # actually perform the training step
        outputs = model(imgs)                 # [B, num_classes, H, W]
        loss = criterion(outputs, labels)     # Compute raw loss

        preds = outputs.argmax(dim=1)         # [B, H, W]
        mIoU, ious, acc = compute_iou(preds, labels, num_classes) 
        loss.backward()  # Calculates gradients of the loss with respect to all model parameters
        optimizer.step() # Updates Parameter 
        batch_bar.set_postfix(loss=loss.item(), mIoU=mIoU.item(), acc=acc.item())
        average_loss += loss.item()
        average_acc += acc.item()
        average_mIoU += mIoU.item()
        
    print(f"Epoch [{epoch+1}] Train Loss: {average_loss/len(loader):.4f}, Train mIoU: {average_mIoU/len(loader):.4f}, Train Acc: {average_acc/len(loader):.4f}")

    model.eval()  # <-- switch to eval mode
    with torch.no_grad():  # turn off gradient tracking for speed and memory
        average_loss = 0.0
        average_acc = 0.0
        average_mIoU = 0.0
        batch_bar = tqdm(loader_val, desc=f"Evaluating", leave=False)
        for imgs, labels in batch_bar:
    
            imgs = imgs.to(device)                # [B, C, H, W]
            labels = labels.to(device)             # [B, H, W]
    
            outputs = model(imgs)                 # [B, num_classes, H, W]
            loss = criterion(outputs, labels)     # Compute raw loss
    
            preds = outputs.argmax(dim=1)         # [B, H, W]
            mIoU, ious, acc = compute_iou(preds, labels, num_classes) 
            batch_bar.set_postfix(loss=loss.item(), mIoU=mIoU.item(), acc=acc.item())
            average_loss += loss.item()
            average_acc += acc.item()
            average_mIoU += mIoU.item()
            
        print(f"Validation Loss: {average_loss/len(loader_val):.4f}, Validation mIoU: {average_mIoU/len(loader_val):.4f}, Validation Acc: {average_acc/len(loader_val):.4f}")
        val_mIoU = average_mIoU/len(loader_val)
        if val_mIoU > best_val_mIoU:
            best_val_mIoU = val_mIoU
            print('saving better model...')
            torch.save(model.state_dict(), pretrain_path)

cuda


model.safetensors:   0%|          | 0.00/88.8M [00:00<?, ?B/s]

Grid size: 32 x 256, Patches: 8192
Using 4 GPUs with DataParallel!


Epochs:   0%|          | 0/60 [00:00<?, ?it/s]

Training Epoch 1:   0%|          | 0/1196 [00:00<?, ?it/s]

Epoch [1] Train Loss: 0.4803, Train mIoU: 0.2786, Train Acc: 0.6775


Evaluating:   0%|          | 0/255 [00:00<?, ?it/s]

Validation Loss: 0.4488, Validation mIoU: 0.2529, Validation Acc: 0.6863
saving better model...


Training Epoch 2:   0%|          | 0/1196 [00:00<?, ?it/s]

Epoch [2] Train Loss: 0.2558, Train mIoU: 0.3818, Train Acc: 0.7377


Evaluating:   0%|          | 0/255 [00:00<?, ?it/s]

Validation Loss: 0.4263, Validation mIoU: 0.2726, Validation Acc: 0.6968
saving better model...


Training Epoch 3:   0%|          | 0/1196 [00:00<?, ?it/s]

Epoch [3] Train Loss: 0.2015, Train mIoU: 0.4310, Train Acc: 0.7522


Evaluating:   0%|          | 0/255 [00:00<?, ?it/s]

Validation Loss: 0.4496, Validation mIoU: 0.2797, Validation Acc: 0.6993
saving better model...


Training Epoch 4:   0%|          | 0/1196 [00:00<?, ?it/s]

Epoch [4] Train Loss: 0.1770, Train mIoU: 0.4560, Train Acc: 0.7587


Evaluating:   0%|          | 0/255 [00:00<?, ?it/s]

Validation Loss: 0.4504, Validation mIoU: 0.2876, Validation Acc: 0.7029
saving better model...


Training Epoch 5:   0%|          | 0/1196 [00:00<?, ?it/s]

KeyboardInterrupt: 