In [None]:
import QuantumCollocation as QC
import NamedTrajectories as NT
import TrajectoryIndexingUtils as NTidx
import LinearAlgebra as LA
import SparseArrays as SA
import ForwardDiff as FD
import Plots
import Interpolations as IP
using LaTeXStrings
import JLD2

In [None]:
include("utils.jl")
include("system.jl")
include("constraints.jl")
include("objectives.jl")

In [None]:
V = 10.
trunc = 13
# E_R [kHz] found in Weidner thesis
system = ShakenLatticeSystem1D(V, trunc; bloch_basis=true, bloch_transformation_trunc=3*trunc, include_V_derivative=true)
# middle index of statevector where p = 0
mid = system.params[:mid]
dim = system.params[:dim]
E_R = system.params[:E_R]
#E_R = 1/0.05
#print("time unit $(1/E_R) ms\nE_R = $E_R kHz")

In [None]:
system.params[:bloch_energies]

In [None]:
duration = 2pi * 1.0 # in units of 1/E_R

T = 601
dt = duration / (T-1)
dts = zeros(T) .+ dt
dt_bound = (dt, dt)
times = cumsum(dts) - dts
duration

In [None]:
a_bound = fill(1.0, 2)
phi_bound = [1.0 * pi]
#dphi_bound = [100.]
#ddphi_bound = [3000.]

#phi = collect(sin.(11.5 *times)')
phi = rand(1, T)
a = vcat(cos.(phi), sin.(phi))

dphi = NT.derivative(phi, dts)
ddphi = NT.derivative(dphi, dts)

