# Types in Julia

In this notebook we will start exploring types by implementing a type to implement algorithmic differentiation.

### What is a type?

A *type* can be thought of as a label that is associated with data stored in memory; this label tells Julia how to interpret the data. For example:

In [1]:
x = 1
y = 1.3

1.3

In [2]:
sizeof(x), sizeof(y)

(8, 8)

In [3]:
typeof(x), typeof(y)

(Int64, Float64)

In [5]:
z = reinterpret(Float64, x)

5.0e-324

In [4]:
bitstring(x)

"0000000000000000000000000000000000000000000000000000000000000001"

In [6]:
bitstring(z)

"0000000000000000000000000000000000000000000000000000000000000001"

Both variables are stored in 8 bytes (64 bits), but one is interpreted as an integer and the other as a floating-point number.  Similarly, a pair of two numbers may be intepreted as a complex number, or an interval, or a dual number, or...; although the same information may be stored (two numbers), we want each of these different *kinds* or *types* of objects to be treated differently. 

## Algorithmic differentiation

In the previous notebook we used `ForwardDiff.jl` to automatically differentiate a function. Here we will see how to implement a simple version of this.

The idea is to approximate a (nice enough) function $f$ near a point $a$ by a Taylor series of order 1, i.e. a straight line passing through $(a, f(a))$, with slope equal to the derivative $f'(a)$:
    
$f(x) \simeq f(a) + \epsilon f'(a)$,

where $\epsilon := x - a$.

We now use this to derive the standard rules for the derivative of a sum and product:

$$f(x) + g(x) \simeq [f(a) + g(a)] + \epsilon [f'(a) + g'(a)]$$

$$f(x) \cdot g(x) \simeq [f(a) \cdot g(a)] + \epsilon [f(a) g'(a) + g(a) f'(a)],$$

where we suppose that $\epsilon$ is small enough that $\epsilon^2 = 0$, or alternatively just "take the linear part".

$$[f(a) + \epsilon f'(a)] [g(a) + \epsilon g'(a)]$$

### Defining a composite type

We see that by using just two pieces of information, namely the value $f(a)$ and the derivative $f'(a)$, we can represent a function $f$ near a given point $a$. 

