# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for SemanticKITTI in PyTorch 2.4</span>  

## <span style="font-weight:bold">1. DataLoader</span>

### 1.1 Dataset Structure
The dataset should be structured as follows:
```
sequences/
├── 00/
│   ├── preprocess/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
├── 01/
│   ├── preprocess/
│   │   ├── 000000.bin
│   │   ├── 000001.bin
```

Libraries required: timm, torch, tqdm



In [None]:
# clone the repository branch WebVersion
# !git clone -b WebVersion https://github.com/haiquangdinh/HDRangeViT.git
# setup the git email and name
# !git config --global user.email "haiquangdinh@gmail.com"
# !git config --global user.name "Hai Dinh"
# commit the changes
# !git add . && git commit -m "Message" && git push origin WebVersion

# install dependencies in the requirements.txt file
# !pip install -r requirements.txt

# reload the module after making changes
# import importlib
# importlib.reload(mymodule)

# Set a flag to indicate where the code is running
# Set to True if running on RunPod, False if running locally
is_runpod = False  

# Reload modules automatically
%load_ext autoreload
%autoreload 2

In [None]:
import torch.optim as optim
import torch
import numpy as np

import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn

from tqdm.notebook import tqdm
from segmentation_models_pytorch.losses import FocalLoss

In [None]:
from KITTISegmentationDataset import KITTISegmentationDataset
from RangeViTSegmentationModel import RangeViTSegmentationModel
from Evaluation import compute_iou

In [None]:
# Hardware to run on; Uncomment appropriate lines

# runpod cloud RTX 4090: ~ 10 it/s: paralell might not needed since training takes about 3 hours.
# Powerful GPU so increase the batch size for faster training, the num_workers also increase so that the data loading is not a bottleneck
# dataset = KITTISegmentationDataset('./dataset/sequences',['00','01','02','03','04','05','06','07','09','10'])
# Till I see the overfit issue resolve, let only train on 00 and 01
if is_runpod:
    dataset = KITTISegmentationDataset('../dataset/sequences',['00','01'])
    loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=16)
    dataset_val = KITTISegmentationDataset('../dataset/sequences',['08'])
    loader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, num_workers=16)
    pretrain_path = '../range_vit_segmentation_noRGB_patch.pth'

else:
    # local Legion computer: ~ 2 it/s
    # batch_size and num_workers are set to 1 due to limited resources
    # dataset = KITTISegmentationDataset('../SemanticKITTI/dataset/sequences',['00','01','02','03','04','05','06','07','09','10'])
    dataset = KITTISegmentationDataset('../SemanticKITTI/dataset/sequences',['00'])
    loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
    dataset_val = KITTISegmentationDataset('../SemanticKITTI/dataset/sequences',['08'])
    loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=1)
    pretrain_path = '../range_vit_segmentation_noRGB_patch.pth'



In [None]:
### Train the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
num_classes = 20
in_channels = 5 # range, x, y, z, intensity, flag, R, G, B
num_epochs = 60
model = RangeViTSegmentationModel(num_classes=num_classes, in_channels=in_channels).to(device)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel!")
    model = nn.DataParallel(model)

model.to(device)
# Create Focal Loss
criterion = FocalLoss(mode='multiclass', ignore_index=0)  # Use your ignore_index if needed
# criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.0004)
# Load the model if you have a pre-trained one
if os.path.exists(pretrain_path):
    print(f"Loading pre-trained model from {pretrain_path}")
    model.load_state_dict(torch.load(pretrain_path, map_location=device))
# Training loop
best_val_mIoU = 0.0
model.train() # a switch that tells the model to be in training mode. It doesn't actually perform any training computations itself
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    batch_bar = tqdm(loader, desc=f"Training Epoch {epoch+1}", leave=False)
    average_loss = 0.0
    average_acc = 0.0
    average_mIoU = 0.0
    for imgs, labels in batch_bar:
        imgs = imgs.to(device)                # [B, C, H, W]
        labels = labels.to(device)             # [B, H, W]
        optimizer.zero_grad()
        # actually perform the training step
        outputs = model(imgs)                 # [B, num_classes, H, W]
        loss = criterion(outputs, labels)     # Compute raw loss

        preds = outputs.argmax(dim=1)         # [B, H, W]
        mIoU, ious, acc = compute_iou(preds, labels, num_classes) 
        loss.backward()  # Calculates gradients of the loss with respect to all model parameters
        optimizer.step() # Updates Parameter 
        batch_bar.set_postfix(loss=loss.item(), mIoU=mIoU.item(), acc=acc.item())
        average_loss += loss.item()
        average_acc += acc.item()
        average_mIoU += mIoU.item()
        
    print(f"Epoch [{epoch+1}] Train Loss: {average_loss/len(loader):.4f}, Train mIoU: {average_mIoU/len(loader):.4f}, Train Acc: {average_acc/len(loader):.4f}")

    model.eval()  # <-- switch to eval mode
    with torch.no_grad():  # turn off gradient tracking for speed and memory
        average_loss = 0.0
        average_acc = 0.0
        average_mIoU = 0.0
        batch_bar = tqdm(loader_val, desc=f"Evaluating", leave=False)
        for imgs, labels in batch_bar:
    
            imgs = imgs.to(device)                # [B, C, H, W]
            labels = labels.to(device)             # [B, H, W]
    
            outputs = model(imgs)                 # [B, num_classes, H, W]
            loss = criterion(outputs, labels)     # Compute raw loss
    
            preds = outputs.argmax(dim=1)         # [B, H, W]
            mIoU, ious, acc = compute_iou(preds, labels, num_classes) 
            batch_bar.set_postfix(loss=loss.item(), mIoU=mIoU.item(), acc=acc.item())
            average_loss += loss.item()
            average_acc += acc.item()
            average_mIoU += mIoU.item()
            
        print(f"Validation Loss: {average_loss/len(loader_val):.4f}, Validation mIoU: {average_mIoU/len(loader_val):.4f}, Validation Acc: {average_acc/len(loader_val):.4f}")
        val_mIoU = average_mIoU/len(loader_val)
        if val_mIoU > best_val_mIoU:
            best_val_mIoU = val_mIoU
            print('saving better model...')
            torch.save(model.state_dict(), pretrain_path)