# Interactive demo to load a trained model for page extraction and apply it to a randomly selected file

#### 1. Get the annotated sample dataset, which already contains the folders images and labels. Unzip it into `demo/pages_sample`.

In [None]:
! wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/pages.zip
! unzip pages.zip

#### 2. Download the provided model (download and unzip it in `demo/model`)

In [None]:
! wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/model.zip
! unzip model.zip

#### 3. Run the code step by step

In [None]:
import os
import cv2
from glob import glob
import numpy as np
import random
import tensorflow as tf
from imageio import imread, imsave

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from dh_segment.io import PAGE
from dh_segment.inference import LoadedModel
from dh_segment.post_processing import boxes_detection, binarization

In [None]:
def page_make_binary_mask(probs: np.ndarray, threshold: float=-1) -> np.ndarray:
    """
    Computes the binary mask of the detected Page from the probabilities outputed by network
    :param probs: array with values in range [0, 1]
    :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used
    :return: binary mask
    """

    mask = binarization.thresholding(probs, threshold)
    mask = binarization.cleaning_binary(mask, kernel_size=5)
    return mask

Define input and output directories / files

In [None]:
model_dir = 'page_model/export'
if not os.path.exists(model_dir):
    model_dir = 'model/'
assert(os.path.exists(model_dir))

input_files = glob(os.path.join('pages', 'test_a1', 'images/*'))

In [None]:
output_dir = './processed_images'
os.makedirs(output_dir, exist_ok=True)
# PAGE XML format output
output_pagexml_dir = os.path.join(output_dir, 'page_xml')
os.makedirs(output_pagexml_dir, exist_ok=True)

Start a tensorflow session

In [None]:
session = tf.InteractiveSession()

Select a random image

In [None]:
file_to_process = random.sample(input_files, 1)[0]

Load the model

In [None]:
m = LoadedModel(model_dir, predict_mode='filename')

Predict each pixel's label

In [None]:
# For each image, predict each pixel's label
prediction_outputs = m.predict(file_to_process)
probs = prediction_outputs['probs'][0]
original_shape = prediction_outputs['original_shape']

probs = probs[:, :, 1]  # Take only class '1' (class 0 is the background, class 1 is the page)
probs = probs / np.max(probs)  # Normalize to be in [0, 1]

# Binarize the predictions
page_bin = page_make_binary_mask(probs)

# Upscale to have full resolution image (cv2 uses (w,h) and not (h,w) for giving shapes)
bin_upscaled = cv2.resize(page_bin.astype(np.uint8, copy=False),
                          tuple(original_shape[::-1]), interpolation=cv2.INTER_NEAREST)

Show the probability map and binarized mask

In [None]:
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(probs, cmap='gray')
plt.axis('off')
plt.title('Probability map')
plt.subplot(1,2,2)
plt.imshow(page_bin, cmap='gray')
plt.axis('off')
plt.title('Binary mask')

Find quadrilateral enclosing the page

In [None]:
pred_page_coords = boxes_detection.find_boxes(bin_upscaled.astype(np.uint8, copy=False),
                                              mode='min_rectangle', n_max_boxes=1)

In [None]:
# Draw page box on original image and export it. Add also box coordinates to the txt file
original_img = imread(file_to_process, pilmode='RGB')
if pred_page_coords is not None:
    cv2.polylines(original_img, [pred_page_coords[:, None, :]], True, (0, 0, 255), thickness=5)
else:
    print('No box found in {}'.format(filename))

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(original_img)

Export image and create page region and XML file

In [None]:
basename = os.path.basename(file_to_process).split('.')[0]
imsave(os.path.join(output_dir, '{}_boxes.jpg'.format(basename)), original_img)

page_border = PAGE.Border(coords=PAGE.Point.cv2_to_point_list(pred_page_coords[:, None, :]))
page_xml = PAGE.Page(image_filename=file_to_process, image_width=original_shape[1], image_height=original_shape[0], page_border=page_border)
xml_filename = os.path.join(output_pagexml_dir, '{}.xml'.format(basename))
page_xml.write_to_file(xml_filename, creator_name='PageExtractor')

#### 4. Have a look at the results in ``demo/processed_images``