Skip to content

konn/ad-delcont-primop

Repository files navigation

ad-delcont-primop

An attempt to implement Reverse-Mode AD in terms of delcont primops introduced in GHC 9.6.

That is, it reimplements ad-delcont, which translates Scala implementation of Backpropagation with Continuation Callbacks, in terms of newPromptTag#, prompt#, and control0#.

Performance

Summary

  • In computing multivariate gradients: in most cases, our implementation is at most slightly faster than Edward Kmett's ad. In some cases, ours is 4x-10x faster.
  • To differentiate univariate functions, always use ad as it uses Forward-mode.
  • Our implementation in most cases outperforms backprop and ad-delcont (monad transformer-based impl).

Legends

  • transformers: ad-delcont
  • ad: ad, generic functions from Numeric.AD
  • ad/double: ad, Double-specialised functions provided in Numeric.AD.Double.
  • backprop: backprop
  • primop: our generic implementation.
  • primop/double: our implementation specialised for Double.

Univariate Differentiation

Identity Function: $f(x) = x$

ad wins

Binomial: $(x + 1)(x + 1)$

ad wins

Gauß-like: $x e^{x^2 + 1}$

ad wins

Bivariate

Addition: $f(x, y) = x + y$

we win by 10!

Trigonometrics: $f(x,y) = \sin x \cos y (x^2 + y)$

we are 4x faster!

Exponentials: $f(x, y) = y e^{x^2 + y}$

still 4x faster!

Exponentials and Trigonometrics: $f(x, y) = (x \cos x + y)^2 e^{x \sin (x + y^2 + 1)}$

twice as fast

Complex formula

$$ f(x, y) = (\tanh (e^y \cosh x) + x ^ 2) ^ 3 - (x \cos x + y) ^ 2 e^{x \sin (x + y ^2 + 1)} $$

1.5x fast

Trivariate

Multiplication: $f(x,y,z) = xyz$

10x fast

Complex

$$ (\tanh (e^{y + z ^ 2} \cosh x) + x ^ 2) ^ 3 - (x (z ^ 2 - 1) \cos x + y)^{2z} e^{x \sin (x + yzx + 1)} $$

1.5x fast

4-ary (quadrivariate)

Multiplication: $f(x,y,z,w) = xyzw$

10x fast

Trigonometrics: $f(x,y,z,w) = (x + w) ^ 4 \exp(x + \cos (y ^ 2 \sin z) w)$

thrice as fast

Some logarithm

$$ f(x,y,z,w) = \log (x ^ 2 + w) / \log (x + w) ^ 4 \exp (x + \cos (y ^ 2 \sin z) w) $$

twice as fast

Some more logarithm

$$ f(x,y,z,w) = \log_{x ^ 2 + w}(\cos (x ^ 2 + 2 z) + w + 1) ^ 4 \exp (x + \sin (\pi x) \cos ((e^y) ^ 2 \sin z) w) $$

slightly faster

Really complex

$$ f(x,y,z,w) = \log_{x ^ 2 + \tanh w} (\cos (x ^ 2 + 2z) + w + 1) ^ 4 + \exp (x + \sin (\pi x + w ^ 2) \cosh ((e^y)^ 2 \sin z) ^ 2 (w + 1)) $$

slightly faster

TODOs

  • :checkmark: Explore more fine-grained use of delcont
    • See Numeric.AD.DelCont.MultiPrompt for PoC
    • We can abolish refs except for the ones for the outermost primitive variables
      • perhaps coroutine-like hack can eliminateThis
    • This implementation, however, is not as efficient as STRef-based in terms of time
      • This is because each continuation allocates different values rather than single mutable variable
      • But still in some cases, allocation can be slightly reduced by this approach (need confirmation)
      • In particular, as the # of variable increases, the time overhead seems decaying and allocation becomes slightly fewer
  • Avoids (indirect) references at any costs!
  • Remove Refs from constants
    • This increases both runtime and allocation by twice (see the benchmark log)
    • Branching overhead outweighs

References

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published