# Reactant demo

In [1]:
using Tenet
using EinExprs
using Reactant
using Enzyme
using BenchmarkTools
using Adapt

using LinearAlgebra
BLAS.set_num_threads(1)

using Random
Random.seed!(0)

TaskLocalRNG()

In [2]:
tn = rand(TensorNetwork, 15, 3; dim=(16,16))

path = einexpr(tn; optimizer=Exhaustive())

@benchmark contract(tn; path)

BenchmarkTools.Trial: 438 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m10.154 ms[22m[39m … [35m202.250 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 94.66%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m10.385 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m11.387 ms[22m[39m ± [32m  9.307 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m5.02% ±  5.30%

  [39m█[39m▆[34m▅[39m[39m▃[39m [39m▄[39m▄[39m▂[39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[39m█[34m█[39m[39m█

In [3]:
tn′ = adapt(Reactant.ConcreteRArray, tn)

g = Reactant.compile(Tuple(tensors(tn′))) do ts...
    _tn = TensorNetwork(ts)
    contract(_tn; path)
end

@benchmark g(tensors(tn′)...)

BenchmarkTools.Trial: 590 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m7.593 ms[22m[39m … [35m17.671 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m8.174 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m8.481 ms[22m[39m ± [32m 1.115 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m [39m▃[39m▇[39m█[39m▇[34m▆[39m[39m▆[39m▆[32m▄[39m[39m▃[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▇[39m█[39m█[39m█[39m█[34m█[39m[3

In [4]:
function f(x...)
    _tn = TensorNetwork(x)
    contract(_tn; path)
end

∇g = Reactant.compile(Tuple(tensors(tn′))) do x...
    dx = Enzyme.make_zero.(x)
    Enzyme.autodiff(Reverse, f, Active, Duplicated.(x,dx)...)
    return dx
end

@benchmark ∇g(tensors(tn′)...)

BenchmarkTools.Trial: 181 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m25.377 ms[22m[39m … [35m 33.153 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m27.563 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m27.610 ms[22m[39m ± [32m782.976 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m▁[39m▁[39m [39m▂[39m▄[39m▆[39m▄[39m▂[39m▂[34m▃[39m[32m█[39m[39m [39m▄[39m▃[39m▁[39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m▁[39m▁[39m▁[39m▁[

In [5]:
re∇g = Reactant.compile(Tuple(tensors(tn′))) do x...
    dx = Enzyme.make_zero.(x)
    primal = Enzyme.autodiff(ReverseWithPrimal, f, Active, Duplicated.(x,dx)...)
    return (primal, dx)
end

@benchmark re∇g(tensors(tn′)...)

BenchmarkTools.Trial: 197 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m23.327 ms[22m[39m … [35m44.784 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m24.845 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m25.364 ms[22m[39m ± [32m 2.029 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m [39m [39m [39m [39m [39m▄[39m▆[39m█[34m▆[39m[39m▆[39m▅[39m▂[32m▁[39m[39m▁[39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▅[39m▁[39m▁[39m▁[39m▁[39m▆[39m