In [None]:
Plots.plot(times, a')

In [None]:
Plots.plot(times, phi')

In [None]:
Plots.plot(times, dphi')

In [None]:
Plots.plot(times, ddphi')

psi0 = zeros(system.params[:dim])
psi0[mid] = 0.
psi0[mid-1] = sqrt(0.5)
psi0[mid+1] = sqrt(0.5)
normalize!(psi0)

In [None]:
#psi0 = get_bloch_state(system; lvl=0)
psi0 = QC.cavity_state(0, dim)
append!(psi0, zeros(dim))

In [None]:
#Plots.bar(-p_max:p_max, abs2.(psi0))
Plots.bar(0:dim-1, abs2.(psi0[1:dim]))

In [None]:
psi0_iso = QC.ket_to_iso(psi0)

In [None]:
psi_iso = QC.rollout(psi0_iso, a, dts, system; integrator=exp)

In [None]:
psi_iso += rand(size(psi_iso)...)

In [None]:
psi = hcat([QC.iso_to_ket(psi_iso[:,t]) for t=1:T]...)
pops = abs2.(psi)

In [None]:
#Plots.plot(times, pops[:,:]', labels=(-p_max:1:p_max)', legend=:outertopright)#, xlim=(0.0, 2.0))
Plots.plot(times, pops[1:dim,:]', labels=(0:dim-1)', legend=:outertopright)#, xlim=(0.0, 2.0))

In [None]:
comps = (
    psi_iso = psi_iso,
    a = a,
    phi = phi,
    # dphi = dphi,
    # ddphi = ddphi,
    dts = dts
)
initial = (
    psi_iso = psi0_iso,
    phi = [0.],
    #dphi = [0.]
)
final = (;
    phi = [0.],
    #dphi = [0.]
)
goal = (;)
bounds = (
    phi = phi_bound,
    # ddphi = ddphi_bound,
    dts = dt_bound
)

Z_guess = nothing
GC.gc()
Z_guess = NT.NamedTrajectory(
    comps;
    controls=(:phi),
    timestep=:dts,
    bounds=bounds,
    initial=initial,
    final=final,
    goal=goal
)

In [None]:
R = 1.0 * collect(LA.I(dim))
R[8,8] = R[9,9] = 0.
R = QC.QuantumSystems.iso(R)
#R[8,8] = R[9+dim,9+dim] = 0. # this should get populations right AND fix y-z greatcircle
R

In [None]:
kernel = sinc_kernel(50., vec(Z_guess.dts))
convolver = LA.I(Z_guess.T) - kernel
convolver = convolver' * convolver

In [None]:
state_goal = QC.cavity_state(3, dim)
append!(state_goal, zeros(dim))

In [None]:
#J = NameComponentObjective(:psi_iso, [1:dim...], [1:T...], x -> 1 - abs2.(x' * state_goal), ([1:T...] ./ T) .* 100.; is_iso=true)
#J1 = QC.QuantumObjective(name=:psi_iso, goals=QC.ket_to_iso(state_goal), loss=:InfidelityLoss, Q=100.0)
#J += QC.QuadraticRegularizer(:dts, Z_guess, 0.01)
#J += NameComponentQuadraticRegularizer(:a, [2], Z_guess, [0.001])

J1 = QuadraticObjective(:psi_iso, Z_guess, R, vcat(1:dim, (2dim+1):3dim); Q=200.0)
# J2 = NameComponentObjective(:psi_iso, [8, 9], [T], x -> real(x[1]'*x[2])^2, [100.0]; is_iso=true)
# J4 = QC.QuadraticRegularizer(:ddphi, Z_guess, 1e-8/T)
# J5 = QC.QuadraticRegularizer(:dts, Z_guess, 1e1/T)
J5 = NameComponentQuadraticRegularizer(:psi_iso, [13], Z_guess, [1e1/T]; is_iso=true)

# convolution
J4 = QuadraticObjective(:phi, Z_guess, convolver, [1], 1:Z_guess.T; Q=50.0/T)

J6 = NameComponentQuadraticRegularizer(:psi_iso, (dim+1):2*dim, Z_guess, fill(1e4/T/dim, dim); is_iso=true)

J = J1 + J4 + J5 + J6

In [None]:
J.L(Z_guess.datavec, Z_guess)

In [None]:
J6.L(Z_guess.datavec, Z_guess)

In [None]:
integrators = nothing
GC.gc()
integrators = [
    QC.QuantumStatePadeIntegrator(
        system,
        :psi_iso,
        :a,
        :dts;
        order=4
    ),
    # QC.DerivativeIntegrator(
    #     :phi,
    #     :dphi,
    #     :dts,
    #     Z_guess
    # ),
    # QC.DerivativeIntegrator(
    #     :dphi,
    #     :ddphi,
    #     :dts,
    #     Z_guess
    # )
]

In [None]:
constraints = [
    IQPhiConstraint(:a, :phi, Z_guess),
    FinalYZGreatCircleConstraint(:psi_iso, [8, 9], Z_guess)
    #LinearSincConvolutionConstraint(:phi, :dts, Z_guess, 60.)
    #OmegaAbsConstraint(1.0, Z_guess),
    #PhiSincConvolutionConstraint(:a, :dts, Z_guess, 80.),
    #PhiFunctionBoundConstraint(phase_bound, Z_guess),
    #TimeSymmetricControlConstraint(:a, Z_guess)
    #custom_bounds_constraint(:a, Z_guess, Int[], a_bound)
]

In [None]:
# Ipopt options
options = QC.Options(
    max_iter=100,
)

In [None]:
# defining quantum control problem
prob = nothing
GC.gc()
prob = QC.QuantumControlProblem(
    system, 
    Z_guess, 
    J, 
    integrators;
    constraints=constraints,
    options=options,
)

In [None]:
QC.solve!(prob)

In [None]:
Z = nothing
GC.gc()
Z = prob.trajectory

In [None]:
Z.psi_iso

In [None]:
psi_iso_rollout = QC.rollout(psi0_iso, Z.a, Z.dts, system; integrator=exp)
psi = hcat([QC.iso_to_ket(psi_iso_rollout[:,t]) for t=1:T]...)
pops = abs2.(psi)

psi = hcat([QC.iso_to_ket(Z.psi_iso[:,t]) for t=1:T]...)
pops = hcat([abs2.(QC.iso_to_ket(Z.psi_iso[:,t])) for t=1:T]...)

In [None]:
psi[:,end]

In [None]:
1 - J1.L(Z.datavec, Z)/100

In [None]:
J4.L(Z.datavec, Z)

In [None]:
times = cumsum(Z.dts[1,:]) - Z.dts[1,:]

In [None]:
p = Plots.plot(times, pops[1:dim,:]', labels=(0:dim-1)', legend=:outertopright, size=(500, 300))
Plots.xaxis!(p, 
    xlabel=L"$t$ $[1/\omega_R]$"
)
Plots.yaxis!(p, 
    #ylabel=L"population of momentum state $|p\rangle$"
    ylabel="Bloch state population"
)
Plots.title!(p, "Splitting shaking sequence")

In [None]:
p = Plots.plot(times, pops[dim+1:end,:]', labels=(0:dim-1)', legend=:outertopright, size=(500, 300))
Plots.xaxis!(p, 
    xlabel=L"$t$ $[1/\omega_R]$"
)
Plots.yaxis!(p, 
    #ylabel=L"population of momentum state $|p\rangle$"
    ylabel="Bloch state population"
)
Plots.title!(p, "Splitting shaking sequence - Diff states")

In [None]:
∂pops = 2*real.(psi[1:dim,:].*conj.(psi[dim+1:end,:]));

In [None]:
p = Plots.plot(times, ∂pops', labels=(0:dim-1)', legend=:outertopright, size=(500, 300))
Plots.xaxis!(p, 
    xlabel=L"$t$ $[1/\omega_R]$"
)
Plots.yaxis!(p, 
    #ylabel=L"population of momentum state $|p\rangle$"
    ylabel="Bloch state population"
)
Plots.title!(p, L"Splitting shaking sequence - $\partial_V Pops$")

In [None]:
Plots.plot(times, pops[end,:])

blochs = bloch_states' * psi
bloch_pops = abs2.(blochs)

p = Plots.plot(times, bloch_pops[:,:]', labels=(0:dim-1)', legend=:outertopright, size=(500, 300))
Plots.xaxis!(p, 
    xlabel=L"$t$ $[1/\nu_R]$"
)
Plots.yaxis!(p, 
    ylabel=L"population of Bloch state $|B\rangle$"
)
Plots.title!(p, "Splitting shaking sequence")

In [None]:
Plots.plot(times, Z.a')

In [None]:
Plots.plot(times, Z.phi')

In [None]:
dphi = NT.derivative(Z.phi, dts)
ddphi = NT.derivative(dphi, dts);

In [None]:
Plots.plot(times, dphi')

In [None]:
Plots.plot(times, ddphi')

In [None]:
p = Plots.plot(times, Z.phi', linecolor=:blue, label=L"\varphi(t)", size=(500, 300))
Plots.xaxis!(p, 
    xlabel=L"$t$ $[1/\omega_R]$"
)
Plots.yaxis!(p, 
    ylabel="shaking amplitude"
)
Plots.title!(p, "Splitting shaking sequence")

In [None]:
Z_guess = Z

In [None]:
import JLD2

duration_wr = round(duration/2pi; digits=2)

duration = sum(Z.dts) - Z.dts[end]
duration_wr = round(duration/2pi; digits=2)
s = """
Final infidelity (Bloch 3) in %:
$(J1.L(Z.datavec, Z))
"""
write("$(duration_wr)wr.txt", s)

In [None]:
JLD2.save("interferometer/split_bloch78_Vrobust.jld2", Z)

In [None]:
slice = 1:T #div(2*T,3):T

In [None]:
freqs = collect(0:0.1:60)
phi_ft = fourier_time_freq(Z.phi[1,slice], times[slice], freqs/2pi);

In [None]:
p = Plots.plot(freqs, abs2.(phi_ft), size=(500, 300), label=nothing)
Plots.xaxis!(p, 
    xlabel=L"$\omega$ $[\omega_R]$"
)
Plots.yaxis!(p, 
    ylabel="Fourier amplitude"
)
Plots.title!(p, "Shaking protocol spectrum")