In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import PIL
from matplotlib import pyplot as plt

from merge.evaluation import load_model
import datasets.ICDAR.ICDAR
import datasets.FinTabNet.FinTabNet
from utils.visualization import create_merge_result_image

In [None]:
m = load_model('checkpoints/merge_icdar.ckpt', False)

In [None]:
ds = tfds.load('icdar_merge', split='test')
ds = ds.shuffle(128, seed=42)

In [None]:
def get_predictions(ds_element):
    input_keys = [
        'image',
        'horz_split_points_probs',
        'vert_split_points_probs',
        'horz_split_points_binary',
        'vert_split_points_binary'
    ]
    inputs = {key: tf.expand_dims(ds_element[key], 0) for key in input_keys}
    
    outputs = m(inputs)
    return (
        outputs['h_positions'].numpy(),
        outputs['v_positions'].numpy(),
        outputs['cells_grid_rects'].numpy()
    )

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(16,32))

for element, ax in zip(ds.take(10), axes.flat):
    h_positions, v_positions, cells = get_predictions(element)
    image = PIL.Image.fromarray(element['image'].numpy())
    debug_image = create_merge_result_image(
        image, h_positions, v_positions, cells)
    
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.imshow(debug_image)
fig.subplots_adjust(wspace=0.1, hspace=0.1)

plt.savefig('images/merge_model_predictions.png', bbox_inches='tight')
plt.show()