In [14]:
import os
from pathlib import Path
import tensorflow as tf
import pandas as pd
import numpy as np
from utilities import encode_one_data_set, decode_batch_predictions, num_to_char
from test_config import TestConfiguration
config = TestConfiguration()

IMAGE_SET_NAME = 'standley_4058_test' # config.IMAGE_SET_NAME
METADATA_FILENAME = 'words_metadata.csv' # config.METADATA_FILE_NAME

## Load the test set locally

In [None]:
# Download test set from Google Cloud Storage
IMAGE_BUCKET = 'fmnh_datasets'

storage_path = f'gs://{IMAGE_BUCKET}/{IMAGE_SET_NAME}/'
!gsutil -m cp -r $storage_path .

In [15]:
data_dir = Path(IMAGE_SET_NAME)

metadata = pd.read_csv(Path(data_dir, METADATA_FILENAME))
metadata['word_image_basenames'] = metadata['image_location'].map(lambda b: b.split('\\')[-1])
print(metadata.head())

images = list()
images.extend(data_dir.rglob('*.png'))
images.extend(data_dir.rglob('*.jpg'))
images = sorted(list(map(str, images)))
print(f'\nNumber of images in test set: {len(images)}\n')

labels = list()
labels = [os.path.basename(l) for l in images]
labels = [metadata[metadata['word_image_basenames'] == b] for b in labels]
labels = [b['transcription'].item() for b in labels]
labels = [str(e).ljust(config.MAX_LABEL_LENGTH) for e in labels]

test_images = np.array(images)
test_labels = np.array(labels)
print(f'Testing images ({test_images.shape[0]}) and labels ({test_labels.shape[0]}) loaded.')

test_dataset = encode_one_data_set(test_images, test_labels)

                 id    barcode  block  paragraph  word gcv_identification  \
0  C0047076F-b4p0w0  C0047076F      4          0     0          Asplenium   
1  C0047076F-b4p0w1  C0047076F      4          0     1           Brakleji   
2  C0047076F-b4p0w2  C0047076F      4          0     2               p.c.   
3  C0047076F-b4p0w3  C0047076F      4          0     3              Eaton   
4  C0047076F-b4p0w4  C0047076F      4          0     4             clefta   

                           zooniverse_image_location  handwritten  \
0  file_resources\processed_images_zooniverse-Ste...         True   
1  file_resources\processed_images_zooniverse-Ste...         True   
2  file_resources\processed_images_zooniverse-Ste...         True   
3  file_resources\processed_images_zooniverse-Ste...         True   
4  file_resources\processed_images_zooniverse-Ste...         True   

  transcription  unclear  seen_count  confidence    status   collector  \
0     Asplenium    False           8       0.875

## Model loading

In [None]:
# Load model from Google Cloud Storage
MODEL_BUCKET = 'iam-model-staging'
MODEL_NAME = 'run_55_all'
model_uri = f'gs://{MODEL_BUCKET}/{MODEL_NAME}/model'
!gsutil -m cp -r $model_uri .
prediction_model_filename = Path('./model')
prediction_model = tf.keras.models.load_model(prediction_model_filename)

In [22]:
# Load model from local filesystem
MODEL_NAME = 'fine_tuned-prediction'
model_location = Path(f'../transfer_learning/{MODEL_NAME}.model')
prediction_model = tf.keras.models.load_model(model_location)

In [23]:
opt = tf.keras.optimizers.Adam()
prediction_model.compile(optimizer=opt)

## Prediction generation

In [24]:
prediction_results = pd.DataFrame(columns=['label', 'prediction'])
for batch in test_dataset:
    images = batch['image']
    labels = batch['label']
    preds = prediction_model.predict(images)
    pred_texts = decode_batch_predictions(preds)
    pred_texts = [t.replace('[UNK]', '').replace(' ', '') for t in pred_texts]
    orig_texts = []
    for label in labels:
        label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
        orig_texts.append(label)
    orig_texts = [t.replace('[UNK]', '').replace(' ', '') for t in orig_texts]
    new_results = pd.DataFrame(zip(orig_texts, pred_texts), columns=['label', 'prediction'])
    prediction_results = prediction_results.append(new_results, ignore_index=True)
print(prediction_results)
if not os.path.exists('predictions'):
    os.makedirs('predictions')
prediction_results.to_csv(Path('predictions', f'{MODEL_NAME}-predictions.csv'))

           label   prediction
0             .)            "
1          Maxon         Maan
2    delitescens  dillilconan
3       radicans     Pilicana
4           Damp         trtt
..           ...          ...
400     ropinqua    pospirnou
401    macrosora        aarsu
402          Chr          po.
403   Polypodium    Beleuiaem
404       plesio          Pas

[405 rows x 2 columns]
