# Introduction

In this notebook we perform the comparison experiments with a numpy implementation of HIO which can also be used to use the  Gerchberg-Saxton and Input-Output. In the follwing we compute all three algorithms for each dataset. 

In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

from dataset import *
from train_abs import *
from test_abs import *
from utils import *
from skimage.transform import rotate
from skimage.metrics import structural_similarity, mean_squared_error, peak_signal_noise_ratio
from skimage.util import crop
from skimage.registration import phase_cross_correlation
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

cuda
Loaded util functions


# Load data

In [2]:
x_train_mnist, x_test_mnist = load_MNIST(32)
x_train_emnist, x_test_emnist = load_EMNIST(32)
x_train_fmnist, x_test_fmnist = load_FMNIST(32)
x_train_cifar, x_test_cifar = load_CIFAR10(32, channel = -1)
x_train_svhn, x_test_svhn = load_SVHN(32,channel = -1)

Loaded MNIST dataset: x_train(100000, 32, 32), x_valid(140000, 32, 32)
Loaded EMNIST dataset: x_train(100000, 32, 32), x_valid(24800, 32, 32)
Loaded FashionMNIST gray dataset: x_train(60000, 32, 32), x_valid(10000, 32, 32)
Files already downloaded and verified
Files already downloaded and verified
using Gray image
Loaded CIFAR10 gray dataset: x_train(50000, 32, 32), x_valid(10000, 32, 32)
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat
using Gray image
Loaded SVHN gray dataset: x_train(73257, 32, 32), x_valid(26032, 32, 32)


### ATTENTION: In requires a lot of computational time to rerun this notebook, if the whole datasets are used

In [3]:
n_test_mnist = 10000 # as the authors use also only 10000
n_test_emnist = len(x_test_emnist)
n_test_fmnist = len(x_test_fmnist)
n_test_cifar = len(x_test_cifar)
n_test_svhn = len(x_test_svhn)

x_test_mnist = x_test_mnist[:n_test_mnist]

# Import methods for experiments

In [4]:
# import augmentation method
np.random.seed(417)
def image_magnitudes_oversample(image, pad):
    image_padded = np.pad(image, pad, 'constant')
    magnitudes_oversampled = np.abs(np.fft.fft2(image_padded))
    mask = np.pad(np.ones_like(image), pad, 'constant')
    return magnitudes_oversampled, mask

# psnr without register
def psnr_crop(img, gt, pad):
    img = crop(img, pad)
    img[img>1]=1
    mini = min(np.min(img), np.min(gt))
    maxi = max(np.max(img), np.max(gt))
    dst = maxi - mini
    psnr = peak_signal_noise_ratio(gt, img, data_range= dst)
    return psnr

# psnr with register
def psnr_register_flip(img, gt, pad):
    img = crop(img, pad)
    img[img>1]=1
    img = rotate(img, 180)
    s,_,_ = phase_cross_correlation(gt, img)
    shift = (int(s[0]) , int(s[1]))
    img = np.roll(img, shift, axis=0)
    mini = min(np.min(img), np.min(gt))
    maxi = max(np.max(img), np.max(gt))
    dst = maxi - mini
    psnr = peak_signal_noise_ratio(gt, img, data_range= dst)
    return psnr

In [5]:
#import retrieval algorithm
def fienup_phase_retrieval(mag, mask=None, beta=0.8, 
                           steps=200, mode='hybrid', verbose=True):
    """
    Implementation of Fienup's phase-retrieval methods. This function
    implements the input-output, the output-output and the hybrid method.
    
    Note: Mode 'output-output' and beta=1 results in 
    the Gerchberg-Saxton algorithm.
    
    Parameters:
        mag: Measured magnitudes of Fourier transform
        mask: Binary array indicating where the image should be
              if padding is known
        beta: Positive step size
        steps: Number of iterations
        mode: Which algorithm to use
              (can be 'input-output', 'output-output' or 'hybrid')
        verbose: If True, progress is shown
    
    Returns:
        x: Reconstructed image
    
    Author: Tobias Uelwer
    Date: 30.12.2018
    
    References:
    [1] E. Osherovich, Numerical methods for phase retrieval, 2012,
        https://arxiv.org/abs/1203.4756
    [2] J. R. Fienup, Phase retrieval algorithms: a comparison, 1982,
        https://www.osapublishing.org/ao/abstract.cfm?uri=ao-21-15-2758
    [3] https://github.com/cwg45/Image-Reconstruction
    """
    
    assert beta > 0, 'step size must be a positive number'
    assert steps > 0, 'steps must be a positive number'
    assert mode == 'input-output' or mode == 'output-output'\
        or mode == 'hybrid',\
    'mode must be \'input-output\', \'output-output\' or \'hybrid\''
    
    if mask is None:
        mask = np.ones(mag.shape)
        
    assert mag.shape == mask.shape, 'mask and mag must have same shape'
    
    # sample random phase and initialize image x 
    y_hat = mag*np.exp(1j*2*np.pi*np.random.rand(*mag.shape))
    x = np.zeros(mag.shape)
    
    # previous iterate
    x_p = None
        
    # main loop
    for i in range(1, steps+1):
        # show progress
        if i % 100 == 0 and verbose: 
            print("step", i, "of", steps)
        
        # inverse fourier transform
        y = np.real(np.fft.ifft2(y_hat))
        
        # previous iterate
        if x_p is None:
            x_p = y
        else:
            x_p = x 
        
        # updates for elements that satisfy object domain constraints
        if mode == "output-output" or mode == "hybrid":
            x = y
            
        # find elements that violate object domain constraints 
        # or are not masked
        indices = np.logical_or(np.logical_and(y<0, mask), 
                                np.logical_not(mask))
        
        # updates for elements that violate object domain constraints
        if mode == "hybrid" or mode == "input-output":
            x[indices] = x_p[indices]-beta*y[indices] 
        elif mode == "output-output":
            x[indices] = y[indices]-beta*y[indices] 
        
        # fourier transform
        x_hat = np.fft.fft2(x)
        
        # satisfy fourier domain constraints
        # (replace magnitude with input magnitude)
        y_hat = mag*np.exp(1j*np.angle(x_hat))
    return x

