In [None]:
import os
import json

# external library
import cv2
import numpy as np
from tqdm.auto import tqdm
import albumentations as A

# torch
import torch
from torch.utils.data import Dataset

In [None]:
# 데이터 경로를 입력하세요

IMAGE_ROOT = "/Users/johyewon/Desktop/BoostCamp/Project/4. Semnatic-Segmentation/Code/data/train/DCM"
LABEL_ROOT = "/Users/johyewon/Desktop/BoostCamp/Project/4. Semnatic-Segmentation/Code/data/train/outputs_json"

In [None]:
CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

CLASS2IND = {v: i for i, v in enumerate(CLASSES)}


IND2CLASS = {v: k for k, v in CLASS2IND.items()}


pngs = {
    os.path.relpath(os.path.join(root, fname), start=IMAGE_ROOT)
    for root, _dirs, files in os.walk(IMAGE_ROOT)
    for fname in files
    if os.path.splitext(fname)[1].lower() == ".png"
}

jsons = {
    os.path.relpath(os.path.join(root, fname), start=LABEL_ROOT)
    for root, _dirs, files in os.walk(LABEL_ROOT)
    for fname in files
    if os.path.splitext(fname)[1].lower() == ".json"
}

jsons_fn_prefix = {os.path.splitext(fname)[0] for fname in jsons}
pngs_fn_prefix = {os.path.splitext(fname)[0] for fname in pngs}

assert len(jsons_fn_prefix - pngs_fn_prefix) == 0
assert len(pngs_fn_prefix - jsons_fn_prefix) == 0

pngs = sorted(pngs)
jsons = sorted(jsons)

pngs = np.array(pngs)
jsons = np.array(jsons)

In [None]:
class XRayDataset(Dataset):
    def __init__(self, filenames, labelnames, transforms=None, is_train=False):
        self.filenames = filenames
        self.labelnames = labelnames
        self.is_train = is_train
        self.transforms = transforms

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, item):
        image_name = self.filenames[item]
        image_path = os.path.join(IMAGE_ROOT, image_name)
    
        try:
            image = cv2.imread(image_path)
            image = image / 255.  # 이미지 정규화
    
            label_name = self.labelnames[item]
            label_path = os.path.join(LABEL_ROOT, label_name)
    
            # process a label of shape (H, W, NC)
            label_shape = tuple(image.shape[:2]) + (len(CLASSES), )
            label = np.zeros(label_shape, dtype=np.uint8)
    
            # read label file
            with open(label_path, "r") as f:
                annotations = json.load(f)
            annotations = annotations["annotations"]
            
            # 이미지와 레이블의 크기가 맞지 않으면 오류 처리
            if np.array(annotations).shape != (len(annotations),):
                print(f"Error: {image_name} has wrong annotation shape")
                return None, None
    
            # iterate each class
            for ann in annotations:
                c = ann["label"]
                class_ind = CLASS2IND.get(c, None)
                if class_ind is None:
                    print(f"Unknown class label: {c} for {image_name}")
                    continue
                
                points = np.array(ann["points"])
    
                # 폴리곤을 마스크로 변환
                class_label = np.zeros(image.shape[:2], dtype=np.uint8)
                cv2.fillPoly(class_label, [points], 1)
    
                # 각 클래스 마스크를 label에 적용
                label[..., class_ind] = np.maximum(label[..., class_ind], class_label)
    
            if self.transforms is not None:
                inputs = {"image": image, "mask": label} if self.is_train else {"image": image}
                result = self.transforms(**inputs)
    
                image = result["image"]
                label = result["mask"] if self.is_train else label
    
            # 이미지와 레이블을 채널 우선 형태로 변환
            image = image.transpose(2, 0, 1)    # 채널 우선
            label = label.transpose(2, 0, 1)
    
            image = torch.from_numpy(image).float()
            label = torch.from_numpy(label).float()
    
            return image, label
    
        except Exception as e:
            print(f"Error: {e}")
            return None, None


In [None]:
# define colors
PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

In [None]:
# utility function
# this does not care overlap
def label2rgb(label):
    # 이미지 크기 설정: (H, W, 3) 형태로 이미지 생성
    image_size = label.shape[:2] + (3,)
    image = np.zeros(image_size, dtype=np.uint8)

    # 각 클래스를 순차적으로 처리하면서 이미지에 색을 적용
    for i in range(label.shape[2]):  # label.shape[2] = 29, 클래스 수만큼 반복
        image[label[:, :, i] == 1] = PALETTE[i]  # 해당 클래스에 맞는 PALETTE 색 적용

    return image

