# Prediction Script
Contains function that takes image and makes prediction from model.

## Imports

In [1]:
from tensorflow.keras.layers import StringLookup
import tensorflow as tf
import numpy as np



## Load model

In [2]:
def load_model(dotkeras_path: str):
    prediction_model = tf.keras.models.load_model(dotkeras_path)
    return prediction_model

## Image preprocess functions

In [6]:
AUTOTUNE = tf.data.AUTOTUNE

batch_size = 1
padding_token = 99
image_width = 128
image_height = 32
max_len = 21

charfile = open('./models/characters.txt', 'r')
characters = charfile.read()
characters = characters.split(' ')

# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)

# Mapping integers back to original characters.
num_to_char = StringLookup(
    vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)

def distortion_free_resize(image, img_size):
    w, h = img_size
    image = tf.image.resize_with_pad(image, h, w)
    # documentation had a different thing with 
    # tf.image.resize, find padding diff, then tf.pad
    # found this function instead and I guess it works fine

    # lines below convert from vertical to horizontal, then flips to correct orientation
    image = tf.transpose(image, perm=[1, 0, 2])
    image = tf.image.flip_left_right(image)
    return image

def preprocess_image(image_path, img_size=(image_width, image_height)):
    # plot image as test
    import matplotlib.pyplot as plt
    import numpy as np

    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, 1) # 1: output to grayscale. tensor of uint8 or uint16
    image = distortion_free_resize(image, img_size)
    image = tf.cast(image, tf.float32) / 255.0 # cast to float instead of int
    return image

def process_image_labels(image_path, label):
    # calls above functions, gets preprocessed image and label, returns as dict
    image = preprocess_image(image_path)
    return {"image": image, "label": label}

def prepare_dataset(image_paths, labels):
    # calls all functions above, makes tf dataset with image paths, 
    # maps image paths and labels to tf images and tf labels
    # TODO look up AUTOTUNE
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels)).map(
        process_image_labels, num_parallel_calls=AUTOTUNE
    )
    return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)

2024-01-13 21:52:22.444871: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-01-13 21:52:22.444894: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 24.00 GB
2024-01-13 21:52:22.444901: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 8.00 GB
2024-01-13 21:52:22.444940: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-01-13 21:52:22.444962: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


## Image prediction function(s)

In [7]:
# A utility function to decode the output of the network.
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search.
    results = tf.keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_len
    ]
    # Iterate over the results and get back the text.
    output_text = []
    for res in results:
        res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

## Main function

In [8]:
def predict(img_path: str, dotkeras_path: str) -> str:
    dataset = prepare_dataset([img_path], ["null"])
    prediction_model = load_model(dotkeras_path)
    for batch in dataset:
        batch_image = batch["image"]
        pred = prediction_model.predict(batch_image)
        pred_text = decode_batch_predictions(pred)

    return pred_text

In [10]:
predict('./data/words/a01/a01-000u/a01-000u-00-00.png', './models/50_epochs.keras')
predict('./data/words/a02/a02-000/a02-000-00-04.png', './models/50_epochs.keras')



['a']