In [21]:
import os
import json
import pandas as pd
import torch
from ensemble_boxes import weighted_boxes_fusion
import cv2
from tqdm import tqdm
file_paths = [os.path.join('predictions', f) for f in os.listdir('predictions') if f.endswith('.csv')]
ocr_results = []
LANG = {'zh': 'chinese_receipt', 'ja': 'japanese_receipt', 'th': 'thai_receipt', 'vi':'vietnamese_receipt'}
data_dir = 'data'

for path in tqdm(file_paths, desc="Processing CSV files"):
    with open(path, 'r') as file:
        data = json.load(file)
        for image_id, image_data in tqdm(data['images'].items(), desc="Processing Images", leave=False):
            # 이미지 경로 생성
            lang_key = image_id.split('.')[1][:2]
            image_path = os.path.join(data_dir, LANG.get(lang_key, 'unknown_receipt'), 'img', 'test', f"{image_id}")

            # 이미지 크기 읽기
            if os.path.exists(image_path):
                img = cv2.imread(image_path)
                image_height, image_width = img.shape[:2]

                for word_id, word_data in image_data['words'].items():
                    if 'points' in word_data:
                        points = word_data['points']
                        x_coords = [p[0] for p in points]
                        y_coords = [p[1] for p in points]
                        bbox = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
                        confidence = word_data.get('confidence', 1.0)

                        # 각 박스에 대해 정규화 수행
                        normalized_bbox = [
                            bbox[0] / image_width,
                            bbox[1] / image_height,
                            bbox[2] / image_width,
                            bbox[3] / image_height
                        ]

                        # 정규화된 좌표가 [0, 1] 범위를 벗어나지 않도록 클램프
                        normalized_bbox = [max(0, min(1, coord)) for coord in normalized_bbox]

                        ocr_results.append({
                            'image_id': image_id,
                            'word_id': word_id,
                            'bbox': normalized_bbox,
                            'confidence': confidence,
                            'image_width': image_width,
                            'image_height': image_height
                        })

ocr_df = pd.DataFrame(ocr_results)
results_after_wbf = []

for image_id in ocr_df['image_id'].unique():
    image_detections = ocr_df[ocr_df['image_id'] == image_id]
    image_width = image_detections['image_width'].iloc[0]
    image_height = image_detections['image_height'].iloc[0]
    
    # WBF를 위한 데이터 변환
    boxes_list = [image_detections['bbox'].tolist()]
    scores_list = [image_detections['confidence'].tolist()]
    labels_list = [[1] * len(image_detections)]  # 모든 박스에 같은 라벨
    
    # WBF 수행 (IoU 임계값 0.6)
    boxes, scores, labels = weighted_boxes_fusion(
        boxes_list, scores_list, labels_list, iou_thr=0.1, skip_box_thr=0.0001
    )
    
    # 원래 크기로 변환 및 결과 저장
    for idx, (box, score) in enumerate(zip(boxes, scores)):
        p_xmin = box[0] * image_width
        p_ymin = box[1] * image_height
        p_xmax = box[2] * image_width
        p_ymax = box[3] * image_height
        points = [[p_xmin, p_ymin], [p_xmax, p_ymin], [p_xmax, p_ymax], [p_xmin, p_ymax]]
        
        results_after_wbf.append({
            'image_id': image_id,
            'word_id': str(idx),
            'points': points,
            'confidence': score
        })

wbf_result = {'images': {}}
for result in results_after_wbf:
    image_id = result['image_id']
    if image_id not in wbf_result['images']:
        wbf_result['images'][image_id] = {'words': {}}
    wbf_result['images'][image_id]['words'][result['word_id']] = {
        'points': result['points'],
        'confidence': result['confidence']
    }

# WBF 결과 JSON 파일로 저장
with open('predictions/ocr_results_after_wbf.csv', 'w') as f:
    json.dump(wbf_result, f, indent=4)

print("WBF 결과가 'ocr_results_after_wbf.json'에 저장되었습니다.")


Processing CSV files: 100%|██████████| 5/5 [00:07<00:00,  1.40s/it]


WBF 결과가 'ocr_results_after_wbf.json'에 저장되었습니다.
