In [1]:
%load_ext line_profiler

In [2]:
import cv2
import numpy as np
import knn

import random

In [3]:
DEFINED = 0
UNDEFINED = 255

In [5]:
def create_pattern(mask: np.ndarray, origin: tuple, k: int) -> list[tuple]:
    origin_y, origin_x = origin
    neighbors = knn.nn_circular_native(mask, origin, k)
    neighbors_relative = []
    for y, x in neighbors:
        neighbors_relative.append((y - origin_y, x - origin_x))
    return neighbors_relative


In [7]:
def generate_candidates(mask: np.ndarray, neighbors_relative: list[tuple], n: int) -> list[tuple]:
    height, width = mask.shape
    candidates = []
    while len(candidates) < n:
        y, x = np.random.randint(0, height), np.random.randint(0, width)
        if mask[y, x] == UNDEFINED:
            continue
        for y_n, x_n in neighbors_relative:
            if y + y_n < 0 or y + y_n >= height or x + x_n < 0 or x + x_n >= width or mask[y + y_n, x + x_n] == UNDEFINED:
                break
        else:
            candidates.append((y, x))
    return candidates

In [8]:
def choose_candidate(image: np.ndarray, origin: tuple, candidates: list[tuple], neighbors: list[tuple]) -> tuple:
    origin_y, origin_x = origin
    distances = []
    for candidate_y, candidate_x in candidates:
        distance = 0
        for neighbor_y, neighbor_x in neighbors:
            distance += np.sqrt(np.sum(np.power(image[origin_y + neighbor_y, origin_x + neighbor_x] - image[candidate_y + neighbor_y, candidate_x + neighbor_x], 2)))
        distances.append(distance)
    return candidates[np.argmin(distances)]

In [18]:
def inpainting(image: np.ndarray, mask: np.ndarray, n_neighbors: int = 10, n_candidates: int = 100):
    image[mask == UNDEFINED] = UNDEFINED
    undefined_pixels = np.argwhere(mask == UNDEFINED).tolist()
    iteration = 0
    while undefined_pixels:
        origin = random.choice(undefined_pixels)
        pattern = create_pattern(mask, origin, n_neighbors)
        candidates = generate_candidates(mask, pattern, n_candidates)
        candidate = choose_candidate(image, origin, candidates, pattern)
        image[origin[0], origin[1]] = image[candidate[0], candidate[1]]
        undefined_pixels.remove(origin)

        if iteration == 100:
            break
        if iteration % 100 == 0:
            cv2.imwrite(f"/tmp/inpainting/iteration{iteration:04}.png", image)
        iteration += 1
    cv2.imwrite(f"/tmp/inpainting/iteration{iteration:04}.png", image)

In [20]:
image01 = cv2.imread("data/image_01.jpg")
mask01 = cv2.imread("data/mask_01.png", cv2.IMREAD_GRAYSCALE)

%lprun -f inpainting inpainting(image01, mask01)

Timer unit: 1e-09 s

Total time: 0.480802 s
File: /tmp/ipykernel_83585/2782660982.py
Function: inpainting at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def inpainting(image: np.ndarray, mask: np.ndarray, n_neighbors: int = 10, n_candidates: int = 100):
     2         1     164928.0 164928.0      0.0      image[mask == UNDEFINED] = UNDEFINED
     3         1    1360194.0    1e+06      0.3      undefined_pixels = np.argwhere(mask == UNDEFINED).tolist()
     4         1        300.0    300.0      0.0      iteration = 0
     5       101      24121.0    238.8      0.0      while undefined_pixels:
     6       101     304045.0   3010.3      0.1          origin = random.choice(undefined_pixels)
     7       101    4956859.0  49077.8      1.0          pattern = create_pattern(mask, origin, n_neighbors)
     8       101  137572039.0    1e+06     28.6          candidates = generate_candidates(mask, pattern, n_candidates