# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for SemanticKITTI in PyTorch 2.4</span>  

## <span style="font-weight:bold">1. DataLoader</span>

### 1.1 Dataset Structure
The dataset should be structured as follows:
```
sequences/
├── 00/
│   ├── preprocess/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
├── 01/
│   ├── preprocess/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
```



In [1]:
import torch.optim as optim
import torch
import numpy as np

import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

from model.KITTISegmentationDataset import KITTISegmentationDataset
from model.RangeViTSegmentationModel import RangeViTSegmentationModel

from segmentation_models_pytorch.losses import FocalLoss, LovaszLoss


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dataset = KITTISegmentationDataset('../sequences',['03'], training=True)
dataset = KITTISegmentationDataset('../sequences',['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'], training=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

val_dataset = KITTISegmentationDataset('../sequences',['08'], training=False)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)



In [3]:
# Use torchmetrics or do manually
from torchmetrics.classification import MulticlassJaccardIndex
# create a metric and put it on gpu
metric = MulticlassJaccardIndex(num_classes=20, average=None, ignore_index=0).to(device)

num_classes = 20
in_channels = 9 # range, x, y, z, intensity, flag, R, G, B
num_epochs = 60
model = RangeViTSegmentationModel(n_classes=num_classes, in_channels=in_channels).to(device)
# criterion = LovaszLoss(mode='multiclass', ignore_index=0, per_image=False)
focal = FocalLoss(mode='multiclass', ignore_index=0)
lovasz = LovaszLoss(mode='multiclass', ignore_index=0, per_image=False)
def criterion(outputs, targets):
    return focal(outputs, targets) + lovasz(outputs, targets)
optimizer = optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.01, betas=(0.9, 0.999))

from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0)

In [4]:
def train_one_epoch(model, loader, optimizer, criterion, metric,epoch):
    model.train()
    total_loss = 0.0
    metric.reset()  # Reset the IoU metric for the next epoch
    batch_bar = tqdm(loader, desc=f"Training Epoch {epoch+1}", leave=False)
    for imgs, labels in batch_bar:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        preds = outputs.argmax(dim=1)
        metric.update(preds, labels)
        ious = metric.compute()
        mean_iou = torch.mean(ious[ious != 0])
        loss.backward()
        optimizer.step()
        batch_bar.set_postfix(loss=loss.item(), mIoU=mean_iou.item())
        total_loss += loss.item()
    print(f"Epoch [{epoch+1}] Loss: {total_loss/len(loader):.4f}, mIoU: {mean_iou.item():.4f}")


In [5]:
def eval_model(model, loader, criterion, metric):
    model.eval()
    total_loss = 0.0
    metric.reset()  # Reset the IoU metric for the evaluation
    with torch.no_grad():
        batch_bar = tqdm(loader, desc="Evaluating", leave=False)
        for imgs, labels in batch_bar:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            preds = outputs.argmax(dim=1)
            metric.update(preds, labels)
            ious = metric.compute()
            mean_iou = torch.mean(ious[ious != 0])
            batch_bar.set_postfix(loss=loss.item(), mIoU=mean_iou.item())
            total_loss += loss.item()
    print(f"Evaluation Loss: {total_loss/len(loader):.4f}, mIoU: {mean_iou.item():.4f}")


In [6]:
# Load the model if you have a pre-trained one
pretrain_path = 'range_vit_segmentation.pth'
if os.path.exists(pretrain_path):
    print(f"Loading pre-trained model from {pretrain_path}")
    model.load_state_dict(torch.load('range_vit_segmentation.pth'))

Loading pre-trained model from range_vit_segmentation.pth


In [7]:
### Train the model
best_val_mIoU = 0.0
# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    train_one_epoch(model, loader, optimizer, criterion, metric,epoch)
    if epoch % 5 == 0: # Evaluate every 5 epochs
        eval_model(model, val_loader, criterion, metric)
        ious = metric.compute()
        current_val_mIoU = torch.mean(ious[ious != 0]).item()
        if current_val_mIoU > best_val_mIoU:
            best_val_mIoU = current_val_mIoU
            torch.save(model.state_dict(), 'range_vit_segmentation.pth')
    scheduler.step()


Epochs:   0%|          | 0/60 [00:00<?, ?it/s]

Training Epoch 1:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [1] Loss: 0.4353, mIoU: 0.7152


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.7990, mIoU: 0.4686


