Skip to content

Commit

Permalink
Merge pull request #58 from iitis/lp/memoize
Browse files Browse the repository at this point in the history
Add examples with `@memoize`
  • Loading branch information
bartekGardas committed Feb 11, 2021
2 parents da1b158 + 75f11f1 commit 30e0fc6
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Expand Up @@ -7,10 +7,12 @@ version = "0.0.1"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LowRankApprox = "898213cb-b102-5a47-900c-97e73b919f73"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
55 changes: 55 additions & 0 deletions benchmarks/memoize_example.jl
@@ -0,0 +1,55 @@
using Memoize, LRUCache
using SpinGlassPEPS, TensorOperations, TensorCast
using LinearAlgebra

@memoize function left_env2::AbstractMPS, idx::NTuple)
l = length(idx)
if l == 0
L = [1.]
else
m = idx[l]
new_idx = idx[1:l-1]
L_old = left_env2(ϕ, new_idx)
M = ϕ[l]
@reduce L[x] := sum(α) L_old[α] * M[α, $m, x]
end
return L
end

@memoize function right_env2::AbstractMPS, W::AbstractMPO, idx::NTuple)
l = length(idx)
L = length(ϕ)
if l == 0
R = fill(1., 1, 1)
else
m = idx[1]
new_idx = idx[2:l]
R_old = right_env2(ϕ, W, new_idx)
M = ϕ[L-l+1]
= W[L-l+1]
@reduce R[x, y] := sum(α, β, γ) M̃[y, $m, β, γ] * M[x, γ, α] * R_old[α, β]
end
return R
end

@memoize function left_env3::AbstractMPS, ψ::AbstractMPS)
l = length(ψ)
T = promote_type(eltype(ψ), eltype(ϕ))

L = Vector{Matrix{T}}(undef, l+1)
L[1] = ones(eltype(ψ), 1, 1)

for i 1:l
M = ψ[i]
= conj.(ϕ[i])

C = L[i]
@tensor C[x, y] := M̃[β, σ, x] * C[β, α] * M[α, σ, y] order = (α, β, σ)
L[i+1] = C
end
L
end

ϕ = randn(MPS{Float64}, 10, 10, 10);
ψ = randn(MPS{Float64}, 10, 10, 10);
W = randn(MPO{Float64}, 10, 10, 10);
3 changes: 2 additions & 1 deletion src/SpinGlassPEPS.jl
Expand Up @@ -8,7 +8,8 @@ module SpinGlassPEPS
using CSV
using Logging
using StatsBase

using Memoize, LRUCache

using DocStringExtensions
const product = Iterators.product

Expand Down
2 changes: 1 addition & 1 deletion src/base.jl
Expand Up @@ -110,7 +110,7 @@ function tensor(ψ::MPS, state::Union{Vector, NTuple})
for (A, σ) zip(ψ, state)
C *= A[:, idx(σ), :]
end
tr(C)
C[]
end

function tensor::MPS)
Expand Down
28 changes: 28 additions & 0 deletions src/contractions.jl
Expand Up @@ -49,6 +49,19 @@ function left_env(ϕ::AbstractMPS, ψ::AbstractMPS)
L
end

@memoize function left_env::AbstractMPS, σ::Union{Vector, NTuple})
l = length(σ)
if l == 0
L = [1.]
else
m = idx(σ[l])
= left_env(ϕ, σ[1:l-1])
M = ϕ[l]
@reduce L[x] := sum(α) L̃[α] * M[α, $m, x]
end
L
end

# NOT tested yet
function right_env::AbstractMPS, ψ::AbstractMPS)
L = length(ψ)
Expand All @@ -68,6 +81,21 @@ function right_env(ϕ::AbstractMPS, ψ::AbstractMPS)
R
end

@memoize function right_env::AbstractMPS, W::AbstractMPO, σ::Union{Vector, NTuple})
l = length(σ)
k = length(ϕ)
if l == 0
R = ones(1, 1)
else
m = idx(σ[1])
= right_env(ϕ, W, σ[2:l])
M = ϕ[k-l+1]
= W[k-l+1]
@reduce R[x, y] := sum(α, β, γ) M̃[y, $m, β, γ] * M[x, γ, α] * R̃[α, β]
end
R
end


"""
$(TYPEDSIGNATURES)
Expand Down
36 changes: 36 additions & 0 deletions test/contractions.jl
Expand Up @@ -35,4 +35,40 @@ end
@test abs(dot(ϕ, ψ)) <= norm(ϕ) * norm(ψ)
end


@testset "left_env correctly contracts MPS for a given configuration" begin
D = 10
d = 2
sites = 5
T = ComplexF64

ψ = randn(MPS{T}, sites, D, d)
σ = 2 * (rand(sites) .< 0.5) .- 1

@test tensor(ψ, σ) left_env(ψ, σ)[]
end

@testset "right_env correctly contracts MPO with MPS for a given configuration" begin
D = 10
d = 2
sites = 5
T = Float64

ψ = randn(MPS{T}, sites, D, d)
W = randn(MPO{T}, sites, D, d)

σ = 2 * (rand(sites) .< 0.5) .- 1

ϕ = MPS(T, sites)
for (i, A) enumerate(W)
m = idx(σ[i])
@cast B[x, s, y] := A[x, $m, y, s]
ϕ[i] = B
end

@test dot(ψ, ϕ) right_env(ψ, W, σ)[]
end



end

0 comments on commit 30e0fc6

Please sign in to comment.