In [10]:
import yaml
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
from torchvision.utils import draw_segmentation_masks

from data import (
    create_test_dataset, create_test_dataloader
)

ds_cfg_fp = 'configs/dataset/suim.yaml'
ds_cfg = yaml.load(open(ds_cfg_fp, 'r'), yaml.FullLoader)
test_ds = create_test_dataset('seg', ds_cfg)
test_dl = create_test_dataloader(test_ds)

color_map = ds_cfg['test']['color_map']
colors = ['#'+c for c in sorted(color_map.keys())]

In [11]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    fig.set_size_inches(4*len(imgs), 4)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = img.cpu().numpy().transpose(1,2,0)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [12]:
def calc_mIOU(pred, label, num_classes):
    pred = F.softmax(pred, dim=1)
    pred = pred.argmax(1).squeeze(1)
    label = label.argmax(1).squeeze(1)
    iou_list = list()
    present_iou_list = list()

    pred = pred.view(-1)
    label = label.view(-1)
    for sem_class in range(num_classes):
        pred_inds = (pred == sem_class)
        target_inds = (label == sem_class)
        if target_inds.long().sum().item() == 0:
            iou_now = float('nan')
        else: 
            intersection_now = (pred_inds[target_inds]).long().sum().item()
            union_now = pred_inds.long().sum().item() +\
                    target_inds.long().sum().item() - intersection_now
            iou_now = float(intersection_now) / float(union_now)
            present_iou_list.append(iou_now)
        iou_list.append(iou_now)
    return np.mean(present_iou_list)

In [None]:
for batch in test_dl:
    imgs = batch['img']
    masks = batch['mask']
    imgs_with_mask = []
    for img, mask in zip(imgs, masks):
        img = (img * 255).to(torch.uint8)
        normalized_mask = F.softmax(mask, dim=0)
        boolean_mask = torch.stack([(normalized_mask.argmax(0) == i) for i in range(len(color_map))])
        img_with_mask = draw_segmentation_masks(img, boolean_mask, alpha=0.5, colors=colors)
        imgs_with_mask.append(img_with_mask)
    show(imgs_with_mask)
    break