Skip to content

Commit

Permalink
updated demo
Browse files Browse the repository at this point in the history
  • Loading branch information
solivr committed Nov 1, 2018
1 parent 91540f2 commit 9889c7d
Show file tree
Hide file tree
Showing 2 changed files with 349 additions and 2 deletions.
4 changes: 2 additions & 2 deletions demo.py
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

from dh_segment.io import PAGE
from dh_segment.network import LoadedModel
from dh_segment.inference import LoadedModel
from dh_segment.post_processing import boxes_detection, binarization

# To output results in PAGE XML format (http://www.primaresearch.org/schema/PAGE/gts/pagecontent/2013-07-15/)
Expand Down Expand Up @@ -96,7 +96,7 @@ def format_quad_to_string(quad):

# Create page region and XML file
page_border = PAGE.Border(coords=PAGE.Point.cv2_to_point_list(pred_page_coords[:, None, :]))
page_xml = PAGE.Page(filename, image_width=original_shape[1], image_height=original_shape[0],
page_xml = PAGE.Page(image_filename=filename, image_width=original_shape[1], image_height=original_shape[0],
page_border=page_border)
xml_filename = os.path.join(output_pagexml_dir, '{}.xml'.format(basename))
page_xml.write_to_file(xml_filename, creator_name='PageExtractor')
Expand Down
347 changes: 347 additions & 0 deletions demo/interactive_demo.ipynb
@@ -0,0 +1,347 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Interactive demo to load a trained model for page extraction and apply it to a randomly selected file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1. Get the annotated sample dataset, which already contains the folders images and labels. Unzip it into `demo/pages_sample`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! wget https://github.com/dhlab-epfl/dhSegment/releases/download/untagged-b55f9aa4fff5efd4b1b8/pages_sample.zip\n",
"! unzip pages_sample.zip"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2. Download the provided model (download and unzip it in `demo/model`)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/model.zip\n",
"! unzip model.zip"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3. Run the code step by step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"from glob import glob\n",
"import numpy as np\n",
"import random\n",
"import tensorflow as tf\n",
"from imageio import imread, imsave"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dh_segment.io import PAGE\n",
"from dh_segment.inference import LoadedModel\n",
"from dh_segment.post_processing import boxes_detection, binarization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def page_make_binary_mask(probs: np.ndarray, threshold: float=-1) -> np.ndarray:\n",
" \"\"\"\n",
" Computes the binary mask of the detected Page from the probabilities outputed by network\n",
" :param probs: array with values in range [0, 1]\n",
" :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used\n",
" :return: binary mask\n",
" \"\"\"\n",
"\n",
" mask = binarization.thresholding(probs, threshold)\n",
" mask = binarization.cleaning_binary(mask, kernel_size=5)\n",
" return mask"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define input and output directories / files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_dir = 'page_model/export'\n",
"if not os.path.exists(model_dir):\n",
" model_dir = 'model/'\n",
"assert(os.path.exists(model_dir))\n",
"\n",
"input_files = glob(os.path.join('pages_sample', 'images/*'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output_dir = './processed_images'\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"# PAGE XML format output\n",
"output_pagexml_dir = os.path.join(output_dir, 'page_xml')\n",
"os.makedirs(output_pagexml_dir, exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Start a tensorflow session"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"session = tf.InteractiveSession()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Select a random image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"file_to_process = random.sample(input_files, 1)[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m = LoadedModel(model_dir, predict_mode='filename')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict each pixel's label"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For each image, predict each pixel's label\n",
"prediction_outputs = m.predict(file_to_process)\n",
"probs = prediction_outputs['probs'][0]\n",
"original_shape = prediction_outputs['original_shape']\n",
"\n",
"probs = probs[:, :, 1] # Take only class '1' (class 0 is the background, class 1 is the page)\n",
"probs = probs / np.max(probs) # Normalize to be in [0, 1]\n",
"\n",
"# Binarize the predictions\n",
"page_bin = page_make_binary_mask(probs)\n",
"\n",
"# Upscale to have full resolution image (cv2 uses (w,h) and not (h,w) for giving shapes)\n",
"bin_upscaled = cv2.resize(page_bin.astype(np.uint8, copy=False),\n",
" tuple(original_shape[::-1]), interpolation=cv2.INTER_NEAREST)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Show the probability map and binarized mask"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10,10))\n",
"plt.subplot(1,2,1)\n",
"plt.imshow(probs, cmap='gray')\n",
"plt.axis('off')\n",
"plt.title('Probability map')\n",
"plt.subplot(1,2,2)\n",
"plt.imshow(page_bin, cmap='gray')\n",
"plt.axis('off')\n",
"plt.title('Binary mask')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Find quadrilateral enclosing the page"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_page_coords = boxes_detection.find_boxes(bin_upscaled.astype(np.uint8, copy=False),\n",
" mode='min_rectangle', n_max_boxes=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Draw page box on original image and export it. Add also box coordinates to the txt file\n",
"original_img = imread(file_to_process, pilmode='RGB')\n",
"if pred_page_coords is not None:\n",
" cv2.polylines(original_img, [pred_page_coords[:, None, :]], True, (0, 0, 255), thickness=5)\n",
"else:\n",
" print('No box found in {}'.format(filename))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10,10))\n",
"plt.imshow(original_img)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Export image and create page region and XML file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"basename = os.path.basename(file_to_process).split('.')[0]\n",
"imsave(os.path.join(output_dir, '{}_boxes.jpg'.format(basename)), original_img)\n",
"\n",
"page_border = PAGE.Border(coords=PAGE.Point.cv2_to_point_list(pred_page_coords[:, None, :]))\n",
"page_xml = PAGE.Page(image_filename=file_to_process, image_width=original_shape[1], image_height=original_shape[0], page_border=page_border)\n",
"xml_filename = os.path.join(output_pagexml_dir, '{}.xml'.format(basename))\n",
"page_xml.write_to_file(xml_filename, creator_name='PageExtractor')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 4. Have a look at the results in ``demo/processed_images``"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:dhsegment]",
"language": "python",
"name": "conda-env-dhsegment-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 9889c7d

Please sign in to comment.