In [1]:
using Revise
using PastaQ
using ITensors
using Random
using OptimKit
using Zygote
using Zygote: ChainRulesCore
using BenchmarkTools
using LinearAlgebra
using JLD2
using Flux
using PyCall
using SymPy
using QOS

import mVQE
using mVQE.Hamiltonians: hamiltonian_tfi, hamiltonian_ghz, hamiltonian_aklt_half
using mVQE.ITensorsExtension: projective_measurement
using mVQE: loss, optimize_and_evolve
using mVQE.Circuits: AbstractVariationalCircuit, VariationalCircuitRy, VariationalMeasurement, VariationalMeasurementMC, VariationalMeasurementMCFeedback
using mVQE.Misc: get_ancillas_indices, pprint
using mVQE.Optimizers: OptimizerWrapper
using mVQE.pyflexmps: pfs

┌ Info: Precompiling PastaQ [30b07047-aa8b-4c78-a4e8-24d720215c19]
└ @ Base loading.jl:1664
┌ Info: Precompiling mVQE [fbf8b4f3-d5ee-4fcd-97a7-1cb357585aed]
└ @ Base loading.jl:1664


In [2]:
N_state = 4 * 10
state_indices, ancilla_indices, N = get_ancillas_indices(N_state, [false, true, true, true, true, false])
hilbert = qubits(N)

hilbert_state = hilbert[state_indices]
hilbert_ancilla = hilbert[ancilla_indices]

