In [None]:
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.csv as pc
# ================== TODO ==================
exp = '[model]deeplabv3_res101_V1'
read_csv_file = './predictions/' + exp + '/submission.csv'
IMAGE_ROOT = "/opt/ml/input/data/test/DCM"
# ==========================================
df = pd.read_csv(read_csv_file)
df.head()

In [None]:
PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

def decode_rle_to_mask(rle, height, width):
    if rle == None:
        img = np.zeros(height * width, dtype=np.float64)
        return img.reshape(height, width), []
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    px = np.concatenate((starts//2048, ends//2048))
    py = np.concatenate((starts%2048, ends%2048))
    point = [px, py]
    img = np.zeros(height * width, dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(height, width), point
    
def label2rgb(label):
    image_size = label.shape[1:] + (3, )
    image = np.zeros(image_size, dtype=np.uint8)
    for i, class_label in enumerate(label):
        image[class_label == 1] = PALETTE[i]
    return image
    
def show(idx):
    rles1 = list(df['rle'][(idx*2)*len(CLASSES):(idx*2)*len(CLASSES)+len(CLASSES)].values)
    for root, _dirs, files in os.walk(IMAGE_ROOT):
        for fname in files:
            if os.path.splitext(fname)[1].lower() == ".png" and fname == df['image_name'][(idx*2)*len(CLASSES)]:
                image_file1 = os.path.join(root.split('/')[-1], fname)
    image1 = cv2.imread(os.path.join(IMAGE_ROOT, image_file1))
    preds1 = []
    points1 = []
    for i, rle in enumerate(rles1):
        pred, point = decode_rle_to_mask(rle, height=2048, width=2048)
        if np.max(pred) != 0:
            preds1.append(pred)
            points1.append((i, point))
    preds1 = np.stack(preds1, 0)

    rles2 = list(df['rle'][(idx*2+1)*len(CLASSES):(idx*2+1)*len(CLASSES)+len(CLASSES)].values)
    for root, _dirs, files in os.walk(IMAGE_ROOT):
        for fname in files:
            if os.path.splitext(fname)[1].lower() == ".png" and fname == df['image_name'][(idx*2+1)*len(CLASSES)]:
                image_file2 = os.path.join(root.split('/')[-1], fname)
    image2 = cv2.imread(os.path.join(IMAGE_ROOT, image_file2))
    preds2 = []
    points2 = []
    for i, rle in enumerate(rles2):
        pred, point = decode_rle_to_mask(rle, height=2048, width=2048)
        if np.max(pred) != 0:
            preds2.append(pred)
            points2.append((i, point))
    preds2 = np.stack(preds2, 0)
        
    fig, ax = plt.subplots(2, 2, figsize=(24, 24), sharex=True, sharey=True)
    ax[0][0].set_title(image_file1, fontsize=30)
    ax[0][0].imshow(image1)
    for (i, (x, y)) in points1:
        ax[0][0].scatter(y, x, s=0.1, c=[np.array(PALETTE[i])/255])
    ax[0][1].set_title(image_file2, fontsize=30)
    ax[0][1].imshow(image2)
    for (i, (x, y)) in points2:
        ax[0][1].scatter(y, x, s=0.1, c=[np.array(PALETTE[i])/255])
    ax[1][0].imshow(label2rgb(preds1))
    ax[1][1].imshow(label2rgb(preds2))

    fig.subplots_adjust(wspace=0, hspace=0)
    plt.show()

In [None]:
for i in range(110, 115):  # TODO : index 0~149
    show(i)