In [10]:
import os
from pathlib import Path
import tensorflow as tf
import pandas as pd
import numpy as np
from utilities import encode_one_data_set, decode_batch_predictions, num_to_char
from test_config import TestConfiguration
config = TestConfiguration()

## Load the test set locally

In [5]:
# Set variables
IMAGE_BUCKET = 'fmnh_datasets'
MODEL_BUCKET = 'iam-model-staging'
MODEL_NAME = 'run_55_all'

In [None]:
# Download test set from Google Cloud Storage
storage_path = f'gs://{IMAGE_BUCKET}/{config.IMAGE_SET_NAME}/'
!gsutil -m cp -r $storage_path .

In [6]:
data_dir = Path(config.IMAGE_SET_NAME)

metadata = pd.read_csv(Path(data_dir, config.METADATA_FILE_NAME))
metadata['word_image_basenames'] = metadata['image_location'].map(lambda b: b.split('\\')[-1])
print(metadata.head())

images = sorted(list(map(str, list(data_dir.rglob(f'*.{config.IMAGE_FORMAT}')))))
print(f'\nNumber of images in test set: {len(images)}\n')

labels = list()
labels = [os.path.basename(l) for l in images]
labels = [metadata[metadata['word_image_basenames'] == b] for b in labels]
labels = [b['transcription'].item() for b in labels]
labels = [str(e).ljust(config.MAX_LABEL_LENGTH) for e in labels]

test_images = np.array(images)
test_labels = np.array(labels)
print(f'Testing images ({test_images.shape[0]}) and labels ({test_labels.shape[0]}) loaded.')
test_dataset = encode_one_data_set(test_images, test_labels)

                                    image_location transcription error_value  \
0  resources\words\a01\a01-000u\a01-000u-00-00.png             A          ok   
1  resources\words\a01\a01-000u\a01-000u-00-01.png          MOVE          ok   
2  resources\words\a01\a01-000u\a01-000u-00-02.png            to          ok   
3  resources\words\a01\a01-000u\a01-000u-00-03.png          stop          ok   
4  resources\words\a01\a01-000u\a01-000u-00-04.png           Mr.          ok   

  word_image_basenames  
0   a01-000u-00-00.png  
1   a01-000u-00-01.png  
2   a01-000u-00-02.png  
3   a01-000u-00-03.png  
4   a01-000u-00-04.png  

Number of images in test set: 150

Testing images (150) and labels (150) loaded.


## Model loading

In [7]:
model_uri = f'gs://{MODEL_BUCKET}/{MODEL_NAME}/model'
!gsutil -m cp -r $model_uri .

Copying gs://iam-model-staging/run_55_all/model/keras_metadata.pb...
Copying gs://iam-model-staging/run_55_all/model/run_55_all-training_history.csv...
Copying gs://iam-model-staging/run_55_all/model/variables/variables.data-00000-of-00001...
Copying gs://iam-model-staging/run_55_all/model/saved_model.pb...
Copying gs://iam-model-staging/run_55_all/model/variables/variables.index...    
| [5/7 files][ 83.0 MiB/ 83.0 MiB]  99% Done                                    

In [8]:
prediction_model_filename = Path('./model')
prediction_model = tf.keras.models.load_model(prediction_model_filename)
opt = tf.keras.optimizers.Adam()
prediction_model.compile(optimizer=opt)

## Prediction generation

In [12]:
prediction_results = pd.DataFrame(columns=['label', 'prediction'])
for batch in test_dataset:
    images = batch['image']
    labels = batch['label']
    preds = prediction_model.predict(images)
    pred_texts = decode_batch_predictions(preds)
    pred_texts = [t.replace('[UNK]', '').replace(' ', '') for t in pred_texts]
    orig_texts = []
    for label in labels:
        label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
        orig_texts.append(label)
    orig_texts = [t.replace('[UNK]', '').replace(' ', '') for t in orig_texts]
    new_results = pd.DataFrame(zip(orig_texts, pred_texts), columns=['label', 'prediction'])
    prediction_results = prediction_results.append(new_results, ignore_index=True)
print(prediction_results)
prediction_results.to_csv(Path('predictions', f'{MODEL_NAME}-predictions.csv'))

          label  prediction
0             A           A
1          more        more
2             a           a
3          Foot        Foot
4           and         and
..          ...         ...
145           .           .
146           .           .
147        have        have
148        said        said
149  Government  Government

[150 rows x 2 columns]
