In [None]:
using ITensors
using StatsBase
using ProgressMeter
using JLD
using Dates

global const nsweeps::Int = 100
global const maxdim::Int = 1000
global const cutoff::Float64 = 1e-9
global const noise::Vector{Float64} = [1e-4]

global const observer::DMRGObserver{Float64} = DMRGObserver(energy_tol=1e-6)

global const rho::Float64 = (sqrt(5) - 1) / 2  # 1 / phi
global const rho2::Float64 = (3 - sqrt(5)) / 2;  # 1 / phi^2

In [None]:
function build_HAos(Ad::Matrix{Int64}, s::Int, t::Int, A=50)
    n = size(Ad)[1]

    edges = Tuple.(findall(Ad .!= 0))
    N = size(edges)[1]

    os = OpSum()

    J = zeros(N, N)
    K = zeros(N)

    for i in 1:n
        si = findfirst(edges .== ((s, i),))
        is = findfirst(edges .== ((i, s),))

        ti = findfirst(edges .== ((t, i),))
        it = findfirst(edges .== ((i, t),))

        for j in 1:n
            sj = findfirst(edges .== ((s, j),))
            js = findfirst(edges .== ((j, s),))

            # build Hs
            if (sj != nothing) && (si != nothing) && (sj != si)
                os .+= 1, "Z", sj, "Z", si
                J[sj, si] += 1
            end

            if (js != nothing) && (is != nothing) && (js != is)
                os .+= 1, "Z", js, "Z", is
                J[js, is] += 1
            end

            if (js != nothing) && (si != nothing) && (js != si)
                os .+= -1, "Z", js, "Z", si
                J[js, si] += -1
            end

            if (sj != nothing) && (is != nothing) && (sj != is)
                os .+= -1, "Z", sj, "Z", is
                J[sj, is] += -1
            end

            tj = findfirst(edges .== ((t, j),))
            jt = findfirst(edges .== ((j, t),))

            # build Ht
            if (tj != nothing) && (ti != nothing) && (tj != ti)
                os .+= 1, "Z", tj, "Z", ti
                J[tj, ti] += +1
            end

            if (jt != nothing) && (it != nothing) && (jt != it)
                os .+= 1, "Z", jt, "Z", it
                J[jt, it] += +1
            end

            if (jt != nothing) && (ti != nothing) && (jt != ti)
                os .+= -1, "Z", jt, "Z", ti
                J[jt, ti] += -1
            end

            if (tj != nothing) && (it != nothing) && (tj != it)
                os .+= -1, "Z", tj, "Z", it
                J[tj, it] += -1
            end

            # build Hij
            if (i != s) && (i != t)
                if i == j
                    continue
                end

                ij = findfirst(edges .== ((i, j),))
                ji = findfirst(edges .== ((j, i),))
                for k in 1:n
                    if i == k
                        continue
                    end
                    ki = findfirst(edges .== ((k, i),))
                    ik = findfirst(edges .== ((i, k),))

                    if (ij != nothing) && (ik != nothing) && (ij != ik)
                        os .+= 1, "Z", ij, "Z", ik
                        J[ij, ik] += 1
                    end

                    if (ji != nothing) && (ki != nothing) && (ji != ki)
                        os .+= 1, "Z", ji, "Z", ki
                        J[ji, ki] += 1
                    end

                    if (ji != nothing) && (ik != nothing) && (ji != ik)
                        os .+= -1, "Z", ji, "Z", ik
                        J[ji, ik] += -1
                    end

                    if (ij != nothing) && (ki != nothing) && (ij != ki)
                        os .+= -1, "Z", ij, "Z", ki
                        J[ij, ki] += -1
                    end
                end
            end
        end

        # Hs
        if (si != nothing)
            os .+= -4, "Z", si
            K[si] += -4
        end
        if (is != nothing)
            os .+= +4, "Z", is
            K[is] += +4
        end
        # Ht
        if (ti != nothing)
            os .+= +4, "Z", ti
            K[ti] += +4
        end
        if (it != nothing)
            os .+= -4, "Z", it
            K[it] += -4
        end
    end

    J = J .+ transpose(J)

    os = OpSum()
    for i in 1:N
        for j in i+1:N
            if J[i, j] != 0
                os += J[i, j], "Z", i, "Z", j
            end
        end
    end

    for i in 1:N
        if K[i] != 0
            os += K[i], "Z", i
        end
    end

    os *= A / 4
    J *= A / 4
    K *= A / 4
    return os, J, K
