Skip to content

Commit

Permalink
Debug changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dexter2206 committed Feb 24, 2021
1 parent c460064 commit d2a9825
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 79 deletions.
69 changes: 38 additions & 31 deletions src/PEPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ mutable struct PepsNetwork <: AbstractGibbsNetwork
args::Dict

function PepsNetwork(
m::Int,
n::Int,
m::Int,
n::Int,
fg::MetaDiGraph,
β::Number,
β::Number,
origin::Symbol=:NW,
args_override::Dict=Dict()
)
)

pn = new((m, n))
pn.map, pn.i_max, pn.j_max = peps_indices(m, n, origin)
Expand Down Expand Up @@ -92,7 +92,7 @@ function MPO(::Type{T},

W = MPO(T, peps.j_max)
R = PEPSRow(T, peps, i)

for (j, A) enumerate(R)
v = get(config, j + peps.j_max * (i - 1), nothing)
if v !== nothing
Expand All @@ -113,14 +113,14 @@ function compress(ψ::AbstractMPS, peps::PepsNetwork)
Dcut = peps.args["bond_dim"]
if bond_dimension(ψ) < Dcut return ψ end
compress(ψ, Dcut, peps.args["var_tol"], peps.args["sweeps"])
end
end

@memoize function MPS(
peps::PepsNetwork,
i::Int,
cfg::Dict{Int, Int} = Dict{Int, Int}(),
)
if i > peps.i_max return MPS(I) end
if i > peps.i_max return MPS(I * 1.0) end
W = MPO(peps, i, cfg)
ψ = MPS(peps, i+1, cfg)
compress(W * ψ, peps)
Expand All @@ -129,7 +129,7 @@ end
function contract_network(
peps::PepsNetwork,
config::Dict{Int, Int} = Dict{Int, Int}(),
)
)
ψ = MPS(peps, 1, config)
prod(dropdims(ψ))[]
end
Expand All @@ -143,39 +143,39 @@ end

@inline function _get_local_state(
peps::PepsNetwork,
v::Vector{Int},
i::Int,
v::Vector{Int},
i::Int,
j::Int,
)
k = j + peps.j_max * (i - 1)
)
k = j + peps.j_max * (i - 1)
if k > length(v) || k <= 0 return 1 end
v[k]
end

function generate_boundary(
peps::PepsNetwork,
v::Vector{Int},
i::Int,
peps::PepsNetwork,
v::Vector{Int},
i::Int,
j::Int,
)
)
∂v = zeros(Int, peps.j_max + 1)

