Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@rewrite for arbitrary operation/function #253

Open
mattsignorelli opened this issue Dec 26, 2023 · 7 comments
Open

@rewrite for arbitrary operation/function #253

mattsignorelli opened this issue Dec 26, 2023 · 7 comments

Comments

@mattsignorelli
Copy link

mattsignorelli commented Dec 26, 2023

Does MutableArithmetics not work for general functions, only for the standard arithmetic operations (+,-,/,*)? I have a mutable type that overrides many of the Base functions (sin, cos, abs, sqrt, csc) as well as some others not in base (sinhc). I'd really like to use this interface and the @rewrite macro to speed up evaluation of expressions, but I've found it doesn't work for these functions.

My code for sin for example is

mutability(::Type{TPS}) = IsMutable()

function promote_operation(::typeof(sin), ::Type{TPS}, ::Type{TPS}) 
  return TPS
end

function operate!(::typeof(sin), a::TPS)
  mad_tpsa_sin!(a.tpsa, a.tpsa)
  return a
end

function operate_to!(output::TPS, ::typeof(sin), a::TPS)
  mad_tpsa_sin!(a.tpsa, output.tpsa)
  return output
end
@mattsignorelli
Copy link
Author

mattsignorelli commented Dec 26, 2023

Specifically, there is no speedup nor reduction in the memory allocation for evaluation of say

t = @rewrite sin(x*sin(y)) + sin(z)

@mattsignorelli mattsignorelli changed the title @rewrite for including operations for any function @rewrite for arbitrary operation/function Dec 26, 2023
@odow
Copy link
Member

odow commented Dec 27, 2023

Does MutableArithmetics not work for general functions, only for the standard arithmetic operations (+,-,/,*)?

Correct the @rewrite macro works only for a limited subset of the standard arithmetic operations.

@mattsignorelli
Copy link
Author

Correct the @rewrite macro works only for a limited subset of the standard arithmetic operations.

Would it be a lot of work to generalize this macro for any overloaded function? It looks like other parts of the interface can handle arbitrary functions (seeing abs implemented for BigInt), so there could be a lot of performance benefits for expressions with a lot of sin, cos, sqrt, log, etc

@odow
Copy link
Member

odow commented Dec 27, 2023

