In [1]:
using Revise

In [416]:
using Distributed
procs_to_use = 20

if nprocs() <= procs_to_use
    addprocs(procs_to_use-nprocs())
end

@everywhere using
    QuantumStates,
    OpticalBlochEquations,
    DifferentialEquations,
    UnitsToValue,
    LinearAlgebra,
    Printf,
    Plots,
    Random,
    StatsBase,
    Distributions,
    StructArrays,
    StaticArrays,
    StructArrays,
    LoopVectorization,
    Parameters,
    MutableNamedTuples

@everywhere @consts begin
    λ = 626e-9
    Γ = 2π* 6.4e6
    m = @with_unit 57 "u"
    k = 2π / λ
    _μB = (μ_B / h) * 1e-4
end
;

In [1541]:
@everywhere sim_params = @params_mnt begin
    
    # B-field parameters during blue MOT
    B_gradient = 50
    B_offset = (0e-3, 0e-3, 300e-3)
    B_ramp_time = 30e-3
    
    # Laser parameters
    s1 = 2.18
    s2 = 1.82  
    s_ramp_time = 30e-3
    s_ramp_to = 0.7
    pol_imbalance = 0.01
    s_imbalance = (0.0, 0.05, -0.05)
    retro_loss = 0.02
    off_center = (2, -2, -2, 2, 2, 2) .* 1e-3
    pointing_error = (0,0,0,0,0,0)
    pol1_x = [1,0,0]
    pol2_x = [0,0,1]
    voltage = -1.4
    aom_freq = 53.55
    
    # ODT parameters
    ODT_intensity = 13.3*1e3/(50e-4)^2
    ODT_size = (30e-6, 2e-3, 30e-6)
    ODT_position = [0.,0.]
    ODT_revs = 6
    ODT_motion_t_start = 0.0
    ODT_motion_t_stop = 120e-3
    ODT_pol = [0,1,0]
    ODT_rmax = 200e-6
    
end
;

In [1542]:
# @everywhere begin
#     # SF cooling parameters
#     sim_params.B_gradient = 0
#     sim_params.s1 = 6.0
#     sim_params.s2 = 0.0
#     sim_params.voltage = 8.0
#     sim_params.s_ramp_to = 1.0
# end

@everywhere begin
    # SF cooling parameters
    sim_params.B_gradient = 0
    sim_params.s1 = 6.0
    sim_params.s2 = 0.0
    sim_params.voltage = 4.0
    sim_params.s_ramp_to = 1.0
end
;

### Calculate transition dipole moments

In [1543]:
@everywhere begin
    include("define_CaOH_states.jl")
    X_states, A_states = define_CaOH_states()
    states = [X_states; A_states[1:4]]
    d = tdms_between_states(states, states)
end
;

### Define lasers

In [1544]:
@everywhere begin
    include("define_lasers.jl")
    lasers = define_lasers(
        states,
        sim_params.s1,
        sim_params.s2,
        sim_params.s_ramp_time,
        sim_params.s_ramp_to,
        sim_params.pol_imbalance,
        sim_params.s_imbalance,
        sim_params.retro_loss,
        sim_params.off_center,
        sim_params.pointing_error,
        sim_params.pol1_x,
        sim_params.pol2_x,
        sim_params.voltage,
        sim_params.aom_freq
    )
end
;

### Define Hamiltonian for the ODT-molecule interaction

In [1545]:
@everywhere begin
    include("define_ODT_Hamiltonian.jl")
    H_ODT = get_H_ODT(states, X_states, A_states, sim_params.ODT_intensity, sim_params.ODT_pol)
end
;

### Define Zeeman Hamiltonian

In [1546]:
@everywhere include("define_Zeeman_Hamiltonian.jl")
;

### Run simulation

