# Useful Julia concepts for automatic differentiation

In [1]:
importall Base # allows us to extend the built-in definitions of operations
# define dual number type
immutable Dual{T} <: Number
    re::T
    ϵ::T
end
real(z::Dual) = z.re
epsilon(z::Dual) = z.ϵ;

In [2]:
d = Dual(2.0,1.0)

Dual{Float64}(2.0,1.0)

In [3]:
real(d)

2.0

In [4]:
epsilon(d)

1.0

In [5]:
# Basic operations
(+)(x::Dual,y::Dual) = Dual(real(x)+real(y),epsilon(x)+epsilon(y))
(-)(x::Dual,y::Dual) = Dual(real(x)-real(y),epsilon(x)-epsilon(y))
(*)(x::Dual,y::Dual) = Dual(real(x)*real(y), real(x)*epsilon(y)+real(y)*epsilon(x))
(/)(x::Dual,y::Dual) = Dual(real(x)/real(y), (epsilon(x)*real(y)-real(x)*epsilon(y))/(real(y)*real(y)))
abs(x::Dual) = Dual(abs(real(x)),epsilon(x)*sign(real(x)))
sin(x::Dual) = Dual(sin(real(x)),epsilon(x)*cos(real(x)));

In [6]:
Dual(2.0,1.0)*Dual(3.0,5.0)

Dual{Float64}(6.0,13.0)

In [7]:
# glue code
promote_rule{S<:Real,T<:Real}(::Type{Dual{S}},::Type{T}) = Dual{promote_type(T,S)}
convert{T<:Real}(::Type{Dual{T}},x::Real) = Dual(convert(T,x),zero(T));

In [8]:
Dual(1.0,2.0)*2.0

Dual{Float64}(2.0,4.0)

In [9]:
3*Dual(1.0,2.0)

Dual{Float64}(3.0,6.0)

### We can now perform linear algebra with ``Dual`` types.

Let's solve the linear system:
$$
\left[\begin{array}{cc} 1 & 0 \\ 0 & 1\end{array}\right] \left[\begin{array}{c} x_1 \\ x_2 \end{array}\right] = 
\left[\begin{array}{c} 1+2\epsilon \\ 2+1\epsilon \end{array}\right]
$$

In [10]:
eye(2)\[Dual(1.0,2.0),Dual(2.0,1.0)]

2-element Array{Dual{Float64},1}:
 Dual{Float64}(1.0,2.0)
 Dual{Float64}(2.0,1.0)

### User-defined types are fast!

In [11]:
@code_native 2*3

	.text
Filename: int.jl
Source line: 19
	pushq	%rbp
	movq	%rsp, %rbp
Source line: 19
	imulq	%rsi, %rdi
	movq	%rdi, %rax
	popq	%rbp
	ret


In [12]:
@code_native Dual(1.0,2.0)*Dual(2.0,3.0)

	.text
Filename: In[5]
Source line: 4
	pushq	%rbp
	movq	%rsp, %rbp
Source line: 4
	vmovsd	(%rsi), %xmm0
Source line: 4
	vmulsd	8(%rdi), %xmm0, %xmm1
Source line: 4
	vmovsd	(%rdi), %xmm2
Source line: 4
	vmulsd	8(%rsi), %xmm2, %xmm3
	vaddsd	%xmm1, %xmm3, %xmm1
	vmulsd	%xmm0, %xmm2, %xmm0
	popq	%rbp
	ret


### Dual numbers interact transparently with user-defined functions.

In [13]:
f(x) = 2x+3sin(x)

f (generic function with 1 method)

In [14]:
function derivative(f,x)
    y = f(Dual(x,1.0))
    return epsilon(y)
end

derivative (generic function with 1 method)

In [15]:
derivative(f, 2.0) - (2+3cos(2.0))

0.0

Julia compiles a specialized version of ``f`` according to the input type.

Compare 

``@code_native f(1.0)``

and

``@code_native f(Dual(1.0,1.0))``.

In [16]:
@code_native f(1.0)

	.text
Filename: In[13]
Source line: 1
	pushq	%rbp
	movq	%rsp, %rbp
	subq	$16, %rsp
	vmovsd	%xmm0, -8(%rbp)
	movabsq	$sin, %rax
