In [None]:
# For Jupyter
%matplotlib notebook
# Alternative: %matplotlib inline and uncomment HTML(ani.to_jshtml()) below

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# Physical constants
hbar = 1.0545718e-34  # Reduced Planck constant (J·s)
k_B = 1.380649e-23    # Boltzmann constant (J/K)

# Atom parameters (Rubidium-87)
m = 1.44e-25          # Mass (kg)
lambda_ = 780e-9      # Wavelength (m)
k = 2 * np.pi / lambda_  # Wavevector (1/m)
gamma = 2 * np.pi * 6.065e6  # Linewidth (1/s)
I_sat = 35.8          # Saturation intensity (W/m^2)

# Laser parameters
I = 2 * I_sat         # Intensity (W/m^2) for s=2
s = I / I_sat         # Saturation parameter
delta = -gamma        # Detuning (rad/s)

# Simulation parameters
dt = 1e-8             # Time step (s)
t_max = 4e-4          # 0.4 ms for more cooling
N = int(t_max / dt)   # Number of steps
x = 0.0               # Initial position (m)
v = 5.0               # Initial velocity (m/s)

# Arrays to store data
t = np.linspace(0, t_max, N)
x_data = np.zeros(N)
v_data = np.zeros(N)
x_data[0] = x
v_data[0] = v

# Scattering rate function
def scattering_rate(v, delta, s, k, gamma):
    doppler = k * v
    rate_left = (gamma / 2) * s / (1 + s + 4 * (delta + doppler)**2 / gamma**2)
    rate_right = (gamma / 2) * s / (1 + s + 4 * (delta - doppler)**2 / gamma**2)
    return rate_left, rate_right

# Simulation loop
for i in range(1, N):
    rate_left, rate_right = scattering_rate(v, delta, s, k, gamma)
    F = hbar * k * (rate_right - rate_left)
    total_rate = rate_left + rate_right
    prob = total_rate * dt
    kick = np.random.choice([hbar * k / m, -hbar * k / m, 0], p=[prob/2, prob/2, 1-prob])
    v += (F / m) * dt + kick
    x += v * dt
    x_data[i] = x
    v_data[i] = v

# Diagnostics
max_x = np.max(np.abs(x_data)) * 1e3
min_x = np.min(x_data) * 1e3
print(f"Max x displacement: {max_x:.4f} mm")
print(f"Min x displacement: {min_x:.4f} mm")
print(f"Final velocity: {v_data[-1]:.4f} m/s")
if np.all(x_data == 0):
    print("Warning: x_data is all zeros, animation will show no movement")

# Calculate final temperature
v_squared = np.mean(v_data[-N//10:]**2)
T = m * v_squared / k_B
print(f"Final temperature: {T*1e6:.2f} µK")
print(f"Doppler limit: {hbar * gamma / (2 * k_B) * 1e6:.2f} µK")

# Static velocity plot
plt.figure(figsize=(10, 5))
plt.plot(t * 1e3, v_data, label="Velocity")
plt.xlabel("Time (ms)")
plt.ylabel("Velocity (m/s)")
plt.title("Velocity Evolution in 1D Laser Cooling")
plt.grid(True)
plt.legend()
plt.show()

# Animation with position and velocity subplots
fig, (ax_pos, ax_vel) = plt.subplots(2, 1, figsize=(14, 10), gridspec_kw={'height_ratios': [1, 1], 'hspace': 0.3})
# Adjust spacing between subplots
plt.subplots_adjust(hspace=0.4, top=0.95, bottom=0.08, left=0.1, right=0.95)

# Position subplot
xlim_range = max(max_x, 0.05) * 1.1
ax_pos.set_xlim(-xlim_range, xlim_range)
ax_pos.set_ylim(-0.7, 0.7)
ax_pos.set_xlabel("Position (mm)")
ax_pos.set_title("Atom Motion in 1D Laser Cooling")
scat = ax_pos.scatter([x_data[0] * 1e3], [0], color='red', s=50, label="Atom")
trail, = ax_pos.plot([], [], 'r-', alpha=0.3, label="Trail")
temp_text = ax_pos.text(0.05, 0.9, '', transform=ax_pos.transAxes, fontsize=12)
ax_pos.arrow(-xlim_range, 0.5, xlim_range/2, 0, color='blue', width=0.05, head_width=0.1, label="Left Laser")
ax_pos.arrow(xlim_range, -0.5, -xlim_range/2, 0, color='green', width=0.05, head_width=0.1, label="Right Laser")
ax_pos.legend()

# Velocity subplot
ax_vel.set_xlim(0, t_max * 1e3)
ax_vel.set_ylim(min(np.min(v_data), -0.1), max(np.max(v_data), 5.1))
ax_vel.set_xlabel("Time (ms)")
ax_vel.set_ylabel("Velocity (m/s)")
ax_vel.set_title("Velocity Evolution")
vel_line, = ax_vel.plot([], [], 'b-', label="Velocity")
ax_vel.grid(True)
ax_vel.legend()

def init():
    scat.set_offsets([[x_data[0] * 1e3, 0]])
    trail.set_data([], [])
    temp_text.set_text('')
    vel_line.set_data([], [])
    return scat, trail, temp_text, vel_line

def animate(i):
    if i < len(x_data):
        scat.set_offsets([[x_data[i] * 1e3, 0]])
        trail.set_data(x_data[:i+1] * 1e3, np.zeros(i+1))
        window = min(100, i+1)
        v_window = v_data[max(0, i-window+1):i+1]
        T_running = m * np.mean(v_window**2) / k_B * 1e6  # µK
        temp_text.set_text(f'Temperature: {T_running:.2f} µK')
        vel_line.set_data(t[:i+1] * 1e3, v_data[:i+1])
        if i % (frame_skip * 10) == 0:  # Print every 10th frame
            print(f"Rendering frame {i}/{len(x_data)}, x={x_data[i] * 1e3:.4f} mm, v={v_data[i]:.4f} m/s, T={T_running:.2f} µK")
        plt.draw()
        plt.pause(0.001)
    return scat, trail, temp_text, vel_line

# Animation with fewer frames
frame_skip = 50  # ~800 frames
frames = range(0, min(N, len(x_data)), frame_skip)
ani = FuncAnimation(fig, animate, init_func=init, frames=frames, 
                    interval=30, blit=False)

# Save animation as GIF (commented out to test display only)
print("Saving animation...")
try:
    ani.save('atom_motion.gif', writer='pillow', fps=20)
    print("Animation saved as atom_motion.gif")
except Exception as e:
    print(f"Failed to save as GIF: {e}")

# Try HTML for Jupyter
try:
    html_ani = HTML(ani.to_jshtml())
    print("HTML animation created - displaying below")
    display(html_ani)
except Exception as e:
    print(f"Failed to create HTML: {e}")

try:
    plt.show()
except Exception as e:
    print(f"Display failed: {e}. Check HTML output.")