Rough sketch of a proposal

In [1]:
using TensorKit, TensorOperations

In [2]:
# I didn't want to parse the generated @tensor code again, so I want to simply run the tensor contraction code and record which contractions/additions/... are being done.
# For this, I used a "symbolictensormap" that keeps track of its structure and type

struct SymbolicTensorMap{A,B}
    structure::B
    SymbolicTensorMap(a,b) = new{a,typeof(b)}(b)
end
ttype(d::SymbolicTensorMap{A,B}) where {A,B} = A

TensorOperations.scalartype(a::SymbolicTensorMap) = TensorOperations.scalartype(ttype(a))

In [3]:
# I go through the generated @tensor code once, and generate two sets of code. One which takes in symbolictensormaps and will be run in the constructor of the struct, one which will be run while actually applying the struct and executing the contraction

function subsplit(ex)
    a = Any[ex.head]
    b = Any[ex.head]
    c = []
    for (sa,sb,sc) in split_execution.(ex.args)
        append!(c,sc)
        push!(a,sa)
        push!(b,sb)
    end
    return (Expr(a...),Expr(b...),c)
end

function split_execution(ex::Expr)
    splitmap = Dict(GlobalRef(TensorOperations,:tensorcontract!) => (create_mediated_tensorcontract!,mediated_tensorcontract!),
                    GlobalRef(TensorOperations,:tensoralloc_contract) => (create_mediated_tensoralloc_contract,mediated_tensoralloc_contract),)

    if ex.head == :(=) && length(ex.args) == 2
        if ex.args[2] isa Expr && ex.args[2].head == :call
            t = ex.args[2].args[1]

            if t in keys(splitmap)
                (mapped_1,mapped_2) = splitmap[t]
                nvar = gensym()
                a = quote
                    ($(ex.args[1]),$(nvar)) = $(mapped_1)($(ex.args[2].args[2:end]...))
                end
                b = quote
                    $(ex.args[1]) = $(mapped_2)($(nvar),$(ex.args[2].args[2:end]...))
                end
                return (a,b,[nvar])
            end
        end

        return subsplit(ex)
    elseif ex.head in (:block,)
        subsplit(ex)
    elseif ex isa Expr
        @show ex.head, ex.args
        return (ex,ex,[])
    end
end
split_execution(ex::Symbol) = (ex,ex,[])

split_execution (generic function with 2 methods)

In [4]:
# placeholder code. I'm not yet using the mediators to speed up the code, but you could use these mediators to calculate the rowr/colr once and then reuse them in the mediated calls.

function create_mediated_tensorcontract!(C::SymbolicTensorMap, pC, A::SymbolicTensorMap, pA, conjA, B::SymbolicTensorMap, pB, conjB, α=1, β=0 , backend=nothing)
    (C,Nothing)
end

function mediated_tensorcontract!(mediator,args...)
    TensorOperations.tensorcontract!(args...)
end

function create_mediated_tensoralloc_contract(TC, pC::Index2Tuple{N₁,N₂}, A::SymbolicTensorMap, pA, conjA, B::SymbolicTensorMap, pB, conjB, istemp=false, backend::TensorOperations.Backend...)  where {N₁,N₂}
    spaces1 = [TensorOperations.flag2op(conjA)(A.structure[p]) for p in pA[1]]
    spaces2 = [TensorOperations.flag2op(conjB)(B.structure[p]) for p in pB[2]]
    spaces = (spaces1..., spaces2...)

    S = spacetype(ttype(A))
    cod = ProductSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
    dom = ProductSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))

    C = SymbolicTensorMap(tensormaptype(S,N₁, N₂, TensorKit.similarstoragetype(ttype(A),TC)),dom → cod)

    (C,Nothing)
end

function mediated_tensoralloc_contract(mediator,args...)
    TensorOperations.tensoralloc_contract(args...)
end

mediated_tensoralloc_contract (generic function with 1 method)

In [11]:
macro tightloop_tensor(name,args::Vararg{Expr})
    isempty(args) && throw(ArgumentError("No arguments passed to `@tensor`"))
    
    if length(args) == 1
        parser = TensorOperations.defaultparser
    else
        tensorexpr = args[end]
        kwargs = parse_tensor_kwargs(args[1:(end - 1)])
        parser = tensorparser(tensorexpr, kwargs...)
    end
    
    parsed = parser(args[end])
    
    (a,b,c) = split_execution(parsed)
    c_types = [gensym() for t in c]
    declaration = quote end
    for (c_v,c_t) in zip(c,c_types)
        declaration = quote
            $(declaration)
            $(c_v)::$(c_t)
        end
    end

    input_symbols =  TensorOperations.getinputtensorobjects(args[end])
    output_symbols =  TensorOperations.getoutputtensorobjects(args[end])
    
    arg_symbols = [input_symbols...,output_symbols...];
    kwarg_expr = Expr(:parameters,[Expr(:kw,s,nothing) for s in arg_symbols]...)
    abstract_eval_call = Expr(:parameters,[Expr(:kw,s,Expr(:call,:SymbolicTensorMap,Expr(:call,:getindex,s,1),Expr(:call,:getindex,s,2))) for s in arg_symbols]...)

    instantiated_struct_name = gensym()
    access_inner_fields = quote end
    for c_v in c
        access_inner_fields = quote
            $access_inner_fields
            $(c_v) = $(instantiated_struct_name).$(c_v)
        end
    end

    return esc(quote
        struct $(name){$(c_types...)}
            $(declaration)
            function $(name)($(kwarg_expr))
                tup = abstract_eval($(abstract_eval_call))
                new{typeof.(tup)...}(tup...)
            end
            
            function abstract_eval($(kwarg_expr))
                $(a)
                return tuple($(c...))
            end
            function ($(instantiated_struct_name)::$name)($(kwarg_expr))
                $(access_inner_fields)
                $(b)
            end
        end
    end)
end

@tightloop_tensor (macro with 1 method)

In [9]:
@macroexpand @tightloop_tensor west a[-1;-2] := b[-1;1]*c[1;-2]

(ex.head, ex.args) = (:call, Any[:(TensorOperations.promote_contract), :(TensorOperations.scalartype(b)), :(TensorOperations.scalartype(c))])


quote
    [90m#= /home/maarten/projects/MPSKitExperimental.jl/examples/tightloop_tensor.ipynb:41 =#[39m
    struct west{var"##307", var"##308"}
        [90m#= /home/maarten/projects/MPSKitExperimental.jl/examples/tightloop_tensor.ipynb:42 =#[39m
        begin
            [90m#= /home/maarten/projects/MPSKitExperimental.jl/examples/tightloop_tensor.ipynb:19 =#[39m
            begin
                [90m#= /home/maarten/projects/MPSKitExperimental.jl/examples/tightloop_tensor.ipynb:19 =#[39m
                begin
                    [90m#= /home/maarten/projects/MPSKitExperimental.jl/examples/tightloop_tensor.ipynb:16 =#[39m
                end
                [90m#= /home/maarten/projects/MPSKitExperimental.jl/examples/tightloop_tensor.ipynb:20 =#[39m
                var"##305"::var"##307"
            end
            [90m#= /home/maarten/projects/MPSKitExperimental.jl/examples/tightloop_tensor.ipynb:20 =#[39m
            var"##306"::var"##308"
        end
        [90m#= /h

In [12]:
@tightloop_tensor best a[-1;-2] := b[-1;1]*c[1;-2]

t_1 = TensorMap(rand,ComplexF64,ℂ^5,ℂ^2)
t_2 = TensorMap(rand,Float64,ℂ^2,ℂ^3)

factory = best(b = (typeof(t_1),space(t_1)),c = (typeof(t_2),space(t_2)))
factory(b=rand(5,5),c=rand(5,5))

(ex.head, ex.args) = (:call, Any[:(TensorOperations.promote_contract), :(TensorOperations.scalartype(b)), :(TensorOperations.scalartype(c))])


5×5 Matrix{Float64}:
 0.678895  1.4084    1.76689  1.3367    1.13481
 0.705189  1.66558   1.60041  1.54012   1.17329
 0.647733  0.915739  1.00998  1.03513   0.768398
 0.604854  0.5956    1.13134  0.703446  0.611395
 0.682581  1.14462   1.80783  1.1488    1.04654

In [14]:
factory

best{DataType, DataType}(Nothing, Nothing)