Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dual numbers faster than reals?? #57

Closed
KristofferC opened this issue Sep 24, 2015 · 6 comments
Closed

Dual numbers faster than reals?? #57

KristofferC opened this issue Sep 24, 2015 · 6 comments

Comments

@KristofferC
Copy link
Collaborator

This is not strictly related to this package but I thought maybe one of you could explain this, just close else. I was playing around a bit with the short implementation of dual numbers I found in @mlubin's Github page. I then wanted to benchmark a bit to see performance differences. I then get results that calling the function with dual numbers is faster than with normal reals (??).

First, here is the short implementation

importall Base
immutable Dual{T} <: Number
    re::T
    ɛ::T
end
real(z::Dual) = z.re
dual(z::Dual) = z.ɛ;


(+)(x::Dual,y::Dual) = Dual(real(x)+real(y), dual(x)+dual(y))
(-)(x::Dual,y::Dual) = Dual(real(x)-real(y), dual(x)-dual(y))
(*)(x::Dual,y::Dual) = Dual(real(x)*real(y), real(x)*dual(y)+real(y)*dual(x))
(/)(x::Dual,y::Dual) = Dual(real(x)/real(y), (dual(x)*real(y)-real(x)*dual(y))/(real(y)*real(y)))
exp(x::Dual) = Dual(exp(real(x)), dual(x)*exp(real(x)))
sin(x::Dual) = Dual(sin(real(x)), dual(x)*cos(real(x)))
cos(x::Dual) = Dual(cos(real(x)), -dual(x)*sin(real(x)));

promote_rule{S<:Real,T<:Real}(::Type{Dual{S}},::Type{T}) = Dual{promote_type(T,S)};
convert{T<:Real}(::Type{Dual{T}}, x::Real) = Dual(convert(T,x), zero(T));

const ɛ = Dual(0.0, 1.0);

This is the trial function I used

f(x) = exp(x) / (cos(x)^3 + sin(x)^3)

And the benchmarking:

Pkg.clone("https://github.com/johnmyleswhite/Benchmarks.jl")
using Benchmarks

@benchmark f/4 + ɛ)

@benchmark f/4 + im)

@benchmark f/4)

which gives

julia> @benchmark f/4 + ɛ)
================ Benchmark Results ========================
     Time per evaluation: 104.91 ns [104.65 ns, 105.17 ns]
   Number of evaluations: 3726101
 Time spent benchmarking: 0.52 s


julia> @benchmark f/4 + im)
================ Benchmark Results ========================
     Time per evaluation: 315.03 ns [314.58 ns, 315.48 ns]
   Number of evaluations: 1437201
 Time spent benchmarking: 0.54 s


julia> @benchmark f/4)
================ Benchmark Results ========================
     Time per evaluation: 198.68 ns [196.78 ns, 200.58 ns]
   Number of evaluations: 2314101
 Time spent benchmarking: 0.50 s

Am I making some obvious mistake here because it seems that the dual numbers are significantly faster than the real numbers. That can't be possible...

@StefanKarpinski
Copy link

Here's the LLVM code I'm seeing:

julia> @code_llvm f/4)

define double @julia_f_22987(double) {
top:
  %1 = call double inttoptr (i64 13147291680 to double (double)*)(double %0)
  %2 = call double inttoptr (i64 13147331552 to double (double)*)(double %0)
  %3 = call double inttoptr (i64 13147347792 to double (double)*)(double %0)
  %4 = fcmp ord double %2, 0.000000e+00
  %5 = fcmp uno double %0, 0.000000e+00
  %6 = or i1 %4, %5
  br i1 %6, label %pass, label %fail

fail:                                             ; preds = %top
  %7 = load %jl_value_t** @jl_domain_exception, align 8
  call void @jl_throw_with_superfluous_argument(%jl_value_t* %7, i32 1)
  unreachable

pass:                                             ; preds = %top
  %8 = call double @pow(double %2, double 3.000000e+00)
  %9 = fcmp ord double %3, 0.000000e+00
  %10 = or i1 %9, %5
  br i1 %10, label %pass4, label %fail3

fail3:                                            ; preds = %pass
  %11 = load %jl_value_t** @jl_domain_exception, align 8
  call void @jl_throw_with_superfluous_argument(%jl_value_t* %11, i32 1)
  unreachable

pass4:                                            ; preds = %pass
  %12 = call double @pow(double %3, double 3.000000e+00)
  %13 = fadd double %8, %12
  %14 = fdiv double %1, %13
  ret double %14
}

julia> @code_llvm f/4 + ɛ)

