In [123]:
import numpy as np
from PIL import Image
import torch
from torch.nn import Unfold
from skimage.transform import rotate

def im2col(img, block_size, stride=1):
    if len(img.shape) == 2:
        img = np.expand_dims(img,0)
        img = np.expand_dims(img,0)
    
    filter_h, filter_w = block_size

    N, C, H, W = img.shape
    NN, CC, HH, WW = img.strides
    out_h = (H - filter_h)//stride + 1
    out_w = (W - filter_w)//stride + 1
    col = np.lib.stride_tricks.as_strided(img, (N, out_h, out_w, C, filter_h, filter_w), (NN, stride * HH, stride * WW, CC, HH, WW)).astype(float)
    return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

def w_idx2a_idx(Ashape, tshape, i):
    n,m = Ashape
    h,l = tshape
    
    n_rows = m - l + 1
    return np.array([i // n_rows, i % n_rows])

In [200]:
# Now, we'll place spongebob in a sunset photo

####### Overlay
base_scale = 10
ol_scale = base_scale

ol = Image.open('rockhopper.png')
ol = ol.convert('RGBA')
ol = ol.resize([40, 80]) # Scale the penguin down to hidding size
ol_resized = ol.resize(np.array(ol.size) // (ol_scale))


# For testing, make overlay pure one color
# s = np.array(ol)

# s[:,:,0] = 255
# s[:,:,1] = 0
# s[:,:,2] = 0
# ol = Image.fromarray(s)

t = np.array(ol_resized).transpose()
# Create the mask where the image is transparent
shape_mask = t[-1,:,:] == 0
mask = (np.concatenate([t[-1,:,:].flatten(),t[-1,:,:].flatten(),t[-1,:,:].flatten()]) == 0)

t = t[0:3,:,:]
t = torch.tensor(t, dtype = torch.float32)
flat_t = t.flatten()
t_size = (t.shape[1], t.shape[2])



###### Base Image

base = Image.open('falls.jpeg')
base = base.convert('RGBA')
base_resized = base.resize(np.array(base.size) // base_scale)


# v_off = 100
# h_off = 10
# sun[v_off:(27 + v_off),h_off:(35 + h_off),0][~shape_mask.transpose()] = 255
# sun[v_off:(27 + v_off),h_off:(35 + h_off),1][~shape_mask.transpose()] = 0
# sun[v_off:(27 + v_off),h_off:(35 + h_off),2][~shape_mask.transpose()] = 0

# sunset = Image.fromarray(sun)

# Make placement matrix
A = np.array(base_resized).transpose()
A = torch.tensor(A, dtype = torch.float32)
A = torch.unsqueeze(A,0)
A = A[:,0:3,:,:]

In [201]:
t_size

(4, 8)

In [202]:
print(ol.size)
print(base.size)

(40, 80)
(2560, 1707)


In [203]:
# Get the windows
unfold = Unfold(kernel_size= t_size)
windows = unfold(A).squeeze(0).transpose(0,1)
print(f'There are {windows.shape[0]} windows')
print(f'There are {windows.shape[0] * windows.shape[1]} window elements')

There are 41239 windows
There are 3958944 window elements


In [204]:
difference = windows - t.flatten()
difference[:,mask] = 0
norms = np.linalg.norm(difference, axis = 1)
optimal_window_idx = np.argmin(norms)

In [205]:
optimal_index = w_idx2a_idx((A.shape[2],A.shape[3]), (t.shape[1], t.shape[2]) ,optimal_window_idx)
print(f'Optimal index is {optimal_index}')

Optimal index is [129  28]


In [206]:
base.paste(ol, tuple(optimal_index * base_scale) ,  mask=ol) 
base.show()

In [207]:
base_resized.paste(ol_resized, tuple(optimal_index) ,  mask=ol_resized) 
base_resized.show()