<a href="https://colab.research.google.com/github/dovkess/FGCVBreedDetection/blob/main/YOLO_funcs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import cell. this is a cell run once to avoid having import problems later on.
!pip install tensorflow
from google.colab import drive
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import VGG16, InceptionResNetV2, InceptionV3, NASNetLarge#, PNASNetLarge
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, BatchNormalization, GlobalAveragePooling2D
from keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.preprocessing import image
import time
from sklearn.metrics import confusion_matrix
import seaborn as sns
import pickle
import numpy as np
from tensorflow.keras.models import load_model
import os
import glob
import random
import math
import matplotlib.pyplot as plt
from PIL import Image
%pip install lime
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries

!pip install ultralytics
from ultralytics import YOLO
import io
import base64

# Mount Google Drive -- if images are stored in drive. if not -- it's not needed.
drive.mount('/content/drive')


In [2]:
def get_model(model_version=8, model_size='x'):
    '''
    Get Yolo model
    devault is v8x

    '''
    model = YOLO('yolov{}{}.pt'.format(model_version, model_size))
    return model

In [3]:
def count_dogs(model, image_path, dir=False, g=False, batch_size=1):
    '''
    Count dogs in the image.

    Args:
        model: The YOLO model to use for detection.
        image_path (str): The path to the image file.
        dir (bool): is the image path a dir or a single image
        g (bool): is it a dir of dirs with images
        batch size (int): batch size to send to YOLO

    Returns:
        list (int): The number of dogs detected in the image.
        list (int): number of cats detected in the image.
        list (int): The number of animals detected in the image.
        list (str): The paths to the image file.
    '''
    if not dir:
        # Read the image
        img = Image.open(image_path).convert("RGB")
    else:
        img = image_path

    if g:
        img = '{}/*'.format(image_path)
    animal_index = list(range(17, 24))
    # Perform detection
    results = model(img, device='cuda', batch=batch_size)
    dogs = [sum(results[i].boxes.cls == 16.) for i in range(len(results))]
    cats = [sum(results[i].boxes.cls == 15.) for i in range(len(results))]
    animals = [sum([results[k].boxes.cls[i] in animal_index for i in range(len(results[k].boxes.cls))]) for k in range(len(results))]
    p = [results[i].path for i in range(len(results))]
    return dogs, cats, animals, p

In [4]:
def detect_dogs_and_draw_boxes(model, image_path):
    """
    Reads an image, performs dog detection using YOLOv8, and draws bounding boxes.

    Args:
        model: The YOLO model to use for detection.
        image_path (str): The path to the image file.

    Returns:
        torch.Tensor or None: The bounding box (x1, y1, x2, y2) of the detected dog
                              with the highest confidence. Returns None if no dogs are detected.
    """
    # Read the image
    img = Image.open(image_path).convert("RGB")

    # Perform detection
    results = model(img)
    dog_indeces = results[0].boxes.cls == 16.
    results = model(img, device='cuda', batch=batch_size)
    dogs = [sum(results[i].boxes.cls == 16.) for i in range(len(results))]
    cats = [sum(results[i].boxes.cls == 15.) for i in range(len(results))]
    animals = [sum([results[k].boxes.cls[i] in animal_index for i in range(len(results[k].boxes.cls))]) for k in range(len(results))]
    p = [results[i].path for i in range(len(results))]
    if not dog_indeces.any():
        return None
    dog_confs = results[0].boxes.conf * dog_indeces
    box = results[0].boxes.xyxy[dog_confs.argmax()]
    return box

In [5]:
def get_bounding_box_dogs_and_cats(model, image_path, only_animals=True, good_class=[15, 16]):
    """
    Reads an image, returns largest bounding box of "good" animals.
    Since YOLO is not always rigth about classification, but does have good segmentation,
    if we did not find "good" animals, we will take all.

    Args:
        model: The YOLO model to use for detection.
        image_path (str): The path to the image file.
        only_animals (bool): Only show animals (classes 15-23)
        good_class (list): List of classes to consider as good

    Returns:
        bounding box of all objects detected in the image.
        number of good objects detected in the image.
        number of animals detected in the image.
    """
    # Read the image
    img = Image.open(image_path).convert("RGB")
    bounding_box = [math.inf, math.inf, -math.inf, -math.inf]
    # Perform detection
    results = model(img)
    good_counter = 0
    animals_counter = 0
    for i in range(len(results[0].boxes.cls)):
        if results[0].boxes.cls[i] in good_class:
            good_counter += 1
            bounding_box = [min(bounding_box[0], results[0].boxes.xyxy[i][0]), min(bounding_box[1], results[0].boxes.xyxy[i][1]),
                            max(bounding_box[2], results[0].boxes.xyxy[i][2]), max(bounding_box[3], results[0].boxes.xyxy[i][3])]
    if not good_counter:
        for i in range(len(results[0].boxes.cls)):
            if results[0].boxes.cls[i] in list(range(14, 25)):
                animals_counter += 1
                bounding_box = [min(bounding_box[0], results[0].boxes.xyxy[i][0]), min(bounding_box[1], results[0].boxes.xyxy[i][1]),
                                max(bounding_box[2], results[0].boxes.xyxy[i][2]), max(bounding_box[3], results[0].boxes.xyxy[i][3])]
    if not animals_counter and not good_counter:
        for i in range(len(results[0].boxes.cls)):
            bounding_box = [min(bounding_box[0], results[0].boxes.xyxy[i][0]), min(bounding_box[1], results[0].boxes.xyxy[i][1]),
                            max(bounding_box[2], results[0].boxes.xyxy[i][2]), max(bounding_box[3], results[0].boxes.xyxy[i][3])]
    return bounding_box if animals_counter < 7 else None, good_counter, animals_counter


