In [None]:
import numpy as np
import pickle
import torch

from tqdm.notebook import tqdm
from typing import List, Dict
from rcnn_utils import predict, detach_pred, get_object_detection_model

In [None]:
# Connect to the GPU if one exists.
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using: ", device)
torch.cuda.empty_cache()

In [None]:
def get_image_predictions(sub_images: List, rcnn_model, device:str=device) -> List:
    """Generate sub-image count predictions for the given RCNN model

    Args:
        sub_images (List): List of sub-images predicted to contain a seal
        rcnn_model (Pytorch Model): RCNN model that will make count predictions
        device (str, optional): Device to use. Either 'gpu' or 'cpu'. Defaults to device.

    Returns:
        List: List of sub-image predictions
    """
    image_predictions = []

    for sub_image in sub_images:

        # Get predictions from sub-iomage and detach from GPU
        sub_image_prediction = detach_pred(
            predict(
                rcnn_model, 
                sub_image[0], 
                device
                )
            )
        
        image_predictions.append(sub_image_prediction)

    return image_predictions


def generate_predictions(rcnn_model, seal_sub_images:Dict, write_name:str, write_path:str, device:str=device) -> None:
    """Generates RCNN count predictions on sub-images. Will save files to specified path with specified name. 

    Args:
        rcnn_model (Pytorch model): Pyotorch model that will make count predictions
        seal_sub_images (Dict): Dictionary containing sub-images (key: Image name, value: List of seal sub-images)
        write_name (str): Name of file to use when writing
        write_path (str): Path to location where the predictions should be stored
        device (str, optional): Device to use. Either 'gpu' or 'cpu'. Defaults to device.s
    """
    rcnn_predictions_per_image = {}

    # Generate RCNN predictions and save in dictionary
    for image_name in tqdm(seal_sub_images.keys()):

        rcnn_predictions_per_image[image_name] = get_image_predictions(
            seal_sub_images[image_name], 
            rcnn_model,
            device=device
            )

    # Save dictionary containing rcnn predictions
    with open(f"{write_path}/{write_name}_predictions.pkl", "wb") as f:
        pickle.dump(rcnn_predictions_per_image, f)

In [None]:
# Load Models
rcnn_unfrozen = get_object_detection_model(version=1, path="../..//Models/rcnn_resnet_v1_unfrozen_transformations_step_50_with_backbone_weights_50").to(device)
rcnn_frozen_v1 = get_object_detection_model(version=1, path="../../Models/rcnn_trial1_50").to(device)
rcnn_frozen_v2 = get_object_detection_model(version=2, path="../../Models/rcnn_trial3_50").to(device)

In [None]:
sub_images_path = "../../../seal_detector/Generated Data"

training_sub_images_path = f"{sub_images_path}/training_seals_pytorch.pkl"
validation_sub_image_path = f"{sub_images_path}/validation_seals_pytorch.pkl"
testing_sub_image_path = f"{sub_images_path}/testing_seals_pytorch.pkl"

with open(training_sub_images_path, "rb") as fp:
    training_sub_images = pickle.load(fp)

with open(validation_sub_image_path, "rb") as fp:
    validation_sub_images = pickle.load(fp)

with open(testing_sub_image_path, "rb") as fp:
    testing_sub_images = pickle.load(fp)

In [None]:
data_sets = [training_sub_images, validation_sub_images, testing_sub_images]
data_set_names = ["training", "validation", "testing"]

In [None]:
write_path = "../../Generated Data"

for idx in range(len(data_sets)):
    data_set = data_sets[idx]
    data_set_name = data_set_names[idx]

    generate_predictions(rcnn_unfrozen, data_set, f"unfrozen_{data_set_name}", write_path)
    generate_predictions(rcnn_frozen_v1, data_set, f"frozen_v1_{data_set_name}", write_path)
    generate_predictions(rcnn_frozen_v2, data_set, f"frozen_v2_{data_set_name}", write_path)