In [1]:
using Metatheory
using Metatheory.EGraphs
using TermInterface

1


In [3]:
function EGraphs.make(::Val{:IndexAnalysis}, g::EGraph, n::ENodeLiteral)
    if n.value isa Set
        return n.value
    else 
        return Set()
    end
end

function EGraphs.make(::Val{:IndexAnalysis}, g::EGraph, n::ENodeTerm)
    # Let's consider only binary function call terms.
    if exprhead(n) == :call && operation(n) == :mapJoin
        op = operation(n)
        # Get the left and right child eclasses
        child_eclasses = arguments(n)
        l = g[child_eclasses[2]]
        r = g[child_eclasses[3]]

        # Return the union of the index sets for mapJoin operators
        ldata = getdata(l, :IndexAnalysis, nothing)
        rdata = getdata(r, :IndexAnalysis, nothing)
        return union(ldata, rdata)
        
    elseif exprhead(n) == :call && operation(n) == :reduceDim
            op = operation(n)
            # Get the left and right child eclasses
            child_eclasses = arguments(n)
            l = g[child_eclasses[2]]
            r = g[child_eclasses[3]]

            # Return the union of the index sets for mapJoin operators
            ldata = getdata(l, :IndexAnalysis, nothing)
            rdata = getdata(r, :IndexAnalysis, nothing)
            return setdiff(rdata, ldata)
        
    elseif exprhead(n) == :ref && arity(n) == 2
        return getdata(g[arguments(n)[2]], :IndexAnalysis, nothing)
    end 
    return nothing
end

EGraphs.islazy(::Val{:IndexAnalysis})  = false

function EGraphs.join(::Val{:IndexAnalysis}, a, b)
    if a == b 
        return a 
    else
        # an expression cannot be odd and even at the same time!
        # this is contradictory, so we ignore the analysis value
        return nothing 
    end
end

function reduceDimUnion(f, is, js, A)
    indices = union(is, js)
    return :(reduceDim($f, $indices, $A))
end

function reduceDim end
function mapJoin end

doesntShareIndices(is, a) = false

function doesntShareIndices(is::Set, a::EClass) 
    return length(intersect(is, getdata(a, :IndexAnalysis, nothing))) == 0
end

t = @theory a b c f is js begin
    # Fuse reductions
    reduceDim(f, is::Set, reduceDim(f, js::Set, a)) => :(reduceDim($f, $(union(is, js)), $a))

    # Associativity
    mapJoin(+, a, b) == mapJoin(+, b, a)
    mapJoin(*, a, b) == mapJoin(*, b, a)

    # Commutativity
    mapJoin(+, a, mapJoin(+, b, c)) == mapJoin(+, mapJoin(+, a, b), c)
    mapJoin(*, a, mapJoin(*, b, c)) == mapJoin(*, mapJoin(*, a, b), c)
    
    # Distributivity
    mapJoin(*, a, mapJoin(+, b, c)) == mapJoin(+, mapJoin(*, a, b), mapJoin(*, a, c))

    # Reduction PushUp
    mapJoin(*, a, reduceDim(+, is::Set, b)) --> reduceDim(+, is, mapJoin(*, a, b))

    # Reduction PushDown
    reduceDim(+, is::Set, mapJoin(*, a, b)) => :(mapJoin($*, $a, reduceDim($+, $is, $b))) where doesntShareIndices(is, a)

    reduceDim(+, is::Set, mapJoin(*, a, b)) => :(reduceDim($+,  $(intersect(is, getdata(a, :IndexAnalysis, nothing))), mapJoin($*, $a, reduceDim($+, $(setdiff(is, getdata(a, :IndexAnalysis, nothing))), $b))))

end

is = Set(["i"])
ks = Set(["k"])
a_is = Set(["i", "j"])
b_is = Set(["j", "k"])
g= EGraph(:(reduceDim($+, $is, reduceDim($+, $ks, mapJoin($*, a[$a_is], b[$b_is])))))

analyze!(g, :IndexAnalysis)
report = saturate!(g, t);
println(extract!(g, astsize))
print(report)
getdata(g[g.root], :IndexAnalysis)
# NOTE: I had to change one line of Metatheory.jl to make this run. Specifically, around line 200 of saturation.jl, the `eclass = g[ecid]` should be above the if statement.

reduceDim(+, Set(["k", "i"]), mapJoin(*, b[Set(["j", "k"])], a[Set(["j", "i"])]))
SaturationReport
	Stop Reason: saturated
	Iterations: 7
	EGraph Size: 21 eclasses, 142 nodes
[0m[1m ────────────────────────────────────────────────────────────────────[22m
[0m[1m                   [22m         Time                    Allocations      
                   ───────────────────────   ────────────────────────
 Tot / % measured:      222ms /  52.4%           11.8MiB /  85.7%    

 Section   ncalls     time    %tot     avg     alloc    %tot      avg
 ────────────────────────────────────────────────────────────────────
 Apply          7    113ms   97.7%  16.2ms   8.30MiB   82.2%  1.19MiB
 Search         7   1.97ms    1.7%   281μs   1.59MiB   15.7%   232KiB
   1            7    315μs    0.3%  45.0μs    229KiB    2.2%  32.7KiB
   8            7    304μs    0.3%  43.4μs    285KiB    2.8%  40.8KiB
   9            7    286μs    0.2%  40.9μs    285KiB    2.8%  40.8KiB
   3            7    201μs  

Set{String} with 1 element:
  "j"

In [5]:
getdata(g[6], :IndexAnalysis, nothing)

Set{String} with 2 elements:
  "j"
  "i"

In [10]:
g[2]

EClass 2 ([Set(["i"])], (IndexAnalysis = Base.RefValue{Any}(Set(["i"])), astsize = Base.RefValue{Any}((Set(["i"]), 1))))

In [24]:
g[10]

EClass 10 ([ENode(call, mapJoin, Expr, [3, 10, 6]), ENode(call, mapJoin, Expr, [3, 6, 10]), 0], (IndexAnalysis = Base.RefValue{Any}(nothing),))

In [73]:
t = @theory a b c begin
    a * b == b * a
    a * 1 == a
    a * (b * c) == (a * b) * c
    a * (b + c) == a * b + a * c
end

g = EGraph(:((a * b) * (1 * (b + c))));
report = saturate!(g, t);
extract!(g, astsize)

:((b + c) * (b * a))