Skip to content

Commit

Permalink
add safe keyword, opposite of unsafe(left/right)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Dec 19, 2020
1 parent ebf5c70 commit 89b2318
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ function parse_options(exs...)
)
expr = nothing
nograd = Symbol[]
safe = Symbol[]
ranges = Tuple[]
for ex in exs
# Actual options:
Expand All @@ -160,6 +161,16 @@ function parse_options(exs...)
throw("this accepts nograd=A or nograd=(A,B,C)")
end

# Safe keyword
elseif isexpr(ex, :(=)) && ex.args[1] == :safe
if ex.args[2] isa Symbol
push!(safe, ex.args[2])
elseif isexpr(ex.args[2], :tuple)
append!(safe, ex.args[2].args)
else
throw("this accepts safe=i or safe=(i,j,k)")
end

# Ranges specified outside:
elseif isexpr(ex, :call) && ex.args[1] in [:in, :]
push!(ranges, (ex.args[2], ex.args[3]))
Expand Down Expand Up @@ -201,6 +212,7 @@ function parse_options(exs...)
cuda=opts[:cuda],
tensor=opts[:tensor],
nograd=nograd,
safe=safe,
), ranges, expr
end

Expand Down Expand Up @@ -586,7 +598,7 @@ detectunsafe(expr, list, store) = MacroTools_postwalk(expr) do ex
MacroTools_postwalk(i) do x
@capture_(x, B_[inner__]) || return x
# Now we have found an array which indexes another one, mark its indices unsafe
append!(list, filter(j -> j isa Symbol, inner))
append!(list, setdiff(filter(j -> j isa Symbol, inner), store.safe))
unique!(list)
# and don't compute a gradient for the inner array
B isa Symbol && push!(store.nograd, B)
Expand Down

0 comments on commit 89b2318

Please sign in to comment.