In [6]:
def get_classes(img, model):
    """
    Reads an image, returns a list of classes in the image using YOLOv8

    Args:
        image_path (str): The path to the image file.
        model: The YOLO model to use for detection.

    Returns:
        list of classes in the image.
        list of names of classes in the image.
    """
    results = model(img)
    results_classes = results[0].boxes.cls.tolist()
    return [results[0].names[results_classes[i]] for i in range(len(results_classes))], results[0].names


In [7]:
def crop_to_square_around_dog(original_img, box, dbg=False):
    '''
    Crop the image to a square around the dog.
    If a square can't be created, we will pad the image (zero padding).
    Args:
        original_img (PIL.Image): The original image.
        box (list): The bounding box of the dog.
        dbg (bool): Whether to print debug information.

    Returns:
        PIL.Image: The cropped and padded image.
    '''
    x1, y1, x2, y2 = box
    if dbg:
        print('boxes: {}'.format(box))
    box_center_x = (x1 + x2) / 2
    box_center_y = (y1 + y2) / 2
    box_width = x2 - x1
    box_height = y2 - y1

    # Determine the side length of the square.
    square_side = max(box_width, box_height)

    # Calculate the initial coordinates for the square crop centered around the box center
    crop_x1 = box_center_x - square_side / 2
    crop_y1 = box_center_y - square_side / 2
    crop_x2 = box_center_x + square_side / 2
    crop_y2 = box_center_y + square_side / 2

    img_width, img_height = original_img.size

    # Calculate padding needed
    pad_left = max(0, -crop_x1)
    pad_top = max(0, -crop_y1)
    pad_right = max(0, crop_x2 - img_width)
    pad_bottom = max(0, crop_y2 - img_height)

    # Adjust crop coordinates to be within original image bounds *before* padding
    bounded_crop_x1 = max(0, crop_x1)
    bounded_crop_y1 = max(0, crop_y1)
    bounded_crop_x2 = min(img_width, crop_x2)
    bounded_crop_y2 = min(img_height, crop_y2)

    crop_box_bounded = (int(bounded_crop_x1), int(bounded_crop_y1), int(bounded_crop_x2), int(bounded_crop_y2))

    try:
        cropped_img = original_img.crop(crop_box_bounded)

        # Pad the cropped image to make it a perfect square if necessary
        padded_width = int(bounded_crop_x2 - bounded_crop_x1 + pad_left + pad_right)
        padded_height = int(bounded_crop_y2 - bounded_crop_y1 + pad_top + pad_bottom)
        final_square_side = max(padded_width, padded_height)
        padded_img = Image.new('RGB', (final_square_side, final_square_side), (0, 0, 0))
        paste_x = int(pad_left)
        paste_y = int(pad_top)

        padded_img.paste(cropped_img, (paste_x, paste_y))

        return padded_img
    except Exception as e:
        print(f"Error during cropping or padding: {e}")
        return None


In [8]:
def crop_and_save(model, img_path, new_path, dbg=False, crop_func=detect_dogs_and_draw_boxes):
    '''
    Crop and save new images, using the above functions.
    Args:
        model (YOLO): The YOLO model to use for detection.
        img_path (str): The path to the original image.
        new_path (str): The path to save the new image.
        dbg (bool): Whether to print debug information.
        crop_func (function): The function to use for cropping.

    Returns:
        bool: Whether the cropping was successful.
    '''
    original_img = Image.open(img_path).convert("RGB")
    dog_results, good_count, animal_count = crop_func(model, img_path)
    if dog_results is not None:
        cropped_dog_image = crop_to_square_around_dog(original_img, dog_results, dbg=dbg)
        if cropped_dog_image:
            # for debuggning, we would like to display the cropped image.
            if dbg:
                display(original_img)
                display(cropped_dog_image)
                return True
            cropped_dog_image.save(new_path)
            return True
        else:
            print("Failed to crop and pad the image: {}".format(img_path))
            return False
    else:
        print("No valid dog bounding box result provided: {}".format(img_path))
        return False


In [9]:
def show_image(img_path):
    '''
    Show image.
    Args:
        img_path (str): The path to the image.
    '''
    img_path = Image.open(img_path).convert("RGB")
    display(img_path)