In [None]:
%pip install ultralytics pyyaml -q

In [None]:
import os
import shutil
import yaml
from ultralytics import YOLO


In [None]:
# 원본 데이터셋 경로 (data.yaml이 있는 폴더)
original_dataset = './PPE-1'

# 필터링된 데이터셋 저장 경로
filtered_dataset = './ppe_filtered'

# 원하는 클래스 ID: 0=Hardhat, 2=Safety_Boots, 3=Safety_Gloves
wanted_classes = [0, 2, 3]


In [None]:
def filter_yolo_labels(original_label_path, output_label_path, wanted_class_ids):
    with open(original_label_path, 'r') as f:
        lines = f.readlines()
    
    filtered_lines = [line for line in lines if int(line.split()[0]) in wanted_class_ids]
    
    if filtered_lines:
        with open(output_label_path, 'w') as f:
            f.writelines(filtered_lines)
        return True
    return False

def filter_yolo_dataset(original_path, output_path, wanted_ids):
    for split in ['train', 'valid', 'test']:
        img_dir = os.path.join(original_path, split, 'images')
        lbl_dir = os.path.join(original_path, split, 'labels')
        
        if not os.path.exists(img_dir):
            continue
        
        out_img_dir = os.path.join(output_path, split, 'images')
        out_lbl_dir = os.path.join(output_path, split, 'labels')
        os.makedirs(out_img_dir, exist_ok=True)
        os.makedirs(out_lbl_dir, exist_ok=True)
        
        count = 0
        for lbl_file in os.listdir(lbl_dir):
            if not lbl_file.endswith('.txt'):
                continue
            
            orig_lbl = os.path.join(lbl_dir, lbl_file)
            out_lbl = os.path.join(out_lbl_dir, lbl_file)
            
            if filter_yolo_labels(orig_lbl, out_lbl, wanted_ids):
                img_name = os.path.splitext(lbl_file)[0]
                for ext in ['.jpg', '.jpeg', '.png', '.JPG']:
                    img_file = img_name + ext
                    orig_img = os.path.join(img_dir, img_file)
                    if os.path.exists(orig_img):
                        shutil.copy2(orig_img, os.path.join(out_img_dir, img_file))
                        count += 1
                        break
        
        print(f'[{split}] {count}개 이미지-라벨 쌍 복사 완료')


In [None]:
def update_class_ids(dataset_path, old_to_new):
    for split in ['train', 'valid', 'test']:
        lbl_dir = os.path.join(dataset_path, split, 'labels')
        if not os.path.exists(lbl_dir):
            continue
        
        for lbl_file in os.listdir(lbl_dir):
            if not lbl_file.endswith('.txt'):
                continue
            
            lbl_path = os.path.join(lbl_dir, lbl_file)
            with open(lbl_path, 'r') as f:
                lines = f.readlines()
            
            updated = []
            for line in lines:
                parts = line.split()
                parts[0] = str(old_to_new[int(parts[0])])
                updated.append(' '.join(parts) + '\n')
            
            with open(lbl_path, 'w') as f:
                f.writelines(updated)
    
    print('클래스 ID 재조정 완료')


In [None]:
def create_yaml(orig_yaml, out_yaml, wanted_ids, out_path):
    with open(orig_yaml, 'r', encoding='utf-8') as f:
        data = yaml.safe_load(f)
    
    orig_names = data['names']
    new_names = []
    old_to_new = {}
    
    for new_id, old_id in enumerate(wanted_ids):
        new_names.append(orig_names[old_id])
        old_to_new[old_id] = new_id
    
    new_data = {
        'path': out_path,
        'train': 'train/images',
        'val': 'valid/images',
        'test': 'test/images',
        'names': new_names,
        'nc': len(new_names)
    }
    
    with open(out_yaml, 'w', encoding='utf-8') as f:
        yaml.dump(new_data, f, default_flow_style=False, allow_unicode=True)
    
    print(f'새 data.yaml 생성: {new_names}')
    return old_to_new


In [None]:
print('=== 데이터 필터링 시작 ===')
filter_yolo_dataset(original_dataset, filtered_dataset, wanted_classes)

orig_yaml = os.path.join(original_dataset, 'data.yaml')
new_yaml = os.path.join(filtered_dataset, 'data.yaml')

mapping = create_yaml(orig_yaml, new_yaml, wanted_classes, filtered_dataset)
update_class_ids(filtered_dataset, mapping)
print('=== 필터링 완료 ===\n')


In [None]:
model = YOLO('yolov11n.pt')

results = model.train(
    data=new_yaml,
    epochs=100,
    imgsz=640,
    batch=16,
    workers=2,
    name='ppe_detection',
    patience=10,
    save=True,
    plots=True
)


In [None]:
from IPython.display import Image, display

display(Image(filename='runs/detect/ppe_detection/results.png'))
display(Image(filename='runs/detect/ppe_detection/confusion_matrix.png'))


In [None]:
model = YOLO('runs/detect/ppe_detection2/weights/best.pt')

metrics = model.val(data=new_yaml)

print(f'mAP50: {metrics.box.map50:.3f}')
print(f'mAP50-95: {metrics.box.map:.3f}')


In [None]:
test_img_dir = os.path.join(filtered_dataset, 'test/images')

results = model.predict(
    source=test_img_dir,
    save=True,
    conf=0.25,
    project='runs/detect',
    name='test_results',
    exist_ok=True
)

print(f'테스트 결과 저장 위치: runs/detect/test_results')


In [None]:
import glob

test_images = glob.glob(os.path.join(test_img_dir, '*.jpg'))[:3]

for img_path in test_images:
    results = model(img_path)
    results[0].plot()
    results[0].show()


In [None]:
shutil.copy('runs/detect/ppe_detection2/weights/best.pt', './ppe_best_model.pt')
print('모델이 ./ppe_best_model.pt로 저장되었습니다')
