In [1]:
import collections
import math
import os

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from scipy.io import loadmat
from tqdm import tqdm

import lib.utils.data as torchdata
from dataset import TestDataset
from lib.nn import user_scattered_collate
from lib.nn.modules import SynchronizedBatchNorm2d
from lib.utils import as_numpy
from utils import colorEncode, find_recursive

In [2]:
colors = loadmat('data/color150.mat')['colors']

In [3]:
cpu_only = True

In [4]:
device = torch.device('cuda:0')
if cpu_only:
    device = torch.device('cpu')

### Utils

In [5]:
def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias
    )


def conv3x3_bn_relu(in_planes, out_planes, stride=1):
    return nn.Sequential(
        conv3x3(in_planes, out_planes, stride),
        SynchronizedBatchNorm2d(out_planes),
        nn.ReLU(inplace=True),
    )

In [6]:
class Arguments(object):
    def __init__(
        self,
        arch_decoder='ppm_deepsup',
        arch_encoder='resnet50dilated',
        base_model_name='resnet50-imagenet.pth',
        batch_size=1,
        fc_dim=2048,
        gpu=0,
        imgMaxSize=1000,
        imgSize=[300, 400, 500, 600],
        model_dir='./pretrained/',
        model_path='./baseline-resnet50dilated-ppm_deepsup/',
        num_class=150,
        num_val=-1,
        padding_constant=8,
        result='./test_pics/segmented_images/',
        segm_downsampling_rate=8,
        suffix='_epoch_20.pth',
        test_imgs=['./test_pics/1.png'],
    ):
        self.arch_decoder=arch_decoder
        self.arch_encoder=arch_encoder
        self.batch_size=batch_size
        self.fc_dim=fc_dim
        self.gpu=gpu
        self.imgMaxSize=imgMaxSize
        self.imgSize=imgSize
        self.model_dir=model_dir
        self.model_path=model_path
        self.num_class=num_class
        self.num_val=num_val
        self.padding_constant=padding_constant
        self.result=result
        self.segm_downsampling_rate=segm_downsampling_rate
        self.suffix=suffix
        self.test_imgs=test_imgs
        
        self.weights_encoder = os.path.join(self.model_dir + self.model_path, 'encoder' + self.suffix)
        self.weights_decoder = os.path.join(self.model_dir + self.model_path, 'decoder' + self.suffix)
        self.weights_base = os.path.join(self.model_dir + base_model_name)
        
    def __repr__(self):
        return str({
            'base': self.weights_base,
            'encoder': self.weights_encoder, 
            'decoder': self.weights_decoder
        })

In [7]:
args = Arguments(
    model_path='baseline-resnet50dilated-ppm_deepsup', 
    suffix='_epoch_20.pth',
)
args.__dict__

{'arch_decoder': 'ppm_deepsup',
 'arch_encoder': 'resnet50dilated',
 'batch_size': 1,
 'fc_dim': 2048,
 'gpu': 0,
 'imgMaxSize': 1000,
 'imgSize': [300, 400, 500, 600],
 'model_dir': './pretrained/',
 'model_path': 'baseline-resnet50dilated-ppm_deepsup',
 'num_class': 150,
 'num_val': -1,
 'padding_constant': 8,
 'result': './test_pics/segmented_images/',
 'segm_downsampling_rate': 8,
 'suffix': '_epoch_20.pth',
 'test_imgs': ['./test_pics/1.png'],
 'weights_encoder': './pretrained/baseline-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth',
 'weights_decoder': './pretrained/baseline-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',
 'weights_base': './pretrained/resnet50-imagenet.pth'}

### Models architechture

#### Bottleneck

In [8]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = SynchronizedBatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = SynchronizedBatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = SynchronizedBatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

#### Resnet base

The encoder is built on top of a network that is pretrained for image classification on ImageNet 

