<a href="https://colab.research.google.com/github/bforsbe/SK2534/blob/main/Diffraction_sim.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

x_range = 1
z_range = 5
density = 100

x = np.linspace(0, x_range, density)
z = np.linspace(0, z_range, density)
X, Z = np.meshgrid(x, z)

# 1. Define wave sources
num_sources = 3
source_z = np.zeros(num_sources)
source_x = np.linspace(0, x_range, num_sources)
frequency = 2 # Assuming same frequency for all sources
wavelength = 1 / frequency # Assuming wave speed is 1
epsilon = 1e-9 # Small value to prevent division by zero

# 2. Calculate wave amplitude
x = np.linspace(0, x_range, x_range*density)
z = np.linspace(0, z_range, z_range*density)
X, Z = np.meshgrid(x, z)

total_amplitude = np.zeros_like(X, dtype=complex)

for i in range(num_sources):
    distance = np.sqrt((X - source_x[i])**2 + (Z - source_z[i])**2)
    # Add epsilon to distance to avoid division by zero
    amplitude = np.exp(1j * 2 * np.pi * distance / wavelength) / (distance + epsilon) # Spherical wave decay
    total_amplitude += amplitude

# 3. Normalize and color the phase
phase = np.angle(total_amplitude)
normalized_phase = (phase + np.pi) / (2 * np.pi) # Normalize phase to [0, 1]

# 4. Visualize the wave summation
plt.figure(figsize=(8, 6))
plt.imshow(np.abs(total_amplitude), extent=[x.min(), x.max(), z.min(), z.max()], origin='lower', aspect='auto', cmap='viridis')
plt.imshow(normalized_phase, extent=[x.min(), x.max(), z.min(), z.max()], origin='lower', aspect='auto', cmap='twilight', alpha=0.7)


# 5. Add labels and title
plt.xlabel('x')
plt.ylabel('z')
plt.title('Wave Summation with Phase Coloring')

# 6. Display the plot
plt.colorbar(label='Amplitude')
plt.show()

In [23]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, FloatSlider, Dropdown

def binary_filter(x_coords, center, width):
    """Creates a binary filter (0 or 1) over a range."""
    return np.where((x_coords >= center - width/2) & (x_coords <= center + width/2), 1, 0)

def single_slit(x_coords, width):
    """Creates a single slit binary filter centered in the domain."""
    x_range = np.max(x_coords) - np.min(x_coords)
    center = np.min(x_coords) + x_range / 2
    return binary_filter(x_coords, center, width)

def double_slit(x_coords, width, separation):
    """Creates a double slit binary filter with a given separation and width."""
    filter_values = np.zeros_like(x_coords)
    # Center the double slit around the middle of the domain
    x_range = np.max(x_coords) - np.min(x_coords)
    center = np.min(x_coords) + x_range / 2
    filter_values += binary_filter(x_coords, center - separation/2, width)
    filter_values += binary_filter(x_coords, center + separation/2, width)
    # Ensure values are not greater than 1 where slits overlap (though with typical slit parameters, overlap is unlikely)
    return np.clip(filter_values, 0, 1)


