In [None]:
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# import matplotlib.pyplot as plt
import tqdm
import cv2
import pandas as pd

from street_signs_identify.utils.x_image_loader import XImageLoader
import street_signs_identify.utils.check_overlap as co

from street_signs_identify.recognizer.ocr_recognizer import EasyOCRRecognizer
from street_signs_identify.detector.craft_detector import CRAFTDetector

In [None]:
craft_detector = CRAFTDetector()
model_ckpt = "./street_signs_identify/recognizer/saved_models/epoch_6701.pth"
ch_recognizer = EasyOCRRecognizer(model_ckpt)
en_recognizer = EasyOCRRecognizer(lang_list=["en"])

In [None]:
folder_name = "2Penghu"

data_path = f"./dataset/{folder_name}"
x_image_loader = XImageLoader(data_path, annotation_file="x_labels.csv")

output_folder = f"./dataset/cropped_{folder_name}"
os.makedirs(output_folder, exist_ok=True)

output_folder_gt = f"./dataset/cropped_{folder_name}/ref_gt"
os.makedirs(output_folder_gt, exist_ok=True) 

output_folder_filtered_pred = f"./dataset/cropped_{folder_name}/filtered_pred"
os.makedirs(output_folder_filtered_pred, exist_ok=True)

output_folder_craft_pred = f"./dataset/cropped_{folder_name}/craft_pred"
os.makedirs(output_folder_craft_pred, exist_ok=True)


In [None]:
x_image_loader._image_names

In [None]:
global_index = 0

label_df = pd.DataFrame(columns=["filename", "ch_pred", "ch_score", "en_pred", "en_score", "src_file", "category"])

for image_name in tqdm.tqdm(x_image_loader._image_names):
    gt_image = x_image_loader[image_name]
    detected_image = craft_detector(gt_image.image)
    overlapping_info = co.check_overlapping_area(gt_image.info, detected_image.info)
    if overlapping_info is None:
        continue

    # save gt plot
    gt_img = gt_image.draw_all_box()
    cv2.imwrite(os.path.join(output_folder_gt, image_name), cv2.cvtColor(gt_img, cv2.COLOR_RGB2BGR))
    
    # save craft plot
    det_img = detected_image.draw_all_box()
    cv2.imwrite(os.path.join(output_folder_craft_pred, image_name), cv2.cvtColor(det_img, cv2.COLOR_RGB2BGR))

    detected_image = detected_image.keeps_index(overlapping_info["index_that"])

    filtered_det_img = detected_image.draw_all_box()
    cv2.imwrite(os.path.join(output_folder_filtered_pred, image_name), cv2.cvtColor(filtered_det_img, cv2.COLOR_RGB2BGR))

    for row_idx, row in overlapping_info.iterrows():
        gt_index, detected_index = row["index_this"], row["index_that"]
        detected_instance = detected_image[detected_index]
        if not detected_instance.legal:
            continue
        
        

        reconized_instance_ch = ch_recognizer(detected_instance)
        reconized_instance_en = en_recognizer(detected_instance)


        label = gt_image.info.loc[gt_index, "label"]
        category = gt_image.info.loc[gt_index, "category"]
        
        
        label_df.loc[global_index, "filename"] = f"{global_index}.jpg"
        label_df.loc[global_index, "ch_pred"] = reconized_instance_ch.info
        label_df.loc[global_index, "ch_score"] = reconized_instance_ch.score
        label_df.loc[global_index, "en_pred"] = reconized_instance_en.info
        label_df.loc[global_index, "en_score"] = reconized_instance_en.score
        label_df.loc[global_index, "src_file"] = image_name
        label_df.loc[global_index, "category"] = category
        label_df.loc[global_index, "overlap_pixels"] = row["area"]

        for idx, words in enumerate(label.split('\n')):
            label_df.loc[global_index, f"words_{idx}"] = words



        
        
        crapped_img = detected_instance.image
        cv2.imwrite(os.path.join(output_folder, f"{global_index}.jpg"), cv2.cvtColor(crapped_img, cv2.COLOR_RGB2BGR))
        global_index+=1
        
        
        
        
        
        # print("gt_label:", label.split('\n'))
        # print("ch_pred:", reconized_instance_ch.info, reconized_instance_ch.score)
        # print("en_pred:", reconized_instance_en.info, reconized_instance_en.score)
        # print("overlap pixels:", row["area"])
        # print("category", category)
        # print()

    # print(label_df)
    # assert False

label_df.to_csv(os.path.join(output_folder, "labels.csv"))