end


function build_HBos(N::Int, W, B=1)
    os = OpSum()
    for j in 1:N
        os .+= B/2*W[j], "Z", j
    end

    return os
end


function build_Hxos(N::Int)
    os = OpSum()
    for j in 1:N
        os .+= -1, "X", j
    end

    return os
end;

In [None]:
function compute_delta(sites::Vector{Index{Int64}}, N::Int, s::Float64, HA::MPO, HB::MPO, Hx::MPO;
                       psi0_init=MPS()::MPS, psi1_init=MPS()::MPS)
    H = (1-s)*Hx + s*(HA + HB)

    if psi0_init.rlim == 0
        state = StatsBase.sample(["Up", "Dn"], N)
        psi0_init = MPS(sites, state)
    end

    E0, psi0 = dmrg(H, psi0_init; nsweeps, maxdim, cutoff=cutoff, noise=noise,
                    eigsolve_krylovdim=4, outputlevel=0, observer=observer)

    if psi1_init.rlim == 0
        state = StatsBase.sample(["Up", "Dn"], N)
        psi1_init = MPS(sites, state)
    end

    E1, psi1 = dmrg(H, [psi0], psi1_init; nsweeps, maxdim, cutoff=cutoff, noise=noise,
                    eigsolve_krylovdim=5, outputlevel=0, observer=observer, weight=2)

    if E1 > E0
        return E1-E0, psi0, psi1
    else
        return E0-E1, psi1, psi0
    end
end