define void @julia_f_22763(%Dual* sret, %Dual*) {
top:
  %2 = bitcast %Dual* %1 to double*
  %3 = load double* %2, align 8
  %4 = call double inttoptr (i64 13147291680 to double (double)*)(double %3)
  %5 = load double* %2, align 8
  %6 = call double inttoptr (i64 13147291680 to double (double)*)(double %5)
  %7 = load double* %2, align 8
  %8 = call double inttoptr (i64 13147331552 to double (double)*)(double %7)
  %9 = load double* %2, align 8
  %10 = call double inttoptr (i64 13147347792 to double (double)*)(double %9)
  %11 = fcmp ord double %8, 0.000000e+00
  %12 = fcmp uno double %7, 0.000000e+00
  %13 = or i1 %11, %12
  br i1 %13, label %pass, label %fail

fail:                                             ; preds = %top
  %14 = load %jl_value_t** @jl_domain_exception, align 8
  call void @jl_throw_with_superfluous_argument(%jl_value_t* %14, i32 1)
  unreachable

pass:                                             ; preds = %top
  %15 = fcmp ord double %10, 0.000000e+00
  %16 = fcmp uno double %9, 0.000000e+00
  %17 = or i1 %15, %16
  br i1 %17, label %pass2, label %fail1

fail1:                                            ; preds = %pass
  %18 = load %jl_value_t** @jl_domain_exception, align 8
  call void @jl_throw_with_superfluous_argument(%jl_value_t* %18, i32 1)
  unreachable

pass2:                                            ; preds = %pass
  %19 = bitcast %Dual* %1 to double*
  %20 = alloca %Dual, align 8
  %21 = alloca %Dual, align 8
  %22 = getelementptr inbounds %Dual* %1, i64 0, i32 1
  %23 = load double* %22, align 8
  %24 = fmul double %23, -1.000000e+00
  %25 = insertvalue %Dual undef, double %8, 0
  %26 = fmul double %10, %24
  %27 = insertvalue %Dual %25, double %26, 1
  store %Dual %27, %Dual* %21, align 8
  call void @julia_power_by_squaring_22764(%Dual* sret %20, %Dual* %21, i64 3)
  %28 = load %Dual* %20, align 8
  %29 = load double* %19, align 8
  %30 = call double inttoptr (i64 13147347792 to double (double)*)(double %29)
  %31 = load double* %19, align 8
  %32 = call double inttoptr (i64 13147331552 to double (double)*)(double %31)
  %33 = fcmp ord double %30, 0.000000e+00
  %34 = fcmp uno double %29, 0.000000e+00
  %35 = or i1 %33, %34
  br i1 %35, label %pass4, label %fail3

fail3:                                            ; preds = %pass2
  %36 = load %jl_value_t** @jl_domain_exception, align 8
  call void @jl_throw_with_superfluous_argument(%jl_value_t* %36, i32 1)
  unreachable

pass4:                                            ; preds = %pass2
  %37 = fcmp ord double %32, 0.000000e+00
  %38 = fcmp uno double %31, 0.000000e+00
  %39 = or i1 %37, %38
  br i1 %39, label %pass6, label %fail5

fail5:                                            ; preds = %pass4
  %40 = load %jl_value_t** @jl_domain_exception, align 8
  call void @jl_throw_with_superfluous_argument(%jl_value_t* %40, i32 1)
  unreachable

pass6:                                            ; preds = %pass4
  %41 = alloca %Dual, align 8
  %42 = alloca %Dual, align 8
  %43 = extractvalue %Dual %28, 0
  %44 = extractvalue %Dual %28, 1
  %sunkaddr = ptrtoint %Dual* %1 to i64
  %sunkaddr13 = add i64 %sunkaddr, 8
  %sunkaddr14 = inttoptr i64 %sunkaddr13 to double*
  %45 = load double* %sunkaddr14, align 8
  %46 = insertvalue %Dual undef, double %30, 0
  %47 = fmul double %32, %45
  %48 = insertvalue %Dual %46, double %47, 1
  store %Dual %48, %Dual* %42, align 8
  call void @julia_power_by_squaring_22764(%Dual* sret %41, %Dual* %42, i64 3)
  %49 = load %Dual* %41, align 8
  %50 = extractvalue %Dual %49, 0
  %51 = extractvalue %Dual %49, 1
  %52 = load double* %sunkaddr14, align 8
  %53 = fmul double %6, %52
  %54 = fadd double %43, %50
  %55 = fadd double %44, %51
  %56 = fdiv double %4, %54
  %57 = insertvalue %Dual undef, double %56, 0
  %58 = fmul double %53, %54
  %59 = fmul double %4, %55
  %60 = fsub double %58, %59
  %61 = fmul double %54, %54
  %62 = fdiv double %60, %61
  %63 = insertvalue %Dual %57, double %62, 1
  store %Dual %63, %Dual* %0, align 8
  ret void
}

It seems that the real version is calling the pow intrinsic while the Julia version is calling power_by_squaring. They're getting the same answer, which is reassuring, but it's a bit crazy that doing the power by squaring thing could be faster than calling pow.

@StefanKarpinski
Copy link

Also, installing all of that code was insanely easy. Great work packaging, everyone!

@KristofferC
Copy link
Collaborator Author

Ah, that's the difference of course, I haven't implemented ^ for Dual numbers so it falls back to power_by_squaring.

@KristofferC
Copy link
Collaborator Author

Writing out the powers as explicit multiplication gives:

f(x) =  exp(x) / (cos(x)*cos(x)*cos(x) + sin(x)*sin(x)*sin(x))

julia> @benchmark f/4 + ɛ)
================ Benchmark Results ========================
     Time per evaluation: 90.01 ns [88.58 ns, 91.44 ns]


julia> @benchmark f/4 + im)
================ Benchmark Results ========================
     Time per evaluation: 604.65 ns [602.47 ns, 606.82 ns]


julia> @benchmark f/4)
================ Benchmark Results ========================
     Time per evaluation: 81.51 ns [81.30 ns, 81.73 ns]

where reals are a bit faster than duals and complex are much(!) slower.

@KristofferC
Copy link
Collaborator Author

Okay, this is the reason I think:

julia> @benchmark sin/4)^3
================ Benchmark Results ========================
     Time per evaluation: 87.58 ns [86.45 ns, 88.72 ns]


julia> @benchmark sin/4)^3.0
================ Benchmark Results ========================
     Time per evaluation: 22.57 ns [22.27 ns, 22.87 ns]

Changing the exponents in f to floats like:

f(x) = exp(x) / (cos(x)^3.0 + sin(x)^3.0)

makes the float version ~4 times faster than dual.

You heard it here first folks, start writing your exponents as floats!

@KristofferC
Copy link
Collaborator Author

Closing this, since it turned out to be irrelevant to Dual numbers or Automatic Differentiation, sorry for making noise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants