[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/kornia/tutorials/blob/master/source/extract-combine-patches.ipynb)

# Extracting and Combining Tensor Patches

In this tutorial we will show how you can extract and combine tensor patches using kornia

In [1]:
%%capture
%matplotlib inline
!pip install git+https://github.com/kornia/kornia

## Using Modules

In [2]:
import torch

from kornia.contrib import CombineTensorPatches, ExtractTensorPatches

h, w = 8, 8
win = 4
pad = 2

image = torch.randn(2, 3, h, w)
print(image.shape)
tiler = ExtractTensorPatches(window_size=win, stride=win, padding=pad)
merger = CombineTensorPatches(original_size=(h, w), window_size=win,  unpadding=pad)
image_tiles = tiler(image)
print(image_tiles.shape)
new_image = merger(image_tiles)
print(new_image.shape)
assert (image == new_image).all()

torch.Size([2, 3, 8, 8])
torch.Size([2, 9, 3, 4, 4])
torch.Size([2, 3, 8, 8])


## Using Functions

In [3]:
import torch

from kornia.contrib import combine_tensor_patches, extract_tensor_patches

h, w = 8, 8
win = 4
pad = 2

image = torch.randn(1, 1, h, w)
print(image.shape)
patches = extract_tensor_patches(image, window_size=win, stride=win, padding=pad)
print(patches.shape)
restored_img = combine_tensor_patches(patches, original_size=(h, w), window_size=win,  stride=win, unpadding=pad)
print(restored_img.shape)
assert (image == restored_img).all()

torch.Size([1, 1, 8, 8])
torch.Size([1, 9, 1, 4, 4])
torch.Size([1, 1, 8, 8])


## Important cases to consider

While using these functions, it is important to keep track of the following points:

1. Original image dimensions prior to extraction must be divisible by 2
2. Image after padding must be divisible by window_size 
3. CombineTensorPatches only works with stride == window_size

We will now examine the cases 1 and 2 and how to address them.

In [4]:
def extract_and_combine(image, window_size, padding):
    h, w = image.shape[-2:]
    tiler = ExtractTensorPatches(window_size=window_size, stride=window_size, padding=padding)
    merger = CombineTensorPatches(original_size=(h, w), window_size=window_size, unpadding=padding)
    image_tiles = tiler(image)
    print(f"Shape of tensor patches = {image_tiles.shape}")
    merged_image = merger(image_tiles)
    print(f"Shape of merged image = {merged_image.shape}")
    assert (image == merged_image).all()
    return merged_image

In [5]:
image = torch.randn(2, 3, 9, 9)
_ = extract_and_combine(image, window_size=(4, 4), padding=2)

Shape of tensor patches = torch.Size([2, 9, 3, 4, 4])


NotImplementedError: Original image size must be divisible by 2. Got (9, 9)

To solve this we could pad the image prior to extracting tensor patches

In [6]:
import torch.nn.functional as F

image = torch.randn(2, 3, 9, 9)
print(image.shape)

# Pad last two dim by 1
padded_image = F.pad(image, (1,0,1,0))
print(padded_image.shape)

h, w = padded_image.shape[-2:]

torch.Size([2, 3, 9, 9])
torch.Size([2, 3, 10, 10])


Now that the image dimensions are divisible by 2, let's try extracting and combining tensor patches

In [7]:
_ = extract_and_combine(padded_image, window_size=(4,4), padding=2)

Shape of tensor patches = torch.Size([2, 9, 3, 4, 4])


NotImplementedError: Insufficient padding

Notice that we now run into the second case i.e. padded image should be divisible by `window_size`. From the previous cell:

- original_size = (10, 10) # after we padded by 1
- window_size = (4, 4)
- padding = 2

We can indeed verify that (10 + 2 + 2) % 4 != 0. A simple solution would be to reduce padding by 1 which would result in 

- original_size = (10, 10) # after we padded by 1
- window_size = (4, 4)
- padding = 1

Now that (10 + 1 + 1) % 4 == 0, we should be good to go.

In [8]:
prepad_restored_image = extract_and_combine(padded_image, window_size=(4,4), padding=1)

Shape of tensor patches = torch.Size([2, 9, 3, 4, 4])
Shape of merged image = torch.Size([2, 3, 10, 10])


Finally, to get back our original image, we simply need to remove the padding that we added earlier

In [9]:
restored_image = F.pad(prepad_restored_image, (-1,-0,-1,-0))
print(restored_image.shape)

torch.Size([2, 3, 9, 9])


In [10]:
# Confirm that the original image and restored image are the same
assert (restored_image == image).all()

## Rectangular images

These functions also work with rectangular images provided we account for the cases mentioned above.

In [11]:
import torch

rect_image = torch.randn(1, 1, 8, 6)
print(rect_image.shape)

torch.Size([1, 1, 8, 6])


Notice that the original image dimensions (8, 6) are even so we just need to ensure the padded image is divisible by window size. In this case, the height of the image (8) is already divisible by window height (4). But this is not the case for the image width (6). To fix this, we only need to pad the width. 

In [12]:
restored_image = extract_and_combine(rect_image, window_size=(4,4), padding=(0, 1))

Shape of tensor patches = torch.Size([1, 4, 1, 4, 4])
Shape of merged image = torch.Size([1, 1, 8, 6])


In [13]:
# Confirm that the original image and restored image are the same
assert (restored_image == rect_image).all()