function find_delta_min(sites::Vector{Index{Int64}}, N::Int, p0::Vector{Union{Nothing, Int64}}, p1::Vector{Union{Nothing, Int64}}, HA::MPO, HB::MPO, Hx::MPO; tol=1e-4::Float64)
    a = 0
    b = 0.6

    h = b - a
    c = a + rho2 * h
    d = a + rho * h

    # Required steps to achieve tolerance
    n = trunc(Int, ceil(log(tol / h) / log(rho)))

    ground_state = ["Dn" for i in 1:N]
    for e in p0
        ground_state[e] = "Up"
    end

    excited_state = ["Dn" for i in 1:N]
    for e in p1
        excited_state[e] = "Up"
    end

    psi0 = MPS(sites, ground_state)
    psi1 = MPS(sites, excited_state)

    println(Dates.format(now(), "HH:MM"))

    _, psi0, psi1 = compute_delta(sites, N, 0.7, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)

    println(Dates.format(now(), "HH:MM"))
    _, psi0, psi1 = compute_delta(sites, N, 0.5, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)

    println(Dates.format(now(), "HH:MM"))
    yc, psi0, psi1 = compute_delta(sites, N, c, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
    println(Dates.format(now(), "HH:MM"))
    yd, psi0, psi1 = compute_delta(sites, N, d, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)

    println(Dates.format(now(), "HH:MM"))

    for k in 1:n
        if yc < yd
            b = d
            d = c
            yd = yc
            h = rho * h
            c = a + rho2 * h

            yc, psi0, psi1 = compute_delta(sites, N, c, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
        else
            a = c
            c = d
            yc = yd
            h = rho * h
            d = a + rho * h

            yd, psi0, psi1 = compute_delta(sites, N, d, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
        end
    end

    if yc < yd
        sc = (a + d) / 2
    else
        sc = (b + c) / 2
    end

    dmin, psi0, psi1 = compute_delta(sites, N, sc, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)

    return sc, dmin
end;

In [None]:
ITensors.Strided.disable_threads()
ITensors.Strided.disable_threaded_mul()

ITensors.enable_combine_contract()
ITensors.enable_contraction_sequence_optimization()

In [None]:
n = 9
graph_file = "data/graph_n=$n.jld"

Ad = load(graph_file, "Ad")
source = load(graph_file, "source")
target = load(graph_file, "target")

edges = Tuple.(findall(Ad .!= 0))
N = size(edges)[1]
sites = siteinds("S=1/2", N)

A = N
B = 1

os_HA, J, K = build_HAos(Ad, source, target, A)
os_Hx = build_Hxos(N)

HA = MPO(os_HA, sites)
Hx = MPO(os_Hx, sites)

e, V = eigen(J)
e = round.(e, digits=3)
nd = size(e[abs.(e .- e[1]) .== 0])[1]

W = zeros(N)
for i in 1:N
    if reverse(edges[i]) in edges[1:i-1]
        W[i] = W[findfirst(edges .== (reverse(edges[i]),))]
    else
        W[i] = randn() ./ 6 .+ 0.5
    end
end
W[W .> 1] .= 1
W[W .< 0] .= 0

os_HB = build_HBos(N, W, B)
HB = MPO(os_HB, sites);

In [None]:
using PyCall

py"""
from itertools import islice

import networkx as nx
from networkx.generators.harary_graph import hkn_harary_graph

def two_sp(n, s, t, edges, W):
    G = hkn_harary_graph(3, n)
    G = nx.DiGraph(G)

    for i, e in enumerate(edges):
        G[e[0]][e[1]]['weight'] = W[i]

    p0, p1 = list(islice(nx.shortest_simple_paths(G, s, t, weight="weight"), 2))
    
    return p0, p1
"""

In [None]:
py_edges = [(e[1]-1, e[2]-1) for e in edges]
p0, p1 = py"two_sp"(n, source-1, target-1, py_edges, W)

p0 .+= 1
p1 .+= 1

p0 = [(p0[i], p0[i+1]) for i in 1:(length(p0)-1)]
p1 = [(p1[i], p1[i+1]) for i in 1:(length(p1)-1)]

p0 = indexin(p0, edges)
p1 = indexin(p1, edges);

In [None]:
@time find_delta_min(sites, N, p0, p1, HA, HB, Hx)
@time find_delta_min(sites, N, p0, p1, HA, HB, Hx)

In [None]:
S = [1-0.01*i for i in 1:100]

ground_state = ["Dn" for i in 1:N]
for e in p0
    ground_state[e] = "Up"
end

excited_state = ["Dn" for i in 1:N]
for e in p1
    excited_state[e] = "Up"
end

psi0 = MPS(sites, ground_state)
psi1 = MPS(sites, excited_state)

dminS = zeros(100)
@showprogress for (i, s) in enumerate(S)
    @time dminS[101-i], psi0, psi1 = compute_delta(sites, N, s, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
end

In [None]:
using Plots

plot(reverse(S), dminS)

In [None]:
reverse(S)[argmin(dminS)]

In [None]:
minimum(dminS)

In [None]:
a = 0
b = 0.6

h = b - a
c = a + rho2 * h
d = a + rho * h

In [None]:
ground_state = ["Dn" for i in 1:N]
for e in p0
    ground_state[e] = "Up"
end

excited_state = ["Dn" for i in 1:N]
for e in p1
    excited_state[e] = "Up"
end

psi0 = MPS(sites, ground_state)
psi1 = MPS(sites, excited_state)

@time _, psi0, psi1 = compute_delta(sites, N, 0.7, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
@time _, psi0, psi1 = compute_delta(sites, N, 0.5, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
# @time _, psi0, psi1 = compute_delta(sites, N, 0.45, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
# @time _, psi0, psi1 = compute_delta(sites, N, 0.40, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
@time _, psi0, psi1 = compute_delta(sites, N, 0.35, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
@time _, psi0, psi1 = compute_delta(sites, N, 0.3, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
@time _, psi0, psi1 = compute_delta(sites, N, 0.28, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
@time _, psi0, psi1 = compute_delta(sites, N, 0.26, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
# @time _, psi0, psi1 = compute_delta(sites, N, c, HA, HB, Hx, psi0_init=psi0, psi1_init=psi1)
