# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for Waymo</span>  

In [1]:
import torch.optim as optim
import torch
import numpy as np

import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

from model.WaymoSegmentationDataset import WaymoSegmentationDataset
from model.RangeViTSegmentationModel import RangeViTSegmentationModel



In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

val_dataset = WaymoSegmentationDataset('../WoD/validation', training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)



In [3]:
# Use torchmetrics or do manually
from torchmetrics.classification import MulticlassJaccardIndex
# create a metric and put it on gpu
iou_metric = MulticlassJaccardIndex(num_classes=20, average=None, ignore_index=0).to(device)
# Use torchmetrics to compute class-wise Precision and Recall
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall
precision_metric = MulticlassPrecision(num_classes=20, average=None, ignore_index=0).to(device)
recall_metric = MulticlassRecall(num_classes=20, average=None, ignore_index=0).to(device)

num_classes = 20
in_channels = 9 # range, x, y, z, intensity, flag, R, G, B
model = RangeViTSegmentationModel(n_classes=num_classes, in_channels=in_channels).to(device)



In [4]:
def eval_model(model, loader, iou_metric, precision_metric, recall_metric):
    model.eval()
    iou_metric.reset()  # Reset the IoU metric for the evaluation
    precision_metric.reset()
    recall_metric.reset()

    with torch.no_grad():
        batch_bar = tqdm(loader, desc="Evaluating", leave=False)
        for imgs, labels in batch_bar:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)

            preds = outputs.argmax(dim=1)
            iou_metric.update(preds, labels)
            precision_metric.update(preds, labels)
            recall_metric.update(preds, labels)

            ious = iou_metric.compute()
            precisions = precision_metric.compute()
            recalls = recall_metric.compute()
            
            mean_iou = torch.mean(ious[ious != 0])
            mean_precision = torch.mean(precisions[precisions != 0])
            mean_recall = torch.mean(recalls[recalls != 0])
            batch_bar.set_postfix(mIoU=mean_iou.item())

    print(f"mIoU: {mean_iou.item():.4f}, mPrecision: {mean_precision.item():.4f}, mRecall: {mean_recall.item():.4f}")

    return ious, precisions, recalls


In [5]:
# Load the model if you have a pre-trained one
pretrain_path = 'range_vit_waymo.pth'
if os.path.exists(pretrain_path):
    print(f"Loading pre-trained model from {pretrain_path}")
    model.load_state_dict(torch.load(pretrain_path))

Loading pre-trained model from range_vit_waymo.pth


In [6]:
ious, precisions, recalls = eval_model(model, val_loader, iou_metric, precision_metric, recall_metric)

Evaluating:   0%|          | 0/5976 [00:00<?, ?it/s]

mIoU: 0.8126, mPrecision: 0.8968, mRecall: 0.8803


In [7]:
# print metrics of each class
for cls in range(num_classes):
    print(f"Class {cls}: IoU={ious[cls]:.4f}, Precision={precisions[cls]:.4f}, Recall={recalls[cls]:.4f}")      

Class 0: IoU=0.0000, Precision=0.0000, Recall=0.0000
Class 1: IoU=0.9562, Precision=0.9756, Recall=0.9797
Class 2: IoU=0.7992, Precision=0.8783, Recall=0.8987
Class 3: IoU=0.9020, Precision=0.9348, Recall=0.9626
Class 4: IoU=0.9314, Precision=0.9731, Recall=0.9560
Class 5: IoU=0.9180, Precision=0.9433, Recall=0.9716
Class 6: IoU=0.8599, Precision=0.9347, Recall=0.9148
Class 7: IoU=0.7001, Precision=0.8625, Recall=0.7881
Class 8: IoU=0.3713, Precision=0.6273, Recall=0.4764
Class 9: IoU=0.9451, Precision=0.9707, Recall=0.9729
Class 10: IoU=0.0000, Precision=0.0000, Recall=0.0000
Class 11: IoU=0.8848, Precision=0.9405, Recall=0.9372
Class 12: IoU=0.5909, Precision=0.7318, Recall=0.7543
Class 13: IoU=0.9309, Precision=0.9625, Recall=0.9659
Class 14: IoU=0.0000, Precision=0.0000, Recall=0.0000
Class 15: IoU=0.8915, Precision=0.9437, Recall=0.9416
Class 16: IoU=0.7328, Precision=0.8545, Recall=0.8372
Class 17: IoU=0.0000, Precision=0.0000, Recall=0.0000
Class 18: IoU=0.7990, Precision=0.9087