In [544]:
import cv2
import numpy as np
from numpy.lib.stride_tricks import as_strided
from itertools import product

In [548]:
# Function crop image based on patch_size and stride
def crop_images(data, patch_size, stride=1):
    def crop(image):
        i_h, i_w = image.shape[:2]
        p_h, p_w = patch_size

        crop_h, crop_w = i_h-((i_h-p_h)%stride), i_w-((i_w-p_w)%stride)

        return img[:crop_h, :crop_w]
    
    batch = data.shape[0]
    pair = data.shape[1]
    
    cropped_images = []
    for idx in range(batch):
        cropped_images.append(np.array(list(map(crop, data[idx]))))
    
    return np.array(cropped_images)

In [545]:
# Function to create patches from image
def extract_patches(data, patch_size, stride = 1, random_state=None):
    
    def _compute_n_patches(i_h, i_w, p_h, p_w):

        n_h = i_h - p_h + 1
        n_w = i_w - p_w + 1
        all_patches = n_h * n_w

        return all_patches
    
    def get_patches(arr, patch_shape):
        arr_ndim = arr.ndim

        extraction_step = tuple([stride] * arr_ndim)

        patch_strides = arr.strides

        slices = tuple(slice(None, None, st) for st in extraction_step)
        indexing_strides = arr[slices].strides

        patch_indices_shape = ((np.array(arr.shape) - np.array(patch_shape)) //
                               np.array(extraction_step)) + 1

        shape = tuple(list(patch_indices_shape) + list(patch_shape))
        strides = tuple(list(indexing_strides) + list(patch_strides))

        patches = as_strided(arr, shape=shape, strides=strides)
        return patches
    
    def _ex_pt(image):
        i_h, i_w = image.shape[:2]
        p_h, p_w = patch_size

        image = image.reshape((i_h, i_w, -1))
        n_colors = image.shape[-1]

        extracted_patches = get_patches(image,
                                            patch_shape=(p_h, p_w, n_colors))

        n_patches = _compute_n_patches(i_h, i_w, p_h, p_w)

        patches = extracted_patches

        patches = patches.reshape(-1, p_h, p_w, n_colors)
        # remove the color dimension if useless
        if patches.shape[-1] == 1:
            return patches.reshape((n_patches, p_h, p_w))
        else:
            return patches
    
    batch = data.shape[0]
    pair = data.shape[1]
    
    image_patches = []
    for idx in range(batch):
        image_patches.extend(np.array(list(map(_ex_pt, data[idx]))).reshape((-1, pair, patch_size[0], patch_size[1], 3)))
    
    return np.array(image_patches)

In [546]:
# Function to create image from patches
def reconstruct_patches(patches, image_size, stride=1):
    i_h, i_w = image_size[:2]
    p_h, p_w = patches.shape[1:3]
    img = np.zeros(image_size)
    img_map = np.zeros(image_size)

    n_h = i_h - p_h + 1
    n_w = i_w - p_w + 1
    for p, (i, j) in zip(patches, product(range(0, n_h, stride), range(0, n_w, stride))):
            img_map[i:i + p_h, j:j + p_w] += 1
            img[i:i + p_h, j:j + p_w] += p
    
    return np.divide(img,img_map, out=np.zeros_like(img, dtype=np.float32), where=img_map!=0)

In [530]:
PATCH_SIZE = (256, 256)
STRIDE = 100

In [531]:
img = cv2.imread('./Places365/val_large/Places365_val_00000001.jpg')

In [532]:
img.shape

(772, 512, 3)

In [533]:
train = np.array([[img, img], [img, img], [img, img]])

In [534]:
train.shape

(3, 2, 772, 512, 3)

In [535]:
# np.apply_along_axis(crop_image, 2, train, PATCH_SIZE, STRIDE)

In [536]:
images_cropped = crop_images(train, PATCH_SIZE, STRIDE)

In [537]:
images_cropped.shape

(3, 2, 756, 456, 3)

In [538]:
patches_img = extract_patches(train, PATCH_SIZE, STRIDE)

In [539]:
patches_img.shape

(54, 2, 256, 256, 3)

In [540]:
# cv2.imwrite('./patch_results/_og.png', img)

In [541]:
# cv2.imwrite('./patch_results/og.png', img)

In [542]:
# h, w, c = img.shape

In [543]:
# print(img.shape)

In [382]:
for idx, patch in enumerate(patches):
    cv2.imwrite('./patch_results/patch_1_'+str(idx+1)+'.png', patch)

In [383]:
re_img = reconstruct_patches(patches, (h, w, c))

In [384]:
re_img.shape

(712, 500, 3)

In [385]:
cv2.imwrite('./patch_results/og1.png', re_img)

True