Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
349 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,347 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Interactive demo to load a trained model for page extraction and apply it to a randomly selected file" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 1. Get the annotated sample dataset, which already contains the folders images and labels. Unzip it into `demo/pages_sample`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"! wget https://github.com/dhlab-epfl/dhSegment/releases/download/untagged-b55f9aa4fff5efd4b1b8/pages_sample.zip\n", | ||
"! unzip pages_sample.zip" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 2. Download the provided model (download and unzip it in `demo/model`)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"! wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/model.zip\n", | ||
"! unzip model.zip" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 3. Run the code step by step" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import cv2\n", | ||
"from glob import glob\n", | ||
"import numpy as np\n", | ||
"import random\n", | ||
"import tensorflow as tf\n", | ||
"from imageio import imread, imsave" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"%matplotlib inline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dh_segment.io import PAGE\n", | ||
"from dh_segment.inference import LoadedModel\n", | ||
"from dh_segment.post_processing import boxes_detection, binarization" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def page_make_binary_mask(probs: np.ndarray, threshold: float=-1) -> np.ndarray:\n", | ||
" \"\"\"\n", | ||
" Computes the binary mask of the detected Page from the probabilities outputed by network\n", | ||
" :param probs: array with values in range [0, 1]\n", | ||
" :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used\n", | ||
" :return: binary mask\n", | ||
" \"\"\"\n", | ||
"\n", | ||
" mask = binarization.thresholding(probs, threshold)\n", | ||
" mask = binarization.cleaning_binary(mask, kernel_size=5)\n", | ||
" return mask" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Define input and output directories / files" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_dir = 'page_model/export'\n", | ||
"if not os.path.exists(model_dir):\n", | ||
" model_dir = 'model/'\n", | ||
"assert(os.path.exists(model_dir))\n", | ||
"\n", | ||
"input_files = glob(os.path.join('pages_sample', 'images/*'))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"output_dir = './processed_images'\n", | ||
"os.makedirs(output_dir, exist_ok=True)\n", | ||
"# PAGE XML format output\n", | ||
"output_pagexml_dir = os.path.join(output_dir, 'page_xml')\n", | ||
"os.makedirs(output_pagexml_dir, exist_ok=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Start a tensorflow session" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"session = tf.InteractiveSession()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Select a random image" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"file_to_process = random.sample(input_files, 1)[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Load the model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"m = LoadedModel(model_dir, predict_mode='filename')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Predict each pixel's label" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# For each image, predict each pixel's label\n", | ||
"prediction_outputs = m.predict(file_to_process)\n", | ||
"probs = prediction_outputs['probs'][0]\n", | ||
"original_shape = prediction_outputs['original_shape']\n", | ||
"\n", | ||
"probs = probs[:, :, 1] # Take only class '1' (class 0 is the background, class 1 is the page)\n", | ||
"probs = probs / np.max(probs) # Normalize to be in [0, 1]\n", | ||
"\n", | ||
"# Binarize the predictions\n", | ||
"page_bin = page_make_binary_mask(probs)\n", | ||
"\n", | ||
"# Upscale to have full resolution image (cv2 uses (w,h) and not (h,w) for giving shapes)\n", | ||
"bin_upscaled = cv2.resize(page_bin.astype(np.uint8, copy=False),\n", | ||
" tuple(original_shape[::-1]), interpolation=cv2.INTER_NEAREST)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Show the probability map and binarized mask" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"plt.figure(figsize=(10,10))\n", | ||
"plt.subplot(1,2,1)\n", | ||
"plt.imshow(probs, cmap='gray')\n", | ||
"plt.axis('off')\n", | ||
"plt.title('Probability map')\n", | ||
"plt.subplot(1,2,2)\n", | ||
"plt.imshow(page_bin, cmap='gray')\n", | ||
"plt.axis('off')\n", | ||
"plt.title('Binary mask')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Find quadrilateral enclosing the page" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pred_page_coords = boxes_detection.find_boxes(bin_upscaled.astype(np.uint8, copy=False),\n", | ||
" mode='min_rectangle', n_max_boxes=1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Draw page box on original image and export it. Add also box coordinates to the txt file\n", | ||
"original_img = imread(file_to_process, pilmode='RGB')\n", | ||
"if pred_page_coords is not None:\n", | ||
" cv2.polylines(original_img, [pred_page_coords[:, None, :]], True, (0, 0, 255), thickness=5)\n", | ||
"else:\n", | ||
" print('No box found in {}'.format(filename))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"plt.figure(figsize=(10,10))\n", | ||
"plt.imshow(original_img)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Export image and create page region and XML file" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"basename = os.path.basename(file_to_process).split('.')[0]\n", | ||
"imsave(os.path.join(output_dir, '{}_boxes.jpg'.format(basename)), original_img)\n", | ||
"\n", | ||
"page_border = PAGE.Border(coords=PAGE.Point.cv2_to_point_list(pred_page_coords[:, None, :]))\n", | ||
"page_xml = PAGE.Page(image_filename=file_to_process, image_width=original_shape[1], image_height=original_shape[0], page_border=page_border)\n", | ||
"xml_filename = os.path.join(output_pagexml_dir, '{}.xml'.format(basename))\n", | ||
"page_xml.write_to_file(xml_filename, creator_name='PageExtractor')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 4. Have a look at the results in ``demo/processed_images``" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python [conda env:dhsegment]", | ||
"language": "python", | ||
"name": "conda-env-dhsegment-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.5.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |