Rough sketch of a proposal

In [1]:
using TensorKit, TensorOperations, BenchmarkTools, Strided 

In [2]:
using LinearAlgebra
BLAS.set_num_threads(1)
Strided.set_num_threads(1)

1

In [3]:
using Base.Threads
nthreads()

1

In [4]:
# I didn't want to parse the generated @tensor code again, so I want to simply run the tensor contraction code and record which contractions/additions/... are being done.
# For this, I used a "symbolictensormap" that keeps track of its structure and type

struct SymbolicTensorMap{A,B}
    structure::B
    SymbolicTensorMap(a,b) = new{a,typeof(b)}(b)
end
ttype(d::SymbolicTensorMap{A,B}) where {A,B} = A

TensorOperations.scalartype(a::SymbolicTensorMap) = TensorOperations.scalartype(ttype(a))

In [5]:
function fast_init(codom::ProductSpace{S,N₁},
    dom::ProductSpace{S,N₂},A) where {S<:IndexSpace,N₁,N₂}

    I = sectortype(S)
    if I == Trivial
        d1 = dim(codom)
        d2 = dim(dom)

        return function init_trivial()
            data = initializer(undef,(d1, d2))
            return TensorMap{S,N₁,N₂,Trivial,A,Nothing,Nothing}(data, codom, dom)
        end
    end
    blocksectoriterator = blocksectors(codom ← dom)
    rowr, rowdims = TensorKit._buildblockstructure(codom, blocksectoriterator)
    colr, coldims = TensorKit._buildblockstructure(dom, blocksectoriterator)
    
    
    F₁ = TensorKit.fusiontreetype(I, N₁)
    F₂ = TensorKit.fusiontreetype(I, N₂)
   
    ds = TensorKit.SectorDict{I,A}
    function init_sym()
        if !isreal(I)
            data = TensorKit.SectorDict(c => complex(A(undef,(rowdims[c], coldims[c])))
                            for c in blocksectoriterator)
        else
            data = TensorKit.SectorDict(c => A(undef,(rowdims[c], coldims[c])) for c in blocksectoriterator)
        end
        return TensorMap{S,N₁,N₂,I,ds,F₁,F₂}(data, codom, dom, rowr, colr)
    end

end

fast_init (generic function with 1 method)

In [6]:
# I go through the generated @tensor code once, and generate two sets of code. One which takes in symbolictensormaps and will be run in the constructor of the struct, one which will be run while actually applying the struct and executing the contraction

function subsplit(ex)
    a = Any[ex.head]
    b = Any[ex.head]
    c = []
    for (sa,sb,sc) in split_execution.(ex.args)
        append!(c,sc)
        push!(a,sa)
        push!(b,sb)
    end
    return (Expr(a...),Expr(b...),c)
end

function split_execution(ex::Expr)
    splitmap = Dict(GlobalRef(TensorOperations,:tensorcontract!) => (create_mediated_tensorcontract!,mediated_tensorcontract!),
                    GlobalRef(TensorOperations,:tensoralloc_contract) => (create_mediated_tensoralloc_contract,mediated_tensoralloc_contract),
                    GlobalRef(TensorOperations,:tensoradd!) => (create_mediated_tensoradd!,mediated_tensoradd!),
                    GlobalRef(TensorOperations,:tensoralloc_add) => (create_mediated_tensoralloc_add,mediated_tensoralloc_add),
                    GlobalRef(TensorOperations,:tensortrace!) => (create_mediated_tensortrace!,mediated_tensortrace!),)

    if ex.head == :(=) && length(ex.args) == 2
        if ex.args[2] isa Expr && ex.args[2].head == :call
            t = ex.args[2].args[1]

            if t in keys(splitmap)
                (mapped_1,mapped_2) = splitmap[t]
                nvar = gensym()
                a = quote
                    ($(ex.args[1]),$(nvar)) = $(mapped_1)($(ex.args[2].args[2:end]...))
                end
                b = quote
                    $(ex.args[1]) = $(mapped_2)($(nvar),$(ex.args[2].args[2:end]...))
                end
                return (a,b,[nvar])
            end
        end

        return subsplit(ex)
    elseif ex.head in (:block,)
        subsplit(ex)
    elseif ex isa Expr
        @show ex.head, ex.args
        return (ex,ex,[])
    end
end
split_execution(ex::Symbol) = (ex,ex,[])

split_execution (generic function with 2 methods)

In [7]:
function generate_permute_table(elt,sp_src,sp_dst, p1::IndexTuple{N₁},p2::IndexTuple{N₂}) where {N₁,N₂}
    
    blocksectoriterator_src = blocksectors(sp_src);
    rowr_src, rowdims = TensorKit._buildblockstructure(codomain(sp_src), blocksectoriterator_src)
    colr_src, coldims = TensorKit._buildblockstructure(domain(sp_src), blocksectoriterator_src)

    blocksectoriterator_dst = blocksectors(sp_dst);
    rowr_dst, rowdims = TensorKit._buildblockstructure(codomain(sp_dst), blocksectoriterator_dst)
    colr_dst, coldims = TensorKit._buildblockstructure(domain(sp_dst), blocksectoriterator_dst)

    ftreemap = (f1, f2)->permute(f1, f2, p1, p2);
    I = eltype(rowr_src.keys);

    N = length(p1)+length(p2);
    table = Tuple{elt,Int,UnitRange{Int},UnitRange{Int},NTuple{N,Int},Int,UnitRange{Int},UnitRange{Int},NTuple{N,Int}}[];
    for (i_src,(s_src,f1_list_src)) in enumerate(rowr_src)
        f2_list_src = colr_src[s_src];

        for (f1_src,r_src) in f1_list_src, (f2_src,c_src) in f2_list_src
            d_src = (dims(codomain(sp_src), f1_src.uncoupled)..., dims(domain(sp_src), f2_src.uncoupled)...)
            for ((f1_dst,f2_dst),α) in ftreemap(f1_src,f2_src)
                
                d_dst = (dims(codomain(sp_dst), f1_dst.uncoupled)..., dims(domain(sp_dst), f2_dst.uncoupled)...)

                s_dst = f1_dst.coupled;
                
                i_dst = searchsortedfirst(rowr_dst.keys,s_dst);

                r_dst = rowr_dst.values[i_dst][f1_dst];
                c_dst = colr_dst.values[i_dst][f2_dst];


                push!(table,(α,i_src,r_src,c_src,d_src,i_dst,r_dst,c_dst,d_dst));
            end
        end
    end

    (table,p1,p2)
end

function execute_permute_table!(t_dst,t_src,bulk,beta=false)
    (table,p1,p2) = bulk
    rmul!(t_dst,beta);

    @inbounds for (α,s_src,r_src,c_src,d_src,s_dst,r_dst,c_dst,d_dst) in table

        view_dst = sreshape(StridedView(t_dst.data.values[s_dst])[r_dst,c_dst],d_dst)
        view_src = sreshape(StridedView(t_src.data.values[s_src])[r_src,c_src],d_src);
        
        #TensorOperations.tensoradd!(view_dst,(p1,p2),view_src,:N,α,true)
        axpy!(α,permutedims(view_src,(p1...,p2...)), view_dst);
    end

    t_dst
end

execute_permute_table! (generic function with 2 methods)

In [8]:
# placeholder code. I'm not yet using the mediators to speed up the code, but you could use these mediators to calculate the rowr/colr once and then reuse them in the mediated calls.
function create_mediated_tensorcontract!(C::SymbolicTensorMap, pC, A::SymbolicTensorMap, pA, conjA, B::SymbolicTensorMap, pB, conjB, α=1, β=0 , backend=nothing)
    S = spacetype(A.structure)
    if !(BraidingStyle(sectortype(S)) isa SymmetricBraiding)
        throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
    end
    #=
    copyA = false
    if BraidingStyle(sectortype(S)) isa Fermionic
        for i in cindA
            if !isdual(space(A, i))
                copyA = true
            end
        end
    end
    =#

    #A′ = permute(A, (oindA, cindA); copy=copyA)
    sp_dst_A =  ProductSpace{S,length(pA[1])}(map(n -> A.structure[n], pA[1])) ← ProductSpace{S,length(pA[2])}(map(n -> dual(A.structure[n]), pA[2]))
    fast_init_A = fast_init(codomain(sp_dst_A),domain(sp_dst_A),storagetype(ttype(A)))
    tbl_A = generate_permute_table(scalartype(ttype(A)),A.structure,sp_dst_A,pA[1],pA[2])

    #B′ = permute(B, (cindB, oindB))
    sp_dst_B =  ProductSpace{S,length(pB[1])}(map(n -> B.structure[n], pB[1])) ← ProductSpace{S,length(pB[2])}(map(n -> dual(B.structure[n]), pB[2]))
    fast_init_B = fast_init(codomain(sp_dst_B),domain(sp_dst_B),storagetype(ttype(B)))
    tbl_B = generate_permute_table(scalartype(ttype(B)),B.structure,sp_dst_B,pB[1],pB[2])
    
    #=
    if BraidingStyle(sectortype(S)) isa Fermionic
        for i in domainind(A′)
            if !isdual(space(A′, i))
                A′ = twist!(A′, i)
            end
        end
    end
    =#
    #=
    ipC = TupleTools.invperm((pC[1]..., pC[2]...))
    oindAinC = TupleTools.getindices(ipC, ntuple(n -> n, N₁))
    oindBinC = TupleTools.getindices(ipC, ntuple(n -> n + N₁, N₂))
    if has_shared_permute(C, (oindAinC, oindBinC))
        C′ = permute(C, (oindAinC, oindBinC))
        mul!(C′, A′, B′, α, β)
    else
        C′ = A′ * B′
        add_permute!(C, C′, (p₁, p₂), α, β)
    end
    return C
    =#
    fast_init_C′ = fast_init(codomain(sp_dst_A),domain(sp_dst_B),storagetype(ttype(C)));
    tbl_C′ = generate_permute_table(scalartype(ttype(C)),codomain(sp_dst_A)←domain(sp_dst_B),C.structure,pC[1],pC[2])

    (C,(fast_init_A,tbl_A,fast_init_B,tbl_B,fast_init_C′,tbl_C′))
end

function mediated_tensorcontract!(mediator,C, pC, A, pA, conjA, B, pB, conjB, α=1, β=0 , backend=nothing)
    (fast_init_A,tbl_A,fast_init_B,tbl_B,fast_init_C′,tbl_C′) = mediator

    Ap = fast_init_A()
    execute_permute_table!(Ap,A,tbl_A)

    Bp = fast_init_B()
    execute_permute_table!(Bp,B,tbl_B)

    C′ = mul!(fast_init_C′(),Ap,Bp,α)
    execute_permute_table!(C,C′,tbl_C′,β)
    
    C    
end

function create_mediated_tensoralloc_contract(TC, pC::Index2Tuple{N₁,N₂}, A::SymbolicTensorMap, pA, conjA, B::SymbolicTensorMap, pB, conjB, istemp=false, backend::TensorOperations.Backend...)  where {N₁,N₂}
    spaces1 = [TensorOperations.flag2op(conjA)(A.structure[p]) for p in pA[1]]
    spaces2 = [TensorOperations.flag2op(conjB)(B.structure[p]) for p in pB[2]]
    spaces = (spaces1..., spaces2...)

    S = spacetype(ttype(A))
    cod = ProductSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
    dom = ProductSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))
    stortype = TensorKit.similarstoragetype(ttype(A),TC)
    C = SymbolicTensorMap(tensormaptype(S,N₁, N₂, stortype),dom → cod)

    (C,fast_init(cod,dom,stortype)) 
end

function mediated_tensoralloc_contract(mediator,TC, pC::Index2Tuple{N₁,N₂}, A, pA, conjA, B, pB, conjB, istemp=false, backend::TensorOperations.Backend...)  where {N₁,N₂}
    mediator()
end

mediated_tensoralloc_contract (generic function with 2 methods)

In [9]:
function create_mediated_tensoradd!(C, pC, A, conjA, α=1, β=1 , backend=nothing)
    (C,Nothing)
end

function mediated_tensoradd!(mediator,args...)
    TensorOperations.tensoradd!(args...)
end

function create_mediated_tensoralloc_add(TC, pC::Index2Tuple{N₁,N₂}, A::SymbolicTensorMap, conjA, istemp=false, backend::TensorOperations.Backend...)  where {N₁,N₂}
    spaces1 = [TensorOperations.flag2op(conjA)(A.structure[p]) for p in pC[1]]
    spaces2 = [TensorOperations.flag2op(conjA)(A.structure[p]) for p in pC[2]]
    spaces = (spaces1..., spaces2...)

    S = spacetype(ttype(A))
    cod = ProductSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
    dom = ProductSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))
    stortype = TensorKit.similarstoragetype(ttype(A),TC)
    C = SymbolicTensorMap(tensormaptype(S,N₁, N₂, stortype),dom → cod)

    (C,fast_init(cod,dom,stortype))
end

function mediated_tensoralloc_add(mediator,TC, pC::Index2Tuple{N₁,N₂}, A, conjA, istemp=false, backend::TensorOperations.Backend...)  where {N₁,N₂}
    mediator()
end

mediated_tensoralloc_add (generic function with 2 methods)

In [10]:
function create_mediated_tensortrace!(C, pC, A, pA, conjA, α=1, β=0 , backend=nothing)
    (C,Nothing)
end

function mediated_tensortrace!(mediator,args...)
    TensorOperations.tensortrace!(args...)
end

mediated_tensortrace! (generic function with 1 method)

In [11]:
macro tightloop_tensor(name,args::Vararg{Expr})
    isempty(args) && throw(ArgumentError("No arguments passed to `@tensor`"))
    
    if length(args) == 1
        parser = TensorOperations.defaultparser
    else
        tensorexpr = args[end]
        kwargs = parse_tensor_kwargs(args[1:(end - 1)])
        parser = tensorparser(tensorexpr, kwargs...)
    end
    
    parsed = parser(args[end])
    
    (a,b,c) = split_execution(parsed)
    c_types = [gensym() for t in c]
    declaration = quote end
    for (c_v,c_t) in zip(c,c_types)
        declaration = quote
            $(declaration)
            $(c_v)::$(c_t)
        end
    end

    input_symbols =  TensorOperations.getinputtensorobjects(args[end])
    output_symbols =  TensorOperations.getoutputtensorobjects(args[end])
    
    arg_symbols = [input_symbols...,output_symbols...];
    kwarg_expr = Expr(:parameters,[Expr(:kw,s,nothing) for s in arg_symbols]...)
    abstract_eval_call = Expr(:parameters,[Expr(:kw,s,Expr(:call,:SymbolicTensorMap,Expr(:call,:getindex,s,1),Expr(:call,:getindex,s,2))) for s in arg_symbols]...)

    instantiated_struct_name = gensym()
    access_inner_fields = quote end
    for c_v in c
        access_inner_fields = quote
            $access_inner_fields
            $(c_v) = $(instantiated_struct_name).$(c_v)
        end
    end

    return esc(quote
        struct $(name){$(c_types...)}
            $(declaration)
            function $(name)($(kwarg_expr))
                tup = abstract_eval($(abstract_eval_call))
                new{typeof.(tup)...}(tup...)
            end
            
            function abstract_eval($(kwarg_expr))
                $(a)
                return tuple($(c...))
            end
            function ($(instantiated_struct_name)::$name)($(kwarg_expr))
                $(access_inner_fields)
                $(b)
            end
        end
    end)
end

@tightloop_tensor (macro with 1 method)

In [12]:
@tightloop_tensor ac_eff y[-1 -2;-3] := le[-1 2;1]*O[2 -2;3 4]*x[1 3;5]*re[5 4;-3]

(ex.head, ex.args) = (:call, Any[:(TensorOperations.promote_contract), :(TensorOperations.scalartype(le)), :(TensorOperations.scalartype(x))])
(ex.head, ex.args) = 

