In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import constants
from scipy.integrate import odeint
import multiprocess as mp
from tqdm.notebook import tqdm

from pybkit.amo.atom import Yb171

In [None]:
def gaussian_trap_potential(
    trap_depth: float,
    w0: float,
    zR: float,
    x: float, 
    y: float, 
    z: float, 
    dz_x: float = 0,
    dz_y: float = 0
) -> float:
    """Calculates beam intensity as a function of position relative to center.
    
    The arguments dz_x and dz_y help account for astigmatism in the trap during movement.
    """
    qx = np.sqrt(1 + (z - dz_x)**2 / zR**2)
    qy = np.sqrt(1 + (z - dz_y)**2 / zR**2)
    expx = np.exp(-2 * x**2 / w0**2 / (1 + (z - dz_x)**2 / zR**2))
    expy = np.exp(-2 * y**2 / w0**2 / (1 + (z - dz_y)**2 / zR**2))
    return -trap_depth * expx * expy / qx / qy

def gaussian_trap_potential_gradient(
    trap_depth: float,
    w0: float,
    zR: float,
    x: float, 
    y: float, 
    z: float, 
    dz_x: float = 0,
    dz_y: float = 0
) -> float:
    """Calculates beam intensity gradient as a function of position relative to center.
    
    The arguments dz_x and dz_y help account for astigmatism in the trap during movement.
    """
    I = gaussian_trap_potential(trap_depth, w0, zR, x, y, z, dz_x, dz_y)
    qx = np.sqrt(1 + (z - dz_x)**2 / zR**2)
    qy = np.sqrt(1 + (z - dz_y)**2 / zR**2)
    dqxdz = 1 / qx * (z - dz_x) / zR**2
    dqydz = 1 / qy * (z - dz_y) / zR**2
    dIdx = - I * 4 * x / w0**2 / qx**2
    dIdy = - I * 4 * y / w0**2 / qy**2
    dIdz = I * (-1 / qx + 4 * x**2 / w0**2 / qx**3) * dqxdz + \
            I * (-1 / qy + 4 * y**2 / w0**2 / qy**3) * dqydz
    return np.array([dIdx, dIdy, dIdz])

def harmonic_trap_potential(trap_depth, radial_freq, axial_freq, x, y, z):
    radial_potential = 0.5 * Yb171.mass * (2 * np.pi * radial_freq)**2 * (x**2 + y**2)
    axial_potential = 0.5 * Yb171.mass * (2 * np.pi * axial_freq)**2 * z**2
    return -trap_depth + radial_potential + axial_potential

def harmonic_trap_potential_gradient(radial_freq, axial_freq, x, y, z):
    dx_grad = Yb171.mass * (2 * np.pi * radial_freq)**2 * x
    dy_grad = Yb171.mass * (2 * np.pi * radial_freq)**2 * y
    dz_grad = Yb171.mass * (2 * np.pi * axial_freq)**2 * z
    return np.array([dx_grad, dy_grad, dz_grad])

In [None]:
w0 = 0.5e-6
wavelength = 486e-9
zR = np.pi * w0**2 / wavelength
trap_depth = constants.h * 10e6
radial_freq = np.sqrt(4 * trap_depth / (Yb171.mass * w0**2)) / (2 * np.pi)
axial_freq = np.sqrt(2 * trap_depth * wavelength**2 / (np.pi**2 * Yb171.mass * w0**4)) / (2 * np.pi)

print(radial_freq)
print(axial_freq)

In [None]:
# Gaussian rap

plt.figure()

xs = np.linspace(-10, 10, 1000) * 1e-6
ys = np.linspace(-10, 10, 1000) * 1e-6
zs = np.linspace(-10, 10, 1000) * 1e-6

