<a href="https://colab.research.google.com/github/mlamb-226/Waterbird_classify/blob/main/Application.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import and download

In [None]:
import os
import sys
import csv
import requests
import torch
import torchvision.ops as ops
import matplotlib.pyplot as plt
import re
from PIL import Image, ImageDraw, ImageFont
from ipyfilechooser import FileChooser
import ipywidgets as widgets
from IPython.display import display, clear_output
from collections import Counter
import gdown
!pip install -q ultralytics
from ultralytics import YOLO
import pandas as pd
from tqdm.notebook import tqdm
!pip install -q rasterio
import rasterio
Image.MAX_IMAGE_PIXELS = None

model_url = "https://drive.google.com/uc?export=download&id=15uqVk1X8zkWpmTnnB_QaLJcG0G2LO1Xg"
model_weights_codetr = ""
model_weights_yolo = "/content/model.pt"
gdown.download(model_url, model_weights_yolo, quiet=False)

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m91.9 MB/s[0m eta [36m0:00:00[0m
[?25h

Downloading...
From: https://drive.google.com/uc?export=download&id=15uqVk1X8zkWpmTnnB_QaLJcG0G2LO1Xg
To: /content/model.pt
100%|██████████| 5.76M/5.76M [00:00<00:00, 207MB/s]


'/content/model.pt'

## Functions

In [None]:
def convert_images_to_jpg(input_folder, output_folder):
    """
    Converts all images in the input folder to .jpg format and saves them in the output folder.
    Handles .tif files with multiple bands using rasterio.

    Parameters:
    - input_folder (str): Directory containing the original images.
    - output_folder (str): Directory to save the converted images.

    Returns:
    - List of converted .jpg image paths.
    """
    os.makedirs(output_folder, exist_ok=True)
    converted_images = []

    for image_file in os.listdir(input_folder):
        image_path = os.path.join(input_folder, image_file)
        if os.path.isfile(image_path):
            try:
                if image_file.lower().endswith((".tif", ".tiff")):
                    # Handle .tif files using rasterio
                    with rasterio.open(image_path) as src:
                        if src.count >= 3:  # Ensure at least 3 bands for RGB
                            r = src.read(1)
                            g = src.read(2)
                            b = src.read(3)
                            img = Image.merge("RGB", (Image.fromarray(r), Image.fromarray(g), Image.fromarray(b)))
                        else:
                            # Fallback to single-band grayscale
                            band1 = src.read(1)
                            img = Image.fromarray(band1).convert("RGB")
                else:
                    # Handle other formats using PIL
                    with Image.open(image_path) as img:
                        img = img.convert("RGB")

                # Save as .jpg
                base_name = os.path.splitext(image_file)[0]
                new_filename = f"{base_name}.jpg"
                new_path = os.path.join(output_folder, new_filename)
                img.save(new_path, "JPEG")
                converted_images.append(new_path)
            except Exception as e:
                print(f"Error converting {image_file}: {e}")

    print(f"All images converted to .jpg and saved in '{output_folder}'.")
    return converted_images


def tile_img(image_path, output_folder, tile_size=640, overlap=290):
    """
    Crops an image into tiles of specified size with overlap and saves them to an output folder.

    Parameters:
    - image_path (str): Path to the input image file.
    - output_folder (str): Directory to save the cropped image tiles.
    - crop_size (int, optional): Size of each crop tile (width and height in pixels). Default is 640.
    - overlap (int, optional): Overlap between tiles in pixels. Default is 290.

    Returns:
    - None. Saves each tile as a separate image file in the specified output folder.
    """

    # Open the image
    img = Image.open(image_path)
    img_width, img_height = img.size

    # Ensure output folder exists
    os.makedirs(output_folder, exist_ok=True)

    # Extract the base name of the image file (without extension)
    base_filename = os.path.splitext(os.path.basename(image_path))[0]

    # Calculate step size based on desired overlap
    step_size = tile_size - overlap

    tile_id = 0
    for top in range(0, img_height, step_size):
        for left in range(0, img_width, step_size):
            # Adjust the step size for the last crops on the right and bottom edges
            if left + tile_size > img_width:
                left = img_width - tile_size  # Shift the crop window left to maintain crop size
            if top + tile_size > img_height:
                top = img_height - tile_size  # Shift the crop window up to maintain crop size

            # Define crop boundaries
            right = left + tile_size
            bottom = top + tile_size

            # Tile the image
            tiled_img = img.crop((left, top, right, bottom))

            # Save the cropped tile with the original filename included
            tile_filename = f"{base_filename}_tile_{left}_{top}.png"
            tiled_img.save(os.path.join(output_folder, tile_filename))

            tile_id += 1

    print(f"Tiling completed. {tile_id} tiles saved in '{output_folder}'.")

def run_inference(framework, model_config=None, model_weights=None, image_path=None, img_size=640, conf_threshold=0.2, device='cuda:0'):
    """
    Unified inference function for MMDetection and YOLO models.

    Parameters:
    - framework (str): Specify the framework to use ('mmdet' or 'yolo').
    - model_config (str): Path to the model configuration file (required for MMDetection).
    - model_weights (str): Path to the model weights file (required for both MMDetection and YOLO).
    - image_path (str): Path to the image file on which to perform inference.
    - img_size (int): Input size for YOLO model. Default is 640.
    - conf_threshold (float): Confidence threshold for YOLO predictions. Default is 0.25.
    - device (str): Device to run inference on (e.g., 'cuda' or 'cpu'). Default is 'cuda'.

    Returns:
    - results: A dictionary with the following structure:
      {
        'labels': [...],  # List of class IDs
        'scores': [...],  # List of confidence scores
        'bboxes': [...],  # List of bounding boxes: [min_x, min_y, max_x, max_y]
      }
    """
    if framework == 'mmdet':
        if model_config is None or model_weights is None:
            raise ValueError("For MMDetection, 'model_config' and 'model_weights' must be provided.")

        # Initialize MMDetection inference model
        inferencer = DetInferencer(model=model_config, weights=model_weights, device=device)

        # Run inference on the image
        results = inferencer(image_path)

        # Extract and return a unified dictionary
        predictions = results['predictions'][0]  # MMDetection predictions are grouped in a single dictionary
        return {
            'labels': predictions['labels'],
            'scores': predictions['scores'],
            'bboxes': predictions['bboxes']
        }

    elif framework == 'yolo':
        if model_weights is None:
            raise ValueError("For YOLO, 'model_weights' must be provided.")

        # Load YOLO model
        model = YOLO(model_weights)

        # Run predictions on the image with stream=True
        results = model.predict(source=image_path, imgsz=img_size, conf=conf_threshold, save=False, verbose=False)

        # Parse YOLO results to return a unified dictionary
        labels, scores, bboxes = [], [], []
        for result in results:
            if result.boxes is not None:
                labels.extend(result.boxes.cls.cpu().numpy().astype(int).tolist())
                scores.extend(result.boxes.conf.cpu().numpy().tolist())
                bboxes.extend(result.boxes.xyxy.cpu().numpy().tolist())

        return {
            'labels': labels,
            'scores': scores,
            'bboxes': bboxes
        }

    else:
        raise ValueError("Unsupported framework. Choose 'mmdet' or 'yolo'.")


def adjust_bbox(all_tile_results, iou_threshold=0.6):
    """
    Adjusts bounding boxes from multiple tiles' inference results to align with the original image
    coordinates and applies NMS to remove duplicates across all tiles.

    Parameters:
    - all_tile_results (list of tuples): List where each tuple contains:
        - tile_results (dict): Dictionary containing:
            - 'labels' (list): Class labels for each bounding box.
            - 'scores' (list): Confidence scores for each bounding box.
            - 'bboxes' (list of lists): Bounding boxes in the tile's coordinates (format: [x1, y1, x2, y2]).
        - tile_filename (str): The filename of the tile image, formatted as 'tile_{left}_{top}.png'.
    - iou_threshold (float): IoU threshold for NMS.

    Returns:
    - final_bboxes (Tensor): Adjusted bounding boxes after merging.
    - final_scores (Tensor): Adjusted confidence scores after merging.
    - final_labels (Tensor): Adjusted class labels after merging.
    """
    all_bboxes, all_scores, all_labels = [], [], []

    for tile_results, tile_filename in all_tile_results:
        # Extract left and top coordinates from the filename
        match = re.search(r'_tile_(\d+)_(\d+)\.png', tile_filename)
        if not match:
            raise ValueError(f"Filename {tile_filename} does not match expected format 'tile_{left}_{top}.png'")

        tile_left = int(match.group(1))
        tile_top = int(match.group(2))

        # Adjust bounding boxes for this tile
        for box in tile_results['bboxes']:
            x1, y1, x2, y2 = box
            adjusted_box = [x1 + tile_left, y1 + tile_top, x2 + tile_left, y2 + tile_top]
            all_bboxes.append(adjusted_box)

        # Accumulate scores and labels
        all_scores.extend(tile_results['scores'])
        all_labels.extend(tile_results['labels'])

    # Ensure bounding boxes are a 2D tensor
    if all_bboxes:
        all_bboxes = torch.tensor(all_bboxes, dtype=torch.float32)
    else:
        all_bboxes = torch.empty((0, 4), dtype=torch.float32)

    # Ensure scores are a 1D tensor
    if all_scores:
        all_scores = torch.tensor(all_scores, dtype=torch.float32)
    else:
        all_scores = torch.empty(0, dtype=torch.float32)

    # Ensure labels are a 1D tensor
    if all_labels:
        all_labels = torch.tensor(all_labels, dtype=torch.int64)
    else:
        all_labels = torch.empty(0, dtype=torch.int64)

    # Apply NMS on combined adjusted bounding boxes
    if len(all_bboxes) > 0:
        keep_indices = ops.nms(all_bboxes, all_scores, iou_threshold)
        final_bboxes = all_bboxes[keep_indices]
        final_scores = all_scores[keep_indices]
        final_labels = all_labels[keep_indices]
    else:
        final_bboxes = torch.empty((0, 4), dtype=torch.float32)
        final_scores = torch.empty(0, dtype=torch.float32)
        final_labels = torch.empty(0, dtype=torch.int64)

    return final_bboxes, final_scores, final_labels

def save_final_results(class_map, image_path, final_bboxes, final_scores, final_labels, output_folder, score_threshold=0.2, csv_file="results.csv"):
    """
    Saves the final bounding boxes, labels, and scores visualized on the original image,
    adds a legend to indicate species colors, and writes the results to a CSV file.

    Parameters:
    - class_map (dict): Mapping of class names to class indices.
    - image_path (str): Path to the original image.
    - final_bboxes (Tensor): Adjusted bounding boxes after merging.
    - final_scores (Tensor): Confidence scores for bounding boxes.
    - final_labels (Tensor): Class labels for bounding boxes.
    - output_folder (str): Folder to save the resulting image.
    - score_threshold (float): Minimum score threshold for a bounding box to be saved.
    - csv_file (str): Path to the CSV file where results will be saved.
    """
    # Open the original image
    img = Image.open(image_path).convert("RGB")
    draw = ImageDraw.Draw(img)
    font = ImageFont.load_default()

    # Ensure CSV file exists, and open it in append mode
    csv_path = os.path.join(output_folder, csv_file)
    is_new_file = not os.path.exists(csv_path)

    # Keep track of species present for the legend
    species_present = set()

    with open(csv_path, mode='a', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        # Write the header if the file is new
        if is_new_file:
            csv_writer.writerow(['Image Name', 'Species', 'X1', 'Y1', 'X2', 'Y2'])

        # Iterate through bounding boxes and draw them on the image
        for bbox, score, label in zip(final_bboxes, final_scores, final_labels):
            if score >= score_threshold:  # Apply score threshold
                x1, y1, x2, y2 = bbox.tolist()
                label_idx = label.item()

                # Get class name and color
                class_name = list(class_map.keys())[list(class_map.values()).index(label_idx)]
                color = species_color_map.get(class_name, (255, 255, 255))  # Default to white

                # Draw bounding box
                draw.rectangle([x1, y1, x2, y2], outline=color, width=3)

                # Prepare text
                text = f"{class_name} {score:.2f}"
                text_bbox = draw.textbbox((x1, y1), text, font=font)
                text_width = text_bbox[2] - text_bbox[0]
                text_height = text_bbox[3] - text_bbox[1]

                # Calculate adjusted text position (move text up by 5 pixels)
                text_y_offset = 5  # Amount to move text up
                text_bg_x1, text_bg_y1 = x1, y1 - text_height - text_y_offset
                text_bg_x2, text_bg_y2 = x1 + text_width + 4, y1 - text_y_offset

                # Draw a black rectangle behind the text
                draw.rectangle([text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2], fill="black")

                # Draw the text in white
                draw.text((x1 + 2, y1 - text_height - text_y_offset + 1), text, fill="white", font=font)

                # Add species to the set for the legend
                species_present.add(class_name)

                # Save to CSV file
                csv_writer.writerow([os.path.basename(image_path), class_name, x1, y1, x2, y2])

    # Add a legend to the image
    legend_x = 10  # Starting x position for the legend
    legend_y = img.height - 20 * len(species_present) - 10  # Starting y position for the legend
    for species in sorted(species_present):
        color = species_color_map.get(species, (255, 255, 255))  # Default to white
        # Draw a colored box for the species
        draw.rectangle([legend_x, legend_y, legend_x + 20, legend_y + 20], fill=color)
        # Add the species name next to the box
        draw.text((legend_x + 30, legend_y), species, fill=(0, 0, 0), font=font)  # Black text for the legend
        legend_y += 30  # Move to the next line for the next species

    # Save the image with results
    base_filename = os.path.basename(image_path)
    output_path = os.path.join(output_folder, f"result_{base_filename}")
    img.save(output_path)

    print(f"Results saved to {output_path} and CSV file updated at {csv_path}")



## UI

In [None]:
# Global flag to control the application loop
stop_flag = False

# Class map for species
class_map = {
    'LAGU-Breeding': 0,
    'BRPE-Breeding': 1,
    'BRPE-Chick': 2,
    'Flying': 3,
    'ROYT': 4,
    'CATE': 5,
    'SATE': 6,
    'Tern Spp.': 7,
    'White Wader': 8,
    'REEG': 9,
    'TRHE': 10,
    'GBHE': 11,
    'BCNH': 12,
    'ROSP': 13,
    'Other Spp.': 14,
}

# Define a color map for species
species_color_map = {
    'LAGU-Breeding': (255, 0, 0), # Red
    'BRPE-Breeding': (139, 69, 19), # Brown
    'BRPE-Chick': (255, 105, 180), # Pink
    'Flying': (0, 255, 0),  # Green
    'ROYT': (255, 255, 0),  # Yellow
    'CATE': (128, 0, 128),  # Purple
    'SATE': (0, 255, 255),  # Cyan
    'Tern Spp.': (255, 165, 0), # Orange
    'White Wader': (255, 255, 255),# White
    'REEG': (0, 128, 0),  # Dark Green
    'TRHE': (128, 128, 0),  # Olive
    'GBHE': (0, 0, 255),  # Blue
    'BCNH': (0, 0, 0),  # Black
    'ROSP': (123, 104, 238),  # Medium Slate Blue
    'Other Spp.': (210, 180, 140),  # Tan
}

# File and Folder Choosers
img_folder_chooser = FileChooser("/content")
img_folder_chooser.title = '<b>Select Image Folder:</b>'
img_folder_chooser.show_only_dirs = True

tile_folder_chooser = FileChooser("/content")
tile_folder_chooser.title = '<b>Select Tile Folder:</b>'
tile_folder_chooser.show_only_dirs = True

output_folder_chooser = FileChooser("/content")
output_folder_chooser.title = '<b>Select Output Folder:</b>'
output_folder_chooser.show_only_dirs = True

model_weights_widget = widgets.Dropdown(
    options=["YOLO"],
    value="YOLO",
    description="Model:"
)

run_button = widgets.Button(
    description="Run",
    button_style="success"
)

stop_button = widgets.Button(
    description="Stop",
    button_style="danger"
)

output = widgets.Output()

# Stop function to set the flag
def stop_application(b):
    global stop_flag
    stop_flag = True
    print("Application stopped by user.")

# Bind the stop button
stop_button.on_click(stop_application)

# Run function
def run_application(b):
    global stop_flag
    stop_flag = False  # Reset the flag at the start of each run

    with output:
        clear_output()
        print("Starting application...")

        # Get selected paths
        img_folder = img_folder_chooser.selected_path
        tile_folder = tile_folder_chooser.selected_path
        output_folder = output_folder_chooser.selected_path
        model_type = model_weights_widget.value

        if not img_folder or not tile_folder or not output_folder:
            print("Please select all required folders before running the application.")
            return

        # Ensure folders exist
        os.makedirs(tile_folder, exist_ok=True)
        os.makedirs(output_folder, exist_ok=True)
        print("Converting images to .jpg format...")
        converted_folder = os.path.join("/content", "converted_images")
        converted_images = convert_images_to_jpg(img_folder, converted_folder)
        if not converted_images:
            print("No valid images found for conversion. Please check your input folder.")
            return

        # Step 1: Tile
        print("Tiling images...")
        # Dynamically check if files are valid images
        valid_images = []
        for f in os.listdir(converted_folder):
          file_path = os.path.join(converted_folder, f)
          if os.path.isfile(file_path):  # Ensure it's a file
              try:
                  # Attempt to open the file as an image
                  with Image.open(file_path) as img:
                      valid_images.append(f)
              except (IOError, ValueError):
                  # Ignore non-image files
                  pass

        # Process each valid image
        for image_name in tqdm(valid_images, desc="Tiling Images"):
          if stop_flag:
              print("Application stopped during tiling.")
              return

          image_path = os.path.join(converted_folder, image_name)
          if os.path.isfile(image_path):
              tile_img(image_path, tile_folder)

        # Step 2: Inference
        print("Running inference...")
        all_tile_results = []
        for tile_filename in tqdm(os.listdir(tile_folder), desc="Inference on Tiles"):
            if stop_flag:
                print("Application stopped during inference.")
                return

            if tile_filename.endswith(".png"):
                tile_path = os.path.join(tile_folder, tile_filename)

                tile_results = run_inference(
                    framework=model_type.lower(),
                    model_weights=model_weights_codetr if model_type == "CO-DETR" else model_weights_yolo,
                    image_path=tile_path,
                    conf_threshold=0.25
                )
                all_tile_results.append((tile_results, tile_filename))

        # Step 3: Adjust bounding boxes and apply NMS
        print("Adjusting bounding boxes...")
        if stop_flag:
            print("Application stopped during bounding box adjustment.")
            return

        final_bboxes, final_scores, final_labels = adjust_bbox(all_tile_results, iou_threshold=0.6)

        # Step 4: Save results and count species
        print("Saving results and counting species...")
        species_counts = Counter()
        image_files = [f for f in os.listdir(converted_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        # for image_name in tqdm(image_files, desc="Saving Results"):
        for image_name in image_files:
            if stop_flag:
                print("Application stopped during result saving.")
                return

            image_path = os.path.join(converted_folder, image_name)
            save_final_results(
                class_map, image_path, final_bboxes, final_scores, final_labels, output_folder, score_threshold=0.2
            )

            # Count species
            for label in final_labels:
                species_name = list(class_map.keys())[list(class_map.values()).index(label.item())]
                species_counts[species_name] += 1

        # Print species counts
        print("\nFinal Counts by Species:")
        for species, count in species_counts.items():
            print(f"{species}: {count}")

        print("\nApplication completed!")

# Bind the run button
run_button.on_click(run_application)

# Display the UI with the stop button
display(
    widgets.VBox([
        img_folder_chooser,
        tile_folder_chooser,
        output_folder_chooser,
        model_weights_widget,
        run_button,
        stop_button,
        output
    ])
)

VBox(children=(FileChooser(path='/content', filename='', title='<b>Select Image Folder:</b>', show_hidden=Fals…

## Zip the result

In [None]:
# zip and download
!zip -r results.zip /content/3 # replace with the name of your output folder


updating: content/3/ (stored 0%)
updating: content/3/results.csv (deflated 75%)
updating: content/3/result_CEPI_Yearly_2024-1-1.jpg (deflated 0%)
  adding: content/3/result_DJI_20210517104824_0042.jpg (deflated 0%)


## Clear folders for reuse

In [None]:
import shutil

def empty_folder(folder_path):
    """
    Deletes all files and subdirectories in the specified folder.

    Parameters:
    - folder_path (str): Path to the folder to be emptied.

    Returns:
    - None. The folder will be emptied.
    """
    if not os.path.exists(folder_path):
        print(f"The folder '{folder_path}' does not exist.")
        return

    for item in os.listdir(folder_path):
        item_path = os.path.join(folder_path, item)
        try:
            if os.path.isfile(item_path) or os.path.islink(item_path):
                os.unlink(item_path)  # Remove files or symbolic links
            elif os.path.isdir(item_path):
                shutil.rmtree(item_path)  # Remove directories
        except Exception as e:
            print(f"Failed to delete {item_path}: {e}")

    print(f"The folder '{folder_path}' has been emptied.")

# Example usage
# folder_to_empty = "/content/1"  # Replace with your folder path
# empty_folder(folder_to_empty)
folder_to_empty = "/content/2"  # Replace with your folder path
empty_folder(folder_to_empty)
folder_to_empty = "/content/3"  # Replace with your folder path
empty_folder(folder_to_empty)

The folder '/content/2' has been emptied.
The folder '/content/3' has been emptied.