# MNIST

## HIO

In [6]:
N, height, width = x_test_mnist.shape
pad = height//2
x_test_mnist_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_mnist])
print(x_test_mnist_prep.shape)

(10000, 2, 64, 64)


In [7]:
np.random.seed(417)
n_test = n_test_mnist
psnr_list_max = []

for i in tqdm(range(n_test)):
    mag, m = x_test_mnist_prep[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)
    psnr1 = psnr_crop(rec, x_test_mnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_mnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max.append(psnr)
    
psnr_list_max = np.array(psnr_list_max)
print(np.mean(psnr_list_max), np.std(psnr_list_max))

100%|██████████| 10000/10000 [09:09<00:00, 18.21it/s]

10.533217767885693 3.8119663142058





## Gerchberg-Saxton

In [8]:
N, height, width = x_test_mnist.shape
pad = height//2
x_test_mnist_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_mnist])
print(x_test_mnist_prep_u.shape)

(10000, 2, 64, 64)


In [9]:
np.random.seed(417)
n_test = n_test_mnist
psnr_list_max_u = []

for i in tqdm(range(n_test)):
    mag, m = x_test_mnist_prep_u[i]
    rec = fienup_phase_retrieval(mag, steps=100, beta=1, mask=m, mode='output-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_mnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_mnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_u.append(psnr)
    
psnr_list_max_u = np.array(psnr_list_max_u)
print(np.mean(psnr_list_max_u), np.std(psnr_list_max_u))

100%|██████████| 10000/10000 [08:58<00:00, 18.57it/s]

9.817387072364363 2.4410112407637867





## Input-Output

In [10]:
N, height, width = x_test_mnist.shape
pad = height//2
x_test_mnist_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_mnist])
print(x_test_mnist_prep_io.shape)

(10000, 2, 64, 64)


In [11]:
np.random.seed(417)
n_test = n_test_mnist
psnr_list_max_io = []

for i in tqdm(range(n_test)):
    mag, m = x_test_mnist_prep_io[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, mode='input-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_mnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_mnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_io.append(psnr)
    
psnr_list_max_io = np.array(psnr_list_max_io)
print(np.mean(psnr_list_max_io), np.std(psnr_list_max_io))

100%|██████████| 10000/10000 [08:45<00:00, 19.01it/s]

9.799093051755992 1.3515794111192863





# EMNSIT

## HIO

In [12]:
N, height, width = x_test_emnist.shape
pad = height//2
x_test_emnist_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_emnist])
print(x_test_emnist_prep.shape)

(24800, 2, 64, 64)


In [13]:
np.random.seed(417)
n_test = n_test_emnist
psnr_list_max_emnist = []

for i in tqdm(range(n_test)):
    mag, m = x_test_emnist_prep[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)
    psnr1 = psnr_crop(rec, x_test_emnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_emnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_emnist.append(psnr)
    
psnr_list_max_emnist = np.array(psnr_list_max_emnist)
print(np.mean(psnr_list_max_emnist), np.std(psnr_list_max_emnist))

100%|██████████| 24800/24800 [22:44<00:00, 18.18it/s]

10.809815104812078 3.934378529537481





## Gerchberg-Saxton

In [14]:
N, height, width = x_test_emnist.shape
pad = height//2
x_test_emnist_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_emnist])
print(x_test_emnist_prep_u.shape)