Training Epoch 2:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [2] Loss: 0.2776, mIoU: 0.8450


Training Epoch 3:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [3] Loss: 0.2801, mIoU: 0.8453


Training Epoch 4:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [4] Loss: 0.2628, mIoU: 0.8577


Training Epoch 5:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [5] Loss: 0.2525, mIoU: 0.8647


Training Epoch 6:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [6] Loss: 0.2482, mIoU: 0.8661


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.8381, mIoU: 0.4565


Training Epoch 7:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [7] Loss: 0.2555, mIoU: 0.8633


Training Epoch 8:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [8] Loss: 0.2507, mIoU: 0.8637


Training Epoch 9:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [9] Loss: 0.2304, mIoU: 0.8785


Training Epoch 10:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [10] Loss: 0.2246, mIoU: 0.8835


Training Epoch 11:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [11] Loss: 0.2350, mIoU: 0.8781


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.8340, mIoU: 0.4825


Training Epoch 12:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [12] Loss: 0.2216, mIoU: 0.8874


Training Epoch 13:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [13] Loss: 0.2255, mIoU: 0.8845


Training Epoch 14:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [14] Loss: 0.2125, mIoU: 0.8926


Training Epoch 15:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [15] Loss: 0.2070, mIoU: 0.8952


Training Epoch 16:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [16] Loss: 0.2126, mIoU: 0.8928


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.8779, mIoU: 0.4609


Training Epoch 17:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [17] Loss: 0.2074, mIoU: 0.8947


Training Epoch 18:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [18] Loss: 0.2053, mIoU: 0.8970


Training Epoch 19:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [19] Loss: 0.2011, mIoU: 0.8997


Training Epoch 20:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [20] Loss: 0.1932, mIoU: 0.9035


Training Epoch 21:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [21] Loss: 0.1953, mIoU: 0.9006


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.8739, mIoU: 0.4654


Training Epoch 22:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [22] Loss: 0.1976, mIoU: 0.9011


Training Epoch 23:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [23] Loss: 0.1887, mIoU: 0.9053


Training Epoch 24:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [24] Loss: 0.1869, mIoU: 0.9076


Training Epoch 25:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [25] Loss: 0.1825, mIoU: 0.9093


Training Epoch 26:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [26] Loss: 0.1840, mIoU: 0.9076


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.9062, mIoU: 0.4947


Training Epoch 27:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [27] Loss: 0.1815, mIoU: 0.9083


Training Epoch 28:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [28] Loss: 0.1782, mIoU: 0.9122


Training Epoch 29:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [29] Loss: 0.1758, mIoU: 0.9123


Training Epoch 30:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [30] Loss: 0.1738, mIoU: 0.9139


Training Epoch 31:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [31] Loss: 0.1715, mIoU: 0.9147


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.9463, mIoU: 0.4913


Training Epoch 32:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [32] Loss: 0.1703, mIoU: 0.9155


Training Epoch 33:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [33] Loss: 0.1688, mIoU: 0.9164


Training Epoch 34:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [34] Loss: 0.1666, mIoU: 0.9180


Training Epoch 35:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [35] Loss: 0.1655, mIoU: 0.9191


Training Epoch 36:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [36] Loss: 0.1612, mIoU: 0.9203


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.9619, mIoU: 0.4700


Training Epoch 37:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [37] Loss: 0.1615, mIoU: 0.9207


Training Epoch 38:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [38] Loss: 0.1574, mIoU: 0.9217


Training Epoch 39:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [39] Loss: 0.1573, mIoU: 0.9224


Training Epoch 40:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [40] Loss: 0.1553, mIoU: 0.9233


Training Epoch 41:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [41] Loss: 0.1552, mIoU: 0.9238


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.9773, mIoU: 0.4973


Training Epoch 42:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [42] Loss: 0.1529, mIoU: 0.9246


Training Epoch 43:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [43] Loss: 0.1528, mIoU: 0.9252


Training Epoch 44:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [44] Loss: 0.1501, mIoU: 0.9260


Training Epoch 45:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [45] Loss: 0.1489, mIoU: 0.9263


Training Epoch 46:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [46] Loss: 0.1478, mIoU: 0.9273


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 0.9855, mIoU: 0.4666


Training Epoch 47:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [47] Loss: 0.1481, mIoU: 0.9275


