# Output Visualization

In [1]:
import os
import json
import cv2
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt

In [2]:
IMAGE_ROOT = '/data/ephemeral/home/data/test/DCM/'
OUTPUT_CSV_ROOT = '../swin_unet_output_1.csv'

In [3]:
CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

In [4]:
# define colors
PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

# utility function
# this does not care overlap
def label2rgb(label):
    image_size = label.shape[1:] + (3, )
    image = np.zeros(image_size, dtype=np.uint8)
    
    for i, class_label in enumerate(label):
        image[class_label == 1] = PALETTE[i]
        
    return image

# RLE로 인코딩된 결과를 mask map으로 복원합니다.
def decode_rle_to_mask(rle, height, width):
    s = rle.split()

    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(height * width, dtype=np.uint8)
    
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    
    return img.reshape(height, width)

In [5]:
# Image path를 dictionary로 저장

image_dict = {}
for d in sorted(os.listdir(IMAGE_ROOT)) :
    for f in sorted(os.listdir(os.path.join(IMAGE_ROOT, d))) :
        image_dict[f] = os.path.join(IMAGE_ROOT,d,f)
    

In [6]:
df = pd.read_csv(OUTPUT_CSV_ROOT)

In [7]:
for i in range(int(len(df)/len(CLASSES))) :
    preds = []
    for rle in df['rle'][i*len(CLASSES):(i+1)*len(CLASSES)]:
        pred = decode_rle_to_mask(rle, height=2048, width=2048)
        preds.append(pred)
        
    file_name = df['image_name'][i*len(CLASSES)]
    preds = np.stack(preds, 0)
    pred_img = label2rgb(preds)

    save_dir = './visualization'
    if not os.path.isdir(save_dir) :
        os.makedirs(save_dir)
    
    cv2.imwrite(os.path.join(save_dir, file_name), pred_img)