# Custom Types and Specialisation

The purpose of this notebook is to learn how to define custom data types. We will discuss how little effort is needed to integrate them into the Julia ecosystem (and benefit from pre-defined fallback functions) and how we can step-by-step exploit known structure to make computation faster.

## Defining data types

Apart from the built-in types, Julia also offers to declare custom types (i.e. data structures), for example:

In [None]:
struct MyType end

Note: By convention types are written in [camel case](https://en.wikipedia.org/wiki/Camel_case).

To obtain an obejct of type `MyType` one conventionally uses functions of the same name as the datatype
([constructors](https://docs.julialang.org/en/v1/manual/constructors/)).

A trivial default constructor is created by Julia automatically:

In [None]:
methods(MyType)

In [None]:
m = MyType()

In [None]:
typeof(m)

In [None]:
m isa MyType

Even though such empty data structures have an important use case in Julia (as a marker *singleton type* for dispatch), more frequently we will need types to hold some data.

In [None]:
struct A
    x::Int64  # A field
end

The default constructor expects values for all fields, in the order of appearance:

In [None]:
a = A(3)

In [None]:
a.x

In [None]:
a.x = 2

A key difference from Julia structs compared to equivalent constructs in other languages is that `struct`s are immutable, i.e. their fields cannot be changed. To make a struct mutable, add the `mutable` keyword:

In [None]:
mutable struct B
    x::Int64
end

b = B(3)
b.x = 4
b.x

In [None]:
struct C
    x::Vector{Int64}
end

In [None]:
c = C([1, 2, 3])

In [None]:
c.x

In [None]:
c.x[1] = -1

In [None]:
c.x

In [None]:
c.x = [4, 5, 6]

In [None]:
c.x .= [4, 5, 6]  # dot to perform the assignment element-wise

Defining and using abstract types:

In [None]:
abstract type MyAbstractType end  # No fields! Just an informal interface

In [None]:
struct MyConcreteType <: MyAbstractType  # subtype
    somefield::String
end

In [None]:
c = MyConcreteType("test")

In [None]:
c isa MyAbstractType

In [None]:
supertype(MyConcreteType)

In [None]:
subtypes(MyAbstractType)

## Example: Diagonal Matrix

In [None]:
struct DiagMat
    diag::Vector{Float64}
end

In [None]:
DiagMat([1.2, 4.3, 5.0])

### Arithmetic

In [None]:
import Base: +, -, *, /

+(Da::DiagMat, Db::DiagMat) = DiagMat(Da.diag + Db.diag)
-(Da::DiagMat, Db::DiagMat) = DiagMat(Da.diag - Db.diag)
*(Da::DiagMat, Db::DiagMat) = DiagMat(Da.diag .* Db.diag)
/(Da::DiagMat, Db::DiagMat) = DiagMat(Da.diag ./ Db.diag)

In [None]:
D1 = DiagMat([1,2,3])
D2 = DiagMat([2.4,1.9,5.7])

In [None]:
D1 + D2

In [None]:
D1 - D2

In [None]:
D1 * D2

In [None]:
D1 / D2

In [None]:
# Number
*(x::Number, D::DiagMat) = DiagMat(x * D.diag)
*(D::DiagMat, x::Number) = DiagMat(D.diag * x)
/(D::DiagMat, x::Number) = DiagMat(D.diag / x)

# Vector
*(D::DiagMat, V::AbstractVector) = D.diag .* V

In [None]:
D1 * 2

In [None]:
D1 * rand(3)

Note that `Base` Julia's generic fallback implementations give us some functionality for free:

In [None]:
2 * D1

In [None]:
sum([D1, D2])

### Parameterization and `AbstractArray` interface

That's a good start, but we can do better, because

In [None]:
DiagMat([1, 2, 3]) # implicit conversion to Vector{Float64} => BAD

In [None]:
DiagMat([1+3im, 4-2im, im])  # No complex number support?

In [None]:
DiagMat(["Why", "not", "support", "strings?"])  # Would be cool, wouldn't it?

with actually *less* lines of code we can get a more generic version and fully integrate into Julia's type hierarchy. We do this by defining a *parametric type*, which is subtyping `AbstractMatrix`:

In [None]:
struct DiagonalMatrix{T, V<:AbstractVector{T}} <: AbstractMatrix{T}
    diag::V
end

To integrate properly we implement the [`AbstractArray` interface convention](https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array-1):

In [None]:
# implement AbstractArray interface
Base.size(D::DiagonalMatrix) = (length(D.diag), length(D.diag))

function Base.getindex(D::DiagonalMatrix{T,V}, i::Int, j::Int) where {T, V}
    if i == j
        r = D.diag[i]
    else
        r = zero(T)
    end
    r
end

function setindex!(D::DiagonalMatrix, v, i::Int, j::Int)
    if i == j
        D.diag[i] = v
    else
        throw(ArgumentError("cannot set off-diagonal entry ($i, $j)"))
    end
    return v
end

Now we can:

In [None]:
DiagonalMatrix([1.0, 2.0, 3.0])

In [None]:
D = DiagonalMatrix([1, 2, 3])

Note the fancy pretty-pringing :)

In [None]:
D * D

In [None]:
D + D

In [None]:
D - D

In [None]:
D / D

Similarly basic arithmetic works without additional effort:

In [None]:
sin.(D)

In [None]:
sum([D, D, D])

In [None]:
using LinearAlgebra
eigvals(D)  # Compute the eigenvalues

Of note, these functions work, but they do not exploit the diagonal structure. So a few functions we should define, such that the compiler can make use of what we know about our type.

In [None]:
@which D + D

In [None]:
+(Da::DiagonalMatrix, Db::DiagonalMatrix) = DiagonalMatrix(Da.diag + Db.diag)
*(x::Number, D::DiagonalMatrix) = DiagonalMatrix(x * D.diag)

In [None]:
@which D + D

Certainly, defining all possible options is again a considerable effort. But we don't need to. In practice we would profile our code to identify bottle necks, e.g.

In [None]:
Dbig = DiagonalMatrix(randn(1000))
@time eigvals(Dbig);  # Ouch!

In [None]:
# Define faster version to improve this part of the code
LinearAlgebra.eigvals(D::DiagonalMatrix) = D.diag

In [None]:
@time eigvals(Dbig);  # Much better

Actually this implementation is pretty much the same to the [`Diagonal` implementation](https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/diagonal.jl#L5) in the `LinearAlgebra` package. 

## Exercise: One-hot vector

[One-hot encoding](https://en.wikipedia.org/wiki/One-hot) is useful for classification problems in machine learning.

It simply means that among a group of bits (all either `0` or `1`) only one is hot (`1`) while all others are cold (`0`). Usually this is stored as a vector

```julia
    v = [0, 0, 0, 0, 0, 1, 0, 0, 0]
```

1. What information does such a one-hot vector actually need to store?

2. Define a `OneHot` type which represents a vector with only a single hot (i.e. `== 1`) bit.

3. Extend all the necessary `Base` functions such that the following computation works
   for a matrix `M` and a vector of `OneHot` vectors `vs` (i.e. `vs isa Vector{OneHot}`):
   
    ```julia
    function innersum(M, vs)
        t = zero(eltype(M))
        for v in vs
            y = M * v
            for i in 1:length(vs[1])
                t += v[i] * y[i]
            end
        end
        t
    end

    M = rand(3, 3)
    vs = [rand(3) for i in 1:10] # This should be replaced by a `Vector{OneHot}`

    innersum(M, vs)

    ```

4. Benchmark the speed of `innersum` when called with a vector of `OneHot` vectors (i.e. `vs = [OneHot(3, rand(1:3)) for i in 1:10]`) and when called with a vector of `Vector{Float64}` vectors, respectively. Do you observe a speed up?


5. Define a new `OneHotVector` type which is identical to `OneHot` but is declared as a subtype of `AbstractVector{Bool}`. Extend only the interface-defining functions `Base.getindex(v::OneHotVector, i::Int)` and `Base.size(v::OneHotVector)`. Here, the function `size` should return a `Tuple{Int64}` indicating the length of the vector, i.e. `(3,)` for a one-hot vector of length 3.
 

6. Try to create a single `OneHotVector` and try to run the `innersum` function using the new `OneHotVector` type.
 * What changes do you observe?
 * Do you have to implement any further methods?