In [1]:
import os
import cv2
import numpy as np
from pathlib import Path
from contextlib import contextmanager

from ocrd import Resolver,Workspace
from ocrd_utils import polygon_from_points, MIMETYPE_PAGE
from ocrd_modelfactory import page_from_file
from matplotlib import pyplot as plt

from skimage import morphology
from skimage.color import rgb2gray
from skimage.filters import window, difference_of_gaussians
from scipy.fftpack import fft2, fftshift
from scipy import ndimage as ndi
from skimage.registration import phase_cross_correlation
from skimage.transform import warp_polar, warp, rotate, rescale,AffineTransform, SimilarityTransform



In [2]:
project_dir = Path.cwd() / 'data' / 'kant' 
ws = Workspace(Resolver(), str(project_dir))


In [3]:
"""
General helper functions 
"""

@contextmanager
def working_directory(path):
    """ 
    chdir(path) and back again
    """
    prev_cwd = Path.cwd()
    os.chdir(path)
    try:
        yield
    finally:
        os.chdir(prev_cwd)
        

def get_page(pageId, fileGrp):
    """
    Get PAGE-Xml PcGts root element for a pageId and fileGrp
    """
    with working_directory(project_dir):
        file = list(ws.mets.find_files(pageId=pageId, fileGrp=fileGrp, mimetype=MIMETYPE_PAGE))[0]
        pcgts = page_from_file(file)
        return pcgts


def get_polys(elements):
    """
    Collect the polygon outlines of elements in a numpy array
    """
    polys = []
    for el in elements:
        polys.append(np.array(polygon_from_points(el.get_Coords().points), dtype='int32'))
    return polys


def line_mask(page):
    """
    Draw a textline mask for page
    """
    mask = np.zeros((page.get_imageHeight(), page.get_imageWidth()), dtype='uint8')
    elements = [line for text_region in page.get_TextRegion() for line in text_region.get_TextLine() ]
    polys = get_polys(elements)
    cv2.fillPoly(mask, polys, 1)     
    return mask


def region_mask(page):
    """
    Draw a textline mask for page
    """
    mask = np.zeros((page.get_imageHeight(), page.get_imageWidth()), dtype='uint8')
    elements = [text_region for text_region in page.get_TextRegion()]
    polys = get_polys(elements)
    cv2.fillPoly(mask, polys, 1)     
    return mask

def shrink(images):
    """
    crop all images in a list of images to the same (smallest) width/height  
    """
    padded_images = []
    shape = np.min([image.shape for image in images], axis=0)
    for image in images:
        #padded_image = container.copy()
        padded_image = image[:shape[0],:shape[1]].copy()
        padded_images.append(padded_image)
        
    return padded_images

def post_process(mask):
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15,5)), iterations=1)
    #mx = np.min(mask.shape)
    #mask = mask[:mx,:mx]
    return mask


def align_simple(template, page):
    shift, error, phase = phase_cross_correlation(template, page, upsample_factor=20)
    return shift[1], shift[0], 0, 1, error
    
    
def align_polar(template, page):
    center = (100,100)
    radius = np.linalg.norm(np.array(template.shape[:2])-center)

    template_polar = warp_polar(template, radius=radius, scaling='log', center = center)
    page_polar = warp_polar(page, radius=radius, scaling='log', center = center)    
        
    (shift_py,shift_px), error, phase = phase_cross_correlation(template_polar, page_polar, upsample_factor=50)

    # Calculate scale factor from translation
    rotation = -shift_py
    klog = radius / np.log(radius)
    shift_scale = np.exp(shift_px / klog)

    page_rs = rescale(rotate(page.copy(), rotation), shift_scale)
    
    template,page_rs = shrink([template,page_rs])


    shift, error, phase = phase_cross_correlation(template, page_rs, upsample_factor=20)
        
    return shift[1], shift[0], rotation, shift_scale, error


def _preprocess_fft(image):
    image = difference_of_gaussians(image.copy(), 5, 20)
    wimage = image * window('hann', image.shape)
    # work with shifted FFT magnitudes
    image_fs = np.abs(fftshift(fft2(wimage)))
    return image_fs
    

def align_polar_fft(template, page, cut = 8):
    template_fs = _preprocess_fft(template)
    page_fs = _preprocess_fft(page)

    shape = page_fs.shape
    radius = shape[0] // cut  # only take lower frequencies

    template_p_fs = warp_polar(template_fs, radius=radius, output_shape=shape,scaling='log', order=0)
    page_p_fs = warp_polar(page_fs, radius=radius, output_shape=shape,scaling='log', order=0)

    template_p_fs = template_p_fs[:shape[0] // 2, :]  # only use half of FFT
    page_p_fs = page_p_fs[:shape[0] // 2, :]
    shifts, error, phasediff = phase_cross_correlation(template_p_fs, page_p_fs, upsample_factor=100)
            
    shiftr, shiftc = shifts[:2]
    rotation = -(360 / shape[1]) * shiftr
    klog = shape[1] / np.log(radius)
    shift_scale = 1/np.exp(shiftc / klog)
    
    
    page_rs = rescale(rotate(page.copy(),rotation),shift_scale)
    
    template,page_rs = shrink([template,page_rs])
    
    shift, error, phase = phase_cross_correlation(template, page_rs, upsample_factor=20)
    
    return shift[1], shift[0], rotation, shift_scale, error


In [4]:
ws.mets.physical_pages

['phys_0007',
 'phys_0008',
 'phys_0009',
 'phys_0010',
 'phys_0011',
 'phys_0012',
 'phys_0013',
 'phys_0014',
 'phys_0015',
 'phys_0016',
 'phys_0017',
 'phys_0018',
 'phys_0019',
 'phys_0020']

In [5]:
%matplotlib widget
page_id = 'phys_0007'
template_page_id = 'phys_0009'
page_seg = get_page(template_page_id,'OCR-D-SEG-LINE')
page = get_page(page_id,'OCR-D-SEG-LINE')

template_mask = post_process(line_mask(page_seg.get_Page()))
page_mask = post_process(line_mask(page.get_Page()))

#page_mask = rescale(rotate(ndi.shift(page_mask, (0,120)),-12.0),1.04)

template_mask, page_mask = shrink([template_mask, page_mask])

plt.imshow(template_mask, alpha=1)
plt.imshow(page_mask, alpha=0.7)




Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7f77415eaed0>

In [6]:
shiftx, shifty, rotation, scale, error = align_simple(template_mask, page_mask)
display((shiftx, shifty, -rotation, 1/scale, error))

shiftx, shifty, rotation, scale, error = align_polar(template_mask, page_mask)
display((shiftx, shifty, -rotation, 1/scale, error))

shiftx, shifty, rotation, scale, error = align_polar_fft(template_mask, page_mask, cut = 8)
display((shiftx, shifty, -rotation, 1/scale, error))


(16.6, -12.25, 0, 1.0, 0.7368726272592879)

(23.8, -0.1, -0.1, 1.0093738400203942, 0.7285860336246065)

(36.45, 18.9, 0.014824982841455043, 1.02273391579434, 0.72667429093788)

In [7]:
%matplotlib widget

page_mask_rec = ndi.shift(rescale(rotate(page_mask,rotation),scale),(shifty,shiftx))

template_mask_display, page_mask_rec = shrink([template_mask, page_mask_rec])

plt.imshow(template_mask_display, alpha=1)
plt.imshow(page_mask_rec, alpha=0.7)



Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7f7740542f10>