plt.plot(xs * 1e6, gaussian_trap_potential(trap_depth, w0, zR, xs, 0, 0) / constants.h / 1e6, label='x')
plt.plot(ys * 1e6, gaussian_trap_potential(trap_depth, w0, zR, 0, ys, 0) / constants.h / 1e6, label='y')
plt.plot(zs * 1e6, gaussian_trap_potential(trap_depth, w0, zR, 0, 0, zs) / constants.h / 1e6, label='z')

plt.legend()
plt.ylabel('Trap depth [MHz]')
plt.xlabel('Displacement [um]')

# Harmonic trap

plt.plot(xs * 1e6, harmonic_trap_potential(trap_depth, radial_freq, axial_freq, xs, 0, 0) / constants.h / 1e6, label='x')
plt.plot(ys * 1e6, harmonic_trap_potential(trap_depth, radial_freq, axial_freq, 0, ys, 0) / constants.h / 1e6, label='y')
plt.plot(zs * 1e6, harmonic_trap_potential(trap_depth, radial_freq, axial_freq, 0, 0, zs) / constants.h / 1e6, label='z')

plt.legend()
plt.ylabel('Trap depth [MHz]')
plt.xlabel('Displacement [um]')


plt.ylim(-15, 5)

In [None]:
plt.figure()

xs = np.linspace(-5, 5, 1000) * 1e-6
ys = np.linspace(-5, 5, 1000) * 1e-6
zs = np.linspace(-5, 5, 1000) * 1e-6

plt.plot(xs * 1e6, -np.linalg.norm(gaussian_trap_potential_gradient(trap_depth, w0, zR, xs, 0, 0), axis=0) / Yb171.mass, label='x')
plt.plot(ys * 1e6, -np.linalg.norm(gaussian_trap_potential_gradient(trap_depth, w0, zR, 0, ys, 0), axis=0) / Yb171.mass, label='y')
plt.plot(zs * 1e6, -np.linalg.norm(gaussian_trap_potential_gradient(trap_depth, w0, zR, 0, 0, zs), axis=0) / Yb171.mass, label='z')

plt.legend()
plt.ylabel('Acceleration [$m/s^2$]')
plt.xlabel('Displacement [um]')

In [None]:
%matplotlib widget

def generate_maxwell_boltzmann_sample(T, radial_freq, axial_freq, num_samples):
    sigma_v = np.sqrt(constants.k * T / Yb171.mass)
    sigma_r = np.sqrt(constants.k * T / Yb171.mass) / (radial_freq * 2 * np.pi)
    sigma_z = np.sqrt(constants.k * T / Yb171.mass) / (axial_freq * 2 * np.pi)
    position = np.array([
        np.random.normal(loc=0, scale=sigma_r, size=num_samples),
        np.random.normal(loc=0, scale=sigma_r, size=num_samples),
        np.random.normal(loc=0, scale=sigma_z, size=num_samples)])
    velocity = np.random.normal(loc=0, scale=sigma_v, size=(3, num_samples))
    return position.T, velocity.T

ts = np.linspace(0, 20, int(1e4)) * 1e-6
dt = ts[1]-ts[0]
print(f'dt = {dt/1e-9: .3f} ns')

trapmod_freqs = [radial_freq, 2*radial_freq, 0] 

fig, ax = plt.subplots(figsize=(8,8), nrows=len(trapmod_freqs), ncols=4)

min_E = np.inf
max_E = -np.inf


p0, v0 = generate_maxwell_boltzmann_sample(
    T=5e-6, 
    radial_freq=radial_freq,
    axial_freq=axial_freq,
    num_samples=50)


