In [6]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft2, ifft2, fftshift, ifftshift

# Parameters
N = 64  # Smaller size of the matrix for simplicity
num_iterations = 10  # Maximum number of iterations for demonstration
save_every_n = 1  # Save only every n-th image
threshold = 1e-1  # Convergence threshold

# Create a Gaussian amplitude profile for A0(x, y)
x = np.linspace(-1, 1, N)
y = np.linspace(-1, 1, N)
X, Y = np.meshgrid(x, y)
sigma = 0.3  # Width of the Gaussian
A0 = np.exp(-(X**2 + Y**2) / (2 * sigma**2))  # Gaussian amplitude profile

# Save the initial amplitude A0
plt.figure()
plt.title("Initial Amplitude Profile A0")
plt.xlabel("X")
plt.ylabel("Y")
plt.imshow(A0, cmap='gray')
plt.colorbar(label="Amplitude")
plt.savefig("initial_amplitude_A0.png")
plt.close()

# Create a target intensity pattern in a triangular shape
target_intensity = np.zeros((N, N))
triangle_coords = [
    (N // 2, N // 4),          # Top spot
    (N // 4, 3 * N // 4),      # Bottom-left spot
    (3 * N // 4, 3 * N // 4)   # Bottom-right spot
]

# Set intensity to 1 at the specified triangle coordinates
for x, y in triangle_coords:
    target_intensity[x, y] = 1

# Save the target intensity pattern
plt.figure()
plt.title("Target Intensity Pattern - Triangular Array of Spots")
plt.xlabel("X")
plt.ylabel("Y")
plt.imshow(target_intensity, cmap='gray')
plt.colorbar(label="Intensity")
plt.savefig("target_intensity_pattern.png")
plt.close()

# Initialize the SLM phase with random values in the range [0, 0.2 * 2π]
phase_slm = np.random.uniform(0, 0.2 * 2 * np.pi, (N, N))
field_slm = A0 * np.exp(1j * phase_slm)

# Gerchberg-Saxton Algorithm Loop with Convergence Check
for i in range(num_iterations):
    # Forward Fourier transform to focal plane
    field_focal = fftshift(fft2(ifftshift(field_slm)))
    amplitude_focal = np.abs(field_focal)
    phase_focal = np.angle(field_focal)

    # Calculate the error between achieved amplitude and target amplitude
    error = np.mean((amplitude_focal - np.sqrt(target_intensity)) ** 2)
    print(f"Iteration {i+1}: Error = {error:.6e}")
    
    # Check if the error is below the convergence threshold
    current_threshold = threshold / (i + 1)
    if error < current_threshold:
        print(f"Converged after {i+1} iterations with error {error:.6e}")
        break

    # Save the achieved focal plane amplitude and target sqrt(intensity) comparison only every n-th iteration
    if i % save_every_n == 0:
        plt.figure(figsize=(12, 5))
        
        # Achieved Focal Plane Amplitude plot
        plt.subplot(1, 2, 1)
        plt.title(f"Achieved Focal Plane Amplitude - Iteration {i+1}")
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.imshow(amplitude_focal, cmap='gray')
        plt.colorbar(label="Amplitude")
        
        # Target Amplitude plot
        plt.subplot(1, 2, 2)
        plt.title("Square Root of Target Intensity")
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.imshow(np.sqrt(target_intensity), cmap='gray')
        plt.colorbar(label="Target Amplitude")
        
        # Add MSE as annotation on the figure
        plt.figtext(0.5, 0.01, f"MSE: {error:.6e}", ha="center", fontsize=12)

        # Save the figure for the current iteration
        plt.savefig(f"focal_plane_comparison_iteration_{i+1}.png")
        plt.close()

    # Impose target intensity in the focal plane by updating only the amplitude
    field_focal = np.sqrt(target_intensity) * np.exp(1j * phase_focal)

    # Inverse Fourier transform back to the SLM plane
    field_slm = ifftshift(ifft2(fftshift(field_focal)))
    
    # Update the phase in the SLM plane, keeping A0(x, y) as amplitude
    phase_slm = np.angle(field_slm)
    field_slm = A0 * np.exp(1j * phase_slm)
else:
    print("Reached maximum iterations without full convergence.")


Iteration 1: Error = 2.805466e+02
Iteration 2: Error = 2.800895e+02
Iteration 3: Error = 2.800794e+02
Iteration 4: Error = 2.800782e+02
Iteration 5: Error = 2.800780e+02
Iteration 6: Error = 2.800780e+02
Iteration 7: Error = 2.800780e+02
Iteration 8: Error = 2.800780e+02
Iteration 9: Error = 2.800780e+02
Iteration 10: Error = 2.800780e+02
Reached maximum iterations without full convergence.
