# <span style="color:red; font-weight:bold; ">A clean and modern RangeViT implementation for Waymo</span>  

In [1]:
import torch.optim as optim
import torch
import numpy as np

import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

from model.WaymoSegmentationDataset import WaymoSegmentationDataset
from model.RangeViTSegmentationModel import RangeViTSegmentationModel

from segmentation_models_pytorch.losses import FocalLoss, LovaszLoss


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

val_dataset = WaymoSegmentationDataset('../WoD/validation', training=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)



In [3]:
# Use torchmetrics or do manually
from torchmetrics.classification import MulticlassJaccardIndex
# create a metric and put it on gpu
metric = MulticlassJaccardIndex(num_classes=20, average=None, ignore_index=0).to(device)

num_classes = 20
in_channels = 9 # range, x, y, z, intensity, flag, R, G, B
num_epochs = 60
model = RangeViTSegmentationModel(n_classes=num_classes, in_channels=in_channels).to(device)
# criterion = LovaszLoss(mode='multiclass', ignore_index=0, per_image=False)
focal = FocalLoss(mode='multiclass', ignore_index=0)
lovasz = LovaszLoss(mode='multiclass', ignore_index=0, per_image=False)
def criterion(outputs, targets):
    return focal(outputs, targets) + lovasz(outputs, targets)
optimizer = optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.01, betas=(0.9, 0.999))

from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0)

In [4]:
def eval_model(model, loader, criterion, metric):
    model.eval()
    total_loss = 0.0
    metric.reset()  # Reset the IoU metric for the evaluation
    with torch.no_grad():
        batch_bar = tqdm(loader, desc="Evaluating", leave=False)
        for imgs, labels in batch_bar:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            preds = outputs.argmax(dim=1)
            metric.update(preds, labels)
            ious = metric.compute()
            mean_iou = torch.mean(ious[ious != 0])
            batch_bar.set_postfix(loss=loss.item(), mIoU=mean_iou.item())
            total_loss += loss.item()
    print(f"Evaluation Loss: {total_loss/len(loader):.4f}, mIoU: {mean_iou.item():.4f}")


In [5]:
# Load the model if you have a pre-trained one
pretrain_path = 'range_vit_segmentation4645.pth'
if os.path.exists(pretrain_path):
    print(f"Loading pre-trained model from {pretrain_path}")
    model.load_state_dict(torch.load('range_vit_segmentation.pth'))

In [6]:
eval_model(model, val_loader, criterion, metric)

Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Evaluation Loss: 7.8585, mIoU: 0.0249