# on the left below
for k 1:j-1
∂v[k] = generate_boundary(peps,
∂v[k] = generate_boundary(peps,
(i, k), (i+1, k),
_get_local_state(peps, v, i, k))
end

# on the left at the current row
∂v[j] = generate_boundary(peps,
(i, j-1), (i, j),
∂v[j] = generate_boundary(peps,
(i, j-1), (i, j),
_get_local_state(peps, v, i, j-1))

# on the right above
for k j:peps.j_max
∂v[k+1] = generate_boundary(peps,
(i-1, k), (i, k),
(i-1, k), (i, k),
_get_local_state(peps, v, i-1, k))
end
∂v
Expand All @@ -184,28 +184,28 @@ end
@inline function _contract(
A::Array{T, 5},
M::Array{T, 3},
L::Vector{T},
L::Vector{T},
R::Matrix{T},
∂v::Vector{Int},
) where {T <: Number}

l, u = ∂v
@cast Ã[r, d, σ] := A[$l, $u, r, d, σ]
@tensor prob[σ] := L[x] * M[x, d, y] *
@tensor prob[σ] := L[x] * M[x, d, y] *
Ã[r, d, σ] * R[y, r] order = (x, d, r, y)
prob
end

function _normalize_probability(prob::Vector{T}) where {T <: Number}
# exceptions (negative pdo, etc)
# exceptions (negative pdo, etc)
# will be added here later
prob / sum(prob)
end
end

function conditional_probability(
peps::PepsNetwork,
v::Vector{Int},
)
)
i, j = get_coordinates(peps, length(v)+1)
∂v = generate_boundary(peps, v, i, j)

Expand All @@ -215,9 +215,12 @@ function conditional_probability(
L = left_env(ψ, ∂v[1:j-1])
R = right_env(ψ, W, ∂v[j+2:peps.j_max+1])
A = generate_tensor(peps, (i, j))

prob = _contract(A, ψ[j], L, R, ∂v[j:j+1])
_normalize_probability(prob)
println("pre CP ", prob)
ret = _normalize_probability(prob)
println("CP ", ret)
ret
end

_bond_energy(pn::AbstractGibbsNetwork,
Expand All @@ -238,10 +241,14 @@ function update_energy(

σkj = _get_local_state(network, σ, i-1, j)
σil = _get_local_state(network, σ, i, j-1)

_bond_energy(network, (i, j), (i, j-1), σil) +
_bond_energy(network, (i, j), (i-1, j), σkj) +
_local_energy(network, (i, j))

a = _bond_energy(network, (i, j), (i, j-1), σil)
b = _bond_energy(network, (i, j), (i-1, j), σkj)
c =_local_energy(network, (i, j))
println(size(a))
println(size(b))
println(size(c))
return a + b + c
end

function peps_indices(m::Int, n::Int, origin::Symbol=:NW)
Expand Down
5 changes: 3 additions & 2 deletions src/contractions.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export left_env, right_env, dot!

# --------------------------- Conventions -----------------------
#
#
# MPS MPS* MPO left env right env
# 2 2 2 - 1 2 -
# 1 - A - 3 1 - B - 3 1 - W - 3 L R
Expand Down Expand Up @@ -83,7 +83,8 @@ end

@memoize function right_env::AbstractMPS, W::AbstractMPO, σ::Union{Vector, NTuple})
l = length(σ)
k = length(ϕ)
#k = length(ϕ)
k = length(W)
if l == 0
R = ones(1, 1)
else
Expand Down
42 changes: 21 additions & 21 deletions src/network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ mutable struct NetworkGraph
end

mc = ne(fg)
if count < mc
error("Factor and Ising graphs are incompatible. Edges: $(count) vs $(mc).")
if count < mc
error("Factor and Ising graphs are incompatible. Edges: $(count) vs $(mc).")
end
ng
end
end

function _get_projector(
fg::MetaDiGraph,
v::Int,
fg::MetaDiGraph,
v::Int,
w::Int,
)
if has_edge(fg, w, v)
Expand All @@ -48,8 +48,8 @@ end

for w ng.nbrs[v]
pv = _get_projector(fg, v, w)
if pv === nothing
pv = ones(length(loc_exp), 1)
if pv === nothing
pv = ones(length(loc_exp), 1)
end
@cast tensor[(c, γ), σ] |= tensor[c, σ] * pv[σ, γ]
push!(dim, size(pv, 2))
Expand All @@ -61,51 +61,51 @@ end
fg = ng.factor_graph
if has_edge(fg, w, v)
_, e, _ = get_prop(fg, w, v, :split)
return exp.(-ng.β .* e')
return exp.(-ng.β .* (e' .- minimum(e)))
elseif has_edge(fg, v, w)
_, e, _ = get_prop(fg, v, w, :split)
return exp.(-ng.β .* e)
return exp.(-ng.β .* (e .- minimum(e)))
else
return ones(1, 1)
end
end

function generate_boundary(
ng::NetworkGraph,
v::Int,
w::Int,
ng::NetworkGraph,
v::Int,
w::Int,
state::Int
)
fg = ng.factor_graph
if v vertices(fg) return 1 end

pv = _get_projector(fg, v, w)
if pv === nothing
pv = ones(get_prop(fg, v, :loc_dim), 1)
if pv === nothing
pv = ones(get_prop(fg, v, :loc_dim), 1)
end
findfirst(x -> x > 0, pv[state, :])
end

function bond_energy(
ng::NetworkGraph,
v::Int,
u::Int,
ng::NetworkGraph,
u::Int,
v::Int,
σ::Int,
)
)
fg = ng.factor_graph

if has_edge(fg, u, v)
if has_edge(fg, u, v)
get_prop(fg, u, v, :edge).J[:, σ]
elseif has_edge(fg, v, u)
get_prop(fg, v, u, :edge).J[σ, :]
else
zeros(get_prop(fg, v, :loc_dim))
zeros(get_prop(fg, u, :loc_dim))
end
end

function local_energy(
ng::NetworkGraph,
v::Int,
ng::NetworkGraph,
v::Int,
)
get_prop(ng.factor_graph, v, :loc_en)
end

0 comments on commit d2a9825

Please sign in to comment.