The pair $(f(a), f'(a))$ is often called a **dual number**. We see that is has certain **behaviours** under arithmetic operations. Whenever we have a new behaviour, a *new type is lurking*!

We group the two values into a **composite type**. We can think of a composite type as specifying the structure of a box containing several pieces of information (data) inside. Defining a composite type with two **fields** (pieces of information) has the following syntax:

In [3]:
struct MyType
    a
    b::Int
end

In [2]:
MyType

MyType

Here we have additionally specified that the information stored in the field `b` must be of type `Int` using the **type annotation operator**, `::`.

Creating an object of that type is accomplished as follows:

In [4]:
x = MyType(3, 4)

MyType(3, 4)

In [5]:
x

MyType(3, 4)

In [6]:
typeof(x)

MyType

We can extract information as follows:

In [7]:
x.a

3

In [8]:
x.b

4

In [9]:
x.c

ErrorException: type MyType has no field c

In [11]:
propertynames(x)

(:a, :b)

In [12]:
fieldnames(MyType)

(:a, :b)

In [13]:
(3, 4)

(3, 4)

#### Exercise 1

1. Define a composite type `Dual` with fields `value` and `deriv` of type `Float64`.


2. Create two `Dual` numbers `x` and `y`.


3. What happens if you try to add `x` and `y` together?


4. Make a function `add` that adds `x` and `y` and returns a new `Dual` number, following the rules we found above.

In [35]:
struct Dual
    value::Float64
    deriv::Float64
end

In [37]:
struct Dual2
    value
    deriv
end

In [15]:
x = Dual(3, 4)

Dual(3.0, 4.0)

In [16]:
methods(Dual)

In [17]:
@which Dual(3, 4)

In [18]:
y = Dual(3.1, 4.2)

Dual(3.1, 4.2)

In [19]:
x

Dual(3.0, 4.0)

In [20]:
y

Dual(3.1, 4.2)

In [22]:
x.value

3.0

In [23]:
y.value

3.1

In [24]:
x + y

MethodError: MethodError: no method matching +(::Dual, ::Dual)
Closest candidates are:
  +(::Any, ::Any, !Matched::Any, !Matched::Any...) at operators.jl:502

In [25]:
add

UndefVarError: UndefVarError: add not defined

In [27]:
add(x::Dual, y::Dual) = Dual(x.value + y.value, x.deriv + y.deriv)

add (generic function with 1 method)

In [28]:
add(x, y)

Dual(6.1, 8.2)

## Implementing arithmetic for a type

We would like to be able to use `+` and `*` for our new `Dual` type, rather than typing `add(x, y)`. To do so, we need to do the following

In [29]:
+(x::Dual, y::Dual) = Dual(x.value + y.value, x.deriv + y.deriv)

ErrorException: error in method definition: function Base.+ must be explicitly imported to be extended

In [30]:
import Base: +, *

In [31]:
+(x::Dual, y::Dual) = Dual(x.value + y.value, x.deriv + y.deriv)

+ (generic function with 164 methods)

In [22]:
Base.+(x::Dual, y::Dual) = Dual(x.value + y.value, x.deriv + y.deriv)

ErrorException: syntax: invalid function name ".+"

In [23]:
Base.:+(x::Dual, y::Dual) = Dual(x.value + y.value, x.deriv + y.deriv)

In [None]:
+(x::Dual, y::Dual) = add(x, y)

In [32]:
function +(x::Dual, y::Dual)
    return add(x, y)
end

+ (generic function with 164 methods)

In Julia, `+` and `*` are just functions. They are defined in `Base` (a module containing basic function definitions) and must be `import`ed before being **extended**. They consist of many different **methods** (versions):

In [26]:
@which +(3)

In [27]:
-(3)

-3

In [28]:
methods(+)

In [32]:
x

Dual(3.0, 4.0)

In [33]:
y

Dual(3.1, 4.2)

In [34]:
x + y

Dual(6.1, 8.2)

We can add more methods that work on our own types. (We are not allowed to modify their behaviour on combinations of types that to not contain our user-defined types; doing so is known as "type piracy" and can affect other people's code in unexpected ways.)

#### Exercise 2

1. Import the `+` and `*` functions from `Base` and implement them for the `Dual` type.
They should return a new `Dual` object.


2. Check that the number of methods has changed. 


3. Use `@which x + y` to check that Julia knows which method to use when adding two `Dual`s.


4. Can you define `x + a` for a `Dual` number `x` and a real number `a`? What happens 

In [None]:
include("dual.jl")   # if you put all the definitions in the file "dual.jl"

In [38]:
*(f::Dual, g::Dual) = Dual(f.value * f.value, f.value*g.deriv + g.value*f.deriv)

* (generic function with 350 methods)

In [40]:
using Test

In [42]:
x = Dual(1, 2)
y = Dual(3, 4)

x * y

Dual(1.0, 10.0)

In [43]:
@test x * y == Dual(1, 10)

[32m[1mTest Passed[22m[39m

In [44]:
+(x::Dual, a::Real) = Dual(x.value + a, x.deriv)

+ (generic function with 165 methods)

In [45]:
x

Dual(1.0, 2.0)

In [46]:
x + 3

Dual(4.0, 2.0)

In [47]:
@which x + 3

In [48]:
3 + x

MethodError: MethodError: no method matching +(::Int64, ::Dual)
Closest candidates are:
  +(::Any, ::Any, !Matched::Any, !Matched::Any...) at operators.jl:502
  +(!Matched::Dual, ::Dual) at In[31]:1
  +(::T<:Union{Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8}, !Matched::T<:Union{Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8}) where T<:Union{Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8} at int.jl:53
  ...

In [49]:
+(a::Real, x::Dual) = x + a

+ (generic function with 166 methods)

In [50]:
3 + x

Dual(4.0, 2.0)

In [51]:
@code_warntype 3 + x

Body[36m::Dual[39m
[90m1 ─[39m %1 = (Base.getfield)(x, :value)[36m::Float64[39m
[90m│  [39m %2 = (Base.sitofp)(Float64, a)[36m::Float64[39m
[90m│  [39m %3 = (Base.add_float)(%1, %2)[36m::Float64[39m
[90m│  [39m %4 = (Base.getfield)(x, :deriv)[36m::Float64[39m
[90m│  [39m %5 = %new(Main.Dual, %3, %4)[36m::Dual[39m
[90m└──[39m      return %5


In [52]:
@code_native 3 + x

	.section	__TEXT,__text,regular,pure_instructions
; ┌ @ In[49]:1 within `+' @ In[44]:1 @ promotion.jl:313
; │┌ @ promotion.jl:284 within `promote'
; ││┌ @ promotion.jl:261 within `_promote'
; │││┌ @ number.jl:7 within `convert'
; ││││┌ @ In[49]:1 within `Type'
	vcvtsi2sdl	%esi, %xmm0, %xmm0
; │└└└└
; │ @ In[49]:1 within `+' @ In[44]:1 @ float.jl:395
	vaddsd	(%edx), %xmm0, %xmm0
; │ @ In[49]:1 within `+' @ In[35]:2
	decl	%eax
	movl	8(%edx), %eax
; │ @ In[49]:1 within `+'
	vmovsd	%xmm0, (%edi)
	decl	%eax
	movl	%eax, 8(%edi)
	decl	%eax
	movl	%edi, %eax
	retl
	nopl	(%eax)
; └


Travis CI will run tests each time you update your repository on GitHub.

Using composite types is a **zero-cost abstraction**.

Amazingly, we now have enough to be able to differentiate simple Julia functions involving only `+` and `*`. Define

In [43]:
a = 3.0
xx = Dual(a, 1.0)  # "the identity function x ↦ x, with derivative 1"

Dual(3.0, 1.0)

We initialize the derivative as 1.0 when we make a `Dual`. If we use `x` then we automatically differentiate!

#### Exercise 3

1. Define `a = 3.0` and `xx = Dual(a, 1.0)`.

    (i) Compute `xx + xx`. The result should have the value $2a$ and the derivative $2$ -- write a test that it does so.
    
    (ii) Do the same for `xx * xx`. 
    
    
2. Define the function `f(x) = x * x + x`. Compute `f(xx)` and check that it gives the correct value and derivative!


3. Does this work for the function `f(x) = x^2 + x`?  What do you need to do?


4. What happens for `f(x) = x^2 + 2x`? What do you need to do?


5. What should you do for `f(x) = sin(x) + x`?

In [53]:
a = 3.0
xx = Dual(a, 1.0)

Dual(3.0, 1.0)

In [54]:
xx + xx    # x -> x + x    -- differentiating this at the point x = a

Dual(6.0, 2.0)

In [55]:
xx * xx    # derivative 2x

Dual(9.0, 6.0)

In [57]:
f(x) = x*x + x   # derivative 2x + 1

f (generic function with 1 method)

In [58]:
f(xx)

Dual(12.0, 7.0)

#### Exercise 4

1. Define a function `differentiate` that differentiates a function `f` at a point `a` using `Dual` numbers, by following the above pattern. (It should return just the derivative at the given point.)

This is the basis of ("forward-mode") automatic differentiation. The `ForwardDiff.jl` method contains a sophisticated implementation of this method.

In [59]:
derivative(f, a) = f(Dual(a, 1.0)).deriv

derivative (generic function with 1 method)

In [60]:
derivative(x->x*x + x, 3.0)

7.0

#### Exercise 5

1. Define a `LogProbability` type, such that `LogProbability(p)` represents the value $p$ by storing $\log(p)$.



2. Define `*` on objects of this type using the [corresponding mathematical definitions](https://en.wikipedia.org/wiki/Log_probability). Check that the function `prod` gives reasonable results for a collection of these objects.


3. Define `+` and check that `sum` gives reasonable results. Use the `log1p` function.

In [64]:
struct LogProbability
    log_p::Float64
end

If $l_1 = \log(p_1)$ and $l_2 = \log(p_2)$ then $l_1 * l_2$ should correspond to $\log(p_1 p_2) = \log(p_1) + \log(p_2)$

In [68]:
*(l1::LogProbability, l2::LogProbability) = LogProbability(l1.log_p + l2.log_p)

* (generic function with 351 methods)

In [71]:
rs = rand(5)

5-element Array{Float64,1}:
 0.3937361031791753 
 0.9157177179507572 
 0.3485758618560504 
 0.49560494465717264
 0.6132743479640603 

In [74]:
log.(rs)

5-element Array{Float64,1}:
 -0.9320743829783306 
 -0.08804712997760601
 -1.0538993913915926 
 -0.7019761521639816 
 -0.488942893469747  

In [75]:
lr = LogProbability.(log.(rs))

5-element Array{LogProbability,1}:
 LogProbability(-0.9320743829783306) 
 LogProbability(-0.08804712997760601)
 LogProbability(-1.0538993913915926) 
 LogProbability(-0.7019761521639816) 
 LogProbability(-0.488942893469747)  

In [85]:
rs

5-element Array{Float64,1}:
 0.3937361031791753 
 0.9157177179507572 
 0.3485758618560504 
 0.49560494465717264
 0.6132743479640603 

In [87]:
@which sum(rs)

prod(lr)

In [78]:
log(prod(rs))

-3.2649399499812577

In [80]:
@which prod(rs)

In [81]:
@which prod(lr)

In [66]:
p = 0.5
l = LogProbability(log(p))

LogProbability(-0.6931471805599453)

In [82]:
@code_warntype prod(lr)

Body[36m::LogProbability[39m
[90m1 ──[39m %1  = Base.identity[36m::typeof(identity)[39m
[90m│   [39m %2  = Base.mul_prod[36m::typeof(Base.mul_prod)[39m
[90m│   [39m %3  = (Base.arraysize)(a, 1)[36m::Int64[39m
[90m│   [39m %4  = (Base.slt_int)(%3, 0)[36m::Bool[39m
[90m│   [39m %5  = (Base.ifelse)(%4, 0, %3)[36m::Int64[39m
[90m│   [39m %6  = (Base.sub_int)(%5, 0)[36m::Int64[39m
[90m│   [39m %7  = (%6 === 0)[36m::Bool[39m
[90m└───[39m       goto #3 if not %7
[90m2 ──[39m       invoke Base.mapreduce_empty(%1::typeof(identity), %2::Function, LogProbability::Type)
[90m└───[39m       $(Expr(:unreachable))
[90m3 ┄─[39m %11 = (%6 === 1)[36m::Bool[39m
[90m└───[39m       goto #5 if not %11
[90m4 ──[39m %13 = (Base.arrayref)(false, a, 1)[36m::LogProbability[39m
[90m└───[39m       goto #11
[90m5 ──[39m %15 = (Base.slt_int)(%6, 16)[36m::Bool[39m
[90m└───[39m       goto #10 if not %15
[90m6 ──[39m %17 = (Base.arrayref)(false, a, 1)[36m::LogProb

In [84]:
@code_llvm prod(lr)


;  @ reducedim.jl:648 within `prod'
define { double } @julia_prod_13590(%jl_value_t addrspace(10)* nonnull align 16 dereferenceable(40)) {
top:
  %1 = alloca %jl_value_t addrspace(10)*, i32 4
; ┌ @ reducedim.jl:648 within `#prod#552'
; │┌ @ reducedim.jl:652 within `_prod' @ reducedim.jl:653
; ││┌ @ reducedim.jl:304 within `mapreduce'
; │││┌ @ reducedim.jl:304 within `#mapreduce#548'
; ││││┌ @ reducedim.jl:308 within `_mapreduce_dim'
; │││││┌ @ reduce.jl:302 within `_mapreduce'
; ││││││┌ @ indices.jl:426 within `Type'
; │││││││┌ @ abstractarray.jl:75 within `axes'
; ││││││││┌ @ array.jl:155 within `size'
           %2 = addrspacecast %jl_value_t addrspace(10)* %0 to %jl_value_t addrspace(11)*
           %3 = bitcast %jl_value_t addrspace(11)* %2 to %jl_value_t addrspace(10)* addrspace(11)*
           %4 = getelementptr inbounds %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)* addrspace(11)* %3, i64 3
           %5 = bitcast %jl_value_t addrspace(10)* addrspace(11)* %4 to i64 addr

In [83]:
@code_native prod(lr)

	.section	__TEXT,__text,regular,pure_instructions
; ┌ @ reducedim.jl:648 within `prod'
; │┌ @ reducedim.jl:648 within `#prod#552'
; ││┌ @ reducedim.jl:652 within `_prod' @ reducedim.jl:653
; │││┌ @ reducedim.jl:304 within `mapreduce'
; ││││┌ @ reducedim.jl:304 within `#mapreduce#548'
; │││││┌ @ reducedim.jl:308 within `_mapreduce_dim'
; ││││││┌ @ reduce.jl:302 within `_mapreduce'
; │││││││┌ @ indices.jl:426 within `Type'
; ││││││││┌ @ abstractarray.jl:75 within `axes'
; │││││││││┌ @ reducedim.jl:648 within `size'
	decl	%eax
	subl	$40, %esp
	decl	%eax
	movl	24(%edi), %edx
; ││││││││└└
; ││││││││┌ @ promotion.jl:414 within `axes'
	decl	%eax
	testl	%edx, %edx
; │││││││└└
; │││││││ @ reduce.jl:304 within `_mapreduce'
	jle	L115
; │││││││ @ reduce.jl:306 within `_mapreduce'
; │││││││┌ @ promotion.jl:403 within `=='
	decl	%eax
	cmpl	$1, %edx
; │││││││└
	jne	L31
; │││││││ @ reduce.jl:307 within `_mapreduce'
; │││││││┌ @ array.jl:729 within `getindex'
	decl	%eax
	movl	(%edi), %eax
	vmovsd	(%eax

In [62]:
log1p

log1p (generic function with 8 methods)

In [63]:
?log1p

search: [0m[1ml[22m[0m[1mo[22m[0m[1mg[22m[0m[1m1[22m[0m[1mp[22m [0m[1ml[22m[0m[1mo[22m[0m[1mg[22m[0m[1m1[22m0



```
log1p(x)
```

Accurate natural logarithm of `1+x`. Throws [`DomainError`](@ref) for [`Real`](@ref) arguments less than -1.

# Examples

```jldoctest; filter = r"Stacktrace:(\n \[[0-9]+\].*)*"
julia> log1p(-0.5)
-0.6931471805599453

julia> log1p(0)
0.0

julia> log1p(-2)
ERROR: DomainError with -2.0:
log1p will only return a complex result if called with a complex argument. Try log1p(Complex(x)).
Stacktrace:
 [1] throw_complex_domainerror(::Symbol, ::Float64) at ./math.jl:31
[...]
```


### Parametric types

For simplicity, in the above we fixed the fields in the `Dual` type to be of type `Float64`. By doing so we are actually *losing power*. Instead we should let Julia "fill in" the types. 

To do so, we specify that we want to use a **type parameter** `T`. We can think of this as a "special kind of variable" that can only take on certain kinds of values. We specify this with the following syntax: 

In [80]:
struct MyType2{T}
    a::T
    b::T
end

[Note that we have not reused the name `MyType` since Julia *does not allow types to be redefined in a different way*.]

Here we are specifying that both fields `a` and `b` must share the same type `T`, but we have not restricted what values `T` can take. When we create an object, Julia will *infer* (work out) the type:

In [81]:
x = MyType2(3, 4)

MyType2{Int64}(3, 4)

In [16]:
y = MyType2(3.1, 4.2)

MyType2{Float64}(3.1, 4.2)

In [82]:
z = MyType2(1, 5.3)

MethodError: MethodError: no method matching MyType2(::Int64, ::Float64)
Closest candidates are:
  MyType2(::T, !Matched::T) where T at In[80]:2

In [83]:
struct MyType3{S,T}
    a::S
    b::T
end

In [85]:
x = MyType3(1, 5.3)

MyType3{Int64,Float64}(1, 5.3)

Note that `x` and `y` have *different* types.

We can define functions acting on parametric types without necessarily talking about the type parameter:

#### Exercise 5

1. Define a function that takes an object of type `MyType2`, *without* mentioning the type parameter, and returns the sum of the two fields.

   What happens when you apply this function to `x` and `y`?
   
   
2. Define a type `Dual2` with a type parameter `T` and the same functions `+` and `*` as before.


3. Define the function `f(x) = x * x + x`. What happens if you pass in `Dual` numbers with different type parameters?