In [None]:
VISUALIZATION_ROOT = '/Users/johyewon/Desktop/BoostCamp/Project/4. Semnatic-Segmentation/Code/Pred_visualization'

In [None]:
# 이미지 위에 라벨 생성

def visualize_images_and_labels(image, label, output_filename):
    if image is None or label is None:
        print(f"Error: Image or label is None. Cannot visualize {output_filename}.")
        return

    # 이미지와 레이블을 NumPy 배열로 변환
    image = image.numpy().transpose(1, 2, 0)  # (C, H, W) -> (H, W, C)
    label = label.numpy().transpose(1, 2, 0)  # (C, H, W) -> (H, W, C)

    # 레이블을 RGB 이미지로 변환
    label_overlay = label2rgb(label)

    # 이미지를 uint8 타입으로 변환
    image = (image * 255).astype(np.uint8) 

    # 이미지와 레이블을 겹쳐서 시각화
    alpha = 0.5  # 레이블 오버레이의 투명도 (0.0 ~ 1.0)
    output_image = cv2.addWeighted(image, 1, label_overlay, alpha, 0)

    # 마스크 위에 클래스 이름을 표시합니다.
    for i in range(label.shape[2]):
        if np.any(label[:, :, i]):  # 해당 클래스의 마스크가 존재하는 경우
            class_name = IND2CLASS[i]
            # 마스크 영역의 중심 좌표를 계산합니다.
            contours, _ = cv2.findContours(label[:, :, i].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                M = cv2.moments(contours[0])
                cX = int(M["m10"] / M["m00"])
                cY = int(M["m01"] / M["m00"])

                # 텍스트를 마스크 위에 표시합니다.
                cv2.putText(output_image, class_name, (cX, cY), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255, 1), 1) 

    # 시각화 결과 이미지 저장
    cv2.imwrite(output_filename, output_image)

    print(f"Image with mask saved to {output_filename}")

In [None]:
# Create directories for visualization if they don't exist
if not os.path.exists(VISUALIZATION_ROOT):
    os.makedirs(VISUALIZATION_ROOT)

In [None]:
# Iterate through all IDs (ID001 to ID548) and process images
for i in tqdm(range(20, 549)):
    # Generate the ID folder name (e.g., ID001, ID002, ..., ID548)
    id_folder = f"ID{i:03d}"

    # Create folder for this ID inside the visualization directory
    id_output_folder = os.path.join(VISUALIZATION_ROOT, id_folder)
    if not os.path.exists(id_output_folder):
        os.makedirs(id_output_folder)

    # Get list of all train filenames and labels
    train_filenames = list(pngs)  # pngs 전체를 train으로 사용
    train_labelnames = list(jsons)  # jsons 전체를 train으로 사용

    # Filter the filenames for the current ID, making sure to keep only filenames that correspond to the current ID
    id_filenames = [filename for filename in train_filenames if f"ID{i:03d}" in filename]
    id_labelnames = [label for label in train_labelnames if f"ID{i:03d}" in label]

    # 이미지와 레이블 파일의 길이가 같지 않으면 에러를 발생시켜 확인
    if len(id_filenames) != len(id_labelnames):
        print(f"Warning: Image files count {len(id_filenames)} doesn't match label files count {len(id_labelnames)} for ID {i:03d}")


    # Create the dataset and apply transformations
    transform = A.Resize(512, 512)
    dataset = XRayDataset(id_filenames, id_labelnames, transforms=transform, is_train=True)
    print(dataset)

    for idx, (filename, labelname) in enumerate(zip(id_filenames, id_labelnames)):
        try:
            # Get image and label from dataset
            image, label = dataset[idx]  # idx를 사용하여 이미지와 레이블 가져오기
            
            # Check if the image or label is None
            if image is None or label is None:
                print(f"Warning: Unable to load data from {filename} or {labelname}, skipping.")
                continue
            
            # Generate output filename for visualization
            output_filename = os.path.join(id_output_folder, f"{os.path.basename(filename).split('.')[0]}_vis.png") 
            
            # Visualize and save the image
            visualize_images_and_labels(image, label, output_filename)  

        except Exception as e:
            print(f"Error processing {filename} or {labelname}: {e}")
            continue    