ψ = productstate(hilbert, fill(0, N))
ρ = outer(ψ, ψ')
aklts = mVQE.StateFactory.AKLT_halfs(hilbert[state_indices], basis="girvin");
state_indices, ancilla_indices, N

([2, 3, 4, 5, 8, 9, 10, 11, 14, 15  …  46, 47, 50, 51, 52, 53, 56, 57, 58, 59], [1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 42, 43, 48, 49, 54, 55, 60], 60)

In [5]:
aklts = mVQE.StateFactory.AKLT_halfs(hilbert[state_indices], basis="girvin")
H, = hamiltonian_aklt_half(hilbert_state)
Htot, Htot_aklt, Htot_spin1 = hamiltonian_aklt_half(hilbert, sublattice=state_indices);

In [4]:
gcirc = mVQE.GirvinProtocol.GirvinCircuitIdeal(N_state);
corrcirc = mVQE.GirvinProtocol.GirvinCorrCircuit();

In [6]:
ψ2 = gcirc(ψ);
ψp, measurement = mVQE.ITensorsExtension.projective_measurement_sample(ψ2; indices=ancilla_indices)
params = mVQE.GirvinProtocol.param_correction_gates(measurement.-1);
ψp_corr = corrcirc(ψp; params)
inner(ψp_corr, Htot, ψp_corr')

-7.254215219086706e-15 + 1.5806638110778084e-16im

VariationalMeasurementMCFeedback

In [7]:
vmodels = [mVQE.GirvinProtocol.GirvinCircuitIdeal(N_state), mVQE.GirvinProtocol.GirvinCorrCircuit(Int(N_state/2))]

g(a, b)= mVQE.GirvinProtocol.param_correction_gates
    
model = VariationalMeasurementMCFeedback(vmodels, [g], ancilla_indices);

In [8]:
ψ_aklt = model(ψ);
inner(ψ_aklt, Htot, ψ_aklt')

-7.094734055754784e-15 - 3.0467924655566587e-17im

# Training

In [9]:
vmodels = [mVQE.GirvinProtocol.GirvinCircuitIdeal(N_state), mVQE.GirvinProtocol.GirvinCorrCircuit(Int(N_state/2))]
#vmodels = [mVQE.GirvinProtocol.GirvinCircuit(Int(N_state/4)), mVQE.GirvinProtocol.GirvinCorrCircuit(Int(N_state/2))]
vmodels = [VariationalCircuitRy(N, 1) for _ in 1:2]
    
g(a, b)= mVQE.GirvinProtocol.GirvinCorrectionNetwork()
dense(x, y) = Flux.Dense(Flux.glorot_uniform(y, x)/10, 2π .* rand(y))

model = VariationalMeasurementMCFeedback(vmodels, [dense], ancilla_indices);

In [41]:
model2 = VariationalMeasurementMC(vmodels[1:2], ancilla_indices)

VariationalMeasurementMC(AbstractVariationalCircuit[VariationalCircuitRy(N=60, depth=1), VariationalCircuitRy(N=60, depth=1)], [1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 42, 43, 48, 49, 54, 55, 60], 1)

In [42]:
ψ_aklt = model(ψ; maxdim=10, cutoff=1e-5);
inner(ψ_aklt, Htot, ψ_aklt')

10.555591266516911

In [43]:
loss(ψ, Htot, model.vcircuits[1])

11.069049052327868

In [44]:
mVQE.loss(ψ, Htot, model; maxdim=10, cutoff=1e-5)

10.947517980151352

In [45]:
pathof(mVQE.ITensors)

"/Users/alcalde/workprojects/forks/ITensors.jl/src/ITensors.jl"

In [46]:
H_tfi = MPO(hamiltonian_tfi(state_indices, 0.1), hilbert);

In [None]:
l, grad = @time mVQE.loss_and_grad(ψ, H_tfi, model2)
l, grad[grad.params[1]]

In [52]:
vmodels = [VariationalCircuitRy(N, 2) for _ in 1:2]
model2 = VariationalMeasurementMC(vmodels, ancilla_indices)
l, grad = @time mVQE.loss_and_grad(ψ, H_tfi, model2)
l, grad[grad.params[1]]

  1.955944 seconds (7.28 M allocations: 945.135 MiB, 8.92% gc time, 58.35% compilation time)


(2.7285647129146007, [NaN NaN; NaN NaN; … ; NaN NaN; NaN NaN])

└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617
└ @ Revise /Users/alcalde/.julia/packages/Revise/do2nH/src/packagedef.jl:617


In [38]:
model2.vcircuits[]

1-element Vector{AbstractVariationalCircuit}:
 VariationalCircuitRy(N=60, depth=1)

In [None]:
grad[grad.params[1]]

In [30]:
grad[model.vcircuits[1].params];

In [40]:
optimzer = OptimizerWrapper(ADAM(0.1); gradtol=1e-3, maxiter=10, verbosity=10)

OptimizerWrapper(ADAM(0.1, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}()), 10, 0.001, 10)

In [95]:
loss_value, trained_model, ρ, misc = optimize_and_evolve(ψ, Htot, model; samples=10,
                                           optimizer=optimzer, verbose=true, parallel=false)

┌ Info: ADAM: iter 1: f = 0.7407407409616755, ‖∇f‖ = 0.0, ‖θ‖ = 20.352104255980546
└ @ mVQE.Optimizers /local/alcalde/workprojects/mVQE/mVQE/src/Optimizers.jl:41


(0.7407407409616755, VariationalMeasurementMCFeedback(AbstractVariationalCircuit[mVQE.GirvinProtocol.GirvinCircuit([0.0 -1.5707963267948966 … 1.9106332362490186 3.141592653589793; 3.141592653589793 -1.5707963267948966 … 1.9106332362490186 3.141592653589793; … ; 3.141592653589793 -1.5707963267948966 … 1.9106332362490186 3.141592653589793; 3.141592653589793 -1.5707963267948966 … 1.9106332362490186 3.141592653589793]), mVQE.Circuits.FeedbackCircuit(mVQE.GirvinProtocol.GirvinCorrCircuit([4.591472781039895 0.2800606390466213 4.146049185538529 2.810284712629163; 6.087338315029476 0.06627693469669528 2.1147417087190084 4.626359206140794; … ; 0.12328778398119601 1.302111353606939 3.7297215037556803 0.6941629319075078; 3.7543774681516697 2.3331982789813597 5.943467702674782 2.692997834408845]), mVQE.Circuits.ReshapeModel(mVQE.GirvinProtocol.param_correction_gates, (20, 4)))], [1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 42, 43, 48, 49, 54, 55, 60], 1), MPS
[1] ((dim=2|id=991|"Link,n=1"), (

In [94]:
misc

Dict{String, Any} with 4 entries:
  "niter"    => 0
  "history"  => [0.0 0.0 0.0; 0.0 0.0 6.92085e-310; … ; 0.0 0.0 0.0; 0.0 0.0 0.…
  "gradient" => Grads(...)
  "loss"     => 0.740741