Skip to content

Commit

Permalink
add reduce macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Jan 4, 2019
1 parent f8e509b commit 1e2920b
Showing 1 changed file with 102 additions and 7 deletions.
109 changes: 102 additions & 7 deletions src/TensorSlice.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
module TensorSlice

export @shape, @pretty
export @shape, @reduce, @pretty

using MacroTools

############################### THE MACRO ################################

V = false # capital V for extremely verbose
W = false # printouts where I'm working on things

Expand Down Expand Up @@ -40,8 +38,8 @@ macro shape(expr, rex=nothing)
sign = :(:=)
elseif @capture(expr, left_ == right_ )
sign = :(==)
# elseif @capture(expr, left_ |= right_ )
# sign = :(|=)
elseif @capture(expr, left_ |= right_ )
sign = :(|=)
else
throw(ArgumentError("@shape can't begin to understand $expr"))
end
Expand Down Expand Up @@ -91,6 +89,90 @@ macro shape(expr, rex=nothing)
throw(ArgumentError("@shape can't understand left hand side $left"))
end

tensor_slice_main(outn, oind, oi, # parsing of LHS in progress
willslice, willstaticslice, # flags from LHS of @shape
false, nothing, # info from LHS of @reduce
sign, right, rex, # RHS still to be done
usingstatic, usingstrided) # what's loaded
end

"""
@reduce A[j] := sum(i,k) B[i,j,k]
Tensor reduction macro.
* The right-hand-side can be anything that `@shape` understands, including gluing of slices, and reshaping.
* The reduction funcition can be anything which works like `sum(B, dims=(1,3))`.
* The left-hand side again works as in `@shape`, although slicing is not allowed.
* Only creation `:=` is currently working (although in-place `=` will eventually).
"""
macro reduce(expr, right, rex=nothing)

if @capture(expr, left_ = red_ )
sign = :(=)
error("@reduce doesn't work for in-place operations yet, try :=")
elseif @capture(expr, left_ := red_ )
sign = :(:=)
else
throw(ArgumentError("@reduce can't begin to understand $expr"))
end

V && println("sign = ", sign)#, " of type ", typeof(sign))
V && println("left = ", left)
V && println("right = ", red)
V && println("right = ", right)
V && println("rex = ", something(rex,"nothing"))

## test what packages are loaded
usingstrided = isdefined(TensorSlice, :Strided)
usingstatic = isdefined(TensorSlice, :StaticArrays)

if rex != nothing
@capture(rex, (rvec__,)) || (rvec = Any[rex])
for r in rvec
if r == :(_) # this flag means use base only
usingstrided = false
usingstatic = false
end
end
end

Vvec = [usingstatic ? "static" : missing, usingstrided ? "strided" : missing]
V && println("using = ", join(skipmissing(Vvec), " & "))

#==================== LEFT = OUTPUT = STEPS 4,5 ====================#
# Do this first to get a list of indices in canonical order, oflat

if @capture(left, ( outn_[oind__] | [oind__] ) )
else
throw(ArgumentError("@reduce can't understand left hand side $left"))
end

if @capture(red, redfun_(oi__) )
V && println("will reduce using redfun = ", redfun, ", inner indices = ", oi)
else
throw(ArgumentError("@reduce can't understand reduction formula $red"))
end

tensor_slice_main(outn, oind, oi, # parsing of LHS in progress
false, false, # flags from LHS of @shape
true, redfun, # info from LHS of @reduce
sign, right, rex, # RHS still to be done
usingstatic, usingstrided) # what's loaded
end

############################### GIANT MONO-FUNCTION ################################

function tensor_slice_main(outn, oind, oi,
willslice, willstaticslice, # flags from LHS of @shape
willreduce, redfun, # info from LHS of @reduce
sign, right, rex, # RHS still to be done
usingstatic, usingstrided) # what's loaded

if willreduce && willslice
error("impossible!")
end

#==================== LEFT = OUTPUT, CONTINUED ====================#

if outn == nothing
hasoutputname = false
outn = gensym()
Expand Down Expand Up @@ -137,7 +219,14 @@ macro shape(expr, rex=nothing)
V && println("will reshape output, osz = ",osz)
end

if length(oi)==1
odims = 1 # prettier output, but may make permutedims cleverness harder
else
odims = Tuple(1:length(oi))
end

V && (willslice || willstaticslice) && println("ocode = ", ocode, " for slicing")
V && willreduce && println("odims = ", odims, " for reduction over oi = ",oi)
V && println("oind = ", oind) # oind lists indices whose lengths are some size(A,d)
V && println("oflat = ", oflat, ", this is the canonical order")

Expand Down Expand Up @@ -389,7 +478,7 @@ macro shape(expr, rex=nothing)

needcopy = !havecopied && sign == :(:=)

## 4. SLICE
## 4. SLICE, OR REDUCE
if willstaticslice
if needcopy
ex = :( copy(ex) )
Expand All @@ -409,7 +498,13 @@ macro shape(expr, rex=nothing)
else
ex = :( sliceview($ex, $ocode) )
end
V && println("step 4: ex = ",ex)
V && println("step 4 slice: ex = ",ex)

elseif willreduce
redfun = esc(redfun)
ex = :( dropdims( $redfun($ex, dims = $odims), dims = $odims) )
havecopied = true
V && println("step 4 reduction: ex = ",ex)
end

## 5. RESHAPE
Expand Down

0 comments on commit 1e2920b

Please sign in to comment.