In [None]:
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import json
import torch
# ================== TODO ==================
IMAGE_ROOT = "/opt/ml/input/data/train/DCM"
LABEL_ROOT = "/opt/ml/input/data/train/outputs_json"
# ==========================================

PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

pngs = {
    os.path.relpath(os.path.join(root, fname), start=IMAGE_ROOT)
    for root, _dirs, files in os.walk(IMAGE_ROOT)
    for fname in files
    if os.path.splitext(fname)[1].lower() == ".png"
}
pngs = sorted(pngs)

def label2rgb(label):
    image_size = label.shape[1:] + (3, )
    image = np.zeros(image_size, dtype=np.uint8)
    for i, class_label in enumerate(label):
        image[class_label == 1] = PALETTE[i]
    return image
    
def show(idx):
    fig, ax = plt.subplots(2, 2, figsize=(24, 24), sharex=True, sharey=True)
    idx = 2*idx
    for i in range(2):
        image_name = pngs[idx]
        print(image_name+' : Annotation 생성중 (points가 많아서 시간이 걸립니다)')
        image_path = os.path.join(IMAGE_ROOT, image_name)
        image = cv2.imread(image_path)
        label_name = image_name.split('.')[0]+'.json'
        label_path = os.path.join(LABEL_ROOT, label_name)
        
        # process a label of shape (H, W, NC)
        label_shape = tuple(image.shape[:2]) + (len(CLASSES), )
        label = np.zeros(label_shape, dtype=np.uint8)
        
        # read label file
        with open(label_path, "r") as f:
            annotations = json.load(f)
        annotations = annotations["annotations"]
        
        points = []
        # iterate each class
        for ann in annotations:
            c = ann["label"]
            class_ind = CLASSES.index(c)
            point = np.array(ann["points"])
            points.append((class_ind, point))
            
            # polygon to mask
            class_label = np.zeros(image.shape[:2], dtype=np.uint8)
            cv2.fillPoly(class_label, [point], 1)
            label[..., class_ind] = class_label
            
        label = label.transpose(2, 0, 1)
        label = torch.from_numpy(label).float()
        if i == 0:
            ax[0][0].set_title(image_name, fontsize=30)
            ax[0][0].imshow(image)
            for cls in points:    # TODO : mask만 빠르게 보려면 해당 for문 주석처리
                c_id = cls[0]
                for (x, y) in cls[1]:
                    ax[0][0].scatter(x, y, s=1, c=[np.array(PALETTE[c_id])/255])
            ax[1][0].imshow(label2rgb(label))
        if i == 1:
            ax[0][1].set_title(image_name, fontsize=30)
            ax[0][1].imshow(image)
            for cls in points:    # TODO : mask만 빠르게 보려면 해당 for문 주석처리
                c_id = cls[0]
                for (x, y) in cls[1]:
                    ax[0][1].scatter(x, y, s=1, c=[np.array(PALETTE[c_id])/255])
            ax[1][1].imshow(label2rgb(label))
        idx += 1
    fig.subplots_adjust(wspace=0, hspace=0)
    plt.show()

In [None]:
for i in range(0, 10):  # TODO : index 0~399
    show(i)