## general_trash split class update

In [1]:
import os
import json
import shutil

In [2]:
def process_folders(base_folder, folders, new_json_file, merge_categories=None):
    # JSON 파일 저장을 위한 폴더 생성 (폴더가 존재하지 않을 경우)
    os.makedirs(os.path.dirname(new_json_file), exist_ok=True)

    # 원본 JSON 파일 복사
    shutil.copyfile('../../dataset/train.json', new_json_file)

    # 복사된 JSON 파일 로드
    with open(new_json_file, 'r') as file:
        data = json.load(file)

    # 각 폴더를 새로운 카테고리로 추가 (merge_categories가 있으면 병합)
    for folder in folders:
        category_name = merge_categories[folder] if merge_categories and folder in merge_categories else folder
        category_id = next((item['id'] for item in data['categories'] if item['name'] == category_name), len(data['categories']))
        if category_id == len(data['categories']):
            data['categories'].append({
                "id": category_id,
                "name": category_name,
                "supercategory": "General trash"
            })

        # annotations 업데이트
        image_files = os.listdir(os.path.join(base_folder, folder))
        for file in image_files:
            parts = file.split('_')
            image_id = int(parts[0])
            annotation_id = int(parts[1].split('.')[0])  # 파일 확장자 제거

            for annotation in data['annotations']:
                if annotation['image_id'] == image_id and annotation['id'] == annotation_id:
                    annotation['category_id'] = category_id
                    break

    # 변경된 JSON 데이터 저장
    with open(new_json_file, 'w') as file:
        json.dump(data, file, indent=4)

# 기본 설정
base_folder = '../../general_images_classified_from_original_label'
all_folders = ['box_tape', 'wastepaper', 'cigarette_packet', 'straw', 'binder', 'paper_piece', 'leaflet', 'business_card', 'etc']
save_folder = base_folder.split('_')[4] + '_general_trash_json'

# 각 경우에 따른 처리
process_folders(base_folder, ['box_tape'], f'../../{save_folder}/train_class_11_1.json')
process_folders(base_folder, ['wastepaper'], f'../../{save_folder}/train_class_11_2.json')
process_folders(base_folder, ['cigarette_packet'], f'../../{save_folder}/train_class_11_3.json')
process_folders(base_folder, ['straw'], f'../../{save_folder}/train_class_11_4.json')
process_folders(base_folder, ['binder'], f'../../{save_folder}/train_class_11_5.json')
process_folders(base_folder, ['box_tape', 'wastepaper'], f'../../{save_folder}/train_class_12_1.json')
process_folders(base_folder, ['box_tape', 'wastepaper', 'cigarette_packet'], f'../../{save_folder}/train_class_13_1.json')
process_folders(base_folder, ['box_tape', 'wastepaper', 'cigarette_packet', 'straw'], f'../../{save_folder}/train_class_14_1.json')
process_folders(base_folder, ['box_tape', 'wastepaper', 'cigarette_packet', 'straw', 'binder'], f'../../{save_folder}/train_class_15_1.json')
# ... (나머지 경우에 대해서도 동일한 패턴으로 process_folders 호출)
# 예: process_folders(base_folder, ['box_tape', 'wastepaper'], 'dataset/train_class_12_1.json')

# 특정 폴더들을 하나의 카테고리로 묶는 경우
merged_categories = {'paper_piece': 'paper_piece', 'leaflet': 'paper_piece', 'business_card': 'paper_piece'}
process_folders(base_folder, ['box_tape', 'wastepaper', 'cigarette_packet', 'straw', 'binder', 'paper_piece', 'leaflet', 'business_card'], f'../../{save_folder}/train_class_16_1.json', merge_categories=merged_categories)

# etc 폴더를 제외한 모든 폴더 처리
folders_except_etc = [folder for folder in all_folders if folder != 'etc']
process_folders(base_folder, folders_except_etc, f'../../{save_folder}/train_class_20.json')


## Merge category

In [4]:
def merge_categories(json_file, categories_to_merge, final_category_name):
    # JSON 파일 로드
    with open(json_file, 'r') as file:
        data = json.load(file)

    # 합치려는 카테고리의 ID 찾기
    category_ids = [cat['id'] for cat in data['categories'] if cat['name'] in categories_to_merge]

    # 최종 카테고리의 ID (가장 작은 ID)
    final_category_id = min(category_ids)

    # annotations 업데이트
    for annotation in data['annotations']:
        if annotation['category_id'] in category_ids:
            annotation['category_id'] = final_category_id

    # 카테고리 업데이트: 불필요한 카테고리 제거 및 최종 카테고리 이름 변경
    data['categories'] = [cat for cat in data['categories'] if cat['id'] not in category_ids or cat['id'] == final_category_id]
    for cat in data['categories']:
        if cat['id'] == final_category_id:
            cat['name'] = final_category_name

    # 변경된 JSON 데이터 저장
    with open(json_file, 'w') as file:
        json.dump(data, file, indent=4)

# json file path
json_file = '../../dataset/train_class20.json'

# merge하고자하는 category 추가
categories_to_merge = ['box_tape', 'paper_piece']

# merge 후 최종 category name
final_category_name = 'paper_piece'

merge_categories(json_file, categories_to_merge, final_category_name)