# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for SemanticKITTI in PyTorch 2.4</span>  

## <span style="font-weight:bold">1. DataLoader</span>

### 1.1 Dataset Structure
The dataset should be structured as follows:
```
sequences/
├── 03/
│   ├── velodyne/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
│   ├── labels/
│   │   ├── 000000.label
│   │   ├── 000001.label
```



In [None]:
## DataLoader

import os
import torch
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as transforms

class KITTISegmentationDataset(Dataset):
    def __init__(self, root_dir, label_dir, input_size=(256, 256), transform=None):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.file_list = file_list  # list of filenames without extension
        self.input_size = input_size
        self.transform = transform

    # Read the point cloud data from binary files
    @staticmethod
    def readPCD(path):
        pcd = np.fromfile(path, dtype=np.float32).reshape(-1, 4)
        return pcd
    
    # Read the label data from files
    @staticmethod
    def readLabel(path):
        label = np.fromfile(path, dtype=np.int32)
        sem_label = label & 0xFFFF  # semantic label in lower half
        inst_label = label >> 16  # instance id in upper half
        return sem_label, inst_label
    
    def __len__(self):
        return len(self.file_list)

    def _read_image(self, path, shape, dtype=np.float32):
        return np.fromfile(path, dtype=dtype).reshape(shape)

    def __getitem__(self, idx):
        fname = self.file_list[idx]
        img_path = os.path.join(self.data_dir, f"{fname}.bin")
        label_path = os.path.join(self.label_dir, f"{fname}.bin")

        # Load binary data
        img = self._read_image(img_path, (*self.input_size, 8))  # shape [H, W, 8]
        label = self._read_image(label_path, self.input_size, dtype=np.int64)  # shape [H, W]

        img = torch.tensor(img).permute(2, 0, 1).float()  # to [8, H, W]
        label = torch.tensor(label).long()                # [H, W]

        if self.transform:
            img = self.transform(img)

        return img, label