In [9]:
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 128
        super(ResNet, self).__init__()
        self.conv1 = conv3x3(3, 64, stride=2)
        self.bn1 = SynchronizedBatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = conv3x3(64, 64)
        self.bn2 = SynchronizedBatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv3 = conv3x3(64, 128)
        self.bn3 = SynchronizedBatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=True)
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                SynchronizedBatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


#### Encoder

In [10]:
class ResnetDilated(nn.Module):
    def __init__(self, orig_resnet, dilate_scale=8):
        super(ResnetDilated, self).__init__()
        from functools import partial

        if dilate_scale == 8:
            orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
        elif dilate_scale == 16:
            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))

        # take pretrained resnet, except AvgPool and FC
        self.conv1 = orig_resnet.conv1
        self.bn1 = orig_resnet.bn1
        self.relu1 = orig_resnet.relu1
        
        self.conv2 = orig_resnet.conv2
        self.bn2 = orig_resnet.bn2
        self.relu2 = orig_resnet.relu2
        
        self.conv3 = orig_resnet.conv3
        self.bn3 = orig_resnet.bn3
        self.relu3 = orig_resnet.relu3
        
        self.maxpool = orig_resnet.maxpool
        
        self.layer1 = orig_resnet.layer1
        self.layer2 = orig_resnet.layer2
        self.layer3 = orig_resnet.layer3
        self.layer4 = orig_resnet.layer4

    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            # the convolution with stride
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            # other convoluions
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x, return_feature_maps=False):
        conv_out = []

        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x); conv_out.append(x);
        x = self.layer2(x); conv_out.append(x);
        x = self.layer3(x); conv_out.append(x);
        x = self.layer4(x); conv_out.append(x);

        if return_feature_maps:
            return conv_out
        return [x]

#### Decoder

In [11]:
class PPMDeepsup(nn.Module):
    def __init__(
        self, num_class=150, fc_dim=4096, use_softmax=False, pool_scales=(1, 2, 3, 6)
    ):
        super(PPMDeepsup, self).__init__()
        self.use_softmax = use_softmax

        self.ppm = []
        for scale in pool_scales:
            self.ppm.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(scale),
                    nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
                    SynchronizedBatchNorm2d(512),
                    nn.ReLU(inplace=True)
                )
            )
        self.ppm = nn.ModuleList(self.ppm)
        self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)

        self.conv_last = nn.Sequential(
            nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
                      kernel_size=3, padding=1, bias=False),
            SynchronizedBatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_class, kernel_size=1)
        )
        self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
        self.dropout_deepsup = nn.Dropout2d(0.1)

    def forward(self, conv_out, segSize=None):
        conv5 = conv_out[-1]

        input_size = conv5.size()
        ppm_out = [conv5]
        for pool_scale in self.ppm:
            ppm_out.append(nn.functional.interpolate(
                pool_scale(conv5),
                (input_size[2], input_size[3]),
                mode='bilinear', align_corners=False))
        ppm_out = torch.cat(ppm_out, 1)

        x = self.conv_last(ppm_out)

        if self.use_softmax:  # is True during inference
            x = nn.functional.interpolate(
                x, size=segSize, mode='bilinear', align_corners=False)
            x = nn.functional.softmax(x, dim=1)
            return x

        # deep sup
        conv4 = conv_out[-2]
        _ = self.cbr_deepsup(conv4)
        _ = self.dropout_deepsup(_)
        _ = self.conv_last_deepsup(_)

        x = nn.functional.log_softmax(x, dim=1)
        _ = nn.functional.log_softmax(_, dim=1)

        return (x, _)

### Load pretrained models

#### Base model

