In [1]:
from PIL import Image
import cv2
import tensorflow as tf
from tensorflow.keras.models import load_model
import numpy as np
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
input_shape = (128, 128)

In [3]:
categories = ['Cat', 'Dog']
model = load_model('../ai_models/cat_dog_squared_10.keras')

In [4]:
model.summary()

In [5]:
def make_square(image):
    width, height = image.size

    # If the image is already square, return the original image
    if width == height:
        return image

    # Determine the size of the new square image (it should be the max of width and height)
    new_size = max(width, height)

    # Create a new black image with a square size
    new_image = Image.new("RGB", (new_size, new_size), color=(0, 0, 0))  # Black background

    # Calculate the position to paste the original image (centered)
    paste_position = ((new_size - width) // 2, (new_size - height) // 2)

    # Paste the original image onto the new black square
    new_image.paste(image, paste_position)

    return new_image

In [6]:
file_path = '../samples/ex_8.png'
img = Image.open(file_path)
new_image = make_square(img) 
img_resized = new_image.resize(input_shape, Image.Resampling.LANCZOS)
img_result = np.array(img_resized)
img_result = np.expand_dims(img_result, axis=0)
img_result = img_result / 255.0

print(f'shape: {img_result.shape}')

shape: (1, 128, 128, 3)


In [7]:
result = model.predict(img_result, verbose=1)

if result < 0.5:
    label = categories[0]
else:
    label = categories[1]

print(f'result: {result}')
print(f'label: {label}')    

factor_scale = 3

WIDTH = int(128*factor_scale)
HEIGHT = int(128*factor_scale)

img_result = img_result.reshape((128, 128, 3))
img_result = cv2.resize(img_result, (WIDTH, HEIGHT))
#img_result = cv2.rotate(img_result, cv2.ROTATE_90_CLOCKWISE)

text = f"Label: {label}{' '*4}Result: {result}"
position = (30, 30)  
font = cv2.FONT_HERSHEY_SIMPLEX 
font_scale = 0.8
color = (0, 0, 0)
thickness = 2

cv2.putText(img_result, text, position, font, font_scale, color, thickness)
cv2.imshow('', img_result)
cv2.waitKey(0)
cv2.destroyAllWindows()

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 106ms/step
result: [[0.83611274]]
label: Dog
