In [21]:
import os
import absl.logging

import PIL.Image
import numpy as np

from tensorflow import keras

from functions.ciou import ciou_loss, ciou_metric
from functions.loading_data import SMALLER_HEIGHT, SMALLER_WIDTH
from functions.drawing import draw_rectangle

absl.logging.set_verbosity(absl.logging.ERROR)

In [15]:
best_name = 'roi_detection_iou_26_65'
best_path = os.path.join('best_performing_models', best_name)
best_model = keras.models.load_model(
    best_path,
    custom_objects={
        'ciou_metric': ciou_metric,
        'ciou_loss': ciou_loss
    })

In [16]:
get_names = lambda root_path: [
    file_name.split('.')[0]
    for dir_path, _, file_names in os.walk(root_path)
    for file_name in file_names
]
get_paths = lambda path: [f'{os.path.join(root, file)}' for root, dirs, files in os.walk(path) for file in files]
base_dir = os.path.join('..', 'data', 'images_original_inception_resnet_v2_200x150_splitted')
valid_dir = os.path.join(base_dir, 'validation')

In [22]:
def get_images_array(paths: list[str]) -> np.ndarray:
    rows = []
    rescale = keras.layers.Rescaling(1./255)

    for path in paths:
        with PIL.Image.open(path) as image:
            image_array = np.asarray(image)
            rescaled_image = rescale(image_array)
            rows.append(rescaled_image)

    return np.array(rows)


def get_name(path: str) -> str:
    return '_'.join(
        path
            .split(os.sep)[-1]
            .split('.')[-2])


valid_paths = get_paths(valid_dir)
X_valid = get_images_array(valid_paths)

In [23]:
ys_valid = best_model.predict(X_valid)[0]
ys_valid[:, [0, 2]] *= SMALLER_HEIGHT
ys_valid[:, [1, 3]] *= SMALLER_WIDTH



In [26]:
base_test_path = os.path.join('..', 'data', 'tests', best_name)

for cnt in range(len(valid_paths)):
    path = valid_paths[cnt]
    name = get_name(path)
    new_path = os.path.join(base_test_path, f'{name}.jpg')
    coords = ys_valid[cnt]
    top = coords[0]
    right = coords[1]
    bottom = coords[2]
    left = coords[3]

    with PIL.Image.open(path).convert('RGB').convert('L') as img:
        draw_rectangle(img, (left, top, right, bottom), new_path)