for i, f_trapmod in enumerate(trapmod_freqs):

    def derivative(t, var):
        x, y, z, vx, vy, vz = var
        square_wave = (1 + np.sign(np.sin(2 * np.pi * f_trapmod * t)))
        # fr = radial_freq
        # fa = axial_freq
        ax, ay, az = -square_wave * harmonic_trap_potential_gradient(radial_freq, axial_freq, x, y, z) / Yb171.mass
        # ax, ay, az =  -square_wave * gaussian_trap_potential_gradient(trap_depth, w0, zR, x, y, z) / Yb171.mass
        dvar_dt = [vx, vy, vz, ax, ay, az]
        return dvar_dt
    
    num_samples = np.array(p0).shape[0]
    def _simulate(ample_idx):
        return odeint(func=derivative, y0=np.array([*p0[ample_idx,:], *v0[ample_idx,:]]), t=ts, tfirst=True, hmax=dt)
    num_workers = mp.cpu_count() - 1
    with mp.Pool(num_workers) as p:
        sol_list = list(tqdm(p.imap(_simulate, range(num_samples)), total=num_samples))
        
    ax[i,0].set_title(r'$f_\mathrm{trapmod}$' + f' = {f_trapmod/1e3:.0f} kHz')
    ax[i,0].plot(ts * 1e6, trap_depth / constants.h / 1e6 * 0.5 * (1 + np.sign(np.sin(2 * np.pi * f_trapmod * ts))))
    ax[i,0].plot(ts * 1e6, trap_depth / constants.h / 1e6 * 0.5 * (1 + np.sin(2 * np.pi * radial_freq * ts)))
    ax[i,0].set_ylabel('Depth [MHz]')
        
    for sol in sol_list:
        position, velocity = sol[:,:3], sol[:,3:]
        ax[i,1].plot(ts * 1e6, np.sqrt(position[:,0]**2 + position[:,1]**2) * 1e6, label=f'{f_trapmod/1e3:.0f}kHz', color='C0', alpha=0.05)
        ax[i,2].plot(ts * 1e6, position[:,2] * 1e6, label=f'{f_trapmod/1e3:.0f}kHz', color='C0', alpha=0.05)
    ax[i,1].set_ylabel('radial [um]')
    ax[i,2].set_ylabel('axial [um]')
    ax[i,2].set_ylim(-0.3, 0.3)
    
    p_final = np.array([s[:,:3][-1] for s in sol_list])
    v_final = np.array([s[:,3:][-1] for s in sol_list])
    kinetic_energy = 0.5 * Yb171.mass * np.linalg.norm(v_final, axis=1)**2
    potential_energy = gaussian_trap_potential(trap_depth, w0, zR, p_final[:,0], p_final[:,1], p_final[:,2])
    energy = kinetic_energy + potential_energy
    energy = np.array(energy, dtype=float)
    energy /= (constants.h * 1e6)
    if np.min(energy) < min_E:
        min_E = np.min(energy)
    if np.max(energy) > max_E:
        max_E = np.max(energy)
    ax[i,3].hist(energy, bins=30)
    ax[i,3].set_ylabel('Counts')
    # xs = np.array([s[:,0] for s in sol_list])
    # ys = np.array([s[:,1] for s in sol_list])
    # zs = np.array([s[:,2] for s in sol_list])
    # vxs = np.array([s[:,3] for s in sol_list])
    # vys = np.array([s[:,4] for s in sol_list])
    # vzs = np.array([s[:,5] for s in sol_list])
    # Ks_radial = 0.5 * Yb171.mass * (vxs**2 + vys**2)
    # Ks_axial = 0.5 * Yb171.mass * (vzs**2)
    # Us_radial = 0.5 * Yb171.mass * (2 * np.pi * radial_freq)**2 * (xs**2 + ys*2)
    # Us_axial = 0.5 * Yb171.mass * (2 * np.pi * axial_freq)**2 * zs**2
    # Es_radial = Ks_radial + Us_radial
    # Es_axial = Ks_axial + Us_axial
    # ax[i,3].plot(ts, np.mean(Es_radial / constants.k, axis=0), label='radial')
    # ax[i,3].plot(ts, np.mean(Es_axial / constants.k, axis=0), label='axial')
    # ax[i,3].legend()
    