Source line: 1
	callq	*%rax
	vmovsd	-8(%rbp), %xmm1
	vucomisd	%xmm0, %xmm0
	jp	L68
L40:	movabsq	$139889244385568, %rax  # imm = 0x7F3A80B83120
	vmulsd	(%rax), %xmm0, %xmm0
	vaddsd	%xmm1, %xmm1, %xmm1
	vaddsd	%xmm0, %xmm1, %xmm0
	addq	$16, %rsp
	popq	%rbp
	ret
L68:	vucomisd	%xmm1, %xmm1
	jp	L40
	movabsq	$jl_domain_exception, %rax
	movq	(%rax), %rdi
	movabsq	$jl_throw_with_superfluous_argument, %rax
	movl	$1, %esi
	callq	*%rax


In [17]:
function squareroot(x)
    z = x # Initial starting point
    while abs(z*z - x) > 1e-13
        z = z - (z*z-x)/(2z)
    end
    return z
end

squareroot (generic function with 1 method)

In [18]:
squareroot(10.0)

3.1622776601683795

In [19]:
# define comparison with Dual numbers so that control flow works as expected
(>)(x::Dual,y::Real) = real(x) > y;

In [24]:
derivative(squareroot,10.0)

0.15811388300841897

In [25]:
1/(2*sqrt(10))

0.15811388300841897

In [35]:
@time sqrt(10.0)

  

3.1622776601683795

0.000003 seconds (5 allocations: 176 bytes)


In [36]:
@time squareroot(10.0)

  

3.1622776601683795

0.000003 seconds (5 allocations: 176 bytes)


In [37]:
@time derivative(squareroot,10.0)

  

0.15811388300841897

0.000005 seconds (7 allocations: 240 bytes)


#### The Julia compiler is able to generate efficient code because it uses *type inference* to decide the types of all computations (if possible).

In [45]:
Base.return_types(squareroot,(Float64,))[1]

Float64

In [46]:
Base.return_types(squareroot,(Dual{Float64},))[1]

Dual{Float64}

## Bulding blocks for reverse mode

- Operator overloading
- Built-in expression manipulation, metaprogramming
- Source code access
- Dynamic code generation

In [47]:
expr = :(sin(x)+y)

:(sin(x) + y)

In [48]:
dump(expr)

Expr 
  head: Symbol call
  args: Array(Any,(3,))
    1: Symbol +
    2: Expr 
      head: Symbol call
      args: Array(Any,(2,))
        1: Symbol sin
        2: Symbol x
      typ: Any
    3: Symbol y
  typ: Any


### Macros are compile-time source transformations

Useful for:
- Accessing expression graphs without operator overloading
- Implementing domain-specific languages

In [49]:
macro IMeantToSubtract(expr)
    expr.args[1] = :-
    return esc(expr)
end

In [50]:
@IMeantToSubtract 10+1

9

In [51]:
macroexpand(:(@IMeantToSubtract 10+1))

:(10 - 1)

### Access to source code of user-provided functions!

In [56]:
fdef = Base.uncompressed_ast(methods(squareroot,(Float64,))[1].func.code)

:($(Expr(:lambda, Any[:x], Any[Any[Any[:x,:Any,0],Any[:z,:Any,2]],Any[],0,Any[]], :(begin  # In[17], line 2:
        z = x # line 3:
        unless (Main.abs)(z * z - x) > 1.0e-13 goto 1
        2:  # line 4:
        z = z - (z * z - x) / (2z)
        3: 
        unless (top(!))((Main.abs)(z * z - x) > 1.0e-13) goto 2
        1: 
        0:  # line 6:
        return z
    end))))

### The compiler is a run-time utility (dynamic code generation)

In [3]:
function create_function(name,expr)
    f_expr = quote
        function ($name)(x)
            return $expr
        end
    end
    eval(f_expr)
end;

create_function (generic function with 1 method)

In [6]:
create_function(:sqr, :(x*x));

sqr (generic function with 1 method)

In [7]:
sqr(2.0)

4.0

In [8]:
@code_llvm sqr(2.0)


define double @julia_sqr_21786(double) {
top:
  %1 = fmul double %0, %0
  ret double %1
}
