In [91]:
import numpy as np
from sklearn.feature_extraction.image import extract_patches_2d
from skimage.util.shape import view_as_blocks
from skimage.util import view_as_windows
from numpy.lib import stride_tricks
import cv2

np.set_printoptions(precision=3)

class PatchMatch(object):
    def __init__(self, a, b, patch_size):
        assert a.shape == b.shape, "Dimensions were unequal for patch-matching input"
        
        self.orig_shape = a.shape
        self.patch_size = patch_size
        self.offset_from_edge = self.patch_size // 2
        offset_from_edge = self.offset_from_edge
        self.a = a
        self.b = b
        self.nnf = np.ndarray(shape=(2, a.shape[1], a.shape[2]),dtype=np.int32)
        self.nnd = np.ndarray(shape=(a.shape[1], a.shape[2]))

        print(self.a.shape)
        print(self.b.shape)
       
    
        for i in range(0, a.shape[1] ): #loop  through y
            for j in range(0, a.shape[2]): # loop through x
                self.nnf[0][i][j] = int(np.random.randint(0,a.shape[2]))
                self.nnf[1][i][j] = int(np.random.randint(0,a.shape[1]))
                self.nnd[i][j] = self.calculate_distance(j,i,self.nnf[0][i][j],self.nnf[1][i][j])
                
    def clean_coords(self,x,y):
        assert x >= self.offset_from_edge and x < self.orig_shape[2] , "X coordinate off"
        assert y >= self.offset_from_edge and y < self.orig_shape[1] , "Y coordinate off"
        
        
    def calculate_distance(self,ax,ay,bx,by):

        patch_a = self.get_patch_for_coords(self.a,ax,ay)
        patch_b = self.get_patch_for_coords(self.b,bx,by)
        assert patch_a.shape == patch_b.shape, " ax {}  ay {}  bx {}  by {}".format(ax,ay,bx,by)


        diff = patch_a - patch_b
        distances = np.linalg.norm(diff,axis=0)