def plot_wave_summation(num_sources, wavelength, filter_type, slit_width, slit_separation):
    x_range = 1
    z_range = 5
    density = 100

    # 1. Define wave sources
    source_x = np.linspace(0, x_range, num_sources)
    source_z = np.zeros(num_sources)
    epsilon = 1e-9 # Small value to prevent division by zero

    # Apply the selected filter function to the source amplitudes
    if filter_type == 'None':
        source_amplitudes = np.ones_like(source_x)
        filter_values_plot = np.ones_like(np.linspace(0, x_range, 200)) # For plotting the filter
    elif filter_type == 'Single Slit':
        source_amplitudes = single_slit(source_x, slit_width)
        filter_values_plot = single_slit(np.linspace(0, x_range, 200), slit_width)
    elif filter_type == 'Double Slit':
        source_amplitudes = double_slit(source_x, slit_width, slit_separation)
        filter_values_plot = double_slit(np.linspace(0, x_range, 200), slit_width, slit_separation)
    else:
        source_amplitudes = np.ones_like(source_x)
        filter_values_plot = np.ones_like(np.linspace(0, x_range, 200))

    # Find indices of non-zero source amplitudes
    active_source_indices = np.where(source_amplitudes > 0)[0]

    # 2. Calculate wave amplitude
    x = np.linspace(0, x_range, x_range * density)
    z = np.linspace(0, z_range, z_range * density)
    X, Z = np.meshgrid(x, z)

    total_amplitude = np.zeros_like(X, dtype=complex)

    # Iterate only over active sources
    for i in active_source_indices:
        distance = np.sqrt((X - source_x[i])**2 + (Z - source_z[i])**2)
        # Add epsilon to distance to avoid division by zero and apply the filter
        amplitude = source_amplitudes[i] * np.exp(1j * 2 * np.pi * distance / wavelength) / (distance + epsilon) # Spherical wave decay with modulation
        total_amplitude += amplitude

    # 3. Normalize and color the phase
    phase = np.angle(total_amplitude)
    normalized_phase = (phase + np.pi) / (2 * np.pi) # Normalize phase to [0, 1]

    # Calculate max squared amplitude at the furthest z-value
    max_amplitude_at_furthest_z = 1.0 #np.max(np.abs(total_amplitude[-1, :])**2)


    # 4. Visualize the wave summation with two main subplots and two smaller subplots for the filter
    fig, axes = plt.subplots(2, 2, figsize=(12, 9), gridspec_kw={'height_ratios': [4, 1]})
    ax1 = axes[0, 0]
    ax2 = axes[0, 1]
    ax3 = axes[1, 0]
    ax4 = axes[1, 1]


    # Subplot 1 (Top Left): Phase
    im1 = ax1.imshow(normalized_phase, extent=[x.min(), x.max(), z.min(), z.max()], origin='lower', aspect='auto', cmap='twilight')
    # Plot sources with alpha modulated by filter
    ax1.scatter(source_x, source_z, color='black', s=50, alpha=source_amplitudes, label='Sources')
    ax1.set_xlabel('x')
    ax1.set_ylabel('z')
    ax1.set_title('Wave Phase (Modulated Emitters)')
    fig.colorbar(im1, ax=ax1, label='Normalized Phase')
    ax1.legend()


    # Subplot 2 (Top Right): Squared Amplitude
    im2 = ax2.imshow(np.abs(total_amplitude)**2, extent=[x.min(), x.max(), z.min(), z.max()], origin='lower', aspect='auto', cmap='Blues', vmax=max_amplitude_at_furthest_z) # Set vmax
    # Plot sources with alpha modulated by filter
    ax2.scatter(source_x, source_z, color='black', s=50, alpha=source_amplitudes, label='Sources')
    ax2.set_xlabel('x')
    ax2.set_ylabel('z')
    ax2.set_title('Wave Squared Amplitude (Modulated Emitters)')
    fig.colorbar(im2, ax=ax2, label='Squared Amplitude')
    ax2.legend()


    # Subplot 3 (Bottom Left): Filter Function
    x_filter = np.linspace(0, x_range, 200)
    ax3.plot(x_filter, filter_values_plot)
    ax3.set_xlabel('x')
    ax3.set_ylabel('Filter Value')
    ax3.set_title('Amplitude Modulation Filter')
    ax3.set_ylim(0, 1)

    # Subplot 4 (Bottom Right): Filter Function (repeated for symmetry)
    ax4.plot(x_filter, filter_values_plot)
    ax4.set_xlabel('x')
    ax4.set_ylabel('Filter Value')
    ax4.set_title('Amplitude Modulation Filter')
    ax4.set_ylim(0, 1)


    plt.tight_layout()
    plt.show()

