Import Libraries and mount the drive

In [1]:
import cv2
import numpy as np
from scipy import ndimage
from google.colab import drive
from google.colab.patches import cv2_imshow
drive.mount('/content/drive')

Mounted at /content/drive


A few helper function used by `inpainter` class.

In [2]:
# Static helper functions

def imshow(image):
    """
    A helper function to show the image
    """
    cv2_imshow(image)
    cv2.waitKey(0)

def get_data(img, patch):
  """
  A helper function to get data from an image given the patch range information
  """
    return img[patch[0, 0]:patch[0, 1], patch[1, 0]: patch[1, 1]]

The inpainter class. The default configuration is 
```
patch_size=9, patch_diff_w = 1, patch_dist_w = 1
```

In [3]:
class inpainter():
    def __init__(self, image, mask, patch_size=9, patch_diff_w = 1, patch_dist_w = 1):
        # not updating the following properties
        self.image = image.astype('uint8')  
        self.mask = np.where(mask>177, 1, 0).astype('uint8')
        self.h, self.w = mask.shape
        self.patch_size = patch_size
        self.patch_diff_w = patch_diff_w
        self.patch_dist_w = patch_dist_w

        # update the following properties iteratively
        self.image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        self.fill_image = np.where(np.dstack([self.mask, self.mask, self.mask])==1, np.zeros(shape=image.shape, dtype="uint8"), image)
        self.fill_range = np.copy(self.mask)
        self.fill_front = None
        self.confidence = (self.mask == 0).astype('float')
        self.data = np.zeros(shape=mask.shape)
        self.priority = None
        imshow(self.fill_image)
      
    def save_output(self, filename):
        """
        The helper function to save the filled image
        """
        cv2.imwrite(filename, self.fill_image)


    def get_masked_img(self):
        """
        Return a equivalent np.ma object
        """
        return np.ma.masked_array(self.image_gray, mask=self.fill_range)

    def get_gradient(self):
        """
        Function to calculate the gradient of a image
        """
        masked_img = self.get_masked_img()
        g_y, g_x = np.gradient(masked_img)
        g_y = np.ma.filled(g_y, 0)
        g_x = np.ma.filled(g_x, 0)
        gradient = np.sqrt(g_x * g_x + g_y * g_y)
        gradient = gradient * (255 / gradient.max())

        return gradient, g_y, g_x    

  
    def update_fill_front(self):
        """
        Update the array of fill_front coordinates
        """
        contours, hierarchy = cv2.findContours(self.fill_range, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        cv2.drawContours(self.fill_image, contours, -1, (0,255,0), 1)
        self.fill_front = contours[0].reshape(-1, 2)[:, ::-1]
        if len(contours) > 1:

            for i in range(1, len(contours)):
                self.fill_front = np.concatenate((self.fill_front, contours[i].reshape(-1, 2)[:,::-1]), axis=0)

    def get_patch(self, point):
        """
        Find the patch range given a centre point.
        With n being the patch size, the return value in format:
        [[y_1, y_n],
          [x_1, x_n]]
        with n being the patch size
        """
        k = self.patch_size // 2
        patch_range = [[max(0, point[0]-k), min(point[0]+k+1, self.h)], 
                       [max(0, point[1]-k), min(point[1]+k+1, self.w)]]
        return np.array(patch_range)
    
    def update_img_gray(self):
        """
        Update the grayscale image
        """
        self.image_gray = cv2.cvtColor(self.fill_image, cv2.COLOR_BGR2GRAY)


    def update_C(self):
        """
        update the confidence terms on the fill front
        """
        for pt in self.fill_front:
            patch = self.get_patch(pt)
            confidence_sum = get_data(self.confidence, patch).sum()
            area = (patch[0,1]-patch[0,0]) * (patch[1,1]-patch[1,0])
            self.confidence[pt[0], pt[1]] = confidence_sum/area


    def get_normal(self):
        """
        A function to calculate the unit normal vector of the fill front points.
        """
        g_x = cv2.Scharr(self.fill_range, cv2.CV_64F, 1, 0)
        g_y = cv2.Scharr(self.fill_range, cv2.CV_64F, 0, 1)
        normal = np.dstack([g_y, g_x])
        norm = np.sqrt(g_x * g_x + g_y * g_y).reshape(self.h, self.w, 1).repeat(2, axis=2)
        norm[norm == 0] = 1
        unit_normal = normal/norm
        return unit_normal

    def get_isophote(self):
        """
        A function to calculate the isophote of fill front points.
        The isophote of a fill front point is the maximum gradient vector within 
        the patch centered at that point, and rotated by 90 degrees
        """
        grad, g_y, g_x = self.get_gradient()
        isophote = np.zeros(shape=(self.h, self.w, 2))
        for pt in self.fill_front:
            patch = self.get_patch(pt)
            g_y_patch = get_data(g_y, patch)
            g_x_patch = get_data(g_x, patch)
            grad_patch = get_data(grad, patch)
            
            max_patch_pos = np.unravel_index(grad_patch.argmax(), grad_patch.shape)
            
            isophote[pt[0], pt[1], 0] = g_x_patch[max_patch_pos]
            isophote[pt[0], pt[1], 1] = -g_y_patch[max_patch_pos]
            
        return isophote

    def update_D(self, a=255):
        """
        Update the data terms on the fill front
        """
        normal = self.get_normal()
        isophote = self.get_isophote()
        self.data = np.sqrt(np.sum((normal*isophote)**2, axis=2))/a

    def update_priority(self):
        """
        Calling to update confidence and data terms and then update the priority
        """
        self.update_C()
        self.update_D()
        priority = self.confidence * self.data * self.fill_range
        self.priority = np.array([priority[pt[0], pt[1]] for pt in self.fill_front])
  

    def get_target_point(self):
        """
        Return the point on the fill front with the highest priority
        """
        return self.fill_front[self.priority.argmax()]


    def find_best_match(self, target_point):
        """
        Function to find the best-matching in the source area given a target 
        point, the centre of the target patch to fill.
        """

        target_patch = self.get_patch(target_point)
        target_data = get_data(self.fill_image, target_patch)
        print(f"target: {target_point}")
        imshow(target_data)
        patch_h = target_patch[0, 1]- target_patch[0, 0]
        patch_w = target_patch[1, 1]- target_patch[1, 0]

        best_patch = None
        best_distance = float('inf')
        best_match_pt = (0, 0)

        for j in range(self.h - patch_h + 1):
            for i in range(self.w - patch_w + 1):
                dst_patch = np.array([[j, j+patch_h], [i, i+patch_w]])

                # check if the destination patch is filled
                if get_data(self.fill_range, dst_patch).sum() == 0:
                    patch_diff = self.get_patch_difference(target_patch, dst_patch)
                    patch_dist = self.get_patch_distance(target_patch, dst_patch)
                    distance = self.patch_diff_w * patch_diff + self.patch_dist_w * patch_dist
                    if distance < best_distance:
                        best_distance = distance
                        best_patch = dst_patch
                        best_match_pt = (j, i)
                        
        print(f"best_match: {best_match_pt}")
        best_data = get_data(self.fill_image, best_patch)
        imshow(best_data)
        return best_patch

    def get_patch_distance(self, target_patch, dst_patch):
        """
        Function to calculate the distance of two patches
        """
        return (((dst_patch[:, 0] - target_patch[:, 0])**2).sum())**0.5

    def get_patch_difference(self, target_patch, dst_patch):
        """
        Function to compare the distance of two patches
        """
        patch_h = target_patch[0, 1] - target_patch[0, 0]
        patch_w = target_patch[1, 1] - target_patch[1, 0]
        patch_mask = (1-get_data(self.fill_range, target_patch)).reshape(patch_h, patch_w, 1).repeat(3, axis=2)
        target_data = get_data(self.fill_image, target_patch).astype("float64") * patch_mask
        dst_data = get_data(self.fill_image, dst_patch).astype("float64") * patch_mask
        
        return (((target_data - dst_data)**2).sum())**0.5

    def fill_patch(self, target_point):
        """
        Fill the patch centred at the given target point and update all required 
        properties.
        """
        target_patch = self.get_patch(target_point)


        coord_to_fill = np.where(get_data(self.fill_range, target_patch) == 1)
        
        # fill the pixels inside the target patch
        dst_patch = self.find_best_match(target_point)
        dst_data = get_data(self.fill_image, dst_patch)
        
        target_data = get_data(self.fill_image, target_patch)
        
        target_data[coord_to_fill[0], coord_to_fill[1]] = dst_data[coord_to_fill[0], coord_to_fill[1]]

        # update the confidence value of each pixel inside the target patch 
        target_confidence = get_data(self.confidence, target_patch)
        target_confidence[coord_to_fill[0], coord_to_fill[1]] = self.confidence[target_point[0], target_point[1]]

        # update the fill_range
        target_fill_range = get_data(self.fill_range, target_patch)
        target_fill_range[coord_to_fill[0], coord_to_fill[1]] = 0


    def run_inpaint(self):
        """
        Run the inpainter.
        """
        counter = 0
        while self.fill_range.sum() != 0:
            
            self.update_fill_front()
            self.update_priority()

            target_point = self.get_target_point()
            self.fill_patch(target_point)
            self.update_img_gray()

            # show progress
            imshow(self.fill_image)
            print(f"{self.fill_range.sum()} pixels remianing")
            
            counter+=1

        imshow(self.fill_image)




    

Below is the function to run the inpainter on a given test case. You need to give it a test case number, a patch size, the weight for patch difference, and the weight for patch distance. You may also need to enter the directory of the testcase folder and output directory from Google Drive of your own. We have prepared 10 testing examples in the example folder and you can add more to it.

In [4]:
def run_test(test_num, patch_size, patch_diff_w, patch_dist_w):
    img = cv2.imread(f'/content/drive/MyDrive/CSC420-PROJECT/examples/img{test_num}.png')
    mask = cv2.imread(f'/content/drive/MyDrive/CSC420-PROJECT/examples/mask{test_num}.png', 0)
    a = inpainter(img, mask, patch_size, patch_diff_w, patch_dist_w)
    a.run_inpaint()
    a.save_output(f'/content/drive/MyDrive/CSC420-PROJECT/output/result{test_num}-{patch_size}-{patch_diff_w}-{patch_dist_w}.png')

For example, to run the first test case with a patch size 9, patch difference weight 1, and patch distance weight 0.5, you can run

In [None]:
run_test(1, 9, 1, 0.5)