In [None]:
using Revise

In [None]:
import QuantumCollocation as QC
import NamedTrajectories as NT
import TrajectoryIndexingUtils as NTidx
import Plots
import LinearAlgebra as LA
import SparseArrays as SA
import ForwardDiff as FD
using LaTeXStrings

In [None]:
include("utils.jl")
include("system.jl")
include("objectives.jl")
include("constraints.jl")

In [None]:
V = 10.
p_max = 5
# E_R [kHz] found in Weidner thesis
system = ShakenLatticeSystem1D(V, p_max; acc=0.0, include_acc_derivative=true, sparse=false)
# middle index of statevector where p = 0
mid = system.params[:mid]
dim = system.params[:dim]
#E_R = system.params[:E_R]
#E_R = 1/0.05
#print("time unit $(1/E_R) ms\nE_R = $E_R kHz")

In [None]:
system.H_drift_real

In [None]:
duration = 2pi * 1.0 # in units of 1/E_R

T = 301
dt = duration / (T-1)
dts = zeros(T) .+ dt
dt_bound = (dt, dt)
times = cumsum(dts) - dts;

In [None]:
a_bound = fill(1., 2)
dda_bound = fill(1000., 2)

#a = vcat(ones(T)', zeros(T)') 
#a = vcat(cos.(2pi * times/duration)', sin.(2pi * times/duration)')
#a = vcat(cos.(1. .+ 0.5*cos.(2π* 12. *times))', sin.(1. .+ 0.5*cos.(2π* 12. *times))')

# phi_guess = 0.5 * (cos.(2pi * 4. * times) + cos.(2pi * 12. * times))
# a = vcat(cos.(phi_guess)', sin.(phi_guess)')

a = vcat(cos.(2. *sin.(11.5 *times))', sin.(2. *sin.(11.5 *times))')