#         print("DISTANCE {}".format(distances.size))
        average = np.sum(distances)/(distances.size)
        return average
        
        
    
    def get_patch_for_coords(self, arr, x, y):
        """
        Get a patch of at max patch_size X patch_size.
        x and y denotes the center of the patch

        :arr: an array of dimensions C * H * W 
        :return:
        """
        x,y = self.clip_coords(x,y)

        return arr[ : , y : y+self.offset_from_edge, x:x+self.offset_from_edge]
    
    def propagate(self):
        old_nnd = self.nnd.copy()
        print(self.nnd)
        print("-"*10)
        for i in range(1,self.orig_shape[1]): # loop through ys
            for j in range(1,self.orig_shape[2]):# loop through xs
                current_pos_distance = self.nnd[i][j] 
                
                horiz_x = j - 1
                horiz_y = i
                horiz_dist = self.calculate_distance(j,i,self.nnf[0][horiz_y][horiz_x],self.nnf[1][horiz_y][horiz_x])
                
                vert_x = j 
                vert_y = i - 1  
                vert_dist = self.calculate_distance(j,i,self.nnf[0][vert_y][vert_x],self.nnf[1][vert_y][vert_x])

                best_dist = min(current_pos_distance,horiz_dist,vert_dist)
                
                if best_dist == current_pos_distance:
                    best_x = j
                    best_y = i
                elif best_dist == horiz_dist:
                    best_x = horiz_x
                    best_y = horiz_y
                elif best_dist == vert_dist:
                    best_x = vert_x
                    best_y = vert_y
                    
                rand_d = min(self.a.shape[1]//2, self.a.shape[2]//2)
                
                while rand_d > 0:
                        xmin = max(best_x - rand_d, 0)
                        xmax = min(best_x + rand_d, self.b.shape[2])
                        ymin = max(best_y - rand_d, 0)
                        ymax = min(best_y + rand_d, self.b.shape[1])

                        rand_x = np.random.randint(xmin, xmax)
                        rand_y = np.random.randint(ymin, ymax)
                        val = self.calculate_distance(j,i, rand_x, rand_y)
                        if val < best_dist:
                            best_x, best_y, best_dist = rand_x, rand_y, val
                        rand_d = rand_d // 2    
              
                self.nnf[0][i][j] = best_x 
                self.nnf[1][i][j] = best_y
                self.nnd[i][j] = best_dist
        print(self.nnd)
        print("_"*10)
        print(old_nnd-self.nnd)
            
    
    def get_loss_value(self):
        return np.sum(self.nnd)
    def clip_coords(self,x,y):
        return max(x,0) , max(y,0)
        
    def back_propagate(self):
            old_nnd = self.nnd.copy()
            print(self.nnd)
            print("-"*10)
            for i in range(self.orig_shape[1]-2,0,-1): # loop through ys
                for j in range(self.orig_shape[2]-2,0,-1):# loop through xs
                    current_pos_distance = self.nnd[i][j] 

                    horiz_x = j + 1
                    horiz_y = i
                    horiz_dist = self.calculate_distance(j,i,self.nnf[0][horiz_y][horiz_x],self.nnf[1][horiz_y][horiz_x])

                    vert_x = j 
                    vert_y = i + 1  
                    vert_dist = self.calculate_distance(j,i,self.nnf[0][vert_y][vert_x],self.nnf[1][vert_y][vert_x])

                    best_dist = min(current_pos_distance,horiz_dist,vert_dist)

                    if best_dist == current_pos_distance:
                        best_x = j
                        best_y = i
                    elif best_dist == horiz_dist:
                        best_x = horiz_x
                        best_y = horiz_y
                    elif best_dist == vert_dist:
                        best_x = vert_x
                        best_y = vert_y

                    rand_d = min(self.a.shape[1]//2, self.a.shape[2]//2)

                    while rand_d > 0:
                            xmin = max(best_x - rand_d, 0)
                            xmax = min(best_x + rand_d, self.b.shape[2])
                            ymin = max(best_y - rand_d, 0)
                            ymax = min(best_y + rand_d, self.b.shape[1])

                            rand_x = np.random.randint(xmin, xmax)
                            rand_y = np.random.randint(ymin, ymax)
                            val = self.calculate_distance(j,i, rand_x, rand_y)
                            if val < best_dist:
                                best_x, best_y, best_dist = rand_x, rand_y, val
                            rand_d = rand_d // 2    

                    self.nnf[0][i][j] = best_x 
                    self.nnf[1][i][j] = best_y
                    self.nnd[i][j] = best_dist
            print(self.nnd)
            print("_"*10)
            print(old_nnd-self.nnd)



In [92]:
%timeit
# test_a = np.random.randint(low=0,high=10,size=(7,100, 100))
# test_b = np.random.randint(low=0,high=10,size=(7,100, 100))

x = cv2.imread("/Users/harshvardhangupta/Deep-Image-Analogy/notebooks/Unknown.jpeg").transpose(2,1,0)
y = cv2.imread("/Users/harshvardhangupta/Deep-Image-Analogy/notebooks/Unknown.jpeg").transpose(2,1,0)
# x = x[:,:7,:7]
# y = y[:,:7,:7]


print(x.shape)
print("-")
# test_a = test_a.transpose((2,0,1))
# print(test_a.shape)

pm = PatchMatch(x,y,7)
pm.propagate()
# pm.back_propagate()
# pm.propagate()
# # pm.back_propagate()
# pm.propagate()
# # pm.get_loss_value()
# pm.propagate()



(3, 275, 183)
-
(3, 275, 183)
(3, 275, 183)


AssertionError:  ax 22  ay 0  bx 182  by 92

In [None]:
pm.calculate_distance(3,9,247,77)

In [None]:
pm.get_patch_for_coords(x,0,0)

(3, 275, 183)


In [None]:
x[:,0:3,0:3]

In [None]:
pm.a[:,7+2,8+2]

In [46]:
pm.nnf[:,0,0] = [2,2]

In [47]:
pm.nnf[:,0,0]

array([2, 2], dtype=int32)

In [82]:
pm.nnd

array([[ 359.912,   85.726,  171.248, ...,  197.831,  405.288,  380.669],
       [ 307.176,    0.   ,   66.053, ...,    0.   ,    0.   ,    0.   ],
       [ 239.543,   20.785,   82.735, ...,    0.   ,    0.   ,    0.   ],
       ..., 
       [  93.835,    0.   ,    0.   , ...,    0.   ,    0.   ,    0.   ],
       [ 253.902,    0.   ,    0.   , ...,    0.   ,    0.   ,    0.   ],
       [ 178.978,    0.   ,    0.   , ...,    0.   ,    0.   ,    0.   ]])