In [None]:
# python native
import os

# external library
import cv2
import numpy as np
import pandas as pd

# visualization
import matplotlib.pyplot as plt


# CSV 파일 경로 지정
csv_path = '/data/ephemeral/home/submissions/tta_aug.csv'

# 확인할 이미지 개수와 파일명 (이미지 개수가 파일명 보다 많으면 랜덤으로 채움)
image_n = 3
#file_name=[]
file_name = ['image1661389291522.png', 'image1663724556918.png', 'image1664154289655.png'] 

# CSV 파일을 DataFrame으로 불러오기
df = pd.read_csv(csv_path)

IMAGE_ROOT = "/data/ephemeral/home/data/test/DCM"

CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

pngs = {
    os.path.relpath(os.path.join(root, fname), start=IMAGE_ROOT)
    for root, _dirs, files in os.walk(IMAGE_ROOT)
    for fname in files
    if os.path.splitext(fname)[1].lower() == ".png"
}

def encode_mask_to_rle(mask):
    '''
    mask: numpy array binary mask
    1 - mask
    0 - background
    Returns encoded run length
    '''
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def decode_rle_to_mask(rle, height, width):
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(height * width, dtype=np.uint8)

    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1

    return img.reshape(height, width)

image_list = []

# 파일 이름의 번호(n)를 찾아 추가
for n in range(288):
    image_name = df.iloc[n*29]['image_name']
    if image_name in file_name:
        image_list.append(n)

# 나머지를 랜덤한 이미지로 추가
while len(image_list) < image_n:
    random_number = np.random.randint(0, 288)
    if random_number not in image_list:
        image_list.append(random_number)

print('Image ID: ', image_list)
# DataFrame 내용 출력

for n in image_list:
    image_name = df.iloc[n*29]['image_name']
    for png in pngs:
        if image_name in png:
            image_path = IMAGE_ROOT + '/' +  png
    image = cv2.imread(image_path)

    preds = []
    for i in range(n*29, (n+1)*29):
        rle = df.iloc[i]['rle']
        pred = decode_rle_to_mask(rle, height=2048, width=2048)
        preds.append(pred)

    preds = np.stack(preds, 0)

    # define colors
    PALETTE = [
        (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
        (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
        (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
        (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
        (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
        (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
    ]

    # utility function
    # this does not care overlap
    def label2rgb(label):
        image_size = label.shape[1:] + (3, )
        image = np.zeros(image_size, dtype=np.uint8)

        for i, class_label in enumerate(label):
            image[class_label == 1] = PALETTE[i]

        return image

    fig, ax = plt.subplots(1, 2, figsize=(24, 12))
    ax[0].imshow(image)    # remove channel dimension
    ax[1].imshow(label2rgb(preds))

    plt.show()