# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for SemanticKITTI in PyTorch 2.4</span>  

## <span style="font-weight:bold">1. DataLoader</span>

### 1.1 Dataset Structure
The dataset should be structured as follows:
```
sequences/
├── 00/
│   ├── preprocess/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
├── 01/
│   ├── preprocess/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
```



In [None]:
import torch.optim as optim
import torch
import numpy as np

import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

from model.KITTISegmentationDataset import KITTISegmentationDataset
from model.RangeViTSegmentationModel import RangeViTSegmentationModel

from segmentation_models_pytorch.losses import FocalLoss, LovaszLoss


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = KITTISegmentationDataset('../sequences',['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'], training=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

val_dataset = KITTISegmentationDataset('../sequences',['08'], training=False)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)



In [None]:
# Use torchmetrics or do manually
from torchmetrics.classification import MulticlassJaccardIndex
# create a metric and put it on gpu
metric = MulticlassJaccardIndex(num_classes=20, average=None, ignore_index=0).to(device)

num_classes = 20
in_channels = 9 # range, x, y, z, intensity, flag, R, G, B
num_epochs = 60
model = RangeViTSegmentationModel(n_classes=num_classes, in_channels=in_channels).to(device)
# criterion = LovaszLoss(mode='multiclass', ignore_index=0, per_image=False)
focal = FocalLoss(mode='multiclass', ignore_index=0)
lovasz = LovaszLoss(mode='multiclass', ignore_index=0, per_image=False)
def criterion(outputs, targets):
    return focal(outputs, targets) + lovasz(outputs, targets)
optimizer = optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.01, betas=(0.9, 0.999))

from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0)

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, metric,epoch):
    model.train()
    total_loss = 0.0
    metric.reset()  # Reset the IoU metric for the next epoch
    batch_bar = tqdm(loader, desc=f"Training Epoch {epoch+1}", leave=False)
    for imgs, labels in batch_bar:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        preds = outputs.argmax(dim=1)
        metric.update(preds, labels)
        ious = metric.compute()
        mean_iou = torch.mean(ious[ious != 0])
        loss.backward()
        optimizer.step()
        batch_bar.set_postfix(loss=loss.item(), mIoU=mean_iou.item())
        total_loss += loss.item()
    print(f"Epoch [{epoch+1}] Loss: {total_loss/len(loader):.4f}, mIoU: {mean_iou.item():.4f}")


In [None]:
def eval_model(model, loader, criterion, metric):
    model.eval()
    total_loss = 0.0
    metric.reset()  # Reset the IoU metric for the evaluation
    with torch.no_grad():
        batch_bar = tqdm(loader, desc="Evaluating", leave=False)
        for imgs, labels in batch_bar:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            preds = outputs.argmax(dim=1)
            metric.update(preds, labels)
            ious = metric.compute()
            mean_iou = torch.mean(ious[ious != 0])
            batch_bar.set_postfix(loss=loss.item(), mIoU=mean_iou.item())
            total_loss += loss.item()
    print(f"Evaluation Loss: {total_loss/len(loader):.4f}, mIoU: {mean_iou.item():.4f}")


In [None]:
### Train the model
# Load the model if you have a pre-trained one
pretrain_path = 'range_vit_segmentation.pth'
if os.path.exists(pretrain_path):
    print(f"Loading pre-trained model from {pretrain_path}")
    model.load_state_dict(torch.load('range_vit_segmentation.pth'))
# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    train_one_epoch(model, loader, optimizer, criterion, metric,epoch)
    if epoch % 5 == 0:
        eval_model(model, val_loader, criterion, metric)
    scheduler.step()
    if (epoch == 0):
        ious = metric.compute()
        best_val_mIoU = torch.mean(ious[ious != 0])
    else:
        ious = metric.compute()
        current_val_mIoU = torch.mean(ious[ious != 0]).item()
        if current_val_mIoU > best_val_mIoU:
            best_val_mIoU = current_val_mIoU
            torch.save(model.state_dict(), 'range_vit_segmentation.pth')


In [None]:
# pretrain_path = 'range_vit_segmentation.pth'
# torch.save(model.state_dict(), pretrain_path)

In [None]:
# Validation with the best model
model.load_state_dict(torch.load('range_vit_segmentation.pth'))
model.eval()
metric.reset()  # Reset the IoU metric for validation
with torch.no_grad():
    for images, targets in val_loader:
        images = images.to(device)
        targets = targets.to(device)
        outputs = model(images)
        loss = criterion(outputs, targets)
        metric.update(outputs, targets)
    ious = metric.compute()
    val_mIoU = torch.mean(ious[ious != 0]).item()
    print(f"Validation mIoU: {val_mIoU:.4f}")


In [None]:
# print structure of model
# print(model)

In [None]:
# clear gpu memory
# torch.cuda.empty_cache()