acc = collect(times')

da = NT.derivative(a, dts)
da[end, :] .= 1.
# da[:, end] = da[:, end-1]
dda = NT.derivative(da, dts)
# dda[:, end] = dda[:, end-1] = dda[:, end-2]
dda[end, :] .= 0.;

Z_split = NT.load_traj("interferometer/split_victor.jld2")
Z_mirror = NT.load_traj("interferometer/mirror_victor.jld2")

In [None]:
Z_split = NT.load_traj("interferometer/split_victor_opt2.jld2")
Z_mirror = NT.load_traj("interferometer/mirror_victor2_opt2.jld2")

In [None]:
a, dts = get_interferometer(Z_split, Z_mirror, [])

In [None]:
Z = NT.load_traj("./interferometer/save.jld2")

In [None]:
a = Z.a
dts = vec(Z.dts)

In [None]:
Z_split.T

In [None]:
a[:,176:end] .= [1., 0.]

In [None]:
dts

In [None]:
Plots.plot(dts)

In [None]:
duration = sum(dts) # in units of 1/E_R

T = length(dts)
dt = duration / (T-1)
dts = zeros(T) .+ dt
dt_bound = (dt, dt)
times = cumsum(dts) - dts;

In [None]:
a_bound = fill(1., 2)
dda_bound = fill(1000., 2)

acc = collect(times')

da = NT.derivative(a, dts)
# da[:, end] = da[:, end-1]
dda = NT.derivative(da, dts)
# dda[:, end] = dda[:, end-1] = dda[:, end-2]

In [None]:
acc

In [None]:
Plots.plot(times, a')

In [None]:
Plots.plot(times, dda')

In [None]:
Plots.plot(times)

In [None]:
phi_guess = angle.(a[1,:] + im*a[2,:])
phi_mod_clean!(phi_guess)

In [None]:
Plots.plot(times, phi_guess)

In [None]:
psi0 = get_bloch_state(system; lvl=0)
if system.params[:accelerated]
    append!(psi0, zeros(dim))
end

In [None]:
Plots.bar(-p_max:p_max, abs2.(psi0[1:dim]))

In [None]:
psi0_iso = QC.ket_to_iso(psi0)

In [None]:
bloch_states = hcat([get_bloch_state(system; lvl=i) for i=0:dim-1]...)

In [None]:
time_flight = 2pi * 2.

In [None]:
Z_mirror.T

In [None]:
jumps = [(Z_split.T, time_flight), (Z_split.T+Z_mirror.T, time_flight)]
#jumps = Tuple{Int, Float64}[]
cuts = [jump[1] for jump in jumps]
full_times = get_times(dts, jumps)
G = get_shaken_lattice_propagator(system, times, jumps, 10000)

jumps = [(100, time_flight), (200, time_flight)]
#jumps = Tuple{Int, Float64}[]
cuts = [jump[1] for jump in jumps]
full_times = get_times(dts, jumps)
G = get_shaken_lattice_propagator(system, times, jumps, 10000)

In [None]:
G

In [None]:
U = [QC.iso_vec_to_operator(QC.iso_operator_to_iso_vec(g)) for g in G]

In [None]:
Plots.heatmap(sqrt.(abs.(U[1])), yflip=true)

In [None]:
Plots.heatmap(sqrt.(abs.(U[2])), yflip=true)

In [None]:
psi_iso = shaken_lattice_rollout(psi0_iso, a, dts, system, jumps, G)

psi_iso = QC.rollout(psi0_iso, a, dts, system)#; integrator=exp)

In [None]:
bloch_states

In [None]:
B = blockdiagonal(bloch_states, bloch_states)'

In [None]:
psi = hcat([QC.iso_to_ket(psi_iso[:,t]) for t=1:T]...)
psi = B * psi
pops = abs2.(psi)

In [None]:
function format_plot(
    p,
    times=times,
    jumps=jumps,
    full_times=full_times,
)
    T = length(times)
    Plots.xticks!(p, (times[1:div(T,10):end], string.(round.(full_times[1:div(T,10):end]; digits=1))))
    Plots.xlims!(p, (times[1], times[end]))
    rel_y = 0.98
    for (cut, jump_time) in jumps
        rel_x = times[cut] / times[end]
        Plots.vline!(p, [times[cut]], color=:red, label=nothing)
        Plots.annotate!(((rel_x, rel_y), (string(round(jump_time; digits=1)), 10, :red)))
    end
    Plots.xlabel!(p, L"t \, [1/\omega_R]")
    return p
end

In [None]:
p = Plots.plot(times, pops[1:dim,:]', labels=(-p_max:1:p_max)', legend=:outertopright, size=(500,300))
Plots.plot!(
    p,
    title=L"Bloch state population evolution $|B\rangle$"
)
format_plot(p)

In [None]:
Plots.plot(times, pops[dim+1:end,:]', labels=(-p_max:1:p_max)', legend=:outertopright)

In [None]:
comps = (
    psi_iso = psi_iso,
    a = a,
    acc = acc,
    da = da,
    dda = dda,
    dts = dts
)
initial = (
    psi_iso = psi0_iso,
    a = [1.; 0.],
    #da = da0
)
final = (;
    a = [1.; 0.],
    #da = zeros(2)
)
goal = (;)
bounds = (
    #a = a_bound,
    dda = dda_bound,
    dts = dt_bound,
)

Z_guess = nothing
GC.gc()
Z_guess = NT.NamedTrajectory(
    comps;
    controls=(:dda),
    timestep=:dts,
    bounds=bounds,
    initial=initial,
    final=final,
    goal=goal
)

In [None]:
Z_guess = copy(Z)

pops_goal = abs2.(get_bloch_state(system; lvl=0))

dim = system.params[:dim]
J = NameComponentPathObjective(
    :psi_iso,
    1:dim,
    [T],
    [time -> pop for pop in pops_goal],
    [x -> abs2.(x) for _=1:dim],
    fill(1., dim, 1);
    is_iso=true
)

In [None]:
MZFI = (8pi*(full_times[end]/2)^2)^2

In [None]:
function Fisher(psi, dpsi)
    eps = 0.0
    P = abs2.(psi)
    D = 2*real.(conj.(psi) .* dpsi)
    F = (1 ./ (P .+ eps))' * D.^2
    return F / MZFI
end 


In [None]:
fisher_loss = psi_dpsi -> -Fisher(psi_dpsi[1:dim], psi_dpsi[dim+1:2*dim])

In [None]:
-fisher_loss(QC.iso_to_ket(psi_iso[:,end]))

In [None]:
log_sensitivity_loss = psi_dpsi -> -0.5 * log10(Fisher(psi_dpsi[1:dim], psi_dpsi[dim+1:2*dim]))

In [None]:
log_sensitivity_loss(QC.iso_to_ket(psi_iso[:,end]))

In [None]:
fisher_evol = [-fisher_loss(QC.iso_to_ket(psi_iso[:,t])) for t=1:T];
p = Plots.plot(times, fisher_evol)
Plots.vline!(p, times[cuts], color=:red, label="cut")

flight_times = 2pi * collect(1:40)

#fishers = []
for flight_time in flight_times
    println(flight_time)
    jumps = [(211, flight_time), (632, flight_time)]
    cuts = [jump[1] for jump in jumps]
    full_times = get_times(dts, jumps)
    G = get_shaken_lattice_propagator(system, times, jumps, 10000)
    psi_iso = shaken_lattice_rollout(psi0_iso, a, dts, system, jumps, G)
    push!(fishers, fisher_loss(psi_iso[:,end]))
end

In [None]:
Plots.plot(flight_times, hcat(-fishers, (4pi)^2 * flight_times.^4))

In [None]:
#J = QC.QuantumObjective(name=:psi_iso, goals=QC.ket_to_iso(vcat(get_bloch_state(system; lvl=0), zeros(dim))), loss=:InfidelityLoss, Q=1e2)
#J = QC.QuantumObjective(name=:psi_iso, goals=QC.ket_to_iso(get_bloch_state(system; lvl=3)), loss=:InfidelityLoss, Q=1e2)

# J += NameComponentObjective(
#     :psi_iso,
#     [1:2*dim...],
#     [1:T...],
#     fisher_loss,
#     ([1:T...] ./ T).^2;
#     is_iso=true
# )

J = NameComponentObjective(
    :psi_iso,
    [1:2*dim...],
    [T],
    fisher_loss;
    is_iso=true
)

J += QC.QuadraticRegularizer(:dda, Z_guess, 1e-8/T)

In [None]:
J.L(Z_guess.datavec, Z_guess)

In [None]:
integrators = nothing
GC.gc()
integrators = [
    QC.QuantumStatePadeIntegrator(
        system,
        :psi_iso,
        (:a, :acc),
        :dts;
        order=4
    ),
    QC.DerivativeIntegrator(
        :a,
        :da,
        :dts,
        Z_guess
    ),
    QC.DerivativeIntegrator(
        :da,
        :dda,
        :dts,
        Z_guess
    )
]

In [None]:
dynamics = QC.QuantumDynamics(
    integrators,
    Z_guess;
    cuts=cuts
)

In [None]:
constraints = [
    OmegaAbsConstraint(1.0, Z_guess, Z_guess.components[:a]),
    vcat([get_link_constraints(
        :psi_iso, 
        Z_guess, 
        c, 
        g, 
        (; a=[1.0,0.0]), 
        (; a=[1.0,0.0]); 
        hard_equality_constraint=true)
        for (c, g) in zip(cuts, G)]...)...,
    TimeAffineLinearControlConstraint(:acc, 1, Z_guess; jumps=jumps),
    #custom_bounds_constraint(:a, Z_guess, vcat(cuts, cuts .+ 1), a_bound)
    # NameComponentPathConstraint(
    #     :a,
    #     [3],
    #     Z_guess,
    #     [t -> t],
    #     [x -> x]
    # )
]

In [None]:
# Ipopt options
options = QC.Options(
    max_iter=200,
)

In [None]:
# defining quantum control problem
prob = nothing
GC.gc()
prob = QC.QuantumControlProblem(
    system, 
    Z_guess, 
    J, 
    dynamics;
    constraints=constraints,
    options=options,
)

In [None]:
QC.solve!(prob)

In [None]:
Z = prob.trajectory

In [None]:
Z.psi_iso

In [None]:
#psi_iso_rollout = QC.rollout(psi0_iso, vcat(Z.a, Z.acc'), dts, system; integrator=exp)
#psi_iso_rollout = Z.psi_iso
psi_iso_rollout = shaken_lattice_rollout(psi0_iso, Z.a, dts, system, jumps, G; integrator=exp)
psi = hcat([QC.iso_to_ket(psi_iso_rollout[:,t]) for t=1:T]...)
pops = abs2.(psi)

In [None]:
psi_bloch = blockdiagonal(bloch_states', bloch_states') * psi
pops_bloch = abs2.(psi_bloch)

In [None]:
function format_plot(
    p,
    times=times,
    jumps=jumps,
    full_times=full_times,
)
    T = length(times)
    Plots.xticks!(p, (times[1:div(T,10):end], string.(round.(full_times[1:div(T,10):end]; digits=1))))
    Plots.xlims!(p, (times[1], times[end]))
    rel_y = 0.98
    for (cut, jump_time) in jumps
        rel_x = times[cut] / times[end]
        Plots.vline!(p, [times[cut]], color=:red, label=nothing)
        Plots.annotate!(((rel_x, rel_y), (string(round(jump_time; digits=1)), 10, :red)))
    end
    Plots.xlabel!(p, L"t \, [1/\omega_R]")
    return p
end

In [None]:
p = Plots.plot(times, pops[1:dim,:]', labels=(-p_max:1:p_max)', legend=:outertopright)
Plots.plot!(
    p,
    title=L"Momentum state population evolution $|n\rangle$"
)
format_plot(p)

In [None]:
p = Plots.plot(times, pops[dim+1:end,:]', labels=(-p_max:1:p_max)', legend=:outertopright)
format_plot(p)

In [None]:
p = Plots.plot(times, pops_bloch[1:dim,:]', labels=(0:dim-1)', legend=:outertopright)
Plots.plot!(
    p,
    title=L"Bloch state population evolution $|B\rangle$"
)
format_plot(p)

In [None]:
p = Plots.plot(times, pops_bloch[dim+1:end,:]', labels=(0:dim-1)', legend=:outertopright)
format_plot(p)

In [None]:
Plots.plot(times, Z.a')

In [None]:
Plots.plot(times, Z.dda')

In [None]:
Plots.plot(times, Z.acc')

In [None]:
phi = angle.(Z.a[1,:] + 1im * Z.a[2,:])
phi_mod_clean!(phi)

In [None]:
pi_ticks = LinRange(-pi, pi, 9)
pi_lbls = [L"-\pi", L"-3\pi/4", L"-\pi/2", L"-\pi/4", L"0", L"\pi/4", L"\pi/2", L"3\pi/4", L"\pi"]
p = Plots.plot(times, phi, ylim=(-pi, pi), yticks=(pi_ticks, pi_lbls), label="opt")
Plots.plot!(p, times, phi_guess, alpha=0.5, label="guess")
format_plot(p)
Plots.plot!(
    p,
    title=L"Phase protocol $\varphi(t)$",
    ylabel=L"\varphi"
)

In [None]:
fisher_evol = [-fisher_loss(QC.iso_to_ket(psi_iso_rollout[:,t])) for t=1:T];
p = Plots.plot(times, fisher_evol, label=nothing)
format_plot(p)
Plots.plot!(
    p,
    title=L"Fisher information $F(a)$",
    xlabel=L"t \, [1/\nu_R]",
    ylabel=L"F(a)"
)

In [None]:
log_sensitivity_evol = [log_sensitivity_loss(QC.iso_to_ket(psi_iso_rollout[:,t])) for t=1:T]
p = Plots.plot(times[2:end], log_sensitivity_evol[2:end], ylim=(-5., 5.), label=nothing)
format_plot(p)
Plots.plot!(
    p,
    title=L"Log-Sensitivity $\log_{10}(\delta a) = -0.5 \, \log_{10}(F(a))$",
    ylabel=L"\log_{10}(\delta a)"
)

In [None]:
F = fisher_evol[end]
F, 1/sqrt(F)

In [None]:
#freqs = collect(12.464:0.001:12.468)
freqs = collect(0.:0.1:40.0)
phi_ft = fourier_time_freq(phi, times, freqs/2pi);

In [None]:
Plots.plot(freqs, abs2.(phi_ft))#, ylims=(0.000895, 0.0009))

In [None]:
import JLD2

In [None]:
jumps

In [None]:
JLD2.save("./interferometer/176-5.0_352-5.0.jld2", Z)

In [None]:
Z_guess = Z

### Fisher range over a

In [None]:
acc_range = LinRange(-0.001, 0.001, 81)

In [None]:
fisher_vals = []

In [None]:
jumps

In [None]:
for acc_val in acc_range
    println(acc_val)
    system = ShakenLatticeSystem1D(V, p_max; acc=acc_val, include_acc_derivative=true)
    psi_iso_final = shaken_lattice_rollout(psi0_iso, Z.a, dts, system, jumps, 10000; integrator=exp)[:,end]
    push!(fisher_vals, -fisher_loss(psi_iso_final))
end


In [None]:
Plots.plot(acc_range, fisher_vals/MZFI)

In [None]:
jumps

In [None]:
a_full, dts_full = get_controls_dts(Z.a, vec(Z.dts), jumps, 10000)

In [None]:
phi_full = angle.(a_full[1,:] + im*a_full[2,:])
times_full = cumsum(dts_full) - dts_full

In [None]:
Plots.plot(times_full, phi_full)

In [None]:
psi_iso_full = shaken_lattice_rollout(psi0_iso, Z.a, vec(Z.dts), system, jumps, 10000)

In [None]:
T_full = length(dts_full)

In [None]:
psi_full = hcat([QC.iso_to_ket(psi_iso_full[:,t]) for t=1:T_full]...)
pops_full = abs2.(psi_full)

In [None]:
psi_bloch_full = bloch_states' * psi_full[1:dim,:]
pops_bloch_full = abs2.(psi_bloch_full)

In [None]:
Plots.plot(times_full, pops_full[1:dim,:]')

In [None]:
Plots.plot(times_full, pops_full[dim+1:end,:]')

In [None]:
Plots.plot(times_full, pops_bloch_full', label=(0:8)', legend=:outertopright)

In [None]:
p = Plots.plot(times, pops_bloch', labels=(0:dim-1)', legend=:outertopright)
Plots.plot!(
    p,
    title=L"Bloch state population evolution $|b\rangle$"
)
format_plot(p)

Z_save = copy(Z)
Z_guess = Z_save

## roll out repetition

In [None]:
function get_repeated_controls(a::AbstractMatrix, dts::AbstractVector, N::Int)
    T = length(dts)
    dts_long = repeat(dts, N)
    a_long = repeat(a, 1, N)
    times_long = cumsum(dts_long) - dts_long
    for n=2:N
        a_long[3,(n-1)*T+1:n*T] .+= (n-1)*times_long[T]
    end 
    return (a_long, dts_long, times_long)
end
function get_repeated_controls(Z::NT.NamedTrajectory, N::Int)
    return get_repeated_controls(Z.a, vec(Z.dts), N)
end

function get_repeated_controls_alternated(a::AbstractMatrix, dts::AbstractVector, N::Int)
    T = length(dts)
    dts_long = repeat(dts, N)
    a_long = repeat(a, 1, N)
    for n=2:2:N
        a_long[1:2,(n-1)*T+1:n*T] = a[1:2,end:-1:1]
    end
    times_long = cumsum(dts_long) - dts_long
    for n=2:N
        a_long[3,(n-1)*T+1:n*T] .+= (n-1)*times_long[T]
    end 
    return (a_long, dts_long, times_long)
end
function get_repeated_controls_alternated(Z::NT.NamedTrajectory, N::Int)
    return get_repeated_controls_alternated(Z.a, vec(Z.dts), N)
end


In [None]:
N = 5

In [None]:
a_long, dts_long, times_long = get_repeated_controls(Z, N)

In [None]:
Plots.plot(times_long, a_long')

In [None]:
psi_iso_long = QC.rollout(psi0_iso, a_long, dts_long, system; integrator=exp)

In [None]:
psi_long = hcat([QC.iso_to_ket(psi_iso_long[:,t]) for t=1:N*T]...)
pops_long = hcat([abs2.(QC.iso_to_ket(psi_iso_long[:,t])) for t=1:N*T]...)

In [None]:
Plots.plot(times_long, pops_long[1:dim,:]', labels=(-p_max:1:p_max)', legend=:outertopright)

In [None]:
Plots.plot(times_long, pops_long[dim+1:end,:]', labels=(-p_max:1:p_max)', legend=:outertopright)

In [None]:
P_expect = 2*pops_long[1:dim,:]' * collect(-p_max:p_max) + 1/4*system.params[:acc] * times_long
Plots.plot(times_long, P_expect)

In [None]:
fisher_evol = [-fisher_loss(QC.iso_to_ket(psi_iso_long[:,t])) for t=1:N*T];
Plots.plot(times_long, fisher_evol)

In [None]:
sensitivity_evol

In [None]:
sensitivity_evol = 1 ./ sqrt.(fisher_evol);
Plots.plot(times_long[1:end], sensitivity_evol[1:end], ylims=(0.0, 1.0))

In [None]:
a_long, dts_long = get_repeated_controls(Z_guess, N)

In [None]:
Plots.plot(times_long, a_long')

In [None]:
psi_iso_long = QC.rollout(psi0_iso, a_long, dts_long, system)#; integrator=exp)

In [None]:
psi_long = hcat([QC.iso_to_ket(psi_iso_long[:,t]) for t=1:N*T]...)
pops_long = hcat([abs2.(QC.iso_to_ket(psi_iso_long[:,t])) for t=1:N*T]...)

In [None]:
Plots.plot(times_long, pops_long[1:dim,:]', labels=(-p_max:1:p_max)', legend=:outertopright)

In [None]:
Plots.plot(times_long, pops_long[dim+1:end,:]', labels=(-p_max:1:p_max)', legend=:outertopright)

In [None]:
P_expect = 2*pops_long[1:dim,:]' * collect(-p_max:p_max) + 1/4*system.params[:acc] * times_long
Plots.plot(times_long, P_expect)

In [None]:
fisher_evol = [-fisher_loss(QC.iso_to_ket(psi_iso_long[:,t])) for t=1:N*T];
Plots.plot(times_long, fisher_evol)

In [None]:
sensitivity_evol = 1 ./ sqrt.(fisher_evol);
Plots.plot(times_long, sensitivity_evol, ylims=(0.0, 1.0))

In [None]:
sensitivity_evol