https://github.com/genkuroki/public/blob/main/0036/TensorOperations.jl.ipynb の続き

In [1]:
using TensorOperations
using BenchmarkTools

In [2]:
function refA_Gs(A, G, N)
    Gs = :($G[1][i1, j1, i2])
    for k in 2:N
        ik = Symbol(:i, k)
        jk = Symbol(:j, k)
        ikp1 = Symbol(:i, mod1(k+1, N))
        Gs = :($Gs * $G[$k][$ik, $jk, $ikp1])
    end
    refA = Expr(:ref, :A, Symbol.(:j, 1:N)...)
    refA, Gs
end

macro multitrace(G, N)
    refA, Gs = refA_Gs(:A, :G, N)
    quote
        let G = $(esc(G))
            @tensor $refA := $Gs
        end
    end
end

macro multitrace(A, G, N)
    refA, Gs = refA_Gs(:A, :G, N)
    quote
        let A = $(esc(A)), G = $(esc(G))
            @tensor $refA = $Gs
        end
    end
end

@multitrace (macro with 2 methods)

In [3]:
@generated function multr(G, ::Val{N}) where N
    refA, Gs = refA_Gs(:A, :G, N)
    quote
        @tensor $refA := $Gs
    end
end

multr(G) = multr(G, Val(length(G)))

@generated function multr!(A, G, ::Val{N}) where N
    refA, Gs = refA_Gs(:A, :G, N)
    quote
        @tensor $refA = $Gs
    end
end

multr!(A, G) = multr!(A, G, Val(length(G)))

multr! (generic function with 2 methods)

In [4]:
i = [100, 8, 6, 5]
j = [6, 7, 5, 4]
H = [randn(i[mod1(k, 4)], j[k], i[mod1(k+1, 4)]) for k in 1:4]
size.(H)

4-element Vector{Tuple{Int64, Int64, Int64}}:
 (100, 6, 8)
 (8, 7, 6)
 (6, 5, 5)
 (5, 4, 100)

In [5]:
A = @multitrace H 4
typeof(A), size(A)

(Array{Float64, 4}, (6, 7, 5, 4))

In [6]:
B = multr(H)
A == B

true

In [7]:
C = similar(zeros(), j...)
multr!(C, H)
A == B == C

true

In [8]:
@show size.(H)

A = @btime @multitrace $H 4
B = @btime multr($H)
C = @btime multr!($C, $H)
A == B == C

size.(H) = [(100, 6, 8), (8, 7, 6), (6, 5, 5), (5, 4, 100)]
  223.600 μs (170 allocations: 21.44 KiB)
  226.000 μs (172 allocations: 21.56 KiB)
  224.400 μs (170 allocations: 14.83 KiB)


true

In [9]:
K = [H[mod1(k-1, length(H))] for k in 1:length(H)]
@show size.(K)
F = similar(zeros(), size.(K, 2)...)

D = @btime @multitrace $K 4
E = @btime multr($K)
F = @btime multr!($F, $K)
D == E == F

size.(K) = [(5, 4, 100), (100, 6, 8), (8, 7, 6), (6, 5, 5)]
  32.800 μs (100 allocations: 13.44 KiB)
  33.500 μs (104 allocations: 13.72 KiB)
  31.900 μs (100 allocations: 6.89 KiB)


true

In [10]:
permutedims(A, (4, 1, 2, 3)) ≈ D

true