# Create interactive widgets
interact(plot_wave_summation,
         num_sources=IntSlider(min=20, max=200, step=10, value=20, description='Number of Sources:'),
         wavelength=FloatSlider(min=0.01, max=0.2, step=0.01, value=0.1, description='Wavelength:'),
         filter_type=Dropdown(options=['None', 'Single Slit', 'Double Slit'], value='Double Slit', description='Filter Type:'),
         slit_width=FloatSlider(min=0.05, max=1.0, step=0.05, value=0.1, description='Slit Width:'), # Enabled
         slit_separation=FloatSlider(min=0.1, max=0.5, step=0.05, value=0.3, description='Slit Separation:') # Enabled
        );

interactive(children=(IntSlider(value=20, description='Number of Sources:', max=200, min=20, step=10), FloatSl…

In [43]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, FloatSlider, Dropdown

def binary_filter(x_coords, center, width):
    """Creates a binary filter (0 or 1) over a range."""
    return np.where((x_coords >= center - width/2) & (x_coords <= center + width/2), 1, 0)

def single_slit(x_coords, width):
    """Creates a single slit binary filter centered in the domain."""
    x_range = np.max(x_coords) - np.min(x_coords)
    center = np.min(x_coords) + x_range / 2
    return binary_filter(x_coords, center, width)

def double_slit(x_coords, width, separation):
    """Creates a double slit binary filter with a given separation and width."""
    filter_values = np.zeros_like(x_coords)
    # Center the double slit around the middle of the domain
    x_range = np.max(x_coords) - np.min(x_coords)
    center = np.min(x_coords) + x_range / 2
    filter_values += binary_filter(x_coords, center - separation/2, width)
    filter_values += binary_filter(x_coords, center + separation/2, width)
    # Ensure values are not greater than 1 where slits overlap (though with typical slit parameters, overlap is unlikely)
    return np.clip(filter_values, 0, 1)


def plot_wave_summation(num_sources, wavelength, filter_type, slit_width, slit_separation):
    x_range = 1
    z_range = 5
    density = 100

    # 1. Define wave sources
    source_x = np.linspace(0, x_range, num_sources)
    source_z = np.zeros(num_sources)
    epsilon = 1e-9 # Small value to prevent division by zero

    # Apply the selected filter function to the source amplitudes
    if filter_type == 'None':
        source_amplitudes = np.ones_like(source_x)
        filter_values_plot = np.ones_like(np.linspace(0, x_range, 200)) # For plotting the filter
    elif filter_type == 'Single Slit':
        source_amplitudes = single_slit(source_x, slit_width)
        filter_values_plot = single_slit(np.linspace(0, x_range, 200), slit_width)
    elif filter_type == 'Double Slit':
        source_amplitudes = double_slit(source_x, slit_width, slit_separation)
        filter_values_plot = double_slit(np.linspace(0, x_range, 200), slit_width, slit_separation)
    else:
        source_amplitudes = np.ones_like(source_x)
        filter_values_plot = np.ones_like(np.linspace(0, x_range, 200))

    # Find indices of non-zero source amplitudes
    active_source_indices = np.where(source_amplitudes > 0)[0]

    # 2. Calculate wave amplitude
    x = np.linspace(0, x_range, x_range * density)
    z = np.linspace(0, z_range, z_range * density)
    X, Z = np.meshgrid(x, z)

    total_amplitude = np.zeros_like(X, dtype=complex)

    # Iterate only over active sources
    for i in active_source_indices:
        distance = np.sqrt((X - source_x[i])**2 + (Z - source_z[i])**2)
        # Add epsilon to distance to avoid division by zero and apply the filter
        amplitude = source_amplitudes[i] * np.exp(1j * 2 * np.pi * distance / wavelength) / (distance + epsilon) # Spherical wave decay with modulation
        total_amplitude += amplitude

    # 3. Normalize and color the phase
    phase = np.angle(total_amplitude)
    normalized_phase = (phase + np.pi) / (2 * np.pi) # Normalize phase to [0, 1]

    # Calculate max squared amplitude at the furthest z-value
    max_amplitude_at_furthest_z = np.max(np.abs(total_amplitude[-1, :])**2)


    # 4. Visualize the wave summation with two main subplots and two smaller subplots for the filter
    fig, axes = plt.subplots(2, 2, figsize=(12, 9), gridspec_kw={'height_ratios': [4, 1]})
    ax1 = axes[0, 0]
    ax2 = axes[0, 1]
    ax3 = axes[1, 0]
    ax4 = axes[1, 1]

    # Subplot 1 (Top Left): Phase
    im1 = ax1.imshow(normalized_phase, extent=[x.min(), x.max(), z.min(), z.max()], origin='lower', aspect='auto', cmap='twilight')
    # Plot sources with alpha modulated by filter
    ax1.scatter(source_x, source_z, color='black', s=50, alpha=source_amplitudes, label='Sources')
    ax1.set_xlabel('x')
    ax1.set_ylabel('z')
    ax1.set_title('Wave Phase (Modulated Emitters)')
    fig.colorbar(im1, ax=ax1, label='Normalized Phase')
    ax1.legend()


    # Subplot 2 (Top Right): Squared Amplitude
    im2 = ax2.imshow(np.abs(total_amplitude)**2, extent=[x.min(), x.max(), z.min(), z.max()], origin='lower', aspect='auto', cmap='Blues_r', vmax=max_amplitude_at_furthest_z) # Set vmax and invert colormap
    # Plot sources with alpha modulated by filter
    ax2.scatter(source_x, source_z, color='black', s=50, alpha=source_amplitudes, label='Sources')
    ax2.set_xlabel('x')
    ax2.set_ylabel('z')
    ax2.set_title('Wave Squared Amplitude (Modulated Emitters)')
    fig.colorbar(im2, ax=ax2, label='Squared Amplitude')
    ax2.legend()

    # Get positions of the top subplots to align the bottom ones
    pos1 = ax1.get_position()
    pos2 = ax2.get_position()

    # Subplot 3 (Bottom Left): Filter Function
    x_filter = np.linspace(0, x_range, 200)
    ax3.plot(x_filter, filter_values_plot)
    ax3.set_xlabel('x')
    ax3.set_ylabel('Filter Value')
    ax3.set_title('Amplitude Modulation Filter')
    ax3.set_ylim(0, 1)
    # Set position to match width and left edge of ax1
    ax3.set_position([pos1.x0, axes[1, 0].get_position().y0, pos1.width, axes[1, 0].get_position().height])


    # Subplot 4 (Bottom Right): Filter Function (repeated for symmetry)
    ax4.plot(x_filter, filter_values_plot)
    ax4.set_xlabel('x')
    ax4.set_ylabel('Filter Value')
    ax4.set_title('Amplitude Modulation Filter')
    ax4.set_ylim(0, 1)
    # Set position to match width and left edge of ax2
    ax4.set_position([pos2.x0, axes[1, 1].get_position().y0, pos2.width, axes[1, 1].get_position().height])


    # Removed plt.tight_layout()
    plt.show()

# Create interactive widgets
interact(plot_wave_summation,
         num_sources=IntSlider(min=20, max=200, step=10, value=20, description='Number of Sources:'), # Keep default
         wavelength=FloatSlider(min=0.002, max=0.08, step=0.002, value=0.01, description='Wavelength:'), # Adjusted range, step, and default
         filter_type=Dropdown(options=['None', 'Single Slit', 'Double Slit'], value='Double Slit', description='Filter Type:'), # Keep default
         slit_width=FloatSlider(min=0.05, max=0.3, step=0.05, value=0.1, description='Slit Width:'), # Keep default and enabled
         slit_separation=FloatSlider(min=0.1, max=0.3, step=0.05, value=0.2, description='Slit Separation:') # Keep default and enabled
        );

interactive(children=(IntSlider(value=20, description='Number of Sources:', max=200, min=20, step=10), FloatSl…

In [49]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, FloatSlider, Dropdown

def binary_filter(x_coords, center, width):
    """Creates a binary filter (0 or 1) over a range."""
    return np.where((x_coords >= center - width/2) & (x_coords <= center + width/2), 1, 0)

def single_slit(x_coords, width):
    """Creates a single slit binary filter centered in the domain."""
    x_range = np.max(x_coords) - np.min(x_coords)
    center = np.min(x_coords) + x_range / 2
    return binary_filter(x_coords, center, width)

def double_slit(x_coords, width, separation):
    """Creates a double slit binary filter with a given separation and width."""
    filter_values = np.zeros_like(x_coords)
    # Center the double slit around the middle of the domain
    x_range = np.max(x_coords) - np.min(x_coords)
    center = np.min(x_coords) + x_range / 2
    filter_values += binary_filter(x_coords, center - separation/2, width)
    filter_values += binary_filter(x_coords, center + separation/2, width)
    # Ensure values are not greater than 1 where slits overlap (though with typical slit parameters, overlap is unlikely)
    return np.clip(filter_values, 0, 1)

# Theoretical diffraction pattern for a single slit
def theoretical_single_slit(x, wavelength, slit_width, z):
    """Calculates the theoretical single-slit diffraction intensity."""
    # Shift x so that the center of the diffraction pattern is at x_range/2
    x_shifted = x - (np.max(x) - np.min(x)) / 2
    # Calculate the sine cardinal function argument
    alpha = (np.pi * slit_width / wavelength) * (x_shifted / np.sqrt(x_shifted**2 + z**2))
    # Avoid division by zero at the center
    alpha[alpha == 0] = 1e-9
    # Calculate the intensity (squared amplitude)
    intensity = (np.sin(alpha) / alpha)**2
    return intensity

# Theoretical diffraction pattern for a double slit
def theoretical_double_slit(x, wavelength, slit_width, slit_separation, z):
    """Calculates the theoretical double-slit diffraction intensity."""
    # Single-slit envelope
    single_slit_intensity = theoretical_single_slit(x, wavelength, slit_width, z)
    # Interference term
    # Shift x so that the center of the diffraction pattern is at x_range/2
    x_shifted = x - (np.max(x) - np.min(x)) / 2
    beta = (np.pi * slit_separation / wavelength) * (x_shifted / np.sqrt(x_shifted**2 + z**2))
    interference_term = (np.cos(beta))**2
    # Total intensity
    intensity = single_slit_intensity * interference_term
    return intensity


def plot_wave_summation(num_sources, wavelength, filter_type, slit_width, slit_separation):
    x_range = 4
    z_range = 5
    density = 100

    # 1. Define wave sources
    source_x = np.linspace(0, x_range, num_sources)
    source_z = np.zeros(num_sources)
    epsilon = 1e-9 # Small value to prevent division by zero

    # Apply the selected filter function to the source amplitudes
    if filter_type == 'None':
        source_amplitudes = np.ones_like(source_x)
        filter_values_plot = np.ones_like(np.linspace(0, x_range, 200)) # For plotting the filter
    elif filter_type == 'Single Slit':
        source_amplitudes = single_slit(source_x, slit_width)
        filter_values_plot = single_slit(np.linspace(0, x_range, 200), slit_width)
    elif filter_type == 'Double Slit':
        source_amplitudes = double_slit(source_x, slit_width, slit_separation)
        filter_values_plot = double_slit(np.linspace(0, x_range, 200), slit_width, slit_separation)
    else:
        source_amplitudes = np.ones_like(source_x)
        filter_values_plot = np.ones_like(np.linspace(0, x_range, 200))

    # Find indices of non-zero source amplitudes
    active_source_indices = np.where(source_amplitudes > 0)[0]

    # 2. Calculate wave amplitude
    x = np.linspace(0, x_range, x_range * density)
    z = np.linspace(0, z_range, z_range * density)
    X, Z = np.meshgrid(x, z)

    total_amplitude = np.zeros_like(X, dtype=complex)

    # Iterate only over active sources
    for i in active_source_indices:
        distance = np.sqrt((X - source_x[i])**2 + (Z - source_z[i])**2)
        # Add epsilon to distance to avoid division by zero and apply the filter
        amplitude = source_amplitudes[i] * np.exp(1j * 2 * np.pi * distance / wavelength) / (distance + epsilon) # Spherical wave decay with modulation
        total_amplitude += amplitude

    # 3. Normalize and color the phase
    phase = np.angle(total_amplitude)
    normalized_phase = (phase + np.pi) / (2 * np.pi) # Normalize phase to [0, 1]

    # Calculate max squared amplitude at the furthest z-value
    max_amplitude_at_furthest_z = np.max(np.abs(total_amplitude[-1, :])**2)


    # 4. Visualize the wave summation with two main subplots, two smaller subplots for the filter, and two for the diffraction pattern
    fig, axes = plt.subplots(3, 2, figsize=(12, 12), gridspec_kw={'height_ratios': [4, 1, 1]})
    ax1 = axes[0, 0]
    ax2 = axes[0, 1]
    ax3 = axes[1, 0]
    ax4 = axes[1, 1]
    ax5 = axes[2, 0] # New subplot for phase at last z
    ax6 = axes[2, 1] # New subplot for amplitude at last z


    # Subplot 1 (Top Left): Phase
    im1 = ax1.imshow(normalized_phase, extent=[x.min(), x.max(), z.min(), z.max()], origin='lower', aspect='auto', cmap='twilight')
    # Plot sources with alpha modulated by filter
    ax1.scatter(source_x, source_z, color='black', s=50, alpha=source_amplitudes, label='Sources')
    ax1.set_xlabel('x')
    ax1.set_ylabel('z')
    ax1.set_title('Wave Phase (Modulated Emitters)')
    fig.colorbar(im1, ax=ax1, label='Normalized Phase')
    ax1.legend()


    # Subplot 2 (Top Right): Squared Amplitude
    im2 = ax2.imshow(np.abs(total_amplitude)**2, extent=[x.min(), x.max(), z.min(), z.max()], origin='lower', aspect='auto', cmap='Blues_r', vmax=max_amplitude_at_furthest_z) # Set vmax and invert colormap
    # Plot sources with alpha modulated by filter
    ax2.scatter(source_x, source_z, color='black', s=50, alpha=source_amplitudes, label='Sources')
    ax2.set_xlabel('x')
    ax2.set_ylabel('z')
    ax2.set_title('Wave Squared Amplitude (Modulated Emitters)')
    fig.colorbar(im2, ax=ax2, label='Squared Amplitude')
    ax2.legend()

    # Get positions of the top subplots to align the bottom ones
    pos1 = ax1.get_position()
    pos2 = ax2.get_position()

    # Subplot 3 (Bottom Left): Filter Function
    x_filter = np.linspace(0, x_range, 200)
    ax3.plot(x_filter, filter_values_plot)
    ax3.set_xlabel('x')
    ax3.set_ylabel('Filter Value')
    ax3.set_title('Amplitude Modulation Filter')
    ax3.set_ylim(0, 1)
    # Set position to match width and left edge of ax1
    ax3.set_position([pos1.x0, axes[1, 0].get_position().y0, pos1.width, axes[1, 0].get_position().height])


    # Subplot 4 (Bottom Right): Filter Function (repeated for symmetry)
    ax4.plot(x_filter, filter_values_plot)
    ax4.set_xlabel('x')
    ax4.set_ylabel('Filter Value')
    ax4.set_title('Amplitude Modulation Filter')
    ax4.set_ylim(0, 1)
    # Set position to match width and left edge of ax2
    ax4.set_position([pos2.x0, axes[1, 1].get_position().y0, pos2.width, axes[1, 1].get_position().height])

    # Calculate phase and amplitude at the last z-value for diffraction pattern
    phase_at_last_z = np.angle(total_amplitude[-1, :])
    amplitude_at_last_z = np.abs(total_amplitude[-1, :])**2


    # Subplot 5 (Bottom Left, second row): Phase at last z
    ax5.plot(x, phase_at_last_z, color='purple')
    ax5.set_xlabel('x')
    ax5.set_ylabel('Phase')
    ax5.set_title(f'Phase at z={z_range}')
    ax5.set_xlim([x.min(), x.max()]) # Ensure x-limits match main plots
    ax5.set_position([pos1.x0, axes[2, 0].get_position().y0, pos1.width, axes[2, 0].get_position().height])


    # Subplot 6 (Bottom Right, second row): Amplitude at last z
    ax6.plot(x, amplitude_at_last_z, color='orange', label='Calculated')
    ax6.set_xlabel('x')
    ax6.set_ylabel('Squared Amplitude')
    ax6.set_title(f'Squared Amplitude at z={z_range} (Diffraction Pattern)')
    ax6.set_xlim([x.min(), x.max()]) # Ensure x-limits match main plots
    ax6.set_position([pos2.x0, axes[2, 1].get_position().y0, pos2.width, axes[2, 1].get_position().height])

    # Calculate and plot theoretical diffraction pattern if a filter is applied
    if filter_type == 'Single Slit':
        theoretical_intensity = theoretical_single_slit(x, wavelength, slit_width, z_range)
        # Normalize theoretical intensity to match the calculated amplitude peak
        if np.max(theoretical_intensity) > 1e-9: # Avoid division by zero if theoretical intensity is all zeros
             theoretical_intensity_normalized = theoretical_intensity * (np.max(amplitude_at_last_z) / np.max(theoretical_intensity))
        else:
             theoretical_intensity_normalized = theoretical_intensity

        ax6.plot(x, theoretical_intensity_normalized, color='green', linestyle='--', label='Theoretical')

    elif filter_type == 'Double Slit':
        theoretical_intensity = theoretical_double_slit(x, wavelength, slit_width, slit_separation, z_range)
        # Normalize theoretical intensity to match the calculated amplitude peak
        if np.max(theoretical_intensity) > 1e-9: # Avoid division by zero if theoretical intensity is all zeros
             theoretical_intensity_normalized = theoretical_intensity * (np.max(amplitude_at_last_z) / np.max(theoretical_intensity))
        else:
             theoretical_intensity_normalized = theoretical_intensity
        ax6.plot(x, theoretical_intensity_normalized, color='green', linestyle='--', label='Theoretical')

    # Add legend to the diffraction pattern subplot
    ax6.legend()


    # Removed plt.tight_layout()
    plt.show()

# Create interactive widgets
interact(plot_wave_summation,
         num_sources=IntSlider(min=20, max=400, step=10, value=200, description='Number of Sources:'), # Keep default
         wavelength=FloatSlider(min=0.002, max=0.08, step=0.002, value=0.02, description='Wavelength:', readout_format='.3f'), # Adjusted range, step, default, and readout format
         filter_type=Dropdown(options=['None', 'Single Slit', 'Double Slit'], value='Double Slit', description='Filter Type:'), # Keep default
         slit_width=FloatSlider(min=0.02, max=0.3, step=0.02, value=0.06, description='Slit Width:', readout_format='.3f'), # Keep default and enabled, set readout format
         slit_separation=FloatSlider(min=0.1, max=0.8, step=0.05, value=0.5, description='Slit Separation:', readout_format='.3f') # Keep default and enabled, set readout format
        );

interactive(children=(IntSlider(value=200, description='Number of Sources:', max=400, min=20, step=10), FloatS…