Skip to content

ho-oto/TensorRules.jl

master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
src
 
 
 
 
 
 
 
 
 
 
 
 

TensorRules.jl

Build Status Code Style: Blue

TensorRules.jl provides a macro @∇ (you can type by \nabla<tab>), which enable us to use automatic differentiation (AD) libraries (e.g., Zygote.jl, Diffractor.jl) with @tensor and @tensoropt macros in TensorOperations.jl.

TensorRules.jl uses ChainRulesCore.jl to define custom adjoints. So, you can use any AD libraries which supports ChainRulesCore.jl.

How to use

julia> using TensorOperations, TensorRules, Zygote;
julia> function foo(a, b, c) # define function with Einstein summation
           # d_F = \sum_{A,B,C,D} a_{A,B,C} b_{C,D,E,F} c_{A,B,D,E}
           @tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
           return d[1]
       end;
julia> a, b, c = randn(3, 4, 5), randn(5, 6, 7, 8), randn(3, 4, 6, 7);
julia> gradient(foo, a, b, c); # try to obtain gradient of `foo` by Zygote
ERROR: this intrinsic must be compiled to be called
Stacktrace:
...
julia> @∇ function foo(a, b, c) # use @∇
           @tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
           return d[1]
       end;
julia> gradient(foo, a, b, c); # it works!

How it works

The strategy of TensorRules.jl are very similar to TensorGrad.jl.

@∇ converts functions which contains @tensor or @tensoropt macro. First, @∇ detects @tensor or @tensoropt expressions in function definition and convert them to inlined functions. Then, @∇ define custom adjoint rules for the generated functions.

For example, the following definition

@∇ function foo(a, b, c, d, e, f)
    @tensoropt !C x[A, B] := conj(a[A, C]) * sin.(b)[C, D] * c.d[D, B] + d * e[1, 2][A, B]
    x = x + f
    @tensor x[A, B] += a[A, C] * (a * a)[C, B]
    return x
end

will be converted to a code equivalent to

function foo(a, b, c, d, e, f)
    x = _foo_1(a, sin.(a), c.d, d, e[1, 2])
    x = x + f
    x += _foo_2(a, a * a)
    return x
end

@inline _foo_1(x1, x2, x3, x4, x5) =
    @tensoropt !C _[A, B] := conj(x1[A, C]) * x2[C, D] * x3[D, B] + x4 * x5[A, B]

@inline _foo_2(x1, x2) = @tensor _[A, B] := x1[A, C] * x2[C, B]

function rrule(::typeof(_foo_1), x1, x2, x3, x4, x5)
    f = _foo_1(x1, x2, x3, x4, x5)
    Px1, Px2, Px3, Px4, Px5 = ProjectTo(x1), ProjectTo(x2), ProjectTo(x3), ProjectTo(x4), ProjectTo(x5)
    function _foo_1_pullback(Δf)
        fnΔx1(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, C] := conj(Δf[A, B]) * x2[C, D] * x3[D, B]
        fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, C] += conj(Δf[A, B]) * x2[C, D] * x3[D, B]
        fnΔx2(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[C, D] := conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
        fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[C, D] += conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
        fnΔx3(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[D, B] := conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
        fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[D, B] += conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
        fnΔx4(Δf, x1, x2, x3, x4, x5) = first(@tensoropt !C _[] := conj(conj(Δf[A, B]) * x5[A, B]))
        fnΔx5(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, B] := conj(x4 * conj(Δf[A, B]))
        fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, B] += conj(x4 * conj(Δf[A, B]))
        Δx1 = InplaceableThunk(
            Thunk(() -> Px1(fnΔx1(Δf, x1, x2, x3, x4, x5))),
            x -> fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5)
        )
        Δx2 = InplaceableThunk(
            Thunk(() -> Px2(fnΔx2(Δf, x1, x2, x3, x4, x5))),
            x -> fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5)
        )
        Δx3 = InplaceableThunk(
            Thunk(() -> Px3(fnΔx3(Δf, x1, x2, x3, x4, x5))),
            x -> fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5)
        )
        Δx4 = Thunk(() -> fnΔx4(Δf, x1, x2, x3, x4, x5))
        Δx5 = InplaceableThunk(
            Thunk(() -> Px5(fnΔx5(Δf, x1, x2, x3, x4, x5))),
            x -> fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5)
        )
        return (NoTangent(), Δx1, Δx2, Δx3, Δx4, Δx5)
    end
    return f, _foo_1_pullback
end

function rrule(::typeof(_foo_2), x1, x2)
    ...
end

By using Thunk and InplaceableThunk properly, adjoints will be evaluated only if they are needed.

unsupported features

  • @∇ uses @capture macro defined in MacroTools.jl to parse Expr. Because of the limitation of @capture macro, index notations based on :typed_vcat and :typed_hcat (A[a; b], A[a b]) are unsupported. Please use A[a, b] style.
  • Designations of contraction order based on ord=(...) or NCON style are unsupported. Please use @tensoropt and specify costs of each bonds.
  • Since Zygote.jl does not support inplace operations, we cannot use @tensor A[] = ... in the expression. Please use :=, += and -= instead.

TODO

  • support @tensor block (@tensor begin ... end)
  • support higher order differentiation (by applying @∇ to rrule and frule recursively)
    • better support of InplaceableThunk

About

Macros to define custom adjoints for TensorOperations.jl

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages