# Lecture 2: Type Stability

## Defeating type inference: Type instabilities

To get good performance, there are some fairly simple rules that you need to follow in Julia code to avoid defeating the compiler's type inference.   See also the [performance tips section of the Julia manual](http://docs.julialang.org/en/stable/manual/performance-tips/).

Three of the most important are:

* Don't use (non-constant) global variables in critical code — put your critical code into a function (this is good advice anyway, from a software-engineering standpoint).  The compiler assumes that a **global variable can change type at any time**, so it is always stored in a "box", and "taints" anything that depends on it.

* Local variables should be "type-stable": **don't change the type of a variable inside a function**.  Use a new variable instead.

* Functions should be "type-stable": **a function's return type should only depend on the argument types, not on the argument values**.

To diagnose all of these problems, the `@code_warntype` macro that we used above is your friend.  If it labels any variables (or the function's return value) as `Any` or `Union{...}`, it means that the compiler couldn't figure out a precise type.

The third point, type-stability of functions, leads to lots of important but subtle choices in library design.  For example, consider the (built-in) `sqrt(x)` function, which computes $\sqrt{x}$:

In [3]:
sqrt(4)

2.0

You might think that `sqrt(-1)` should return $i$ (or `im`, in Julia syntax).  (Matlab's `sqrt` function does this.)  Instead, we get:

In [4]:
sqrt(-1)

LoadError: DomainError:
sqrt will only return a complex result if called with a complex argument. Try sqrt(complex(x)).

In [5]:
sqrt(-1 + 0im)

0.0 + 1.0im

Why did Julia implement `sqrt` in this silly way, throwing an error for negative arguments unless you add a zero imaginary part?  Any reasonable person wants an imaginary result from `sqrt(-1)`, surely?

The problem is that defining `sqrt` to return an imaginary result from `sqrt(-1)` would **not be type stable**: `sqrt(x)` would return a real result for non-negative real `x`, and a complex result for negative real `x`, so the **return type would depend on the value of `x`** and **not just its type.**

That would defeat type inference, not just for the `sqrt` function, but for **anything the sqrt function touches**.  Unless the compiler can somehow figure out `x ≥ 0`, it will have to either store the result in a "box" or compile two branches of the result.  Let's see how that works by defining our own square-root function:

In [6]:
mysqrt(x::Complex) = sqrt(x)
mysqrt(x::Real) = x < 0 ? sqrt(complex(x)) : sqrt(x)

mysqrt (generic function with 2 methods)

This definition is an example of Julia's [multiple dispatch style](http://docs.julialang.org/en/stable/manual/methods/), which in some sense is a generalization of object-oriented programming but focuses on "verbs" (functions) rather than nouns.  We will discuss this more in a later lecture.

The `::Complex` and `::Real` are argument-type declarations.  Such declarations are **not related to performance**, but instead **act as a "filter"** to allow us to have one version of `mysqrt` for complex arguments and another for real arguments.

In [9]:
mysqrt(2)

1.4142135623730951

In [10]:
mysqrt(-2)

0.0 + 1.4142135623730951im

In [11]:
mysqrt(-2+0im)

0.0 + 1.4142135623730951im

Looks great, right?  But let's see what happens to type inference in a function that calls `mysqrt` instead of `sqrt`:

In [12]:
slowfun(x) = mysqrt(x) + 1
@code_warntype slowfun(2)

Variables:
  #self#::#slowfun
  x::Int64
  #temp#@_3[1m[91m::Union{Complex{Float64}, Float64}[39m[22m
  #temp#@_4::Core.MethodInstance
  #temp#@_5[1m[91m::Union{Complex{Float64}, Float64}[39m[22m

Body:
  begin 
      $(Expr(:inbounds, false))
      # meta: location In[6] mysqrt 2
      unless (Base.slt_int)(x::Int64, 0)::Bool goto 6
      #temp#@_3[1m[91m::Union{Complex{Float64}, Float64}[39m[22m = $(Expr(:invoke, MethodInstance for sqrt(::Complex{Float64}), :(Base.sqrt), :($(Expr(:new, Complex{Float64}, :((Base.sitofp)(Float64, x)::Float64), :((Base.sitofp)(Float64, 0)::Float64))))))
      goto 8
      6: 
      #temp#@_3[1m[91m::Union{Complex{Float64}, Float64}[39m[22m = (Base.Math.sqrt_llvm)((Base.sitofp)(Float64, x::Int64)::Float64)::Float64
      8: 
      # meta: pop location
      $(Expr(:inbounds, :pop))
      unless (#temp#@_3[1m[91m::Union{Complex{Float64}, Float64}[39m[22m isa Complex{Float64})::Bool goto 14
      #temp#@_4::Core.MethodInstance = MethodI

Because the compiler **doesn't know at compile-time that x is positive** (at compile-time it **uses only types, not values**, it doesn't know whether the result is real (`Float64`) or complex (`Complex{Float64}`) and has to store it in a "box".  This kills performance.

In [13]:
i = 3

3

In [17]:
typeof(i)

Int64

In [16]:
i^100000

-3665183052406099839

In [18]:
i = typemax(Int)

9223372036854775807

In [19]:
i + 1

-9223372036854775808

## Defining our own types

Let's define our own type to represent a **"point" in two dimensions**.  Each point will have an $(x,y)$ location.  So that we can use the points with our `sum` functions above, we'll also define `+` and `zero` functions to do the obvious **vector addition**.

One such definition in Julia is:

In [20]:
mutable struct Point1
    x
    y
end
Base.:+(p::Point1, q::Point1) = Point1(p.x + q.x, p.y + q.y)
Base.zero(::Type{Point1}) = Point1(0,0)

Point1(3,4)

Point1(3, 4)

In [21]:
Point1(3,4) + Point1(5,6)

Point1(8, 10)

Our type is very generic, and can hold any type of `x` and `y` values:

In [22]:
Point1(3.7, 4+5im)

Point1(3.7, 4 + 5im)

Perhaps too generic:

In [23]:
Point1("x", [3,4,5])

Point1("x", [3, 4, 5])

Since `x` and `y` can be *anything*, they must be **pointers to "boxes"**.  This is **bad news for performance**.

A `mutable struct` is *mutable*, which means we can create a `Point1` object and then change `x` or `y`:

In [47]:
p = Point1(3,4)

Point1(3, 4)

In [48]:
p.x = 7
p

Point1(7, 4)

In [49]:
q = p

Point1(7, 4)

This means that every reference to a `Point1` object must be a *pointer* to an object stored elsewhere in memory, because *how else would we "know" when an object changes?*  Furthermore, an **array of `Point1` objects must be an array of pointers** (which is **bad news for performance** again):

In [27]:
P = [p,p,p]

3-element Array{Point1,1}:
 Point1(7, 4)
 Point1(7, 4)
 Point1(7, 4)

In [30]:
P

3-element Array{Point1,1}:
 Point1(7, 8)
 Point1(7, 8)
 Point1(7, 8)

In [33]:
Q = [deepcopy(p), deepcopy(p)]

2-element Array{Point1,1}:
 Point1(7, 8)
 Point1(7, 8)

In [34]:
p.x = -3
P

3-element Array{Point1,1}:
 Point1(-3, 8)
 Point1(-3, 8)
 Point1(-3, 8)

In [35]:
Q

2-element Array{Point1,1}:
 Point1(7, 8)
 Point1(7, 8)

Let's test this out by creating an array of `Point1` objects and summing it.  Ideally, this would be about twice as slow as summing an equal-length array of numbers, since there are twice as many numbers to sum.  But because of all of the boxes and pointer-chasing, it should be far slower.

To create the array, we'll call the `Point1(x,y)` constructor with our array `a`, using Julia's ["dot-call" syntax](http://docs.julialang.org/en/stable/manual/functions/#dot-syntax-for-vectorizing-functions) that applies a function "element-wise" to arrays:

In [36]:
a = rand(10^7)
a1 = Point1.(a, a)

10000000-element Array{Point1,1}:
 Point1(0.45576, 0.45576)    
 Point1(0.199227, 0.199227)  
 Point1(0.664066, 0.664066)  
 Point1(0.106669, 0.106669)  
 Point1(0.108969, 0.108969)  
 Point1(0.146101, 0.146101)  
 Point1(0.066834, 0.066834)  
 Point1(0.933081, 0.933081)  
 Point1(0.901608, 0.901608)  
 Point1(0.680544, 0.680544)  
 Point1(0.891394, 0.891394)  
 Point1(0.128344, 0.128344)  
 Point1(0.354404, 0.354404)  
 ⋮                           
 Point1(0.838992, 0.838992)  
 Point1(0.940758, 0.940758)  
 Point1(0.416696, 0.416696)  
 Point1(0.17713, 0.17713)    
 Point1(0.493455, 0.493455)  
 Point1(0.44981, 0.44981)    
 Point1(0.0298092, 0.0298092)
 Point1(0.894299, 0.894299)  
 Point1(0.270014, 0.270014)  
 Point1(0.32509, 0.32509)    
 Point1(0.191285, 0.191285)  
 Point1(0.643986, 0.643986)  

In [37]:
using BenchmarkTools, Compat

In [38]:
@btime sum($a1)

  517.472 ms (29999997 allocations: 610.35 MiB)


Point1(5.000915878780597e6, 5.000915878780597e6)

The time is at least **50× slower** than we would like, but consistent with our other timing results on "boxed" values from last lecture.

### An imperfect solution: A concrete immutable type

We can avoid these two problems by:

* Declare the types of `x` and `y` to be *concrete* types, so that they don't need to be pointers to boxes.
* Declare our Point to be an [immutable](https://en.wikipedia.org/wiki/Immutable_object) type (`x` and `y` cannot change), so that Julia is not forced to make every reference to a Point into a pointer: just `struct`, not `mutable struct`:

In [39]:
struct Point2
    x::Float64
    y::Float64
end
Base.:+(p::Point2, q::Point2) = Point2(p.x + q.x, p.y + q.y)
Base.zero(::Type{Point2}) = Point2(0.0,0.0)

Point2(3,4)

Point2(3.0, 4.0)

In [40]:
Point2(3,4) + Point2(5,6)

Point2(8.0, 10.0)

In [52]:
p = Point2(3,4)
P = [p,p,p]

3-element Array{Point2,1}:
 Point2(3.0, 4.0)
 Point2(3.0, 4.0)
 Point2(3.0, 4.0)

In [53]:
p.x = 6 # gives an error since p is immutable

LoadError: [91mtype Point2 is immutable[39m

In [54]:
sizeof(P) / length(P)

16.0

If this is working as we hope, then summation should be much faster:

In [55]:
a2 = Point2.(a,a)
@btime sum($a2)

  11.792 ms (0 allocations: 0 bytes)


Point2(5.000915878780597e6, 5.000915878780597e6)

Now the time is **only about 10ms**, only slightly more than twice the cost of summing an array of individual numbers of the same length!

Unfortunately, we paid a big price for this performance: our `Point2` type only works with *a single numeric type* (`Float64`), much like a C implementation.

### The best of both worlds: Parameterized immutable types

How do we get a `Point` type that works for *any* type of `x` and `y`, but at the same time allows us to have an array of points that is concrete and homogeneous (every point in the array is forced to be the same type)?  At first glance, this seems like a contradiction in terms.

The answer is not to define a *single* type, but rather to **define a whole family of types** that are *parameterized* by the type `T` of `x` and `y`.  In computer science, this is known as [parametric polymorphism](https://en.wikipedia.org/wiki/Parametric_polymorphism).  (An example of this can be found in [C++ templates](https://en.wikipedia.org/wiki/Template_%28C%2B%2B%29).)

In Julia, we will define such a family of types as follows:

In [56]:
struct Point3{T<:Real}
    x::T
    y::T
end
Base.:+(p::Point3, q::Point3) = Point3(p.x + q.x, p.y + q.y)
Base.zero{T}(::Type{Point3{T}}) = Point3(zero(T),zero(T))

Point3(3,4)

Point3{Int64}(3, 4)

Here, `Point3` is actually a family of subtypes `Point{T}` for different types `T`.   The notation `<:` in Julia means "is a subtype of", and hence `T<:Real` means that we are constraining `T` to be a `Real` type (a built-in *abstract type* in Julia that includes e.g. integers or floating point).

In [57]:
Point3(3,4) + Point3(5.6, 7.8)

Point3{Float64}(8.6, 11.8)

Now, let's make an array:

In [58]:
a3 = Point3.(a,a)

10000000-element Array{Point3{Float64},1}:
 Point3{Float64}(0.45576, 0.45576)    
 Point3{Float64}(0.199227, 0.199227)  
 Point3{Float64}(0.664066, 0.664066)  
 Point3{Float64}(0.106669, 0.106669)  
 Point3{Float64}(0.108969, 0.108969)  
 Point3{Float64}(0.146101, 0.146101)  
 Point3{Float64}(0.066834, 0.066834)  
 Point3{Float64}(0.933081, 0.933081)  
 Point3{Float64}(0.901608, 0.901608)  
 Point3{Float64}(0.680544, 0.680544)  
 Point3{Float64}(0.891394, 0.891394)  
 Point3{Float64}(0.128344, 0.128344)  
 Point3{Float64}(0.354404, 0.354404)  
 ⋮                                    
 Point3{Float64}(0.838992, 0.838992)  
 Point3{Float64}(0.940758, 0.940758)  
 Point3{Float64}(0.416696, 0.416696)  
 Point3{Float64}(0.17713, 0.17713)    
 Point3{Float64}(0.493455, 0.493455)  
 Point3{Float64}(0.44981, 0.44981)    
 Point3{Float64}(0.0298092, 0.0298092)
 Point3{Float64}(0.894299, 0.894299)  
 Point3{Float64}(0.270014, 0.270014)  
 Point3{Float64}(0.32509, 0.32509)    
 Point3{Float64}(0.19

In [59]:
typeof(a3)

Array{Point3{Float64},1}

Note that the type of this array is `Array{Point3{Float64},1}` (we could equivalently write this as `Vector{Point3{Float64}}`, since `Vector{T}` is a synonym for `Array{T,1}`).  You should learn a few things from this:

* An `Array{T,N}` in Julia is itself a parameterized type, parameterized by the element type `T` and the dimensionality `N`.

* Since the element type `T` is encoded in the `Array{T,N}` type, the element type does not need to be stored in each element.  That means that the `Array` is free to store an array of "inlined" elements, rather than an array of pointers to boxes.  (This is why `Array{Float64,1}` earlier could be stored in memory like a C `double*`.

* It is still important that the element type be `immutable`, since an array of mutable elements would still need to be an array of pointers (so that it could "notice" if another reference to an element mutates it).

In [61]:
@btime sum($a3)

  11.805 ms (0 allocations: 0 bytes)


Point3{Float64}(5.000915878780597e6, 5.000915878780597e6)

Hooray! It is again **only about 10ms**, the same time as our completely concrete and inflexible `Point2`.

In [62]:
p = Point3{Int8}(3,5)

Point3{Int8}(3, 5)

In [64]:
p = Point3{Integer}(3,Int8(5))

Point3{Integer}(3, 5)

In [65]:
typeof(p.x)

Int64

In [66]:
typeof(p.y)

Int8

In [67]:
3 + 7im

3 + 7im

In [68]:
typeof(3 + 7im)

Complex{Int64}

In [69]:
@which sqrt(3 + 4im)

In [70]:
supertype(Point3{Int8})

Any

## Another type of "type instability"

In [71]:
function mysum_slow(a)
    s = 0
    for x in a
        s += x
    end
    return s
end

mysum_slow (generic function with 1 method)

In [72]:
mysum_slow([3,4,5])

12

In [73]:
mysum_slow([3.3,4,5])

12.3

In [74]:
@code_warntype(mysum_slow([3.3,4,5]))

Variables:
  #self#::#mysum_slow
  a::Array{Float64,1}
  x::Float64
  #temp#@_4::Int64
  s[1m[91m::Union{Float64, Int64}[39m[22m
  #temp#@_6::Core.MethodInstance
  #temp#@_7::Float64

Body:
  begin 
      s[1m[91m::Union{Float64, Int64}[39m[22m = 0 # line 3:
      #temp#@_4::Int64 = 1
      4: 
      unless (Base.not_int)((#temp#@_4::Int64 === (Base.add_int)((Base.arraylen)(a::Array{Float64,1})::Int64, 1)::Int64)::Bool)::Bool goto 29
      SSAValue(2) = (Base.arrayref)(a::Array{Float64,1}, #temp#@_4::Int64)::Float64
      SSAValue(3) = (Base.add_int)(#temp#@_4::Int64, 1)::Int64
      x::Float64 = SSAValue(2)
      #temp#@_4::Int64 = SSAValue(3) # line 4:
      unless (s[1m[91m::Union{Float64, Int64}[39m[22m isa Int64)::Bool goto 14
      #temp#@_6::Core.MethodInstance = MethodInstance for +(::Int64, ::Float64)
      goto 23
      14: 
      unless (s[1m[91m::Union{Float64, Int64}[39m[22m isa Float64)::Bool goto 18
      #temp#@_6::Core.MethodInstance = MethodInstance fo

In [75]:
function mysum_fast(a)
    s = zero(eltype(a))
    for x in a
        s += x
    end
    return s
end

mysum_fast (generic function with 1 method)

In [76]:
@code_warntype mysum_fast([3.3,4,5])

Variables:
  #self#::#mysum_fast
  a::Array{Float64,1}
  x::Float64
  #temp#::Int64
  s::Float64

Body:
  begin 
      s::Float64 = (Base.sitofp)(Float64, 0)::Float64 # line 3:
      #temp#::Int64 = 1
      4: 
      unless (Base.not_int)((#temp#::Int64 === (Base.add_int)((Base.arraylen)(a::Array{Float64,1})::Int64, 1)::Int64)::Bool)::Bool goto 14
      SSAValue(2) = (Base.arrayref)(a::Array{Float64,1}, #temp#::Int64)::Float64
      SSAValue(3) = (Base.add_int)(#temp#::Int64, 1)::Int64
      x::Float64 = SSAValue(2)
      #temp#::Int64 = SSAValue(3) # line 4:
      s::Float64 = (Base.add_float)(s::Float64, x::Float64)::Float64
      12: 
      goto 4
      14:  # line 6:
      return s::Float64
  end::Float64


In [77]:
mysum_fast([rand(3,3) for i = 1:10])

LoadError: [91mMethodError: no method matching zero(::Type{Array{Float64,2}})[0m
Closest candidates are:
  zero([91m::Type{Base.LibGit2.GitHash}[39m) at libgit2/oid.jl:106
  zero([91m::Type{Base.Pkg.Resolve.VersionWeights.VWPreBuildItem}[39m) at pkg/resolve/versionweight.jl:82
  zero([91m::Type{Base.Pkg.Resolve.VersionWeights.VWPreBuild}[39m) at pkg/resolve/versionweight.jl:124
  ...[39m

In [78]:
sum([rand(3,3) for i = 1:10])

3×3 Array{Float64,2}:
 5.11819  4.99643  3.24856
 5.22759  5.39936  5.74627
 6.72741  4.78524  4.62801

In [79]:
sum(Matrix{Float64}[])

LoadError: [91mMethodError: no method matching zero(::Type{Array{Float64,2}})[0m
Closest candidates are:
  zero([91m::Type{Base.LibGit2.GitHash}[39m) at libgit2/oid.jl:106
  zero([91m::Type{Base.Pkg.Resolve.VersionWeights.VWPreBuildItem}[39m) at pkg/resolve/versionweight.jl:82
  zero([91m::Type{Base.Pkg.Resolve.VersionWeights.VWPreBuild}[39m) at pkg/resolve/versionweight.jl:124
  ...[39m