In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt

In [None]:
CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

CLASS2IND = {v: i for i, v in enumerate(CLASSES)}
IND2CLASS = {v: k for k, v in CLASS2IND.items()}


PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

csv_path="../../outputs/ensemble_soft_output.csv"
output_dir="./segmentation_visualizations"
base_image_dir="../../data/test"

In [21]:
def rle_decode(mask_rle: str, shape: tuple[int, int]) -> np.ndarray:
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0::2], s[1::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

def label2rgb(label: np.ndarray, height: int, width: int) -> np.ndarray:
    image = np.zeros((height, width, 3), dtype=np.uint8)
    for i, class_mask in enumerate(label):
        image[class_mask == 1] = PALETTE[i]
    return image

def load_image_paths():
    image_paths = {}
    for root, _, files in os.walk(base_image_dir):
        for f in files:
            if f.endswith('.png') or f.endswith('.jpg'):
                image_paths[f] = os.path.join(root, f)
    return image_paths

def visualize_all_segmentations():
    df = pd.read_csv(csv_path)
    os.makedirs(output_dir, exist_ok=True)
    grouped = df.groupby("image_name")    
    image_paths = load_image_paths()  
    
    for image_name, group in grouped:
        image_path = image_paths.get(image_name) 
        if image_path is None:
            print(f"Image not found: {image_name}")
            continue
        
        image = np.array(Image.open(image_path).convert('RGB'))
        height, width = image.shape[:2]
        mask = np.zeros((len(PALETTE), height, width), dtype=np.uint8)
        
        for _, row in group.iterrows():
            class_name, rle = row['class'], row['rle']
            if pd.isna(rle):
                continue
            decoded_mask = rle_decode(rle, (height, width))
            mask[CLASS2IND[class_name]] = decoded_mask
        
        color_mask = label2rgb(mask, height, width)
        blended = cv2.addWeighted(image, 0.5, color_mask, 0.5, 0)
        
        output_path = os.path.join(output_dir, f"{image_name.split('.')[0]}_seg.png")
        plt.imsave(output_path, blended)
        print(f"Saved segmentation visualization for {image_name}")


In [None]:
visualize_all_segmentations()

In [None]:

def display_segmentation_pair(image_path: str, segmentation_path: str) -> None:
    # 원본 이미지 로드
    image = np.array(Image.open(image_path).convert('RGB'))
    # 세그멘테이션 이미지 로드
    segmentation = np.array(Image.open(segmentation_path).convert('RGB'))
    
    # 플롯 설정
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    axes[1].imshow(segmentation)
    axes[1].set_title("Segmentation")
    axes[1].axis('off')

    plt.tight_layout()
    plt.show()

def visualize_all_images_with_segmentations():
    segmentation_files = [f for f in os.listdir(output_dir) if f.endswith('_seg.png')]
    
    for i, seg_file in enumerate(segmentation_files[:5]): 
        image_name = seg_file.replace('_seg.png', '.png')  
        image_path = image_paths.get(image_name)  
        segmentation_path = os.path.join(output_dir, seg_file)
        
        if image_path is None:
            print(f"Original image not found: {image_name}")
            continue
        
        display_segmentation_pair(image_path, segmentation_path)


image_paths = load_image_paths()  

visualize_all_images_with_segmentations()


### 학습 이미지 중에서 잘 분류하지 못하는 것들은 무엇이 있을까?

train dataset으로 결과를 확인해보자!

In [1]:
img_path = "../../data/train/DCM"
label_path = "../../data/train/outputs_json"
top_k = 5
config_path = "../../outputs/dev_smp_unet_kh"
pth_path = "../../outputs/dev_smp_unet_kh/smp_unet_best_model.pth"
threshold = 0.5

In [2]:
import sys
import os
sys.path.append("/data/ephemeral/home/kwak/level2-cv-semanticsegmentation-cv-18-lv3")
from src.models.model_utils import *
from src.datasets.dataset import XRayDataset
from src.utils.augmentation import get_transform
from src.utils.metrics import get_metric_function
from torch.utils.data import Dataset, DataLoader, Subset
from src.utils.rle_convert import encode_mask_to_rle
import pandas as pd
import torch.nn.functional as F
import glob
from tqdm import tqdm
import yaml

def get_config(config_folder):
    config = {}

    config_folder = os.path.join(config_folder,'*.yaml')
    
    config_files = glob.glob(config_folder)
    
    for file in config_files:
        with open(file, 'r') as f:
            config.update(yaml.safe_load(f))
    
    if config['device'] == 'cuda' and not torch.cuda.is_available():
        print('using cpu now...')
        config['device'] = 'cpu'

    return config


  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


데이터셋 받아오기

In [3]:
config = get_config(config_path)

classes = config['classes']
CLASS2IND = {v: i for i, v in enumerate(classes)}
IND2CLASS = {v: k for k, v in CLASS2IND.items()}

transform = get_transform(config['data'], is_train=False)

_dataset = XRayDataset(
    image_root=img_path,
    label_root=label_path,
    classes=config['classes'],
    mode='val',
    transforms=transform
)

_dataloader = DataLoader(
    _dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    drop_last=False
)

모델 불러오기

In [4]:
device = torch.device(config['device'])
model = get_model(config['model'], config['classes']).to(device)

checkpoint = torch.load(pth_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

criterion = get_criterion(config['train']['criterion']['name'])
metric_fn = get_metric_function(config['train']['metric']['name'])

Segmentation 수행하기

In [None]:
total_loss = 0
dice_list = []

rles = []
gt_rles = []
filename_and_class = []

with torch.no_grad():
    for i, batch in enumerate(tqdm(_dataloader, desc="calculating")):
        inputs, masks = batch
        inputs, masks = inputs.to(device), masks.to(device)
        
        outputs = model(inputs)
        
        if isinstance(outputs, tuple):
                logits, logits1, logits2 = outputs
                use_multiple_outputs = True
        else: 
            logits = outputs['out'] if isinstance(outputs, dict) and 'out' in outputs else outputs
            use_multiple_outputs = False

        logits_h, logits_w = logits.size(-2), logits.size(-1)
        labels_h, labels_w = masks.size(-2), masks.size(-1)

        #출력과 레이블의 크기가 다른 경우 출력 텐서를 레이블의 크기로 보간
        if logits_h != labels_h or logits_w != labels_w:
            logits = F.interpolate(logits, size=(labels_h, labels_w), mode="bilinear", align_corners=False)
            if use_multiple_outputs:
                logits1 = F.interpolate(logits1, size=(labels_h, labels_w), mode="bilinear", align_corners=False)
                logits2 = F.interpolate(logits2, size=(labels_h, labels_w), mode="bilinear", align_corners=False)

        probs = torch.sigmoid(logits)
        # threshold 추가해서 기준 치 이상만 label로 분류 
        # outputs = (probs > threshold).detach().cpu()
        # masks = masks.detach().cpu()
        outputs = (probs > threshold)
        dice = metric_fn.calculate(outputs, masks).detach().cpu()
        
        dice_list.append(dice)
        
        # pd에 저장하기 위함
        outputs = outputs.detach().cpu().numpy()
        masks = masks.detach().cpu().numpy()
        
        for idx, (output, mask) in enumerate(zip(outputs, masks)):
            file_path, _ = _dataset.get_filepath(idx*i)
            
            for c, (seg_m, gt_m) in enumerate(zip(output, mask)):
                rle = encode_mask_to_rle(seg_m)
                gt_rle = encode_mask_to_rle(gt_m)
                rles.append(rle)
                gt_rles.append(gt_rle)
                filename_and_class.append(f"{IND2CLASS[c]}:{file_path}")

classes, file_path = zip(*[x.split(":") for x in filename_and_class])
df_pred = pd.DataFrame({
    "file_path": file_path,
    "class": classes,
    "rle": rles,
    "gt_rle": gt_rles,
})

# dice_list = sorted(dice_list, key=lambda x: (-x, x))

dices = torch.cat(dice_list, 0)

dices_per_img = torch.mean(dices, 1)

file_path, _ = _dataset.get_all_path()
df_dice = pd.DataFrame(dices, columns=[IND2CLASS[i] for i in range(29)])
df_dice.insert(0, "file_path", file_path)

calculating: 100%|██████████| 20/20 [09:57<00:00, 29.88s/it]


In [9]:
vis_csv_path = os.path.join(config_path, "vis_csv")
df_pred_path = os.path.join(vis_csv_path, "pred.csv")
df_dice_path = os.path.join(vis_csv_path, "dice.csv")

os.makedirs(vis_csv_path, exist_ok=True)

df_pred.to_csv(df_pred_path, index=False)
df_dice.to_csv(df_dice_path, index=False)

그 중 dice score 값이 제일 낮은 결과값들 가져오기

In [None]:
values, indices = torch.topk(dices_per_img, k=top_k, largest=False)
values, indices

In [None]:
sorted_indices = torch.argsort(dices_per_img)
indices_ = sorted_indices[:10]

l_v = dices_per_img[indices_]
l_v

In [14]:
for v, i in zip(values, indices):
    print(v)
    print(_dataset.get_filepath(i))

tensor(0.0067)
('ID070/image1661736042863.png', 'ID070/image1661736042863.json')
tensor(0.0068)
('ID444/image1666144319702.png', 'ID444/image1666144319702.json')
tensor(0.0069)
('ID506/image1666747111320.png', 'ID506/image1666747111320.json')
tensor(0.0070)
('ID480/image1666661091576.png', 'ID480/image1666661091576.json')
tensor(0.0070)
('ID004/image1661144724044.png', 'ID004/image1661144724044.json')


In [17]:
sorted_indices

tensor([ 23, 119, 145, 135,   3, 131,  27, 144, 147,  45, 155,  97,  29,  25,
        115,  11,  13,   1, 149, 153, 151,  15,  57,  83, 137,  75, 154,  51,
         91,  49,  21,  17,  71,  56, 123,  73,  33, 157,  39, 129, 103,   5,
         95,  44,  31,   2, 139,  54, 133,  61, 152,  41, 159,  22,  19,  89,
         35, 111,  10,  55, 121, 117,  59,  37, 107, 118,  87, 130,  99,  43,
        150,   0,  96,  24, 141, 125, 146, 134,  12,  53, 105,  26, 114,   7,
         74,   9, 148,  47,  32,  63, 127,  50,  30,  77,  90, 138,  81,  60,
         67,  34,  42, 143,  79,  69, 136, 101,  58,  28, 128, 116,  62,  72,
        158,   6, 120,  85,  48,  94,  14,  65,  20, 142,  16,  52, 109,  86,
        102,  82,  40, 122, 126, 140,  70,   4, 106, 132,   8,  38, 110,  80,
        124,  66,  68,  88,  46, 108, 100, 113, 156,  64,  18,  36,  93,  76,
         84, 112,  78,  98,  92, 104])