In [1]:
import os
import json

train_file = "/datasets/mvtec/mvtec_screws_train.json"
val_file = "/datasets/mvtec/mvtec_screws_val.json"
test_file = "/datasets/mvtec/mvtec_screws_test.json"
category_names = {7:'nut', 3:'wood_screw', 2:'lag_wood_screw', 8:'bolt',
                  6:'black_oxide_screw', 5:'shiny_screw', 4:'short_wood_screw',
                  1:'long_lag_screw', 9:'large_nut', 11:'nut2', 10:'nut1',
                  12:'machine_screw', 13:'short_machine_screw' }

import numpy as np
def xywha2xy4(xywha):  # a represents the angle(rad), clockwise, a=0 along the X axis
    x, y, w, h, a = xywha
    corner = np.array([[-w / 2, -h / 2], [w / 2, -h / 2], [w / 2, h / 2], [-w / 2, h / 2]])
    transform = np.array([[np.cos(a), -np.sin(a)], [np.sin(a), np.cos(a)]])
    return (transform.dot(corner.T).T + [x, y]).tolist()

def get_dota_fmt(file):
    label_file = json.load(open(file, "r"))
    
    image_name_lines = {}
    image_id_names = {}
    for img in label_file["images"]:
        image_name_lines[img["file_name"]] = []
        image_id_names[img["id"]] = img["file_name"]
        
    for ant in label_file["annotations"]:
        category_id = ant["category_id"]
        image_id = ant["image_id"]
        ant_id = ant["id"]
        
        # cx, cy가 아닌 cy, cx로 레이블링 되어있음
        cy, cx, w, h, a = ant["bbox"]
        
        # angle은 [-pi, pi] 라고 했는데, 해당 범위를 넘어가는 데이터가 존재하므로 
        # [-pi, pi] 범위로 재조정
        if a != 0:
            a = (a / abs(a))  * (abs(a) % np.pi)
        
        # xywha2xy4 함수를 사용하기 위해서 [0, 2pi] 범위로 조정 (clockwise)
        if a < 0:
            a *= -1
        elif a > 0:
            a = 2*np.pi-a

        xy4 = xywha2xy4((cx, cy, w, h, a))
        
        label = category_names[category_id]
        
        # DOTA 포맷 line
        line = []
        for xy in xy4:
            line.extend(map(int, xy))
        
        line.append(label)
        # difficulty
        line.append(0)        
        line = " ".join(map(str, line))
        image_name_lines[image_id_names[image_id]].append(line)
    return image_name_lines

train_image_name_lines = get_dota_fmt(train_file)
val_image_name_lines = get_dota_fmt(val_file)
test_image_name_lines = get_dota_fmt(test_file)

In [3]:
def write_to_labels(image_name_lines):
    os.makedirs("/datasets/mvtec/annfiles", exist_ok=True)
    for image_name, lines in image_name_lines.items():
        txt_path = os.path.join("/datasets/mvtec/annfiles", image_name.rstrip(".png")+".txt")
        with open(txt_path, "w") as f:
            for line in lines:
                f.writelines(line + "\n")
write_to_labels(train_image_name_lines)
write_to_labels(val_image_name_lines)
write_to_labels(test_image_name_lines)   

In [5]:
import glob 
import shutil

anns_path = "/datasets/mvtec/annfiles"
imgs_path = "/datasets/mvtec/images"
trainval_path = "/datasets/mvtec/trainval"
test_path = "/datasets/mvtec/test"
import os
trainval_imgs_path = os.path.join(trainval_path, "images")
trainval_anns_path = os.path.join(trainval_path, "annfiles")
test_imgs_path = os.path.join(test_path, "images")
test_anns_path = os.path.join(test_path, "annfiles")

os.makedirs(trainval_imgs_path, exist_ok=True)
os.makedirs(trainval_anns_path, exist_ok=True)
os.makedirs(test_imgs_path, exist_ok=True)
os.makedirs(test_anns_path, exist_ok=True)

txtfiles = glob.glob("/datasets/mvtec/annfiles/*.txt")
import random
random.shuffle(txtfiles)
imgfiles = [txtfile.replace("annfiles", "images").replace(".txt", ".png") for txtfile in txtfiles]

n = len(txtfiles)
n_train = int(n*0.8)

In [6]:
train_txtfiles = txtfiles[:n_train]
train_imgfiles = imgfiles[:n_train]

for train_txtfile, train_imgfile in zip(train_txtfiles, train_imgfiles):
    txtfilename = train_txtfile.split("/")[-1]
    imgfilename = train_imgfile.split("/")[-1]
    shutil.copy(train_txtfile, os.path.join(trainval_anns_path, txtfilename))
    shutil.copy(train_imgfile, os.path.join(trainval_imgs_path, imgfilename))
    
test_txtfiles = txtfiles[n_train:]
test_imgfiles = imgfiles[n_train:]

for test_txtfile, test_imgfile in zip(test_txtfiles, test_imgfiles):
    txtfilename = test_txtfile.split("/")[-1]
    imgfilename = test_imgfile.split("/")[-1]
    shutil.copy(test_txtfile, os.path.join(test_anns_path, txtfilename))
    shutil.copy(test_imgfile, os.path.join(test_imgs_path, imgfilename))

In [3]:
import torch
a=torch.load("../datasets/mvtec_balanced.pth")