# Fourier Synthesis 

The purpose of this script is to showcase image reconstruction via Fourier synthesis. More information can be found at: 
https://thepythoncodingbook.com/2021/08/30/2d-fourier-transform-in-python-and-fourier-synthesis-of-images/

In [None]:
# LIBRARIES
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Importing image
image_filename = "____" # Change filename here


# FUNCTIONS
# Fourier transform function
def calculate_2dft(input):
    ft = np.fft.ifftshift(input)
    ft = np.fft.fft2(ft)
    return np.fft.fftshift(ft)

# Inverse Fourier transform function
def calculate_2dift(input):
    ift = np.fft.ifftshift(input) # Centers the image
    ift = np.fft.ifft2(ift) # Fourier transform
    ift = np.fft.fftshift(ift) # Centers the Fourier transform image
    return ift.real

# Calculating distance from center
def calculate_distance_from_centre(coords, centre):
    # Distance from center is √(x^2 + y^2)
    return np.sqrt((coords[0] - centre) ** 2 + (coords[1] - centre) ** 2)

# Finding symmetric coordinates
def find_symmetric_coordinates(coords, centre):
    return (centre + (centre - coords[0]),
            centre + (centre - coords[1]))

# Plotting parameters
def display_plots(individual_grating, reconstruction, idx):
    plt.subplot(121)
    plt.imshow(individual_grating)
    plt.axis("off")
    plt.subplot(122)
    plt.imshow(reconstruction)
    plt.axis("off")
    plt.suptitle(f"Terms: {idx}")
    plt.pause(0.01)

    
# IMAGE PROCESSING    
# Converting image to array
image = plt.imread(image_filename)


# Convert to grayscale
image = image[:, :, :3].mean(axis=2)  


# Array dimensions (array is square) and centre pixel
# Use smallest of the dimensions and ensure it's odd
array_size = min(image.shape) - 1 + min(image.shape) % 2


# Crop image so it's a square image
image = image[:array_size, :array_size]
centre = int((array_size - 1) / 2)


# Get all coordinate pairs in the left half of the array, including the column at the centre of the array (which
# includes the centre pixel)
coords_left_half = ((x, y) for x in range(array_size) for y in range(centre+1))


# Sort points based on distance from center
coords_left_half = sorted(coords_left_half,
    key=lambda x: calculate_distance_from_centre(x, centre))


# Setting color map to grayscale
plt.set_cmap("gray")


# Fourier transform of image
ft = calculate_2dft(image)


# Show grayscale image and its Fourier transform
plt.subplot(121)
plt.imshow(image)
plt.axis("off")
plt.subplot(122)
plt.imshow(np.log(abs(ft)))
plt.axis("off")
plt.pause(2)


# Reconstructing image
fig = plt.figure()

# Step 1
# Set up empty arrays for final image and individual gratings
rec_image = np.zeros(image.shape)
individual_grating = np.zeros(image.shape, dtype="complex")
idx = 0

# All steps are displayed until display_all_until value
display_all_until = 20
# After this, skip which steps to display using the display_step value
display_step = 10
# Work out index of next step to display
next_display = display_all_until + display_step

# Step 2
for coords in coords_left_half:
    # Central column: only include if points in top half of the central column
    if not (coords[1] == centre and coords[0] > centre):
        idx += 1
        symm_coords = find_symmetric_coordinates(coords, centre)
        # Step 3
        # Copy values from Fourier transform into individual_grating for the pair of points in current iteration
        individual_grating[coords] = ft[coords]
        individual_grating[symm_coords] = ft[symm_coords]

        # Step 4
        # Calculate inverse Fourier transform to give the reconstructed grating. Add this reconstructed
        # grating to the reconstructed image
        rec_grating = calculate_2dift(individual_grating)
        rec_image += rec_grating

        # Clear individual_grating array, ready for next iteration
        individual_grating[coords] = 0
        individual_grating[symm_coords] = 0

        # Don't display every step
        if idx < display_all_until or idx == next_display:
            if idx > display_all_until:
                next_display += display_step
                # Accelerate animation the further the iteration runs by increasing display_step
                display_step += 10
            display_plots(rec_grating, rec_image, idx)
            
# Display iamges
plt.show()