for i in range(len(trapmod_freqs)):
    ax[i,3].set_xlim(min_E, max_E)  
    
ax[-1,0].set_xlabel('Time [us]')
ax[-1,3].set_xlabel('Energy [MHz]')
fig.tight_layout()

In [None]:
def generate_maxwell_boltzmann_sample(T, radial_freq, axial_freq, num_samples):
    sigma_v = np.sqrt(constants.k * T / Yb171.mass)
    sigma_r = np.sqrt(constants.k * T / Yb171.mass) / (radial_freq * 2 * np.pi)
    sigma_z = 0.5e-6 #np.sqrt(constants.k * T / Yb171.mass) / (axial_freq * 2 * np.pi)
    position = np.array([
        np.random.normal(loc=0, scale=sigma_r, size=num_samples),
        np.random.normal(loc=0, scale=sigma_r, size=num_samples),
        np.random.normal(loc=0, scale=sigma_z, size=num_samples)])
    velocity = np.random.normal(loc=0, scale=sigma_v, size=(3, num_samples))
    return position.T, velocity.T

p0, v0 = generate_maxwell_boltzmann_sample(
    T=5e-6, 
    radial_freq=radial_freq,
    axial_freq=axial_freq,
    num_samples=1)
p0, v0 = p0[0,:], v0[0,:]
# p0 = np.array([20, 0, 100]) * 1e-9
# v0 = np.array([0, 0, 0])

In [None]:
%matplotlib widget

num_samples = int(1e6)
ts = np.linspace(0, 50, num_samples) * 1e-3
dt = ts[1]-ts[0]
print(f'dt = {dt/1e-9: .3f} ns')

trapmod_freqs = [300e3, 1e6, 0] 
    
xs_arr = []
ys_arr = []
zs_arr = []
vxs_arr = []
vys_arr = []
vzs_arr = []

noise = np.random.normal(0, 10, size=(3, num_samples))
# # Generate white noise
# white_noise = np.random.normal(0, 10, (3, num_samples))
# # Apply frequency-dependent filter to convert white noise to pink noise
# f = np.fft.rfftfreq(num_samples)
# # Scale the frequency components by 1/sqrt(f), avoid divide by zero
# pink_filter = np.sqrt(1 / (f + 1e-5))  # Add small value to avoid division by zero
# pink_noise = np.fft.irfft(np.fft.rfft(white_noise) * pink_filter)

for i, f_trapmod in enumerate(trapmod_freqs):
    
    print(f_trapmod)

    def derivative(t, var):
        x, y, z, vx, vy, vz = var
        # square_wave = (1 + np.sign(np.sin(2 * np.pi * f_trapmod * t))) if f_trapmod != 0 else 1
        # ax, ay, az = -square_wave * gaussian_trap_potential_gradient(trap_depth, w0, zR, x, y, z) / Yb171.mass
        tidx = np.searchsorted(ts, t)
        if tidx >= num_samples:
            tidx = num_samples - 1
        sine_wave = (1 + np.sin(2 * np.pi * f_trapmod * t)) if f_trapmod != 0 else 1
        ax, ay, az = -sine_wave * gaussian_trap_potential_gradient(trap_depth, w0, zR, x, y, z) / Yb171.mass + noise[:,tidx]
        dvar_dt = [vx, vy, vz, ax, ay, az]
        return dvar_dt

    sol = odeint(func=derivative, y0=np.array([*p0, *v0]), t=ts, tfirst=True, hmax=dt)

    xs = sol[:,0]
    ys = sol[:,1]
    zs = sol[:,2]
    vxs = sol[:,3]
    vys = sol[:,4]
    vzs = sol[:,5]
    xs_arr.append(xs)
    ys_arr.append(ys)
    zs_arr.append(zs)
    vxs_arr.append(vxs)
    vys_arr.append(vys)
    vzs_arr.append(vzs)

