In [1]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, List
from shapely.geometry import Point, LineString
from ipywidgets import interact, FloatSlider, Label
from IPython.display import display, HTML

In [2]:
def snell(theta1: float, n1: float, n2: float, unit: Optional[str] = 'rad') -> float:
    ''' Snell's law

    Parameters
    ----------

    theta1: float
        Angle of incidence, default unit is radian
    n1: float
        Refractive index of medium 1
    n2: float
        Refractive index of medium 2
    unit: str, optional
        Unit of theta1 and theta2, either 'rad' or 'deg'
    
    Returns
    -------

    theta2: float
        Angle of refraction, default unit is radian
    '''

    if unit == 'deg':
        theta1 = np.deg2rad(theta1)

    theta2 = np.arcsin(n1*np.sin(theta1)/n2)

    if unit == 'deg':
        return np.rad2deg(theta2)
    
    else:
        return theta2
    

def normalize_line(line, desired_length):
    # Calculate the current length of the line
    current_length = line.length
    
    # Calculate the scaling factor
    scale_factor = desired_length / current_length
    
    # Scale the line using the centroid as the anchor point
    centroid = line.centroid
    scaled_line = LineString([((x - centroid.x) * scale_factor + centroid.x, 
                               (y - centroid.y) * scale_factor + centroid.y)
                              for x, y in line.coords])
    
    return scaled_line
    

def get_incidence_ray(theta1: float) -> LineString:
    ''' Get the incidence ray

    Parameters
    ----------

    theta1: float
        Angle of incidence in radian
    
    Returns
    -------

    incidence_ray: LineString  
        The incidence ray
    '''
    x1 = 0
    y1 = 0
    y2 = 10
    x2 = x1 - np.tan(theta1)*(y2-y1)

    _line = LineString([Point(x2, y2), Point(x1, y1)])
    y2 /= _line.length
    x2 /= _line.length
    y2 *= 10
    x2 *= 10

    return LineString([Point(x2, y2), Point(x1, y1)])

def get_refraction_ray(theta2: float) -> LineString:
    ''' Get the refraction ray

    Parameters
    ----------

    theta1: float
        Angle of incidence in radian
    theta2: float
        Angle of refraction in radian
    
    Returns
    -------

    refraction_ray: LineString  
        The refraction ray
    '''
    x1 = 0
    y1 = 0
    y2 = -10
    x2 = x1 + np.tan(theta2)*(y2-y1)

    _line = LineString([Point(x2, y2), Point(x1, y1)])
    y2 /= _line.length
    x2 /= _line.length
    y2 *= 10
    x2 *= 10

    return LineString([Point(-x2, y2), Point(x1, y1)])

def get_wavefronts(ray: LineString, n: float, wavelength: Optional[float] = 1) -> List[LineString]:
    wavefronts = []
    wavefront_anchors = np.arange(0, ray.length + 1, wavelength / n)
    wavefront_anchors /= ray.length

    for anchor in wavefront_anchors:
        midpoint = ray.interpolate(anchor, normalized=True)

        perpendicular_point1 = Point(midpoint.x + (ray.coords[1][1] - ray.coords[0][1]),
                                midpoint.y - (ray.coords[1][0] - ray.coords[0][0]))
    
        perpendicular_point2 = Point(midpoint.x - (ray.coords[1][1] - ray.coords[0][1]),
                                midpoint.y + (ray.coords[1][0] - ray.coords[0][0]))


        perpendicular_line = LineString([perpendicular_point1, midpoint, perpendicular_point2])
        perpendicular_line = normalize_line(perpendicular_line, 1)

        wavefronts.append(perpendicular_line)

    return wavefronts

In [3]:
theta1_initial = 0
wavelength_inc = 1
n1 = 1.0
n2_initial = 1.5
theta2 = snell(np.deg2rad(theta1_initial), n1, n2_initial)

n2_slider = FloatSlider(value=n2_initial, min=0, max=10, step=0.1, description=r'n_2:')
theta1_slider = FloatSlider(value=theta1_initial, min=0, max=90, step=1, description='Theta1 [deg]:')


@interact(theta1=theta1_slider, n2=n2_slider)
def update_plot(theta1, n2):
    theta1_rad = np.deg2rad(theta1)
    theta2 = snell(theta1_rad, n1, n2)  # Replace with your refraction calculation
    
    fig, ax = plt.subplots(figsize=(8, 6))

    # Plot incident and refracted rays
    incident_ray = get_incidence_ray(theta1_rad)
    ax.plot(*incident_ray.xy, color='red', linewidth=3, label=r'Incident ray')

    incident_ray_wavefronts = get_wavefronts(incident_ray, n1, wavelength_inc)
    for wavefront in incident_ray_wavefronts:
        ax.plot(*wavefront.xy, color='red', linewidth=1)
    
    refracted_ray = get_refraction_ray(theta2)
    ax.plot(*refracted_ray.xy, color='blue', linewidth=3, label=r'Refracted ray')

    refracted_ray_wavefronts = get_wavefronts(refracted_ray, n2, wavelength_inc)
    for wavefront in refracted_ray_wavefronts:
        ax.plot(*wavefront.xy, color='blue', linewidth=1)

    reflected_ray = LineString([Point(-incident_ray.coords[0][0], incident_ray.coords[0][1]), Point(-incident_ray.coords[1][0], incident_ray.coords[1][1])])
    ax.plot(*reflected_ray.xy, color='gold', linewidth=3, label=r'Reflected ray')

    reflected_ray_wavefronts = get_wavefronts(reflected_ray, n1, wavelength_inc)
    for wavefront in reflected_ray_wavefronts:
        ax.plot(*wavefront.xy, color='gold', linewidth=1)

    
    # Draw the medium
    medium2 = plt.Rectangle((-5, 0), 20, -10, fc='b', alpha=0.2)
    ax.add_patch(medium2)

    # Customize the plot
    plt.xlim(-4, 4)
    plt.ylim(-4, 4)
    ax.spines['left'].set_position('center')
    ax.spines['bottom'].set_position('center')
    ax.spines['right'].set_color('none')
    ax.spines['top'].set_color('none')
    ax.xaxis.set_visible(False)
    ax.yaxis.set_ticks([])
    ax.set_aspect('equal')

    ax.text(0.1, 0.6, fr'n$_1$ = {n1}' + '\n' + r'$\lambda_1$ = {:.2f}m'.format(wavelength_inc)  + '\n' + r'$\theta_1$ = {:.1f}°'.format(theta1), ha='center', va='center', transform=ax.transAxes, fontsize=10)
    ax.text(0.1, 0.4, fr'n$_2$ = {n2}' + '\n' + r'$\lambda_2$ = {:.2f}m'.format(wavelength_inc / n2)  + '\n' + r'$\theta_2$ = {:.1f}°'.format(np.rad2deg(theta2)), ha='center', va='center', transform=ax.transAxes, fontsize=10)

    plt.legend()
    plt.tight_layout()


interactive(children=(FloatSlider(value=0.0, description='Theta1 [deg]:', max=90.0, step=1.0), FloatSlider(val…