In [12]:
from functools import partial
from multiprocessing import Pool
from imageio import imread
import matplotlib.pyplot as plt
import numpy as np
import cvxpy as cp
from inpainting import inpaint
np.random.seed(21)

In [13]:
def partial_ravel(arr, lo=0, hi=-1):
    """Ravel axes in [lo, hi)."""
    assert (lo < hi) or (hi == -1)
    shape = arr.shape[:lo] + (-1,) + arr.shape[hi:-1]
    return arr.reshape(shape)

In [14]:
def paste_region(arr, region, val):
    """Make copy of arr with arr[region] == val."""
    arr = arr.copy()
    arr[region] = val
    return arr

In [15]:
def region_to_mask(shape, region):
    mask = np.full(shape, False)
    mask[region] = True
    return mask

In [19]:
def _imshow(ax, img):
    ax.imshow(img)
    ax.set_axis_off()

In [20]:
def plot_compare(corrupted, recovered, region):
    difference = (corrupted - recovered)[region]
    fig = plt.figure(figsize=(11, 6))
    ax = partial(plt.subplot2grid, shape=(7, 4), colspan=2)
    _imshow(ax(loc=(0, 0), rowspan=6), corrupted)
    _imshow(ax(loc=(0, 2), rowspan=6), recovered)
    _imshow(ax(loc=(6, 1), rowspan=1), difference)
    fig.tight_layout()
    fig.savefig('../images/readme/watermark_results.png')

In [21]:
def main():
    basedir = '../images/watermarked'
    fnames = [f'{basedir}/stock{i}.jpg' for i in range(3)]
    all_img = np.stack([imread(fname) for fname in fnames], axis=-1)
    sample_img = all_img[..., 0]
    
    lo, hi = 60, 256
    mask = partial_ravel((all_img > lo) & (all_img < hi), 2).all(-1)
    r1, r2, c1, c2 = 340, 380, 135, 475
    region = (slice(r1, r2), slice(c1, c2))
    mask &= region_to_mask(mask.shape, region)
    rows, cols = np.where(~mask[region])
    cropped_img = sample_img[region]

    task = partial(inpaint, rows=rows, cols=cols)
    data = np.rollaxis(cropped_img, -1)
    with Pool(3) as pool:
        recovered = pool.map(task, data)
    recovered = np.stack(recovered, -1).astype(np.uint8)
    recovered_img = paste_region(sample_img, region, recovered)
    plot_compare(sample_img, recovered_img, region)

In [22]:
main()