# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for Waymo</span>  

In [1]:
import torch.optim as optim
import torch
import numpy as np

import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

from model.WaymoSegmentationDataset import WaymoSegmentationDataset
from model.RangeViTSegmentationModel import RangeViTSegmentationModel



In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

val_dataset = WaymoSegmentationDataset('../WoD/validation', training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)



In [None]:
# Use torchmetrics or do manually
from torchmetrics.classification import MulticlassJaccardIndex
# create a metric and put it on gpu
iou_metric = MulticlassJaccardIndex(num_classes=20, average=None, ignore_index=0).to(device)
# Use torchmetrics to compute class-wise Precision and Recall
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall
precision_metric = MulticlassPrecision(num_classes=20, average=None, ignore_index=0).to(device)
recall_metric = MulticlassRecall(num_classes=20, average=None, ignore_index=0).to(device)

num_classes = 20
in_channels = 9 # range, x, y, z, intensity, flag, R, G, B
model = RangeViTSegmentationModel(n_classes=num_classes, in_channels=in_channels).to(device)



In [4]:
def eval_model(model, loader, iou_metric, precision_metric, recall_metric):
    model.eval()
    iou_metric.reset()  # Reset the IoU metric for the evaluation
    precision_metric.reset()
    recall_metric.reset()

    with torch.no_grad():
        batch_bar = tqdm(loader, desc="Evaluating", leave=False)
        for imgs, labels in batch_bar:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)

            preds = outputs.argmax(dim=1)
            iou_metric.update(preds, labels)
            precision_metric.update(preds, labels)
            recall_metric.update(preds, labels)

            ious = iou_metric.compute()
            precisions = precision_metric.compute()
            recalls = recall_metric.compute()
            
            mean_iou = torch.mean(ious[ious != 0])
            mean_precision = torch.mean(precisions[precisions != 0])
            mean_recall = torch.mean(recalls[recalls != 0])
            batch_bar.set_postfix(mIoU=mean_iou.item())

    print(f"mIoU: {mean_iou.item():.4f}, mPrecision: {mean_precision.item():.4f}, mRecall: {mean_recall.item():.4f}")


In [5]:
# Load the model if you have a pre-trained one
pretrain_path = 'range_vit_waymo.pth'
if os.path.exists(pretrain_path):
    print(f"Loading pre-trained model from {pretrain_path}")
    model.load_state_dict(torch.load(pretrain_path))

Loading pre-trained model from range_vit_waymo.pth


In [6]:
eval_model(model, val_loader, iou_metric, precision_metric, recall_metric)

Evaluating:   0%|          | 0/5976 [00:00<?, ?it/s]

RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). This could be due to the metric class not being on the same device as input. Instead of `metric=MulticlassPrecision(...)` try to do `metric=MulticlassPrecision(...).to(device)` where device corresponds to the device of the input.