In [1]:
import sys
sys.path.append('../../')

import pandas as pd 
from glob import glob
import cv2
from inference import decode_rle_to_mask
from dataset import get_xray_classes
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import utils
from tqdm import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 시각화를 위한 팔레트를 설정합니다.
PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

# 시각화 함수입니다. 클래스가 2개 이상인 픽셀을 고려하지는 않습니다.
def label2rgb(label):
    image_size = label.shape[1:] + (3, )
    image = np.zeros(image_size, dtype=np.uint8)
    
    for i, class_label in enumerate(label):
        image[class_label == 1] = PALETTE[i]
        
    return image


def seach_matching_paths(image_name, paths):
    matching_paths = [path for path in paths if image_name in path]
    return matching_paths


def get_preds(image_name, df, classes):
    index = df['image_name'] == image_name
    preds = [[] for _ in range(classes['num_class'])] # 행렬 초기화

    for _, row in df[index].iterrows():
        c, rle = row['class'], row['rle']
        idx = classes['class2idx'][c]
        pred = decode_rle_to_mask(rle, height=2048, width=2048)
        preds[idx] = pred

    preds = np.stack(preds, 0)
    return preds
    

def color_norm(color):
    return [i/255. for i in color]

def plot_gt(anns_path, classes, ax, color=None):
    anns = utils.read_json(anns_path)['annotations']

    for ann in anns:
        class_name = ann["label"]
        class_ind = classes['class2idx'][class_name]
        points = np.array(ann["points"])

        if color is None:
            ax.plot(points[:, 0], points[:, 1], color=color_norm(PALETTE[class_ind]))
        else:
            ax.plot(points[:, 0], points[:, 1], color=color)


def get_crop_coord(mask):
    image_size = [2024, 2024]
    margin = [50, 50]

    y, x = np.where(mask)
    x_min, x_max = min(x), max(x)
    y_min, y_max = min(y), max(y)
    x_min, x_max = x_min - margin[0], x_max + margin[0]
    y_min, y_max = y_min - margin[1], y_max + margin[1]

    result = [x_min, x_max, y_min, y_max]
    result = np.clip(result, 0, image_size[0]).astype(int).tolist()
    x_min, x_max, y_min, y_max = result[0], result[1], result[2], result[3]

    return x_min, x_max, y_min, y_max


def draw_rect(coord, ax, edgecolor='blue'):
    x_min, x_max, y_min, y_max = coord
    x_width, y_width = x_max-x_min, y_max-y_min
    rect = Rectangle((x_min, y_min), x_width, y_width, edgecolor=edgecolor, facecolor='none', linewidth=2)
    ax.add_patch(rect)


In [3]:
data_dir_path = '/data/ephemeral/home/data'

image_paths = glob(data_dir_path + '/**/*.png', recursive=True)
anns_paths = glob(data_dir_path + '/**/*.json', recursive=True)
classes = get_xray_classes()

In [4]:
def save_crop_coord(file_path):
    df = pd.read_csv(file_path)
    image_names = df['image_name'].unique()
    df_new = pd.DataFrame(columns=['image_name', 'crop_finger', 'crop_backhand', 'crop_fingerbackhand', 'crop_arm'])

    for i, image_name in enumerate(tqdm(image_names[0:10])):
        # 예측결과 불러오기
        preds = get_preds(image_name, df, classes)

        mask_finger = preds[classes['finger_idx']].sum(axis=0)
        mask_backhand = preds[classes['backhand_idx']].sum(axis=0)
        mask_figerbackhand = preds[classes['fingerbackhand_idx']].sum(axis=0)
        mask_arm = preds[classes['arm_idx']].sum(axis=0)

        crop_finger = get_crop_coord(mask_finger)
        crop_backhand = get_crop_coord(mask_backhand)
        crop_fingerbackhand = get_crop_coord(mask_figerbackhand)
        crop_arm = get_crop_coord(mask_arm)

        df_new.loc[i] = [image_name, crop_finger, crop_backhand, crop_fingerbackhand, crop_arm]

    save_name = os.path.basename(file_path)
    df_new.to_csv(save_name, index=False)
     

In [5]:
file_path_format = '/data/ephemeral/home/Dongjin/level2-cv-semanticsegmentation-cv-02-lv3/Baseline/Dongjin/transformers_1120/trained_models/cont/upernet-convnext-small_cont_size_1024_cont_weight/{mode}_ep_47_vdice_0.9716_upernet-convnext-small_cont_size_1024_cont_weight.csv'
modes = ['valid', 'train', 'test']

for mode in modes:
    file_path = file_path_format.format(mode=mode)
    save_crop_coord(file_path)
    


100%|██████████| 10/10 [00:09<00:00,  1.10it/s]
100%|██████████| 1/1 [00:22<00:00, 22.94s/it]
100%|██████████| 10/10 [00:09<00:00,  1.07it/s]
