In [1]:
import cv2
import numpy as np
import os
import shutil
import sys

# --- Configuration Section ---

# 1. Source Directory: Base path containing 'train' and 'val' folders with Grad-CAM heatmaps.
#    --> THIS SCRIPT READS FROM HERE <--
#    Expected structure: GRADCAM_OUTPUT_BASE_DIR / 'train' / CLASS_NAME / image.jpg
#                      GRADCAM_OUTPUT_BASE_DIR / 'val'   / CLASS_NAME / image.jpg
#    (Adjust the path below to match where your Grad-CAM heatmaps are stored)
GRADCAM_OUTPUT_BASE_DIR = "../data/bounding_box/" # ADJUST IF NEEDED

# 2. Target Directory: Base path where the new YOLO dataset will be created.
#    --> THIS SCRIPT WRITES HERE <--
#    Output structure: YOLO_DATASET_BASE_DIR / 'train' / 'images' / image.jpg
#                      YOLO_DATASET_BASE_DIR / 'train' / 'labels' / image.txt
#                      YOLO_DATASET_BASE_DIR / 'val'   / 'images' / image.jpg
#                      YOLO_DATASET_BASE_DIR / 'val'   / 'labels' / image.txt
#    (Adjust the path below to your desired final dataset location)
YOLO_DATASET_BASE_DIR = "../data/yolo_bbox_dataset" # ADJUST IF NEEDED

# 3. Class Mapping: Must exactly match the class folder names inside the
#    'train' and 'val' folders of GRADCAM_OUTPUT_BASE_DIR.
CLASS_MAP = {
    "Cerscospora": 0,  # Key must match folder name EXACTLY
    "Healthy": 1,
    "Leaf rust": 2,    # Key must match folder name EXACTLY
    "Miner": 3,
    "Phoma": 4
}

# 4. Bounding Box Extraction Parameters:
#    Threshold (0-255) for converting heatmap intensity to a binary mask.
#    *** This likely needs tuning based on your heatmap appearance! ***
BBOX_THRESHOLD = 150
#    Channel to threshold: 2 for Red channel in BGR (common for 'hot'/'jet' maps),
#    1 for Green, 0 for Blue, or None to convert to Grayscale before thresholding.
THRESHOLD_CHANNEL = 2
#    Minimum area (in pixels) for a detected heatmap region to be considered valid.
#    Helps filter out small noise specks.
MIN_CONTOUR_AREA = 50

# --- Helper Function: Extract Bounding Box from Heatmap Image ---

def extract_bbox_from_heatmap(heatmap_image_path, threshold_value, threshold_channel=None, min_area=50):
    """
    Loads a heatmap image, thresholds it, finds contours, and returns the
    bounding box encompassing all significant contour(s).
    Identical to your provided version.
    """
    try:
        img = cv2.imread(heatmap_image_path)
        if img is None:
            print(f"Warning: Could not read image {heatmap_image_path}")
            return None, (None, None)

        height, width = img.shape[:2]
        if height == 0 or width == 0:
             print(f"Warning: Invalid image dimensions (0) for {heatmap_image_path}")
             return None, (None, None)

        # Select channel or convert to grayscale for thresholding
        if threshold_channel is not None and threshold_channel in [0, 1, 2]:
            channel = img[:, :, threshold_channel]
        else:
            channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Apply threshold to get binary mask (activated regions = white)
        _, binary_mask = cv2.threshold(channel, threshold_value, 255, cv2.THRESH_BINARY)

        # Find contours of the activated regions
        contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        if not contours:
            return None, (height, width) # No contours found

        # Filter contours by area
        valid_contours = [cnt for cnt in contours if cv2.contourArea(cnt) >= min_area]

        if not valid_contours:
            # print(f"Debug: No contours found above min_area {min_area} for {heatmap_image_path}")
            return None, (height, width) # No significant contours

        # Combine all significant contours into a single bounding box
        all_points = np.vstack([cnt for cnt in valid_contours]) # Combine points from all valid contours
        bbox = cv2.boundingRect(all_points) # Get the bounding box around ALL valid points

        return bbox, (height, width) # (x, y, w, h), (img_height, img_width)

    except Exception as e:
        print(f"Error processing heatmap {heatmap_image_path}: {e}")
        return None, (None, None)

# --- Main Function: Generate YOLO Dataset from Pre-Split Grad-CAM Outputs ---