Training Epoch 48:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [48] Loss: 0.1462, mIoU: 0.9284


Training Epoch 49:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [49] Loss: 0.1439, mIoU: 0.9284


Training Epoch 50:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [50] Loss: 0.1433, mIoU: 0.9293


Training Epoch 51:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [51] Loss: 0.1439, mIoU: 0.9291


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 1.0054, mIoU: 0.4938


Training Epoch 52:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [52] Loss: 0.1436, mIoU: 0.9295


Training Epoch 53:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [53] Loss: 0.1428, mIoU: 0.9294


Training Epoch 54:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [54] Loss: 0.1433, mIoU: 0.9297


Training Epoch 55:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [55] Loss: 0.1416, mIoU: 0.9301


Training Epoch 56:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [56] Loss: 0.1420, mIoU: 0.9303


Evaluating:   0%|          | 0/509 [00:00<?, ?it/s]

Evaluation Loss: 1.0095, mIoU: 0.4627


Training Epoch 57:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [57] Loss: 0.1430, mIoU: 0.9295


Training Epoch 58:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [58] Loss: 0.1416, mIoU: 0.9303


Training Epoch 59:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [59] Loss: 0.1402, mIoU: 0.9308


Training Epoch 60:   0%|          | 0/598 [00:00<?, ?it/s]

Epoch [60] Loss: 0.1400, mIoU: 0.9305


In [8]:
# copy weights when two models doesn't have exact same architecture
# old_dict = torch.load('range_vit_segmentation_4616.pth')
# from model.model_utils import approximately_clone_state_dict
# new_dict = approximately_clone_state_dict(model.state_dict(), old_dict)
# model.load_state_dict(new_dict)
# torch.save(model.state_dict(), 'range_vit_segmentation.pth')

In [15]:
val_dataset = KITTISegmentationDataset('../sequences',['00'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/4541 [00:00<?, ?it/s]

Evaluation Loss: 0.1446, mIoU: 0.9437


In [16]:
val_dataset = KITTISegmentationDataset('../sequences',['01'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/1101 [00:00<?, ?it/s]

Evaluation Loss: 0.2052, mIoU: 0.8457


In [17]:
val_dataset = KITTISegmentationDataset('../sequences',['02'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/4661 [00:00<?, ?it/s]

Evaluation Loss: 0.1664, mIoU: 0.9472


In [18]:
val_dataset = KITTISegmentationDataset('../sequences',['03'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/801 [00:00<?, ?it/s]

Evaluation Loss: 0.1475, mIoU: 0.9076


In [19]:
val_dataset = KITTISegmentationDataset('../sequences',['04'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/271 [00:00<?, ?it/s]

Evaluation Loss: 0.1674, mIoU: 0.9146


In [20]:
val_dataset = KITTISegmentationDataset('../sequences',['05'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/2761 [00:00<?, ?it/s]

Evaluation Loss: 0.1709, mIoU: 0.9410


In [21]:
val_dataset = KITTISegmentationDataset('../sequences',['06'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/1101 [00:00<?, ?it/s]

Evaluation Loss: 0.1358, mIoU: 0.9575


In [22]:
val_dataset = KITTISegmentationDataset('../sequences',['07'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/1101 [00:00<?, ?it/s]

Evaluation Loss: 0.1332, mIoU: 0.9619


In [23]:
val_dataset = KITTISegmentationDataset('../sequences',['08'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/4071 [00:00<?, ?it/s]

Evaluation Loss: 0.9823, mIoU: 0.4645


In [24]:
val_dataset = KITTISegmentationDataset('../sequences',['09'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/1591 [00:00<?, ?it/s]

Evaluation Loss: 0.1594, mIoU: 0.9538


In [25]:
val_dataset = KITTISegmentationDataset('../sequences',['10'], training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/1201 [00:00<?, ?it/s]

Evaluation Loss: 0.1888, mIoU: 0.9340


In [26]:
# Validation with the best model
# model = RangeViTSegmentationModel(n_classes=num_classes, in_channels=in_channels).to(device)
# model.load_state_dict(torch.load('range_vit_segmentation.pth'))
# eval_model(model, val_loader, criterion, metric)


In [27]:
# torch.save(model.state_dict(), 'range_vit_segmentation.pth')

In [28]:
# print structure of model
# print(model)

In [29]:
# clear cuda memory

torch.cuda.empty_cache()