In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import fft2, ifft2, fftshift

def create_ground_truth(size=256):
    image = np.zeros((size, size))
    cx, cy = size//2, size//2

    y, x = np.ogrid[-cx:size-cx, -cy:size-cy]
    circle = x**2 + y**2 <= (size*0.2)**2
    image[circle] = 0.8

    image[cx-30:cx+30, cy-50:cy+50] = 0.5

    gaussian = np.exp(-(x**2 + y**2)/(0.2*size))
    image += 0.3 * gaussian

    image = np.clip(image, 0, 1)

    support = np.zeros_like(image)
    support[cx-80:cx+80, cy-100:cy+100] = 1

    return image, support

def create_measurement(image, sampling_rate=0.65):
    F_true = fft2(image)
    M_true = np.abs(F_true)

    K = np.zeros_like(image, dtype=bool)
    h, w = image.shape
    radius = int(np.sqrt(sampling_rate) * h//2)
    y, x = np.ogrid[-h//2:h//2, -w//2:w//2]
    mask = x**2 + y**2 <= radius**2
    K[fftshift(mask)] = True

    noise = 0.02 * np.random.randn(*M_true.shape) * M_true.mean()
    M_measured = np.abs(M_true + noise)
    M_measured[M_measured < 0] = 0

    return M_measured, K

def reconstruct(M, K, support, max_iter=300, beta=0.85):
    image_est = np.random.rand(*M.shape) * support
    errors = []

    for i in range(max_iter):
        F_est = fft2(image_est)
        phase = np.angle(F_est)
        F_est = M * np.exp(1j * phase)
        F_est[K] = M[K] * np.exp(1j * phase[K])

        image_new = np.real(ifft2(F_est))

        if i < max_iter * 0.7:
            mask = (image_new < 0) | (support == 0)
            image_est = image_new * (~mask) + (beta * image_est - (1-beta) * image_new) * mask
        else:
            image_est = np.clip(image_new, 0, None) * support

        known_mags = np.abs(fft2(image_est))[K]
        error = np.mean(np.abs(known_mags - M[K])) / np.mean(M[K])
        errors.append(error)

        if i % 50 == 0:
            print(f'Iter {i}, Error: {error:.6f}')

    return image_est, errors

image_gt, support = create_ground_truth(256)
M_measured, K = create_measurement(image_gt)
image_rec, errors = reconstruct(M_measured, K, support)

plt.figure(figsize=(15, 10))

plt.subplot(2, 3, 1)
plt.imshow(image_gt, cmap='gray')

plt.subplot(2, 3, 2)
plt.imshow(fftshift(K), cmap='gray')

plt.subplot(2, 3, 3)
plt.imshow(image_rec, cmap='gray')

plt.subplot(2, 3, 4)
plt.imshow(np.log(1 + fftshift(np.abs(fft2(image_gt)))), cmap='viridis')

plt.subplot(2, 3, 5)
plt.imshow(np.log(1 + fftshift(np.abs(fft2(image_rec)))), cmap='viridis')

plt.subplot(2, 3, 6)
plt.plot(errors)

plt.tight_layout()
plt.show()