사용법: 

Roboflow에서 데이터셋 Export 옵션에서 YOLO v5 Pytorch로 받은 후, 압축폴더를 원하는 directory에다가 해제. 
box_to_keypoin.ipynb 파일을 data 폴더와 같은 디렉토리에 넣은 후 실행. 

In [161]:
import json
import os
import cv2 as cv
import matplotlib.pyplot as plt
import re

In [162]:
yml_path = './data/data.yaml'
std_labels = ['tail_0', 'tail_1', 'tail_2', 'tail_3', 'tail_4', 'tail_5' 'spine', 'neck', 'nose', 'fu', 'hu']
num_points = len(std_labels) + 1 # cat까지 합쳐서 12개

with open(yml_path, 'r') as fp:
    raw = fp.readlines()
    result = re.findall("'.+'", raw[5])
    labels = [i.replace('\'', '').strip() for i in result[0].split(',')]
    print("해당 yaml에 등록된 라벨: ", labels)


해당 yaml에 등록된 라벨:  ['cat', 'fu', 'hu', 'neck', 'nose', 'spine', 'tail_0', 'tail_1', 'tail_2', 'tail_3', 'tail_4', 'tail_5']


In [163]:
label_dict = {i: labels[i] for i in range(1, len(labels))}

In [164]:
# keypoint좌표 자동 변경
# keypoint 형식 = [x좌표, y좌표, 라벨번호, visibility(0 or 1)]
def converter(file_labels:str, file_image:str):
    keypoints = []
    img = cv.imread(file_image)
    img_w, img_h = img.shape[1], img.shape[0]

    with open(file_labels) as f:
        lines_txt = f.readlines()
        lines = []
        for line in lines_txt:
            lines.append([int(line.split()[0])] + [round(float(coord), 5) for coord in line.split()[1:]])

    for line in lines:
        if line[0] == 0:
            x_c = round(line[1] * img_w)
            y_c = round(line[2] * img_h)
            w = round(line[3] * img_w)
            h = round(line[4] * img_h)
            
            bboxes = ([round(x_c - w/2), round(y_c - h/2), round(x_c + w/2), round(y_c + h/2)])
        
        else:
            kp_id = line[0]
            x_c = round(line[1] * img_w)
            y_c = round(line[2] * img_h)
            keypoints.append([x_c, y_c, kp_id, 1]) # 기본으로 visibility는 1로 체크

    keypoints = sorted(keypoints, key=lambda x: x[2]) # keypoint id기준으로 상향정렬

    return bboxes, keypoints


# 사진+keypoint+bbox 시각화
def visualize(image_path:str, keypoints:list, bboxes:list, save=False):
    top_left_corner, bottom_right_corner = tuple([bboxes[0], bboxes[1]]), tuple([bboxes[2], bboxes[3]])
    image = cv.imread(image_path)
    img = cv.rectangle(image, top_left_corner, bottom_right_corner, (0, 255, 0), 3)
    
    for kp_idx, kp in enumerate(keypoints):
        center = tuple([kp[0], kp[1]])
        img = cv.circle(img, center, 5, (255,0,0), 5)
        img = cv.putText(img, " " + label_dict[kp[2]], center, cv.FONT_HERSHEY_SIMPLEX, 1.0, (255,0,0), 2)
    
    plt.figure(figsize=(15, 15))
    plt.imshow(img)


def dump2json(bboxes:list, keypoints: list, json_path: str):
    annotations = {}
    annotations['bboxes'], annotations['keypoints'] = bboxes, keypoints
    with open(json_path, "w") as f:
        json.dump(annotations, f)


def main(IMAGE_PATH:str, LABELS_PATH:str, ANNOTATIONS_PATH:str):
    image_file_list = [file.split('.jpg')[0] for file in os.listdir(IMAGE_PATH)]
    
    
    for index, value in enumerate(image_file_list):
        label_path = os.path.join(LABELS_PATH, value + ".txt")
        image_path = os.path.join(IMAGE_PATH, value + ".jpg")
        bboxes, keypoints = converter(label_path, image_path)
        dump2json(bboxes, keypoints, os.path.join(ANNOTATIONS_PATH, str(index) + '.json'))
        
        old_name = IMAGE_PATH + value + '.jpg'
        new_name = IMAGE_PATH + str(index) + '.jpg'
        os.rename(old_name, new_name)

        

In [165]:
TRAIN_PATH = './data/train/'
TEST_PATH = './data/test/'

# annotations 폴더 생성

for i in [TRAIN_PATH, TEST_PATH]:
    try:
        if not os.path.exists(i + '/annotations'):
            os.makedirs(i + '/annotations')
    except OSError:
        print("Error while creating the data directory")

In [166]:
# 화이팅!
main(TRAIN_PATH + 'images/', TRAIN_PATH + 'labels/',  TRAIN_PATH + 'annotations/')

main(TEST_PATH + 'images/', TEST_PATH + 'labels/',  TEST_PATH + 'annotations/')

In [160]:
failed_files = [] 
no = 0
for j in [TRAIN_PATH]:
    for i in sorted(os.listdir(os.path.join(j, "annotations"))):
        file = os.path.join(j, "annotations", i)
        with open(file) as fp:
            data = json.load(fp)
            if len(data['keypoints']) != 11:
                failed_files.append(i)
                print(data['bboxes'])
                print(data['keypoints'])
                print(f"형식 문제 있음. 오류가 있는 파일은 {i}이고, 육안점검을 추천하고, 자동으로 문제 있는 파일들을 현재 로컬 저장소에서 삭제하고자 할 경우 다음 셀을 실행하세요")