In [None]:
%matplotlib inline 

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
from matplotlib import pyplot as plt
import PIL

from datasets.ICDAR.ICDAR import IcdarMerge
from utils.visualization import create_merge_result_image
from table.markup_table import Table

ops_module = tf.load_op_library('ops/ops.so')

In [None]:
ds = tfds.load('icdar_merge', split='train')
ds = ds.shuffle(128, seed=42)

In [None]:
def has_cell_to_merge(element):
    return tf.reduce_any(element['merge_right_mask']) or tf.reduce_any(element['merge_down_mask'])

num_of_tables_with_spanning_cells = ds.reduce(
    0, lambda state, element: state + tf.cast(has_cell_to_merge(element), tf.int32))
print('Tables with spanning cells: {}%'.format(num_of_tables_with_spanning_cells/len(ds)))

In [None]:
fig, axes = plt.subplots(10, 2, figsize=(16,32))

for element, ax in zip(ds.take(20), axes.flat):
    table_image = PIL.Image.fromarray(element['image'].numpy())
    h_positions = ops_module.intervals_centers(element['horz_split_points_binary'])
    v_positions = ops_module.intervals_centers(element['vert_split_points_binary'])
    cells = ops_module.infer_cells_grid_rects(
        element['merge_right_mask'], element['merge_down_mask'])
    debug_image = create_merge_result_image(
        table_image, 
        h_positions.numpy(), 
        v_positions.numpy(),
        cells.numpy()
    )
    table_id = Table.from_tensor(element['markup_table']).id
    ax.set_title('table_id = {}'.format(table_id))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.imshow(debug_image)
plt.show()