(24800, 2, 64, 64)


In [15]:
np.random.seed(417)
n_test = n_test_emnist
psnr_list_max_emnist_u = []

for i in tqdm(range(n_test)):
    mag, m = x_test_emnist_prep_u[i]
    rec = fienup_phase_retrieval(mag, steps=100,beta=1, mask=m, mode='output-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_emnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_emnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_emnist_u.append(psnr)
    
psnr_list_max_emnist_u = np.array(psnr_list_max_emnist_u)
print(np.mean(psnr_list_max_emnist_u), np.std(psnr_list_max_emnist_u))

100%|██████████| 24800/24800 [22:40<00:00, 18.23it/s]

9.985711223304026 2.4161293402906083





## Input-Output

In [16]:
N, height, width = x_test_emnist.shape
pad = height//2
x_test_emnist_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_emnist])
print(x_test_emnist_prep_io.shape)

(24800, 2, 64, 64)


In [17]:
np.random.seed(417)
n_test = n_test_emnist
psnr_list_max_emnist_io = []

for i in tqdm(range(n_test)):
    mag, m = x_test_emnist_prep_io[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, mode='input-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_emnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_emnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_emnist_io.append(psnr)
    
psnr_list_max_emnist_io = np.array(psnr_list_max_emnist_io)
print(np.mean(psnr_list_max_emnist_io), np.std(psnr_list_max_emnist_io))

100%|██████████| 24800/24800 [22:03<00:00, 18.73it/s]

9.852396506159906 1.4600540026659212





# FMNIST

## HIO

In [18]:
N, height, width = x_test_fmnist.shape
pad = height//2
x_test_fmnist_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_fmnist])
print(x_test_fmnist_prep.shape)

(10000, 2, 64, 64)


In [19]:
np.random.seed(417)
n_test = n_test_fmnist
psnr_list_max_fmnist = []

for i in tqdm(range(n_test)):
    mag, m = x_test_fmnist_prep[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)
    psnr1 = psnr_crop(rec, x_test_fmnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_fmnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_fmnist.append(psnr)
    
psnr_list_max_fmnist = np.array(psnr_list_max_fmnist)
print(np.mean(psnr_list_max_fmnist), np.std(psnr_list_max_fmnist))

100%|██████████| 10000/10000 [09:14<00:00, 18.02it/s]

14.061361823805145 8.536005081683598





## Gerchberg-Saxton

In [20]:
N, height, width = x_test_fmnist.shape
pad = height//2
x_test_fmnist_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_fmnist])
print(x_test_fmnist_prep_u.shape)

(10000, 2, 64, 64)


In [21]:
np.random.seed(417)
n_test = n_test_fmnist
psnr_list_max_fmnist_u = []

for i in tqdm(range(n_test)):
    mag, m = x_test_fmnist_prep_u[i]
    rec = fienup_phase_retrieval(mag, steps=100, beta=1, mask=m, mode='output-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_fmnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_fmnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_fmnist_u.append(psnr)
    
psnr_list_max_fmnist = np.array(psnr_list_max_fmnist_u)
print(np.mean(psnr_list_max_fmnist_u), np.std(psnr_list_max_fmnist_u))

100%|██████████| 10000/10000 [09:06<00:00, 18.30it/s]

11.245463183268203 3.632738956327884





## Input-Output

In [22]:
N, height, width = x_test_fmnist.shape
pad = height//2
x_test_fmnist_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_fmnist])
print(x_test_fmnist_prep_io.shape)

(10000, 2, 64, 64)


In [23]:
np.random.seed(417)
n_test = n_test_fmnist
psnr_list_max_fmnist_io = []

for i in tqdm(range(n_test)):
    mag, m = x_test_fmnist_prep_io[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, mode='input-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_fmnist[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_fmnist[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_fmnist_io.append(psnr)
    
psnr_list_max_fmnist_io = np.array(psnr_list_max_fmnist_io)
print(np.mean(psnr_list_max_fmnist_io), np.std(psnr_list_max_fmnist_io))

100%|██████████| 10000/10000 [08:56<00:00, 18.63it/s]

8.737355971097722 2.625021556657955





# SVHN

## HIO

In [24]:
N, height, width = x_test_svhn.shape
pad = height//2
x_test_svhn_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_svhn])
print(x_test_svhn_prep.shape)

(26032, 2, 64, 64)


In [36]:
np.random.seed(417)
n_test = n_test_svhn
psnr_list_max_svhn = []

