# Julia's type system

There are five important features of Julia's type system to understand:
1. Interfaces & abstract types
2. Subtyping
3. Parametric types
4. Multiple dispatch
5. Type stability


## Type operator & function reference

| Operator / Function | Name                         | Purpose                                                                |
|-----------|-------------------------------|-----------------------------------------------------------------------------------|
| `::`      | **Type assertion**            | Asserts or annotates the type of a variable or function argument/return value     |
| `<:`      | **Subtype**                   | Used to express "is a subtype of"                                                 |
| `>:`      | **Supertype**                 | Used to express "is a supertype of"                                               |
| `isa`     | **Type check**                | Checks if a value *is an instance* of a type                                      |
| `typeof`  | **Type query**                | Returns the type of a value                                                       |
| `where`   | **Type parameter constraint** | Constrains a parametric type to satisfy a type assertion                          |
| `methods` | **Type-based function lookup**| Looks up all functions matching a name and a specified type signature             |


## Interfaces & abstract types

Reference: https://docs.julialang.org/en/v1/manual/interfaces/

Some languages have support for *interfaces*, which are collections of
properties and functions that must be implemented by any type adhering to an
interface.

Julia does not actually support a formal interface system, but people like to
pretend it does.

*Abstract types* are a close cousin to interfaces and are used to give some
legitimacy to the notion of interfaces, but they are not the same thing.
Some examples of abstract types:
- `Number`: The abstract supertype for all numeric types.
- `Real`: The abstract supertype for real-valued numeric types.
- `AbstractArray`: The abstract supertype for array-like data structures.
- `AbstractString`: The abstract supertype for string-like types.
- `AbstractDict`: The abstract supertype for dictionary-like collections.
- `IO`: The abstract supertype for I/O streams.
- `Function`: The abstract supertype for all function objects.

Functions can take an abstract typed argument, and any (appropriately subtyped)
type implementing the interface associated with that abstract type can then be
passed to the function.
For example, an `Int` should probably adhere to the interface associated with
the `Number` abstract type, so basic arithmetic ought to work for `Int`s
supplied to a function that accepts `Number`s and performs arithmetic
operations.

In [None]:
function operate(x::Number)
  x * x + x
end

In [None]:
typeof(-7)

In [None]:
operate(-7)

## Subtyping

In [None]:
# Define a Zero type that is a subtype of Number.
struct Zero <: Number end

# Define basic arithmetic operations for Zero.
Base.:+(::Zero, ::Zero) = Zero()
Base.:+(x::Number, ::Zero) = x
Base.:+(::Zero, x::Number) = x

Base.:-(::Zero, ::Zero) = Zero()
Base.:-(x::Number, ::Zero) = x
Base.:-(::Zero, x::Number) = -x

Base.:*(::Zero, ::Number) = Zero()
Base.:*(::Number, ::Zero) = Zero()

Base.:/(::Zero, x::Number) = Zero()
Base.:/(::Number, ::Zero) = error("Division by zero.")

# Show method for printing.
Base.show(io::IO, ::Zero) = print(io, "zero")

In [None]:
# Create an instance of Zero.
zero = Zero()
println("zero + zero = ", zero + zero)
println("zero - 5 = ", zero - 5)
println("5 * zero = ", 5 * zero)
println("5 / zero = ", 5 / zero)

## Parametric types

Types can be created for "container" data (e.g. vectors, matrices) with
constraints on the types of contained data.
Constraints can be refined further in particular methods that accept or return
values of those container types.

In [None]:
struct SparseMatrix{T<:Number} <: AbstractMatrix{T}
  size::Tuple{Int, Int}
  indices::Vector{Tuple{Int, Int}}
  values::Vector{T}
end

# Constructor with empty matrix.
function SparseMatrix{T}(m::Int, n::Int) where T <: Number
  SparseMatrix{T}((m, n), Tuple{Int, Int}[], T[])
end

# Get value at specific index.
function Base.getindex(A::SparseMatrix{T}, i::Int, j::Int) where T <: Number
  if i < 1 || i > A.size[1] || j < 1 || j > A.size[2]
    throw(BoundsError(A, (i, j)))
  end
  
  for (idx, (row, col)) in enumerate(A.indices)
    if row == i && col == j
      return A.values[idx]
    end
  end
  
  return 0 # Return zero for non-stored elements.
end

# Set value at specific index.
function Base.setindex!(A::SparseMatrix{T}, v, i::Int, j::Int) where T <: Number
  if i < 1 || i > A.size[1] || j < 1 || j > A.size[2]
    throw(BoundsError(A, (i, j)))
  end
  
  # Check if index already exists.
  for (idx, (row, col)) in enumerate(A.indices)
    if row == i && col == j
      if v == 0
        # Remove the element if setting to zero.
        deleteat!(A.indices, idx)
        deleteat!(A.values, idx)
      else
        # Update existing value.
        A.values[idx] = v
      end
      return A
    end
  end
  
  # If we're here, the index doesn't exist yet.
  if v != 0
    push!(A.indices, (i, j))
    push!(A.values, v)
  end
  
  return A
end

# Matrix addition.
function Base.:+(A::SparseMatrix{T}, B::SparseMatrix{T}) where T <: Number
  if A.size != B.size
    throw(DimensionMismatch("Matrix dimensions must match."))
  end
  
  result = SparseMatrix{T}(A.size[1], A.size[2])
  
  # Add all elements from A.
  for (idx, (i, j)) in enumerate(A.indices)
    result[i, j] += A.values[idx]
  end
  
  # Add all elements from B.
  for (idx, (i, j)) in enumerate(B.indices)
    result[i, j] += B.values[idx]
  end
  
  return result
end

# Matrix-matrix multiplication.
function Base.:*(A::SparseMatrix{T}, B::SparseMatrix{S}) where {T <: Number, S <: Number}
  if A.size[2] != B.size[1]
    throw(DimensionMismatch("Inner matrix dimensions must match."))
  end
  
  R = promote_type(T, S)
  result = SparseMatrix{R}(A.size[1], B.size[2])
  
  # For each non-zero element in A.
  for (idx_a, (i, k)) in enumerate(A.indices)
    # For each non-zero element in B where the row matches k.
    for (idx_b, (k_b, j)) in enumerate(B.indices)
      if k == k_b
        # Multiply and add to result.
        result[i, j] += A.values[idx_a] * B.values[idx_b]
      end
    end
  end
  
  return result
end

# Matrix-scalar multiplication.
function Base.:*(A::SparseMatrix{T}, scalar::Number) where T <: Number
  R = promote_type(T, typeof(scalar))
  result = SparseMatrix{R}(A.size[1], A.size[2])
  
  for (idx, (i, j)) in enumerate(A.indices)
    result[i, j] = A.values[idx] * scalar
  end
  
  return result
end

# Scalar-matrix multiplication.
function Base.:*(scalar::Number, A::SparseMatrix{T}) where T <: Number
  return A * scalar  # Reuse the matrix-scalar multiplication.
end

# Show method for printing.
function Base.show(io::IO, A::SparseMatrix{T}) where T <: Number
  println(io, "$(A.size[1])×$(A.size[2]) SparseMatrix{$T} with $(length(A.indices)) stored entries:")
  for (idx, (i, j)) in enumerate(A.indices)
    println(io, "  [$i, $j] = $(A.values[idx])")
  end
end

Note: We could allow multiplication/addition of `SparseMatrix`es with
`Matrix`es, but this would nearly triple the amount of code we need to write.

Because multiple dispatch follows the anaemic domain model (see later), we can
publish our `SparseMatrix` implementation as a package and someone else can
implement `SparseMatrix`-`Matrix` algebra externally without modifying our code.

In [None]:
# Create two sparse matrices and multiply them.
# First, create a 50x7 sparse matrix.
A = SparseMatrix{Int}(50, 7)
# Add some non-zero entries.
for i in 1:10
    A[rand(1:50), rand(1:7)] = rand(1:10)
end

# Create a 7x50 sparse matrix.
B = SparseMatrix{Int}(7, 50)
# Add some non-zero entries.
for i in 1:10
    B[rand(1:7), rand(1:50)] = rand(1:10)
end

# Multiply the matrices.
C = A * B
println("Result is a $(C.size[1])×$(C.size[2]) matrix with $(length(C.indices)) non-zero entries.")

In [None]:
println("A: ", A)
println("B: ", B)

In [None]:
println("C: ", C)

## Multiple dispatch

Multiple dispatch is the function routing paradigm in Julia, whereby invoked
function implementations are selected based on the concrete types of *all*
arguments **at runtime**.
This differs from:

1. **Single dispatch** (Python/Java): Only considers the type of the
   first/owning object.
2. **Function overloading** (C++/Java): Resolves at compile-time using static
   types.
3. **Duck typing** (MATLAB/Python): Relies on runtime checks for method
   existence, incurring a runtime performance penalty.

### Advantages over other languages

**Over MATLAB/Python**

Eliminates `isa`/`isinstance` checks and `switch`/`nargin` patterns.
You can thus write type-stable (more on this later) code that's both generic and
fully optimized by the compiler:
```julia
# Single method handling both scalars and arrays.
function scale(x::Union{Number, AbstractArray}, factor::Number)
    x * factor
end
```

**Over C++/Java/Rust**

Enables dynamic polymorphism without class hierarchies or trait objects --
multiple dispatch adheres to the *anaemic domain model*.
You can add functionality (e.g. "methods") to existing types without
modification:
```julia
# Extend base function for custom type.
Base.sqrt(x::MyCustomNumber) = my_sqrt_impl(x)
```

