In [None]:
import numpy as np
import os
import pickle 
import torch
import torch.nn as nn

from torchvision import transforms
from tqdm.notebook import tqdm
from torchvision.models import efficientnet_v2_m
from typing import List
from skimage import io

In [None]:
# Connect to the GPU if one exists.
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using: ", device)
torch.cuda.empty_cache()

In [None]:
def get_file_names(path:str) -> List[str]:
    """Get the names of the image file names

    Args:
        path (str): Path to directory containing images

    Returns:
        List[str]: list of image names
    """
    names = set()

    for x in os.listdir(path):
        names.add(x.split(".")[0])
    
    return list(names)


def get_seal_sub_images(img_path:str, model, size:int=150, device:str=device) -> List:
    """Using a CNN model get the sub-images that contain a seal

    Args:
        img_path (str): path to the image
        model (Pytorch Model): CNN model predicting if a seal is within an image
        size (int, optional): Sub-image size. Defaults to 150.
        device (str, optional): "gpu" or "cuda". Defaults to device.

    Returns:
        List: List of sub-image predicted to contain a seal
    """
    seal_sub_images = []
    img = io.imread(img_path, plugin="matplotlib")
    
    x_len = img.shape[1]
    y_len = img.shape[0]
    
    model.to(device)
    transform=transforms.Compose([transforms.ToTensor()])
    
    # crops like reading a book
    i = 0
    while (i < y_len):
        # updates the new y coordinates
        y1 = i
        if i + size > y_len:
            y1 = y_len-size
            y2 = y_len
            i = y_len
        else:
            y2 = i + size
            i += size
        j = 0
        while (j < x_len):
            # updates the new x coordinates
            x1 = j
            if j + size > x_len:
                x1 = x_len - size
                x2 = x_len
                j = x_len
            else:
                x2 = j + size
                j += size

            # Get sub-image
            cropped_original = img[y1:y2,x1:x2]
            
            # Transform sub-image
            cropped = transform(img[y1:y2,x1:x2]).unsqueeze(0).to(device)

            # Get prediction
            pred = (
                model(cropped)
                .cpu()
                .detach()
                .numpy()
                )
            
            # Determine class
            seal = np.argmax(pred, axis = 1)[0]
            
            # If prediction is above threshold, save result
            if seal > .5:
                seal_sub_images.append(cropped_original)
    
    return seal_sub_images

def get_predicted_sub_images(cnn, training_path:str, valid_path:str, test_path:str, write_path:str=None, version:str="", device:str=device):
    """Determines and saves sub-images predicted to contain seals from the training, test, and validation dataset. If a write path is specified the files are saved.
       Saved files have the following naming convention

       {write_path}/{type: training_seals, validation_seals, testing_seals}_{version}.pkl

    Args:
        cnn (Pytorch Model): CNN model to be use for seal prediction
        training_path (str): path to training images
        valid_path (str): path to validation images
        test_path (str): path to test images
        write_path (str, optional): Path to where the sub-images should be saved. Defaults to None.
        version (str, optional): Additional name of files saved. Defaults to "".
        device (str, optional): Specifies device to use. Either 'gpu' or 'cpu'. Defaults to device.

    Returns:
        _type_: _description_
    """
    seal_sub_images_training = {}
    seal_sub_images_validation = {}
    seal_sub_images_testing = {}

    # Get File Names
    training_names = get_file_names(training_path)
    valid_names = get_file_names(valid_path)
    testing_names = get_file_names(test_path)

    # Find Training Sub images
    for x in tqdm(training_names, desc = "Determining Images with Seals on Training Data"):
        img_name = x + ".JPG"
        seal_sub_images_training[x] = get_seal_sub_images(training_path+img_name, cnn, 150, device)

    # Find Validation Sub images
    for x in tqdm(valid_names, desc = "Determining Images with Seals on Validation Data"):
        img_name = x + ".JPG"
        seal_sub_images_validation[x] = get_seal_sub_images(valid_path+img_name, cnn, 150, device)

    # Find Testing Sub images
    for x in tqdm(testing_names, desc = "Determining Images with Seals on Testing Data"):
        img_name = x + ".JPG"
        seal_sub_images_testing[x] = get_seal_sub_images(test_path+img_name, cnn, 150, device)

    # Save Files 
    if write_path != None:
        with open(write_path + f"/training_seals_{version}.pkl", "wb") as f:
            pickle.dump(seal_sub_images_training, f)
        with open(write_path + f"/validation_seals_{version}.pkl", "wb") as f:
            pickle.dump(seal_sub_images_validation, f)
        with open(write_path + f"/testing_seals_{version}.pkl", "wb") as f:
            pickle.dump(seal_sub_images_testing, f)
            
    return seal_sub_images_training, seal_sub_images_validation, seal_sub_images_testing

In [None]:
# Load EfficientNet
efficientnet = efficientnet_v2_m()

# Change num of classes to 2
efficientnet.classifier[1] = nn.Linear(in_features=1280, out_features=2)

# Load pretrained weights
efficientnet.load_state_dict(torch.load("../Models/Pytorch/ImageClassifierPytorch9"))

# Set to eval
efficientnet.eval()

# Load onto GPU
efficientnet = efficientnet.to(device)

In [None]:
seal_path = "../Generated Data"

version = "pytorch"

training_path = "../../Training, Val, and Test Images/Training Images/"
valid_path = "../../Training, Val, and Test Images/Validation Images/"
test_path = "../../Training, Val, and Test Images/Test Images/"

train_seals, val_seals, test_seals = get_predicted_sub_images(efficientnet, training_path, valid_path, test_path, seal_path, version, device)