for i in tqdm(range(n_test)):
    mag, m = x_test_svhn_prep[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)
    psnr1 = psnr_crop(rec, x_test_svhn[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_svhn[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_svhn.append(psnr)
    
psnr_list_max_svhn = np.array(psnr_list_max_svhn)
print(np.mean(psnr_list_max_svhn), np.std(psnr_list_max_svhn))

100%|██████████| 26032/26032 [23:41<00:00, 18.32it/s]

31.89992718180047 16.449739338830508





## Gerchberg-Saxton

In [37]:
N, height, width = x_test_svhn.shape
pad = height//2
x_test_svhn_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_svhn])
print(x_test_svhn_prep_u.shape)

(26032, 2, 64, 64)


In [38]:
np.random.seed(417)
n_test = n_test_svhn
psnr_list_max_svhn_u = []

for i in tqdm(range(n_test)):
    mag, m = x_test_svhn_prep_u[i]
    rec = fienup_phase_retrieval(mag, steps=100,beta=1, mask=m,mode='output-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_svhn[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_svhn[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_svhn_u.append(psnr)
    
psnr_list_max_svhn_u = np.array(psnr_list_max_svhn_u)
print(np.mean(psnr_list_max_svhn_u), np.std(psnr_list_max_svhn_u))

100%|██████████| 26032/26032 [23:14<00:00, 18.66it/s]

17.888926745678102 3.77238934951962





## Input-Output

In [39]:
N, height, width = x_test_svhn.shape
pad = height//2
x_test_svhn_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_svhn])
print(x_test_svhn_prep_io.shape)

(26032, 2, 64, 64)


In [40]:
np.random.seed(417)
n_test = n_test_svhn
psnr_list_max_svhn_io = []

for i in tqdm(range(n_test)):
    mag, m = x_test_svhn_prep_io[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m,mode='input-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_svhn[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_svhn[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_svhn_io.append(psnr)
    
psnr_list_max_svhn_io = np.array(psnr_list_max_svhn_io)
print(np.mean(psnr_list_max_svhn_io), np.std(psnr_list_max_svhn_io))

100%|██████████| 26032/26032 [22:39<00:00, 19.15it/s]

6.680730588725579 1.8482517180509195





# CIFAR

## HIO

In [41]:
N, height, width = x_test_cifar.shape
pad = height//2
x_test_cifar_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_cifar])
print(x_test_cifar_prep.shape)

(10000, 2, 64, 64)


In [42]:
np.random.seed(417)
n_test = n_test_cifar
psnr_list_max_cifar = []

for i in tqdm(range(n_test)):
    mag, m = x_test_cifar_prep[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)
    psnr1 = psnr_crop(rec, x_test_cifar[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_cifar[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_cifar.append(psnr)
    
psnr_list_max_cifar = np.array(psnr_list_max_cifar)
print(np.mean(psnr_list_max_cifar), np.std(psnr_list_max_cifar))

100%|██████████| 10000/10000 [08:53<00:00, 18.75it/s]

28.337736982769762 13.922134641919158





## Gerchberg-Saxton

In [43]:
N, height, width = x_test_cifar.shape
pad = height//2
x_test_cifar_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_cifar])
print(x_test_cifar_prep_u.shape)

(10000, 2, 64, 64)


In [44]:
np.random.seed(417)
n_test = n_test_cifar
psnr_list_max_cifar_u = []

for i in tqdm(range(n_test)):
    mag, m = x_test_cifar_prep_u[i]
    rec = fienup_phase_retrieval(mag, steps=100, beta=1, mask=m, mode='output-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_cifar[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_cifar[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_cifar_u.append(psnr)
    
psnr_list_max_cifar_u = np.array(psnr_list_max_cifar_u)
print(np.mean(psnr_list_max_cifar_u), np.std(psnr_list_max_cifar_u))

100%|██████████| 10000/10000 [08:48<00:00, 18.91it/s]

16.384498253497863 3.081373773670048





## Input-Output

In [45]:
N, height, width = x_test_cifar.shape
pad = height//2
x_test_cifar_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_cifar])
print(x_test_cifar_prep_io.shape)

(10000, 2, 64, 64)


In [46]:
np.random.seed(417)
n_test = n_test_cifar
psnr_list_max_cifar_io = []

for i in tqdm(range(n_test)):
    mag, m = x_test_cifar_prep_io[i]
    rec = fienup_phase_retrieval(mag, steps=100, mask=m, mode='input-output', verbose=False)
    psnr1 = psnr_crop(rec, x_test_cifar[i], pad)
    psnr2 = psnr_register_flip(rec, x_test_cifar[i], pad)
    psnr = max(psnr1, psnr2)
    psnr_list_max_cifar_io.append(psnr)
    
psnr_list_max_cifar_io = np.array(psnr_list_max_cifar_io)
print(np.mean(psnr_list_max_cifar_io), np.std(psnr_list_max_cifar_io))

100%|██████████| 10000/10000 [08:44<00:00, 19.06it/s]

7.794638984232505 1.7285389535306375