**Over Haskell/Rust**

Permits working with concrete types rather than type classes/traits to generate
specialized machine code automatically:
```julia
# The compiler generates optimized versions for each type combination.
add(x::Int, y::Int) = x + y
add(x::Float64, y::Float64) = x + y
```

**High-level abstractions & fine control over the machine**

The JIT compiler generates specialized machine code for each type combination,
which enables:
- Zero-cost abstractions,
- Automatic SIMD vectorization (or guide with `@inbounds` and `@simd`), and
- Type-stable code paths.

### An example

Consider the addition operator:
- **Python/MATLAB**: The implementation of `a + b` to dispatch is determined by
  the left operand's type.
- **Julia**: `+(a, b)` dispatches on both types, correctly handling:
  - Mixed precision via promotion by default (Int + Float64),
  - Dimensional quantities (3m + 5cm), and
  - Distributed arrays + scalars via broadcasting.


### Where multiple dispatch can be tricky

We need to take care to prevent ambiguous argument type matching when using
multiple dispatch.

Because Julia tries to use the most specific/concrete type in each case
(subtypes preferred over supertypes), two definitions of the same function may
have incomparable specificities among argument types.
This happens whenever there is not a single dominant compatible definition of
the function with respect to each argument type supplied.

In [None]:
function addVectors(
  a::AbstractVector{T},
  b::Vector{T}
)::Vector{T} where T<:Number
  # Add the vectors component-wise.
  return a .+ b
end

# Uncomment to give Julia a headache.
# function addVectors(
#   a::Vector{T},
#   b::AbstractVector{T}
# )::Vector{T} where T<:Number
#   addVectors(b, a)
# end

addVectors([1, 2, 3], [4, 5, 6])

In [None]:
methods(addVectors)

We can use multiple dispatch to return values of different types, dependent
upon the types of the arguments.

In [None]:
function return_input(x::String)::String
  x
end

function return_input(x::Int)::Int
  x
end

string_return = return_input("foo")
println("$string_return, $(typeof(string_return))")
int_return = return_input(0)
println("$int_return, $(typeof(int_return))")

methods(return_input)

However, it is not possible to use multiple dispatch *only* on the return type
of a function.

In [None]:
function return_something()::String
  "foo"
end

function return_something()::Int
  0
end

# Uncomment. The type assertion will produce a TypeError.
# string_return = return_something()::String
# println("$string_return, $(typeof(string_return))")
int_return = return_something()::Int
println("$int_return, $(typeof(int_return))")

methods(return_something)

## Type stability

Because Julia is JIT-compiled, we can get a huge performance boost over
interpreted languages by giving the compiler enough constraints to generate
optimized machine code.

https://docs.julialang.org/en/v1/manual/performance-tips/

There are many profiling packages for Julia:
- [AllocCheck.jl](https://github.com/JuliaLang/AllocCheck.jl)
- [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl)
- [Cthulhu.jl](https://github.com/JuliaDebug/Cthulhu.jl)
- [DispatchDoctor.jl](https://github.com/MilesCranmer/DispatchDoctor.jl)
- [JET.jl](https://github.com/aviatesk/JET.jl)
- [The *Profile* module](https://docs.julialang.org/en/v1/manual/profile/)

https://stackoverflow.com/questions/43453944/what-is-the-difference-between-code-native-code-typed-and-code-llvm-in-julia

There are two main ways to ensure type stability:
1. Use `@code_warntype` to check for type instability
2. Use `@inferred` to check for type stability

In [None]:
# Type unstable version.
function unstable_sum(n)
  s = 0 # s is inferred as Int.
  for i in 1:n
    if iseven(i)
      s += i       # Add an Int.
    else
      s += i * 1.0 # Adding a Float64 makes s become Union{Int, Float64}.
    end
  end
  return s
end

# Type stable version.
function stable_sum(n)
  s = 0.0 # s is now a Float64 from the start.
  for i in 1:n
    if iseven(i)
      s += i # i is automatically converted to Float64.
    else
      s += i * 1.0
    end
  end
  return s
end

In [None]:
# Check type stability.
@code_warntype unstable_sum(10)

In [None]:
# Check type stability.
@code_warntype stable_sum(10)

In [None]:
using Test

In [None]:
@inferred unstable_sum(10)

In [None]:
@inferred stable_sum(10)

In `stable_sum`, all of our variables have a concrete inferred type and the
concrete return type is correctly inferred, so we have achieved type stability
and can expect efficient performance.

Let's benchmark and compare performance.

In [None]:
# Benchmark.
ns = 1:1e3
@time unstable_result = [unstable_sum(n) for n in ns]
println("$(length(unstable_result)) many unstable sums computed.")
@time stable_result = [stable_sum(n) for n in ns]
println("$(length(stable_result)) many stable sums computed.")

In this simple case, we tend to get a ~3x speedup from ensuring type stability.