## Loading data

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from datasets.utils import visualize
from datasets.rellis_3d import Rellis3D as Dataset
from torch.utils.data import DataLoader
import torch
import numpy as np
import segmentation_models_pytorch as smp

In [None]:
CLASSES = ['void', 'dirt', 'grass', 'tree', 'pole', 'water',
           'sky', 'vehicle', 'object', 'asphalt', 'building',
           'log', 'person', 'fence', 'bush', 'concrete',
           'barrier', 'puddle', 'mud', 'rubble']
DEVICE = 'cuda'
IMG_SIZE = (352, 640)

## Test best saved model

In [None]:
# load best saved checkpoint
# best_model = torch.load('./best_model.pth')
best_model = torch.load('../config/weights/smp/PSPNet_resnext50_32x4d_704x960_lr0.0001_bs6_epoch18_Rellis3D_iou_0.73.pth')

In [None]:
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

In [None]:
test_dataset = Dataset(classes=CLASSES, crop_size=(704, 960), split='val')
test_dataloader = DataLoader(test_dataset)

# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(test_dataloader)

## Visualize predictions

In [None]:
from hrnet.core.function import convert_label, convert_color
import yaml

cfg = yaml.safe_load(open('../config/rellis.yaml', 'r'))
test_dataset = Dataset(classes=CLASSES, crop_size=(704, 960), split='test')

def mask_to_colormap(mask, cfg):
    mask = np.argmax(mask, axis=0).astype(np.uint8) - 1
    mask = convert_label(mask, True)
    mask = convert_color(mask, cfg["color_map"])
    return mask


for i in range(5):
    n = np.random.choice(len(test_dataset))
    
    image, gt_mask = test_dataset[n]
    image_vis = np.uint8(255 * (image * test_dataset.std + test_dataset.mean))
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image.transpose([2, 0, 1])).to(DEVICE).unsqueeze(0)
    pred = best_model.predict(x_tensor)
    pr_mask = (pred.squeeze().cpu().numpy().round())
    
    pred_colormap = mask_to_colormap(pr_mask, cfg)
    gt_colormap = mask_to_colormap(gt_mask, cfg)

    visualize(image=image_vis, pred=pred_colormap, gt=gt_colormap)