## horizontal flip된 이미지와 json 파일 저장
- train 아래의 각 ID000 폴더 아래에 저장  
- 원본 파일명에 '_flip'이 붙은 형태인 '_flip.png'와 '_filp.json'으로 저장됩니다.

In [None]:
import os
import cv2
import albumentations as A
import matplotlib.pyplot as plt
import copy
import numpy as np
import json
import torch
from PIL import Image

IMAGE_ROOT = "/opt/ml/input/data/train/DCM"
LABEL_ROOT = "/opt/ml/input/data/train/outputs_json"

PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

pngs = {
    os.path.relpath(os.path.join(root, fname), start=IMAGE_ROOT)
    for root, _dirs, files in os.walk(IMAGE_ROOT)
    for fname in files
    if os.path.splitext(fname)[1].lower() == ".png"
}
pngs = sorted(pngs)  # 800개

def label2rgb(label):
    image_size = label.shape[1:] + (3, )
    image = np.zeros(image_size, dtype=np.uint8)
    for i, class_label in enumerate(label):
        image[class_label == 1] = PALETTE[i]
    return image

def show(idx):
    image_name = pngs[idx]
    image_path = os.path.join(IMAGE_ROOT, image_name)
    image = cv2.imread(image_path)
    label_name = image_name.split('.')[0]+'.json'
    label_path = os.path.join(LABEL_ROOT, label_name)
    
    # process a label of shape (H, W, NC)
    label_shape = tuple(image.shape[:2]) + (len(CLASSES), )
    label = np.zeros(label_shape, dtype=np.uint8)
    transformed_label = np.zeros(label_shape, dtype=np.uint8)
    
    # read label file
    with open(label_path, "r") as f:
        data = json.load(f)
    annotations = data["annotations"]

    # iterate each class
    for ann in annotations:
        c = ann["label"]
        class_ind = CLASSES.index(c)
        point = np.array(ann["points"])
        transformed_point = copy.deepcopy(point)
        # horizontal flip
        for i in range(len(transformed_point)):
            transformed_point[i] = np.array([abs(transformed_point[i][0]-2048+1), transformed_point[i][1]])
        
        # polygon to mask
        class_label = np.zeros(image.shape[:2], dtype=np.uint8)
        cv2.fillPoly(class_label, [point], 1)
        label[..., class_ind] = class_label
        transformed_class_label = np.zeros(image.shape[:2], dtype=np.uint8)
        cv2.fillPoly(transformed_class_label, [transformed_point], 1)
        transformed_label[..., class_ind] = transformed_class_label

    # 이미지 변환
    transform = A.Compose([
        A.HorizontalFlip(p=1.0),
    ])
    transformed_image = transform(image=image)["image"]

    label = label.transpose(2, 0, 1)
    label = torch.from_numpy(label).float()
    transformed_label = transformed_label.transpose(2, 0, 1)
    transformed_label = torch.from_numpy(transformed_label).float()

    fig, ax = plt.subplots(2, 2, figsize=(24, 24), sharex=True, sharey=True)
    
    ax[0][0].set_title(image_name, fontsize=30)
    ax[0][0].imshow(image)
    
    ax[0][1].set_title("Transformed Image", fontsize=30)
    ax[0][1].imshow(transformed_image)
    
    ax[1][0].imshow(label2rgb(label))
    ax[1][1].imshow(label2rgb(transformed_label))
    
    fig.subplots_adjust(wspace=0, hspace=0)
    plt.show()

def save_transformed_file():
    for idx in range(len(pngs)):
        image_name = pngs[idx]
        print(str(idx)+' : '+image_name)
        image_path = os.path.join(IMAGE_ROOT, image_name)
        image = cv2.imread(image_path)
        label_name = image_name.split('.')[0]+'.json'
        label_path = os.path.join(LABEL_ROOT, label_name)
        
        # process a label of shape (H, W, NC)
        label_shape = tuple(image.shape[:2]) + (len(CLASSES), )
        label = np.zeros(label_shape, dtype=np.uint8)
        transformed_label = np.zeros(label_shape, dtype=np.uint8)
        
        # read label file
        with open(label_path, "r") as f:
            data = json.load(f)
        annotations = data["annotations"]
        new_ann = []
        
        # iterate each class
        for ann in annotations:
            new_point = []
            c = ann["label"]
            class_ind = CLASSES.index(c)
            point = np.array(ann["points"])
            transformed_point = copy.deepcopy(point)
            # horizontal flip
            for i in range(len(transformed_point)):  
                transformed_point[i] = np.array([abs(transformed_point[i][0]-2048+1), transformed_point[i][1]])
                new_point.append([int(transformed_point[i][0]), int(transformed_point[i][1])])
            # polygon to mask
            class_label = np.zeros(image.shape[:2], dtype=np.uint8)
            cv2.fillPoly(class_label, [point], 1)
            label[..., class_ind] = class_label
            transformed_class_label = np.zeros(image.shape[:2], dtype=np.uint8)
            cv2.fillPoly(transformed_class_label, [transformed_point], 1)
            transformed_label[..., class_ind] = transformed_class_label
            new_ann.append({
                'id': ann["id"],
                'type': 'poly_seg',
                'attributes': {},
                'points': new_point,
                'label': c})
    
        # 이미지 변환 후 저장
        transform = A.Compose([
            A.HorizontalFlip(p=1.0),
        ])
        transformed_image = transform(image=image)["image"]
        transformed_image = Image.fromarray(transformed_image)
        transformed_image.save(image_path.split('.')[0]+'_filp.png')
    
        label = label.transpose(2, 0, 1)
        label = torch.from_numpy(label).float()
        transformed_label = transformed_label.transpose(2, 0, 1)
        transformed_label = torch.from_numpy(transformed_label).float()
    
        # json 파일 저장
        with open(label_path.split('.')[0]+'_filp.json', 'w') as f:
            json.dump({'annotations':new_ann}, f)

## 예시 이미지 출력

In [None]:
show(0)  # TODO : index 0~799

## 파일 저장하기
소요시간 : 약 30분

In [None]:
save_transformed_file()