In [1547]:
@everywhere function update_H_and_∇H(H, p, r, t)
    
    # Define a ramping magnetic field
    Zeeman_Hz = p.extra_data.Zeeman_Hz
    Zeeman_Hx = p.extra_data.Zeeman_Hx
    Zeeman_Hy = p.extra_data.Zeeman_Hy
    
    τ_bfield = p.sim_params.B_ramp_time / (1/Γ)
    scalar = t/τ_bfield
    scalar = min(scalar, 1.0)
    
    gradient_x = -scalar * p.sim_params.B_gradient * 1e2 / k / 2
    gradient_y = +scalar * p.sim_params.B_gradient * 1e2 / k / 2
    gradient_z = -scalar * p.sim_params.B_gradient * 1e2 / k
    
    Bx = gradient_x * r[1] + p.sim_params.B_offset[1]
    By = gradient_y * r[2] + p.sim_params.B_offset[2]
    Bz = gradient_z * r[3] + p.sim_params.B_offset[3]
    
    @turbo for i in eachindex(H)
        H.re[i] = Bz * Zeeman_Hz.re[i] + Bx * Zeeman_Hx.re[i] + By * Zeeman_Hy.re[i]
        H.im[i] = Bz * Zeeman_Hz.im[i] + Bx * Zeeman_Hx.im[i] + By * Zeeman_Hy.im[i]
    end
    
    # Update the Hamiltonian for the molecule-ODT interaction
    H_ODT = p.extra_data.H_ODT
    H_ODT_static = p.extra_data.H_ODT_static
    @turbo for i in eachindex(H_ODT)
       H_ODT.re[i] = H_ODT_static.re[i]
       H_ODT.im[i] = H_ODT_static.im[i]
    end
    
    # Update the ODT position
    update_ODT_center!(p.sim_params, p.extra_data, t)
    # update_ODT_center_circle!(p.sim_params, p.extra_data, t)
    
    ODT_x = p.sim_params.ODT_position[1] / (1 / p.k)
    ODT_z = p.sim_params.ODT_position[2] / (1 / p.k)
    
    ODT_size = p.sim_params.ODT_size .* p.k
    
    scalar_ODT = exp(-2(r[1]-ODT_x)^2/ODT_size[1]^2) * exp(-2r[2]^2/ODT_size[2]^2) * exp(-2(r[3]-ODT_z)^2/ODT_size[3]^2)
    
    @turbo for i in eachindex(H)
        H.re[i] += H_ODT.re[i] * scalar_ODT
        H.im[i] += H_ODT.im[i] * scalar_ODT
    end
    
    ∇H = SVector{3, Float64}((-4(r[1]-ODT_x) / ODT_size[1]^2) * scalar_ODT, (-4r[2] / ODT_size[2]^2) * scalar_ODT, (-4(r[3]-ODT_z) / ODT_size[3]^2) * scalar_ODT)
    
    return ∇H
end
;

In [1548]:
@everywhere extra_data = MutableNamedTuple(
    Zeeman_Hx = Zeeman_x_mat,
    Zeeman_Hy = Zeeman_y_mat,
    Zeeman_Hz = Zeeman_z_mat,
    H_ODT_static = StructArray(H_ODT),
    H_ODT = deepcopy(StructArray(H_ODT)),
    ODT_as = ODT_as,
    ODT_τs = ODT_τs
)
;

In [None]:
@everywhere begin
    t_start = 0.0
    t_end   = 40e-3
    t_span  = (t_start, t_end) ./ (1/Γ)

    n_states = length(states)
    n_excited = 4

    particle = Particle()
    particle.r = (0, 0, 0) ./ (1/k)
    ψ₀ = zeros(ComplexF64, n_states)
    ψ₀[1] = 1.0
end

using Logging: global_logger
using TerminalLoggers: TerminalLogger
global_logger(TerminalLogger())

p = schrodinger_stochastic(particle, states, lasers, d, ψ₀, m/(ħ*k^2/Γ), n_excited; sim_params=sim_params, extra_data=extra_data, λ=λ, Γ=Γ, update_H_and_∇H=update_H_and_∇H)

prob = ODEProblem(ψ_stochastic_potential!, p.ψ, t_span, p)

cb = ContinuousCallback(condition, SE_collapse_pol_always!, nothing, save_positions=(false,false))
@time sol = DifferentialEquations.solve(prob, alg=DP5(), reltol=5e-4, callback=cb, saveat=1000, maxiters=80000000, progress=true, progress_steps=200000)
;

