In [None]:
import numpy as np
import os
import pickle
import pandas as pd
import torch.nn as nn
import torch

from skimage import io
from tqdm.notebook import tqdm
from torchvision.models import efficientnet_v2_m
from torchvision import transforms
from typing import List, Dict

In [None]:
# Connect to the GPU if one exists.
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using: ", device)
torch.cuda.empty_cache()

In [None]:
centriod_directory = "Centroids"

if centriod_directory not in os.listdir():
    os.mkdir(centriod_directory)

In [None]:
def get_seal_sub_images(img_path:str, water_classifier, model, step:int, device=device) -> List:
    """Goes through an image and determines the center coordinates of 150x150 sub-image predicted to contain seals

    Args:
        img_path (str): path to image
        water_classifier (Scikit learn model): water classifier
        model (Pytorch model): CNN predicting whether a seal is present in the sub-image or not
        step (int): distance that the boxes iterate
        device (str, optional): Can be either 'gpu' or 'cpu'. Defaults to device.

    Returns:
        List: List of (x, y) coordinates of predicted seal sub-images
    """
    # crops like reading a book
    centroids = []

    img = io.imread(img_path, plugin="matplotlib")
    
    x_len = img.shape[1]
    y_len = img.shape[0]

    model.to(device)
    transform = transforms.Compose([transforms.ToTensor()])

    # Iterate through y axis (Up to Down)
    i = 0
    while (i < y_len):
        # If y coordinate is greater than len of image, look at the last step of image (y-coords = (ymax, ymax-step))
        y1 = i
        if i + step > y_len:
            y1 = y_len-step
            y2 = y_len
            i = y_len
        # Increase y coords by step
        else:
            y2 = i + step
            i += step

        # Iterate through x axis (left to right)
        j = 0
        while (j < x_len):
            # If x coordinate is greater than len of image, look at the last step of image (x-coords = (xmax, xmax-step))
            x1 = j
            if j + step > x_len:
                x1 = x_len - step
                x2 = x_len
                j = x_len
            # Increase x coords by step
            else:
                x2 = j + step
                j += step

            # Sub images for models
            sub_image = np.array(
                    img[y1:y2, x1:x2]
                )
            sub_image_water_classifier = np.array(
                    [
                    np.array([sub_image]) # Wrapping in another array
                    .reshape(22500, 3)
                    .mean(axis=0)
                    ]
                )
            
            # Model predictions
            water_prediction = water_classifier.predict(sub_image_water_classifier)
            if water_prediction < .5:
                # Get subimage for CNN
                sub_image_cnn = (
                    transform(sub_image)
                    .unsqueeze(0)
                    .to(device)
                    )
                
                # Get seal prediction
                seal_prediction = np.argmax(
                    model(sub_image_cnn)
                        .cpu()
                        .detach()
                        .numpy(), 
                    axis = 1
                )[0]

                if seal_prediction > .5:

                    x_center = (x1 + x2) / 2
                    y_center = (y1 + y2) / 2
                    centroids.append((x_center, y_center))
                    
    return centroids


def get_file_names(path:str) -> List[str]:
    """Get name of all image files located at the specified path

    Args:
        path (str): Path to directory containing image files

    Returns:
        List[str]: List of all image file names
    """
    names = set()
    for x in os.listdir(path):
        names.add(x.split(".")[0])
    return list(names)


def get_predicted_sub_images(cnn, water_classifier, path:str, write_path:str=None, version:str="") -> Dict:
    """For each image file in the specified directory, predictions location of seals and returns centriods of sub-images

    Args:
        cnn (Pytorch model): CNN model to predict is seals are within the sub-image
        water_classifier (Scikit learn model): Water classifier to discard majority water images
        path (str): Path to directory containing image files
        write_path (str, optional): Path to location where centriods should be stored. Defaults to None.
        version (str, optional): Additional label for centriod file. Defaults to "".

    Returns:
        Dict: _description_
    """
    seal_centroids = {}

    file_names = get_file_names(path)

    # Get Centriods
    for file_name in tqdm(file_names, desc = "Determining Images with Seals"):
        img_name =  f"{file_name}.JPG"
        centroids = get_seal_sub_images(path+img_name, water_classifier, cnn, 150)
        seal_centroids[file_name] = centroids

    # Save file
    if write_path != None:
        with open(write_path + f"/seals_centroids_{version}.pkl", "wb") as f:
            pickle.dump(seal_centroids, f)
            
    return seal_centroids

In [None]:
# Load CNN
efficientnet = efficientnet_v2_m()
efficientnet.classifier[1] = nn.Linear(in_features=1280, out_features=2)
efficientnet.load_state_dict(torch.load("../../../seal_detector\Models\PyTorch\ImageClassifierPytorch9"))
efficientnet.eval()
efficientnet = efficientnet.to(device)

In [None]:
# Load water classifier
f = open("../../../seal_detector\Models\water_classifier\water_classifier", "rb")
water_classifier = pickle.load(f)
f.close()

In [None]:
# Load Paths to images
training_path = "../../../Training, Val, and Test Images/Training Images/"
validation_path = "../../../Training, Val, and Test Images/Validation Images/"
testing_path = "../../../Training, Val, and Test Images/Test Images/"

In [None]:
write_path = "Centroids"


centroids_training = get_predicted_sub_images(efficientnet, water_classifier, training_path, write_path=write_path, version="training")
centroids_validation = get_predicted_sub_images(efficientnet, water_classifier, validation_path, write_path=write_path, version="validation")
centroids_testing= get_predicted_sub_images(efficientnet, water_classifier, testing_path, write_path=write_path, version="testing")