In [None]:
import numpy as np
import imageio
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage import exposure
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image

def convert(img, target_type_min=0, target_type_max=255, target_type=np.uint8):
    imin = img.min()
    imax = img.max()

    a = (target_type_max - target_type_min) / (imax - imin)
    b = target_type_max - a * imax
    new_img = (a * img + b).astype(target_type)
    return new_img


def fft_decompose(img):
  img_fft = np.fft.fftshift(np.fft.fft2(img,axes=(0,1)),axes=(0,1))
  return np.abs(img_fft), np.angle(img_fft)

def mutate(src,trg,lambda_u,view_image=False,view_amp=False,lambda_l=0.0,gamma=1.0,save=None):
  h,w,c = src.shape
  if h%2==0:
    h=h-1
  if w%2==0:
    w=w-1

  if view_image:
    plt.figure()
    plt.imshow(src)
    plt.title('src')
    
    plt.figure()
    plt.imshow(trg)
    plt.title('trg')

  src_amp, src_ang = fft_decompose(src)
  trg_amp, trg_ang = fft_decompose(trg)

  if view_amp:
    plt.figure()
    plt.imshow(convert(np.log(1+src_amp)))
    plt.title('log(1+src_amp)')

    plt.figure()
    plt.imshow(convert(np.log(1+trg_amp)))
    plt.title('log(1+trg_amp)')

  if lambda_l>0:
    lefto = int(max(int(w/2)+1 - (lambda_l*w/200),0))
    righto = int(min(int(w/2)+1 + (lambda_l*w/200),w))
    topo = int(max(int(h/2)+1 - (lambda_l*h/200),0))
    bottomo = int(min(int(h/2)+1 + (lambda_l*h/200),h))
    lower_frequencies = src_amp[topo:bottomo, lefto:righto, :].copy()
 

  
  left = int(max(int(w/2)+1 - (lambda_u*w/200),0))
  right = int(min(int(w/2)+1 + (lambda_u*w/200),w))
  top = int(max(int(h/2)+1 - (lambda_u*h/200),0))
  bottom = int(min(int(h/2)+1 + (lambda_u*h/200),h))

  win_h,win_c,_ = trg_amp[top:bottom, left:right, :].shape
  win = np.outer(np.blackman(win_h),np.blackman(win_c))
  win = np.dstack((win,win,win))
  
  src_amp[top:bottom, left:right, :] = np.multiply((1-win),src_amp[top:bottom, left:right, :]) + np.multiply(trg_amp[top:bottom, left:right, :] ,win)

  if lambda_l>0:
    src_amp[topo:bottomo, lefto:righto, :] = lower_frequencies.copy()
  src_fft = src_amp * np.exp( 1j * src_ang )

  new_src = convert(np.real(np.fft.ifft2(np.fft.ifftshift(src_fft,axes=(0,1)),axes=(0,1))))
  if save is not None:
    imageio.imwrite(save,exposure.adjust_gamma(new_src, gamma))
  return new_src