[32mODE   0%|█                                              |  ETA: N/A[39m
[32mODE   2%|█                                              |  ETA: 0:02:25[39m
[32mODE   4%|██                                             |  ETA: 0:02:23[39m
[32mODE   5%|███                                            |  ETA: 0:02:20[39m
[32mODE   7%|████                                           |  ETA: 0:02:17[39m
[32mODE   9%|█████                                          |  ETA: 0:02:14[39m
[32mODE  11%|██████                                         |  ETA: 0:02:13[39m
[32mODE  13%|██████                                         |  ETA: 0:02:10[39m
[32mODE  14%|███████                                        |  ETA: 0:02:07[39m
[32mODE  16%|████████                                       |  ETA: 0:02:04[39m
[32mODE  18%|█████████                                      |  ETA: 0:02:01[39m
[32mODE  20%|██████████                                     |  ETA: 0:01:59[39m
[32mODE  22%|██████

In [None]:
plot_us = sol.u
plot_ts = sol.t
x_trajectories = [real(u[n_states + n_excited + 1]) for u in plot_us]./k*1e3
y_trajectories = [real(u[n_states + n_excited + 2]) for u in plot_us]./k*1e3
z_trajectories = [real(u[n_states + n_excited + 3]) for u in plot_us]./k*1e3
;

In [None]:
lim = 0.5
plot(x_trajectories, z_trajectories, legend=nothing, xlim=(-lim, lim), ylim=(-lim, lim))

In [None]:
prob.p.n_scatters / t_end

### Run simulation for multiple particles in parallel

In [None]:
@everywhere function prob_func(prob, i, repeat)
    
     lasers = define_lasers(
        states,
        sim_params.s1,
        sim_params.s2,
        sim_params.s_ramp_time,
        sim_params.s_ramp_to,
        sim_params.pol_imbalance,
        sim_params.s_imbalance,
        sim_params.retro_loss,
        sim_params.off_center,
        sim_params.pointing_error,
        sim_params.pol1_x,
        sim_params.pol2_x,
        sim_params.voltage,
        sim_params.aom_freq
    )
    
    # Define initial conditions for the molecule
    cloud_size = @with_unit 0.05 "mm"
    particle = Particle()
    particle.r = (rand(Normal(0, cloud_size)), rand(Normal(0, cloud_size)), rand(Normal(0, cloud_size))) ./ (1/k)
    
    ψ₀ = zeros(ComplexF64, n_states)
    ψ₀[1] = 1.0
    
    p = schrodinger_stochastic(particle, states, lasers, d, ψ₀, m/(ħ*k^2/Γ), n_excited; sim_params=sim_params, extra_data=extra_data, λ=λ, Γ=Γ, update_H_and_∇H=update_H_and_∇H)
    
    callback = ContinuousCallback(condition, SE_collapse_pol_always!, nothing, save_positions=(false,false))
    
    prob = ODEProblem(ψ_stochastic_potential!, p.ψ, t_span, p, callback=callback, reltol=1e-4, saveat=4000, maxiters=80000000)
    
    return prob
end
;

In [None]:
ensemble_prob = EnsembleProblem(prob; prob_func=prob_func)
;

In [None]:
n_molecules = 10
@time ensemble_sol = solve(ensemble_prob, DP5(), EnsembleDistributed(); trajectories=n_molecules)
;

In [None]:
x_trajectories = Array{Vector{Float64}}(fill([],n_molecules))
y_trajectories = Array{Vector{Float64}}(fill([],n_molecules)) 
z_trajectories = Array{Vector{Float64}}(fill([],n_molecules))
times = Array{Vector{Float64}}(fill([],n_molecules))

for i ∈ 1:n_molecules
    sol_u = ensemble_sol[i].u
    sol_t = ensemble_sol[i].t
    
    x_trajectories[i] = [x_trajectories[i]; [real(u[n_states + n_excited + 1]) for u in sol_u]./k*1e3 ]
    y_trajectories[i] = [y_trajectories[i]; [real(u[n_states + n_excited + 2]) for u in sol_u]./k*1e3 ]
    z_trajectories[i] = [z_trajectories[i]; [real(u[n_states + n_excited + 3]) for u in sol_u]./k*1e3 ]
    
    times[i] = sol_t .* (1/Γ)
end

trajectories = [
    [[x[1],x[2],x[3]] for x ∈ zip(x_trajectories[i],y_trajectories[i],z_trajectories[i])] 
    for i ∈ 1:n_molecules
    ]
;
;

In [None]:
lim = 0.3
plot()
for i ∈ 1:n_molecules
    plot!(times[i] .* 1e3, z_trajectories[i], legend=nothing, ylim=(-lim, lim))
end
plot!()

In [None]:
lim = 0.5
plot(x_trajectories, z_trajectories, legend=nothing, xlim=(-lim, lim), ylim=(-lim, lim), alpha=0.2)

In [None]:
using Serialization
# serialize("300e-6 radius spiral.jl", ensemble_sol)

In [None]:
mean(ensemble_sol[i].prob.p.n_scatters ./ times[i][end] for i ∈ 1:n_molecules)

In [None]:
prob.p.sim_params.ODT_position

In [None]:
captured_in_ODT(trajectories, times, 40e-3, prob)

In [None]:
ts = 0:5e-3:20e-3
captured = []

for t ∈ ts
    _captured = captured_in_ODT(trajectories, times, t, prob)
    push!(captured, _captured)
end

In [None]:
plot(ts, captured)

In [None]:
function captured_in_ODT(trajectories, times, t, prob)
    n = 0
    
    update_ODT_center!(prob.p.sim_params, prob.p.extra_data, t / (1/Γ))
    ODT_position = prob.p.sim_params.ODT_position
    ODT_size = prob.p.sim_params.ODT_size
    
    for (i,trajectory) ∈ enumerate(trajectories)
        traj_idx = searchsortedfirst(times[i], t)
        if (abs(trajectory[traj_idx][1] - ODT_position[1] * 1e3) <= ODT_size[1] * 1e3) && 
            (abs(trajectory[traj_idx][2]) <= ODT_size[2] * 1e3) &&
            (abs(trajectory[traj_idx][3] - ODT_position[2] * 1e3) <= ODT_size[3] * 1e3)
            n += 1
        end
    end
    return n
end
;