In [12]:
base_model = ResNet(Bottleneck, [3, 4, 6, 3])
base_model.load_state_dict(
    torch.load(args.weights_base, map_location=lambda storage, loc: storage),
    strict=False
)
base_model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
  (relu2): ReLU(inplace)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): SynchronizedBatchNorm2d(128, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
  (relu3): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
      (conv2): Conv2d(

#### Encoder

In [13]:
encoder = ResnetDilated(base_model, dilate_scale=8)
encoder.to(device)

encoder.load_state_dict(
    torch.load(args.weights_encoder, map_location=lambda storage, loc: storage), 
    strict=False
)
encoder = encoder.to(device)
encoder

ResnetDilated(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
  (relu2): ReLU(inplace)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): SynchronizedBatchNorm2d(128, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
  (relu3): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
      (conv2): 

#### Decoder

In [14]:
decoder = PPMDeepsup(num_class=150, fc_dim=2048, use_softmax=True)
decoder.to(device)

decoder.load_state_dict(
    torch.load(args.weights_decoder, map_location=lambda storage, loc: storage), 
    strict=False
)
decoder = decoder.to(device)
decoder

PPMDeepsup(
  (ppm): ModuleList(
    (0): Sequential(
      (0): AdaptiveAvgPool2d(output_size=1)
      (1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): SynchronizedBatchNorm2d(512, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
      (3): ReLU(inplace)
    )
    (1): Sequential(
      (0): AdaptiveAvgPool2d(output_size=2)
      (1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): SynchronizedBatchNorm2d(512, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
      (3): ReLU(inplace)
    )
    (2): Sequential(
      (0): AdaptiveAvgPool2d(output_size=3)
      (1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): SynchronizedBatchNorm2d(512, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
      (3): ReLU(inplace)
    )
    (3): Sequential(
      (0): AdaptiveAvgPool2d(output_size=6)
      (1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias

### Segmentation Model

In [15]:
class SegmentationModuleBase(nn.Module):
    def __init__(self):
        super(SegmentationModuleBase, self).__init__()

    def pixel_acc(self, pred, label):
        _, preds = torch.max(pred, dim=1)
        valid = (label >= 0).long()
        acc_sum = torch.sum(valid * (preds == label).long())
        pixel_sum = torch.sum(valid)
        acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
        return acc

In [16]:
class SegmentationModule(SegmentationModuleBase):
    def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
        super(SegmentationModule, self).__init__()
        self.encoder = net_enc
        self.decoder = net_dec
        self.crit = crit
        self.deep_sup_scale = deep_sup_scale

    def forward(self, feed_dict, *, segSize=None):
        # training
        if segSize is None:
            if self.deep_sup_scale is not None: # use deep supervision technique
                (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
            else:
                pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))

            loss = self.crit(pred, feed_dict['seg_label'])
            if self.deep_sup_scale is not None:
                loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label'])
                loss = loss + loss_deepsup * self.deep_sup_scale

            acc = self.pixel_acc(pred, feed_dict['seg_label'])
            return loss, acc
        # inference
        else:
            pred = self.decoder(
                self.encoder(feed_dict['img_data'], return_feature_maps=True), 
                segSize=segSize
            )
            return pred

In [17]:
crit = nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(encoder, decoder, crit)
segmentation_module.to(device)

SegmentationModule(
  (encoder): ResnetDilated(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace)
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): SynchronizedBatchNorm2d(128, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=

### Dataset

In [18]:
class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

In [19]:
class BaseDataset(Dataset):
    def __init__(self, odgt, opt, **kwargs):
        # parse options
        self.imgSize = opt.imgSize
        self.imgMaxSize = opt.imgMaxSize

        # max down sampling rate of network to avoid rounding during conv or pooling
        self.padding_constant = opt.padding_constant

        # parse the input list
        self.parse_input_list(odgt, **kwargs)

        # mean and std
        self.normalize = transforms.Normalize(
            mean=[102.9801, 115.9465, 122.7717],
            std=[1., 1., 1.])

    def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
        if isinstance(odgt, list):
            self.list_sample = odgt
        elif isinstance(odgt, str):
            self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]

        if max_sample > 0:
            self.list_sample = self.list_sample[0:max_sample]
        if start_idx >= 0 and end_idx >= 0:     # divide file list
            self.list_sample = self.list_sample[start_idx:end_idx]

        self.num_sample = len(self.list_sample)
        assert self.num_sample > 0
        print('# samples: {}'.format(self.num_sample))

    def img_transform(self, img):
        # image to float
        img = img.astype(np.float32)
        img = img.transpose((2, 0, 1))
        img = self.normalize(torch.from_numpy(img.copy()))
        return img

    # Round x to the nearest multiple of p and x' >= x
    def round2nearest_multiple(self, x, p):
        return ((x - 1) // p + 1) * p

In [20]:
class TestDataset(BaseDataset):
    def __init__(self, odgt, opt, **kwargs):
        super(TestDataset, self).__init__(odgt, opt, **kwargs)

    def __getitem__(self, index):
        this_record = self.list_sample[index]
        # load image and label
        image_path = this_record['fpath_img']
        img = cv2.imread(image_path, cv2.IMREAD_COLOR)

        ori_height, ori_width, _ = img.shape

        img_resized_list = []
        for this_short_size in self.imgSize:
            # calculate target height and width
            scale = min(this_short_size / float(min(ori_height, ori_width)),
                        self.imgMaxSize / float(max(ori_height, ori_width)))
            target_height, target_width = int(ori_height * scale), int(ori_width * scale)

            # to avoid rounding in network
            target_height = self.round2nearest_multiple(target_height, self.padding_constant)
            target_width = self.round2nearest_multiple(target_width, self.padding_constant)

            # resize
            img_resized = cv2.resize(img.copy(), (target_width, target_height))

            # image transform
            img_resized = self.img_transform(img_resized)
            img_resized = torch.unsqueeze(img_resized, 0)
            img_resized_list.append(img_resized)

        output = dict()
        output['img_ori'] = img.copy()
        output['img_data'] = [x.contiguous() for x in img_resized_list]
        output['info'] = this_record['fpath_img']
        return output

    def __len__(self):
        return self.num_sample

### Work flow

In [21]:
def async_copy_to(obj, dev, main_stream=None):
    if torch.is_tensor(obj):
        v = obj.to(dev, non_blocking=True)
        if main_stream is not None:
            v.data.record_stream(main_stream)
        return v
    elif isinstance(obj, collections.Mapping):
        return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
    elif isinstance(obj, collections.Sequence):
        return [async_copy_to(o, dev, main_stream) for o in obj]
    else:
        return obj

In [22]:
def get_labels(data, with_percentage=False, min_percentage=0):
    '''
    returns:
        color index, 
        percentage of color in the image
    eg: [5, 13.32, '(120, 120, 80)', '#787850', 'ceiling']
    '''
    data = data.flatten().copy()
    unique_labels = np.unique(data)
    
    labels = []
    for label in unique_labels:
        count_label = len(data[data==label])
        percentage = round(count_label/len(data) * 100, 2)
        if percentage > min_percentage:
            labels.append([label, percentage])

    labels.sort(key=lambda x: x[1])
    labels.reverse()

    if not with_percentage: 
        labels = np.array(labels, int)[:, 0].tolist()
        
    return np.array(labels, dtype=object)

In [23]:
def get_label_definition(labels, dictionary):
    '''
    returns:
        color index, 
        percentage of color in the image,
        color code in RGB,
        color code in HEX,
        label name
    eg: [5, 13.32, '(120, 120, 80)', '#787850', 'ceiling']
    '''
    definitions = []
    for i in labels:
        label = i
        if not isinstance(label, int):
            i = i[0]
        label_def = np.append(label, dictionary[i])
        definitions.append(label_def)
    return np.array(definitions)

In [24]:
def visualize_result(data, pred, args):
    (img, info) = data

    # prediction
    pred_color = colorEncode(pred, colors).astype(np.uint8)

    # aggregate images and save
    im_vis = np.concatenate((img, pred_color), axis=1)

    img_name = info.split('/')[-1]
    cv2.imwrite(
        os.path.join(args.result, img_name.replace('.jpg', '.png')), 
        im_vis
    )
    return im_vis

In [25]:
def test(segmentation_module, loader, args):
    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, device)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, device)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(args.imgSize)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
            return pred, batch_data


In [26]:
list_test = [{'fpath_img': './test_pics/1.png'}]
dataset_test = TestDataset(list_test, args, max_sample=-1)
loader_test = torchdata.DataLoader(
    dataset_test,
    batch_size=args.batch_size,
    shuffle=False,
    collate_fn=user_scattered_collate,
    num_workers=5,
    drop_last=True
)
pred, batch_data = test(segmentation_module, loader_test, args)

  0%|                                                                                                                                                         | 0/1 [00:00<?, ?it/s]

# samples: 1


In [27]:
img = visualize_result(
    (batch_data['img_ori'], batch_data['info']),
    pred, 
    args
)

In [28]:
object150_df = pd.read_csv('./data/object150_info.csv')
labels_dictionary = object150_df[['Color_Code (R,G,B)', 'Color_Code(hex)', 'Name']].values
object150_df.head()

Unnamed: 0,Idx,Ratio,Train,Val,Stuff,"Color_Code (R,G,B)",Color_Code(hex),Color,Name
0,1,0.1576,11664,1172,1,"(120, 120, 120)",#787878,,wall
1,2,0.1072,6046,612,1,"(180, 120, 120)",#B47878,,building;edifice
2,3,0.0878,8265,796,1,"(6, 230, 230)",#06E6E6,,sky
3,4,0.0621,9336,917,1,"(80, 50, 50)",#503232,,floor;flooring
4,5,0.048,6678,641,0,"(4, 200, 3)",#04C803,,tree


In [29]:
labels = get_labels(pred, with_percentage=True, min_percentage=0)
labels

array([[0, 18.69],
       [5, 13.32],
       [23, 11.96],
       [14, 9.59],
       [24, 9.1],
       [3, 8.66],
       [30, 7.06],
       [49, 5.02],
       [64, 3.74],
       [15, 2.65],
       [39, 2.34],
       [27, 1.39],
       [18, 1.31],
       [66, 1.15],
       [135, 0.94],
       [19, 0.75],
       [36, 0.68],
       [10, 0.54],
       [22, 0.52],
       [67, 0.17],
       [142, 0.14],
       [82, 0.13],
       [148, 0.09],
       [97, 0.04]], dtype=object)

In [30]:
get_label_definition(labels, labels_dictionary)

array([[0, 18.69, '(120, 120, 120)', '#787878', 'wall'],
       [5, 13.32, '(120, 120, 80)', '#787850', 'ceiling'],
       [23, 11.96, '(11, 102, 255)', '#0B66FF', 'sofa;couch;lounge'],
       [14, 9.59, '(8, 255, 51)', '#08FF33', 'door;double;door'],
       [24, 9.1, '(255, 7, 71)', '#FF0747', 'shelf'],
       [3, 8.66, '(80, 50, 50)', '#503232', 'floor;flooring'],
       [30, 7.06, '(8, 255, 214)', '#08FFD6', 'armchair'],
       [49, 5.02, '(0250, 10, 15)', '#FA0A0F',
        'fireplace;hearth;open;fireplace'],
       [64, 3.74, '(0, 255, 112)', '#00FF70',
        'coffee;table;cocktail;table'],
       [15, 2.65, '(255, 6, 82)', '#FF0652', 'table'],
       [39, 2.34, '(255, 194, 7)', '#FFC207', 'cushion'],
       [27, 1.39, '(220, 220, 220)', '#DCDCDC', 'mirror'],
       [18, 1.31, '(255, 51, 7)', '#FF3307',
        'curtain;drape;drapery;mantle;pall'],
       [66, 1.15, '(255, 0, 0)', '#FF0000', 'flower'],
       [135, 0.94, '(0, 255, 204)', '#00FFCC', 'vase'],
       [19, 0.75, '(2