def generate_yolo_dataset_from_gradcam(gradcam_base_dir, yolo_base_dir, class_map, bbox_threshold, threshold_channel, min_contour_area):
    """
    Generates a YOLO dataset by processing Grad-CAM heatmaps from existing
    'train' and 'val' subdirectories within gradcam_base_dir.

    Args:
        gradcam_base_dir (str): Base directory containing 'train' and 'val' folders of Grad-CAM outputs.
        yolo_base_dir (str): Base directory for the output YOLO dataset.
        class_map (dict): Dictionary mapping class folder names to integer IDs.
        bbox_threshold (int): Threshold for heatmap processing.
        threshold_channel (int/None): Channel index or None for grayscale.
        min_contour_area (int): Minimum contour area.
    """
    print("-" * 50)
    print(f"Starting YOLO Dataset Generation")
    print(f"Reading pre-split Grad-CAM heatmaps from: {gradcam_base_dir}")
    print(f"Writing YOLO dataset to: {yolo_base_dir}")
    print("-" * 50)

    # Ensure the base output directory exists
    os.makedirs(yolo_base_dir, exist_ok=True)

    total_images_processed = 0
    total_labels_generated = 0

    # --- Loop through the existing 'train' and 'val' splits ---
    for split in ['train', 'val']:
        print(f"\nProcessing existing split folder: '{split}'...")

        # Define input and output paths for this specific split
        split_input_dir = os.path.join(gradcam_base_dir, split)
        split_output_images_dir = os.path.join(yolo_base_dir, split, 'images')
        split_output_labels_dir = os.path.join(yolo_base_dir, split, 'labels')

        # Check if the input directory for the split exists
        if not os.path.isdir(split_input_dir):
            print(f"Warning: Input directory for split '{split}' not found: {split_input_dir}. Skipping.")
            continue # Skip to the next split if this one doesn't exist

        # Create output directories for the split
        os.makedirs(split_output_images_dir, exist_ok=True)
        os.makedirs(split_output_labels_dir, exist_ok=True)
        print(f"  Output Images Dir: {split_output_images_dir}")
        print(f"  Output Labels Dir: {split_output_labels_dir}")

        split_images_processed = 0
        split_labels_generated = 0

        # Iterate through class folders within the split input directory
        for class_folder_name in os.listdir(split_input_dir):
            class_input_dir = os.path.join(split_input_dir, class_folder_name)

            if not os.path.isdir(class_input_dir):
                continue # Skip if it's not a directory (e.g., a stray file)

            if class_folder_name not in class_map:
                print(f"Warning: Class folder '{class_folder_name}' found in '{split}' split is not defined in CLASS_MAP. Skipping.")
                continue

            class_id = class_map[class_folder_name]
            # print(f"  Processing class: {class_folder_name} (ID: {class_id})") # Verbose

            # Iterate through heatmap images in the class folder
            for heatmap_filename in os.listdir(class_input_dir):
                # Basic check for image files
                if not heatmap_filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")):
                    continue

                source_heatmap_path = os.path.join(class_input_dir, heatmap_filename)

                # --- 1. Copy Image ---
                target_image_path = os.path.join(split_output_images_dir, heatmap_filename)
                try:
                    shutil.copy2(source_heatmap_path, target_image_path)
                except Exception as e:
                    print(f"Error copying {source_heatmap_path} to {target_image_path}: {e}")
                    continue # Skip this image if copy fails

                split_images_processed += 1

                # --- 2. Extract Bounding Box ---
                bbox, (h, w) = extract_bbox_from_heatmap(
                    source_heatmap_path, bbox_threshold, threshold_channel, min_contour_area
                )

                # --- 3. Generate Label File (if bbox found) ---
                if bbox is not None and h is not None and w is not None:
                    x, y, box_w, box_h = bbox

                    # Convert to YOLO format (normalized center x/y, width/height)
                    x_center_norm = (x + box_w / 2) / w
                    y_center_norm = (y + box_h / 2) / h
                    width_norm = box_w / w
                    height_norm = box_h / h

                    # Clamp values to [0.0, 1.0] to avoid issues
                    x_center_norm = max(0.0, min(1.0, x_center_norm))
                    y_center_norm = max(0.0, min(1.0, y_center_norm))
                    width_norm = max(0.0, min(1.0, width_norm))
                    height_norm = max(0.0, min(1.0, height_norm))

                    # Construct label filename (image_name.txt)
                    label_filename = os.path.splitext(heatmap_filename)[0] + '.txt'
                    target_label_path = os.path.join(split_output_labels_dir, label_filename)

                    # Write the label file
                    try:
                        with open(target_label_path, 'w') as f_label:
                            f_label.write(f"{class_id} {x_center_norm:.6f} {y_center_norm:.6f} {width_norm:.6f} {height_norm:.6f}\n")
                        split_labels_generated += 1
                    except Exception as e:
                        print(f"Error writing label file {target_label_path}: {e}")
                        # Optional: Consider removing the copied image if label writing fails
                        # try:
                        #     os.remove(target_image_path)
                        #     split_images_processed -= 1
                        # except OSError:
                        #     pass # Ignore if removal fails
                # else: # No bounding box found for this image
                    # print(f"  - No bbox for: {heatmap_filename}") # Verbose

        # --- End of class loop ---
        print(f"  Finished processing '{split}': {split_images_processed} images copied, {split_labels_generated} labels generated.")
        total_images_processed += split_images_processed
        total_labels_generated += split_labels_generated
    # --- End of split loop ---

    print("-" * 50)
    print("YOLO Dataset Generation Summary:")
    print(f"  Total images processed/copied: {total_images_processed}")
    print(f"  Total bounding box labels generated: {total_labels_generated}")
    print(f"  Dataset location: {yolo_base_dir}")
    print("-" * 50)

    # --- Generate data.yaml content ---
    print("\nRecommended content for your YOLO data.yaml file:")
    print("--- data.yaml ---")
    # Try to determine a relative path if possible, otherwise use absolute
    try:
        # Use abspath first to handle "../" correctly then try relpath
        abs_yolo_path = os.path.abspath(yolo_base_dir)
        yaml_path = os.path.relpath(abs_yolo_path)
        # If relpath results in going up too many levels, use absolute
        if yaml_path.startswith(".."):
             yaml_path = abs_yolo_path
    except ValueError: # May occur if on different drives (Windows)
        yaml_path = os.path.abspath(yolo_base_dir)

    print(f"path: {yaml_path}  # Root directory of the dataset")
    print("train: train/images")
    print("val: val/images")
    print("")
    print("# Classes")
    print("names:")
    # Sort class map by ID for consistent YAML output
    for class_name, class_id in sorted(class_map.items(), key=lambda item: item[1]):
        print(f"  {class_id}: {class_name}")
    print("--- End data.yaml ---")