(:call, Any[:(TensorOperations.promote_contract), :(TensorOperations.scalartype(var"####y_A#292_A#293")), :(TensorOperations.scalartype(O))])
(ex.head, ex.args) = (:call, Any[:(TensorOperations.tensorfree!), Symbol("####y_A#292_A#293")])
(ex.head, ex.args) = (:call, Any[:(TensorOperations.promote_contract), :(TensorOperations.scalartype(var"##y_A#292")), :(TensorOperations.scalartype(re))])
(ex.head, ex.args) = (:call, Any[:(TensorOperations.tensorfree!), Symbol("##y_A#292")])


In [22]:
virtspace = Rep[SU₂](i => 20 for i in 0:10);
ospace = Rep[SU₂](0 => 5,1 => 2);
pspace = Rep[SU₂](1 => 1);

t_le = TensorMap(rand,ComplexF64,virtspace*ospace',virtspace);
t_re = TensorMap(rand,ComplexF64,virtspace*ospace,virtspace);
t_ac = TensorMap(rand,ComplexF64,virtspace*pspace,virtspace);
t_o = TensorMap(rand,ComplexF64,ospace*pspace,pspace*ospace);

factory = ac_eff(le = (typeof(t_le),space(t_le)),re = (typeof(t_re),space(t_re)),O = (typeof(t_o),space(t_o)),x = (typeof(t_ac),space(t_ac)))

ac_eff{var"#init_sym#14"{2, 2, GradedSpace{SU2Irrep, TensorKit.SortedVectorDict{SU2Irrep, Int64}}, ProductSpace{GradedSpace{SU2Irrep, TensorKit.SortedVectorDict{SU2Irrep, Int64}}, 2}, ProductSpace{GradedSpace{SU2Irrep, TensorKit.SortedVectorDict{SU2Irrep, Int64}}, 2}, DataType, DataType, DataType, DataType, TensorKit.SortedVectorDict{SU2Irrep, Int64}, TensorKit.SortedVectorDict{SU2Irrep, Dict{FusionTree{SU2Irrep, 2, 0, 1, Nothing}, UnitRange{Int64}}}, TensorKit.SortedVectorDict{SU2Irrep, Int64}, TensorKit.SortedVectorDict{SU2Irrep, Dict{FusionTree{SU2Irrep, 2, 0, 1, Nothing}, UnitRange{Int64}}}, Vector{SU2Irrep}, DataType}, Tuple{var"#init_sym#14"{2, 1, GradedSpace{SU2Irrep, TensorKit.SortedVectorDict{SU2Irrep, Int64}}, ProductSpace{GradedSpace{SU2Irrep, TensorKit.SortedVectorDict{SU2Irrep, Int64}}, 2}, ProductSpace{GradedSpace{SU2Irrep, TensorKit.SortedVectorDict{SU2Irrep, Int64}}, 1}, DataType, DataType, DataType, DataType, TensorKit.SortedVectorDict{SU2Irrep, Int64}, TensorKit.Sorte

In [23]:
@benchmark $factory(le = $t_le, re=$t_re, x = $t_ac, O = $t_o)

BenchmarkTools.Trial: 810 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m4.278 ms[22m[39m … [35m11.995 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 40.24%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m5.061 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m6.152 ms[22m[39m ± [32m 1.883 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m17.61% ± 18.17%

  [39m [39m▁[39m▅[39m█[39m▄[39m▁[39m [34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▅[39m█[39m█[39m█[39m█[39m█[39m▇[3

In [24]:
function slowcontract(;le=nothing,re=nothing,x=nothing,O=nothing)
    @tensor y[-1 -2;-3] := le[-1 2;1]*O[2 -2;3 4]*x[1 3;5]*re[5 4;-3]
end
@benchmark slowcontract(le = $t_le, re=$t_re, x = $t_ac, O = $t_o)

BenchmarkTools.Trial: 968 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m3.686 ms[22m[39m … [35m10.163 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 45.21%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m4.338 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m5.141 ms[22m[39m ± [32m 1.662 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m14.41% ± 17.48%

  [39m [39m [39m▁[39m▆[39m█[39m▆[39m▃[34m▁[39m[39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▅[39m█[39m█[39m█[39m█[39m█[3

In [16]:
norm(slowcontract(le = t_le, re=t_re, x = t_ac, O = t_o)- factory(le = t_le, re=t_re, x = t_ac, O = t_o)) # bit worrying

8.189336637915264e-11