The macro is defined here (x-ref #254):

"""
@rewrite(expr, move_factors_into_sums = false)
Return the value of `expr`, exploiting the mutability of the temporary
expressions created for the computation of the result.
If you have an `Expr` as input, use [`rewrite_and_return`](@ref) instead.
See [`rewrite`](@ref) for an explanation of the keyword argument.
!!! info
Passing `move_factors_into_sums` after a `;` is not supported. Use a `,`
instead.
"""
macro rewrite(args...)
@assert 1 <= length(args) <= 2
if length(args) == 1
return rewrite_and_return(args[1]; move_factors_into_sums = true)
end
@assert Meta.isexpr(args[2], :(=), 2) &&
args[2].args[1] == :move_factors_into_sums
return rewrite_and_return(args[1]; move_factors_into_sums = args[2].args[2])
end

It isn't very complicated.

Here's the actual rewrites:

"""
_rewrite_generic(stack::Expr, expr::Expr)
This method is the heart of the rewrite logic. It converts `expr` into a mutable
equivalent.
"""
function _rewrite_generic(stack::Expr, expr::Expr)
if !Meta.isexpr(expr, :call)
# In situations like `x[i]`, we do not attempt to rewrite. Return `expr`
# and don't let future callers mutate.
return esc(expr), false
elseif Meta.isexpr(expr, :call, 1)
# A zero-argument function
return esc(expr), false
elseif Meta.isexpr(expr.args[2], :(...))
# If the first argument is a splat.
return esc(expr), false
elseif _is_generator(expr) || _is_flatten(expr) || _is_parameters(expr)
if !(expr.args[1] in (:sum, , :∑))
# We don't know what this is. Return the expression and don't let
# future callers mutate.
return esc(expr), false
end
# This is a generator expression like `sum(i for i in args)`. Generators
# come in two forms: `sum(i for i=I, j=J)` or `sum(i for i=I for j=J)`.
# The latter is a `:flatten` expression and needs additional handling,
# but we delay this complexity for _rewrite_generic_generator.
if Meta.isexpr(expr.args[2], :parameters)
# The summation has keyword arguments. We can deal with `init`, but
# not any of the others.
p = expr.args[2]
if length(p.args) == 1 && _is_kwarg(p.args[1], :init)
# sum(iter ; init) form!
root = gensym()
init, _ = _rewrite_generic(stack, p.args[1].args[2])
push!(stack.args, :($root = $init))
return _rewrite_generic_generator(stack, :+, expr.args[3], root)
else
# We don't know how to deal with this
return esc(expr), false
end
else
# Summations use :+ as the reduction operator.
init_expr = expr.args[2].args[end]
if Meta.isexpr(init_expr, :(=)) && init_expr.args[1] == :init
# sum(iter, init) form!
root = gensym()
init, _ = _rewrite_generic(stack, init_expr.args[2])
push!(stack.args, :($root = $init))
new_expr = copy(expr.args[2])
pop!(new_expr.args)
return _rewrite_generic_generator(stack, :+, new_expr, root)
elseif Meta.isexpr(expr.args[2], :flatten)
# sum(iter for iter, init) form!
first_generator = expr.args[2].args[1].args[1]
init_expr = first_generator.args[end]
if Meta.isexpr(init_expr, :(=)) && init_expr.args[1] == :init
root = gensym()
init, _ = _rewrite_generic(stack, init_expr.args[2])
push!(stack.args, :($root = $init))
new_expr = copy(expr.args[2])
pop!(new_expr.args[1].args[1].args)
return _rewrite_generic_generator(stack, :+, new_expr, root)
end
end
return _rewrite_generic_generator(stack, :+, expr.args[2])
end
end
# At this point, we have an expression like `op(args...)`. We can either
# choose to convert the operation to it's mutable equivalent, or return the
# non-mutating operation.
if expr.args[1] == :+
# +(args...) => add_mul(add_mul(arg1, arg2), arg3)
@assert length(expr.args) > 1
if length(expr.args) == 2 # +(arg)
return _rewrite_generic(stack, expr.args[2])
elseif length(expr.args) == 3 && _is_call(expr.args[3], :*)
# +(x, *(y...)) => add_mul(x, y...)
x, is_mutable = _rewrite_generic(stack, expr.args[2])
rhs = if is_mutable
Expr(:call, operate!!, add_mul, x)
else
Expr(:call, operate, add_mul, x)
end
for i in 2:length(expr.args[3].args)
yi, _ = _rewrite_generic(stack, expr.args[3].args[i])
push!(rhs.args, yi)
end
root = gensym()
push!(stack.args, :($root = $rhs))
return root, true
end
return _rewrite_generic_to_nested_op(stack, expr, add_mul)
elseif expr.args[1] == :-
# -(args...) => sub_mul(sub_mul(arg1, arg2), arg3)
@assert length(expr.args) > 1
if length(expr.args) == 2 # -(arg)
return _rewrite_generic(stack, Expr(:call, :*, -1, expr.args[2]))
end
return _rewrite_generic_to_nested_op(stack, expr, sub_mul)
elseif expr.args[1] == :*
# *(args...) => *(*(arg1, arg2), arg3)
@assert length(expr.args) > 2
arg1, is_mutable = _rewrite_generic(stack, expr.args[2])
arg2, _ = _rewrite_generic(stack, expr.args[3])
rhs = if is_mutable
Expr(:call, operate!!, *, arg1, arg2)
else
Expr(:call, operate, *, arg1, arg2)
end
root = gensym()
push!(stack.args, :($root = $rhs))
for i in 4:length(expr.args)
arg, _ = _rewrite_generic(stack, expr.args[i])
rhs = if is_mutable
Expr(:call, operate!!, *, root, arg)
else
Expr(:call, operate, *, root, arg)
end
root = gensym()
push!(stack.args, :($root = $rhs))
end
return root, is_mutable
elseif expr.args[1] == :.+
# .+(args...) => add_mul.(add_mul.(arg1, arg2), arg3)
@assert length(expr.args) > 1
if length(expr.args) == 2 # +(arg)
return _rewrite_generic(stack, expr.args[2])
end
return _rewrite_generic_to_nested_op(
stack,
expr,
add_mul;
broadcast = true,
)
elseif expr.args[1] == :.-
# .-(args...) => sub_mul.(sub_mul.(arg1, arg2), arg3)
@assert length(expr.args) > 1
if length(expr.args) == 2 # .-(arg)
return _rewrite_generic(stack, Expr(:call, :.*, -1, expr.args[2]))
end
return _rewrite_generic_to_nested_op(
stack,
expr,
sub_mul;
broadcast = true,
)
else
# Use the non-mutating call.
result = Expr(:call, esc(expr.args[1]))
for i in 2:length(expr.args)
arg, _ = _rewrite_generic(stack, expr.args[i])
push!(result.args, arg)
end
root = gensym()
push!(stack.args, Expr(:(=), root, result))
# This value isn't safe to mutate, because it might be a reference to
# another object.
return root, false
end
end

I don't know if we want to add a generalized rewrite. The main purpose of MutableArithmetics is for JuMP. @blegat is the one who would need to decide.

@mattsignorelli
Copy link
Author

Thanks. Worst case I suppose I could import and modify it. The use case is for GTPSA.jl, a package wrapping a C library for manipulating truncated power series. Each mutable struct is a truncated power series and all the Base math functions are overloaded. Currently all the intermediate values in an expression allocate a new struct. Preliminary testing with using a preallocated temporary buffer shows quite a big speedup. So using a macro or some other method that could do this would be very advantageous

@blegat
Copy link
Member

blegat commented Dec 28, 2023

I'm not opposed to adding support for these unary functions. PR welcome

@odow
Copy link
Member

odow commented Dec 28, 2023

We'd just need to be very careful to ensure that the new rewrite doesn't break JuMP's nonlinear code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

3 participants