# --- Execution ---
if __name__ == "__main__":
    # Basic check if the configured source directory exists
    if not os.path.isdir(GRADCAM_OUTPUT_BASE_DIR):
        print(f"Error: Source Grad-CAM directory not found: {GRADCAM_OUTPUT_BASE_DIR}")
        print("Please ensure this path is correct and contains 'train' and 'val' subfolders.")
        sys.exit(1) # Exit if source dir missing

    # Check if train/val subdirs exist within the source directory
    if not os.path.isdir(os.path.join(GRADCAM_OUTPUT_BASE_DIR, 'train')) or \
       not os.path.isdir(os.path.join(GRADCAM_OUTPUT_BASE_DIR, 'val')):
        print(f"Warning: Could not find both 'train' and 'val' subdirectories inside {GRADCAM_OUTPUT_BASE_DIR}.")
        print("The script expects these folders to exist for processing.")
        # Decide whether to exit or proceed (might process only one if it exists)
        # sys.exit(1) # Option: Exit if structure is not as expected

    # Call the main function without the model_name argument
    generate_yolo_dataset_from_gradcam(
        gradcam_base_dir=GRADCAM_OUTPUT_BASE_DIR,
        yolo_base_dir=YOLO_DATASET_BASE_DIR,
        class_map=CLASS_MAP,
        bbox_threshold=BBOX_THRESHOLD,
        threshold_channel=THRESHOLD_CHANNEL,
        min_contour_area=MIN_CONTOUR_AREA
    )

    print("\nScript finished.")

--------------------------------------------------
Starting YOLO Dataset Generation
Reading pre-split Grad-CAM heatmaps from: ../data/bounding_box/
Writing YOLO dataset to: ../data/yolo_bbox_dataset
--------------------------------------------------

Processing existing split folder: 'train'...
  Output Images Dir: ../data/yolo_bbox_dataset\train\images
  Output Labels Dir: ../data/yolo_bbox_dataset\train\labels
  Finished processing 'train': 46836 images copied, 46640 labels generated.

Processing existing split folder: 'val'...
  Output Images Dir: ../data/yolo_bbox_dataset\val\images
  Output Labels Dir: ../data/yolo_bbox_dataset\val\labels
  Finished processing 'val': 5853 images copied, 5797 labels generated.
--------------------------------------------------
YOLO Dataset Generation Summary:
  Total images processed/copied: 52689
  Total bounding box labels generated: 52437
  Dataset location: ../data/yolo_bbox_dataset
--------------------------------------------------

Recommende