In [None]:
def trim(arr, N):
    return arr[:len(arr) - (len(arr) % N)]

def group_max(arr, N):
    # Trim the array to make its length a multiple of N
    trimmed_arr = arr[:len(arr) - (len(arr) % N)]
    # Reshape the array to have N columns
    reshaped_arr = trimmed_arr.reshape(-1, N)
    # Compute the maximum along the rows
    max_values = np.max(reshaped_arr, axis=1)
    return max_values

def group_min(arr, N):
    # Trim the array to make its length a multiple of N
    trimmed_arr = arr[:len(arr) - (len(arr) % N)]
    # Reshape the array to have N columns
    reshaped_arr = trimmed_arr.reshape(-1, N)
    # Compute the maximum along the rows
    max_values = np.min(reshaped_arr, axis=1)
    return max_values

def plot(n):

    radial_period_step = int(1 / radial_freq / dt)
    axial_period_step = int(1 / axial_freq / dt)
    N_radial = int(n * radial_period_step)
    N_axial = int(n * axial_period_step)
        
    fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10, 4), sharex=True)

    for i in range(len(trapmod_freqs)):
        xpp = group_max(xs_arr[i], N_radial) - group_min(xs_arr[i], N_radial)
        ypp = group_max(ys_arr[i], N_radial) - group_min(ys_arr[i], N_radial)
        zpp = group_max(zs_arr[i], N_axial) - group_min(zs_arr[i], N_axial)
        ax[0,0].plot(trim(ts, N_radial)[::N_radial] * 1e3, 1e9 * (xpp - xpp[0]), 'o-', ms=3, label=f'{trapmod_freqs[i]/1e3} kHz')
        ax[0,1].plot(trim(ts, N_radial)[::N_radial] * 1e3, 1e9 * (ypp - ypp[0]), 'o-', ms=3)
        ax[0,2].plot(trim(ts, N_axial)[::N_axial] * 1e3, 1e9 * (zpp - zpp[0]), 'o-', ms=3)
        ax[1,0].plot(trim(ts, N_radial)[::N_radial] * 1e3, 1e9 * xpp, 'o-', ms=3)
        ax[1,1].plot(trim(ts, N_radial)[::N_radial] * 1e3, 1e9 * ypp, 'o-', ms=3)
        ax[1,2].plot(trim(ts, N_axial)[::N_axial] * 1e3, 1e9 * zpp, 'o-', ms=3)

    ax[0,0].set_ylabel('Peak-peak delta [nm]')
    ax[1,0].set_ylabel('Peak-peak amplitude [nm]')
    ax[1,0].set_xlabel('Time [ms]')
    ax[1,1].set_xlabel('Time [ms]')
    ax[1,2].set_xlabel('Time [ms]')
    
    fig.legend(loc="upper center", ncol=len(trapmod_freqs))
    fig.tight_layout(rect=[0, 0, 1, 0.93])

In [None]:
plot()

In [None]:
plot()

In [None]:
gaussian_trap_potential(trap_depth, w0, zR, x=20e-9, y=20e-9, z=100e-9) / constants.h / 1e6

In [None]:
i = 2
print(trapmod_freqs[i] / 1e3)

Ks_radial = 0.5 * Yb171.mass * (vxs_arr[i][:]**2 + vys_arr[i][:]**2)
Ks_axial = 0.5 * Yb171.mass * (vzs_arr[i][:]**2)
Us = gaussian_trap_potential(trap_depth, w0, zR, xs_arr[i][:], ys_arr[i][:], zs_arr[i][:]) # 0.5 * Yb171.mass * (2 * np.pi * radial_freq)**2 * (xs**2 + ys**2)
Es_total = Ks_axial + Ks_radial + Us
Es = Es_total / constants.h / 1e6

fig, ax = plt.subplots()
ax.plot(ts * 1e3, Es, label=f'{f_trapmod / 1e3} kHz', c=f'C{i}')