In [19]:
import os
from os import path as osp
from ipywidgets import interact, Dropdown
import matplotlib.pyplot as plt
import cv2
import numpy as np
import matplotlib.patches as mpatches

In [2]:
CLASSES = [
    'All', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass',
    'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing'
]

PALLETE = [[  0,   0,   0], [128,   0,   0], [  0, 128,   0], [128, 128,   0],
            [  0,   0, 128], [128,   0, 128], [  0, 128, 128], [128, 128, 128],
            [ 64,   0,   0], [192,   0,   0],[ 64, 128,   0]]

In [6]:
base_dir = '/opt/ml/input/data/copy_paste'
batch_dir = 'batch_02_vt'
img_dir = osp.join(base_dir, batch_dir, 'images')
anno_dir = osp.join(base_dir, batch_dir, 'annotations')

In [7]:
img_files = os.listdir(img_dir)
anno_files = os.listdir(anno_dir)
img_files.sort()
anno_files.sort()

In [47]:
@interact(
    idx=(0, len(img_files)), 
    alpha=(0, 1, 0.1),
    option=Dropdown(
        options=CLASSES,
        value=CLASSES[0],
        description='Class:',
        disabled=False,
    ))
def show_img(idx=0, alpha=0.5, option=None):
    plt.figure(figsize=(11, 11))
    plt.suptitle(batch_dir)
    
    plt.subplot(2, 2, 1)
    plt.title(img_files[idx])
    img = cv2.imread(osp.join(img_dir, img_files[idx]))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    
    mask = cv2.imread(osp.join(anno_dir, anno_files[idx]))
    stuff = CLASSES.index(option)
    print(stuff)
    if stuff:
        bgrLower = np.array(PALLETE[stuff][::-1])    # 추출할 색의 하한(BGR)
        bgrUpper = np.array(PALLETE[stuff][::-1])    # 추출할 색의 상한(BGR)
        img_mask = cv2.inRange(mask, bgrLower, bgrUpper) # BGR로 부터 마스크를 작성
        mask = cv2.bitwise_and(mask, mask, mask=img_mask)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
    
    plt.subplot(2, 2, 2)
    plt.title(anno_files[idx])
    handles, labels = plt.gca().get_legend_handles_labels()
    patch1 = mpatches.Patch(color=(0, 0, 0), label='All')
    patch2 = mpatches.Patch(color=(0.5, 0, 0), label='General trash')
    patch3 = mpatches.Patch(color=(0, 0.5, 0), label='Paper')
    patch4 = mpatches.Patch(color=(0.5, 0.5, 0), label='Paper pack')
    patch5 = mpatches.Patch(color=(0, 0, 0.5), label='Metal')
    patch6 = mpatches.Patch(color=(0.5, 0, 0.5), label='Glass')
    patch7 = mpatches.Patch(color=(0, 0.5, 0.5), label='Plastic')
    patch8 = mpatches.Patch(color=(0.5, 0.5, 0.5), label='Styrofoam')
    patch9 = mpatches.Patch(color=(0.25, 0, 0), label='Plastic bag')
    patch10 = mpatches.Patch(color=(0.75, 0, 0), label='Battery')
    patch11 = mpatches.Patch(color=(0.25, 0.5, 0), label='Clothing')
    handles.extend([patch1, patch2, patch3, patch4, patch5, patch6, patch7, patch8, patch9, patch10, patch11])
    plt.legend(handles=handles, loc=(0, -0.7))
    plt.imshow(mask)
    
    plt.subplot(2, 2, 3)
    plt.title('Mix')
    plt.imshow(img)
    plt.imshow(mask, alpha=alpha)
    
    plt.show()

interactive(children=(IntSlider(value=0, description='idx', max=1561), FloatSlider(value=0.5, description='alp…