# Custom Types

## Defining data types

We can define types (i.e. data structures) ourselves using the `struct` keyword.

It is a convention that type names are capitalized and [camel cased](https://en.wikipedia.org/wiki/Camel_case).

(**Note that types can not be redefined** - you have to restart your Julia session to change a type definiton.)

In [1]:
struct MyType end

To create an object of type `MyType` we have to call a [constructor](https://docs.julialang.org/en/v1/manual/constructors/). Loosely speaking, a constructor is a function that create new objects.

Julia automatically creates a trivial constructors for us, which has the same name as the type.

In [2]:
methods(MyType)

In [3]:
m = MyType()

MyType()

In [4]:
typeof(m)

MyType

In [5]:
m isa MyType

true

Since no data is contained in our `MyType`  - it is a so-called *singleton type* - we can basically only use it for dispatch.

Most of the time, we'll want a self-defined type to hold some data. For this, we need *fields*.

In [6]:
struct Point
    x::Float64
    y::Float64
end

In [7]:
Point()

LoadError: MethodError: no method matching Point()
[0mClosest candidates are:
[0m  Point([91m::Float64[39m, [91m::Float64[39m) at In[6]:2
[0m  Point([91m::Any[39m, [91m::Any[39m) at In[6]:2

The default constructor always expects values for all fields.

In [8]:
Point(1.2, 3.4)

Point(1.2, 3.4)

In [9]:
p = Point(1.2, 3.4)

Point(1.2, 3.4)

In [10]:
# a.<TAB>
p.x

1.2

Note that types defined with `struct` are **immutable**, that is the values of it's fields cannot be changed.

In [11]:
p.x = 2

LoadError: setfield!: immutable struct of type Point cannot be changed

In [12]:
mutable struct Point2
    x::Float64
    y::Float64
end

In [13]:
p = Point2(1.2, 3.4)

Point2(1.2, 3.4)

In [14]:
p.x = 4

4

In [15]:
p

Point2(4.0, 3.4)

Abstract types are just as easy to define using the keyword `abstract type`.

In [16]:
abstract type AbstractPoint end

Since abstract types don't have fields, they only (informally) define interfaces and can be used for dispatch.

In [17]:
struct Point3 <: AbstractPoint
    x::Float64
    y::Float64
end

In [18]:
c = Point3(1.2, 3.4)

Point3(1.2, 3.4)

In [19]:
c isa AbstractPoint

true

In [20]:
supertype(Point3)

AbstractPoint

In [21]:
subtypes(AbstractPoint)

1-element Vector{Any}:
 Point3

## Example: Diagonal Matrix

In [22]:
struct DiagMat
    diag::Vector{Float64}
end

In [23]:
DiagMat([1.2,4.3,5.0])

DiagMat([1.2, 4.3, 5.0])

### Arithmetic

In [24]:
import Base: +, -, *, /

+(Da::DiagMat, Db::DiagMat) = DiagMat(Da.diag .+ Db.diag)
-(Da::DiagMat, Db::DiagMat) = DiagMat(Da.diag .- Db.diag)
*(Da::DiagMat, Db::DiagMat) = DiagMat(Da.diag .* Db.diag)
/(Da::DiagMat, Db::DiagMat) = DiagMat(Da.diag ./ Db.diag)

/ (generic function with 112 methods)

In [25]:
D1 = DiagMat([1,2,3])
D2 = DiagMat([2.4,1.9,5.7])

DiagMat([2.4, 1.9, 5.7])

In [26]:
D1 + D2

DiagMat([3.4, 3.9, 8.7])

In [27]:
D1 - D2

DiagMat([-1.4, 0.10000000000000009, -2.7])

In [28]:
D1 * D2

DiagMat([2.4, 3.8, 17.1])

In [29]:
D1 / D2

DiagMat([0.4166666666666667, 1.0526315789473684, 0.5263157894736842])

Arithmetics involving other types:

In [30]:
# Number
*(x::Number, D::DiagMat) = DiagMat(x * D.diag)
*(D::DiagMat, x::Number) = DiagMat(D.diag * x)
/(D::DiagMat, x::Number) = DiagMat(D.diag / x)

# Vector
*(D::DiagMat, V::AbstractVector) = D.diag .* V

* (generic function with 368 methods)

In [31]:
D1 * 2

DiagMat([2.0, 4.0, 6.0])

In [32]:
D1 * rand(3)

3-element Vector{Float64}:
 0.135921160356552
 0.4426045917200139
 2.453034600712555

Note that some functions already work for our `DiagonalMat`:

In [33]:
sum([D1, D2])

DiagMat([3.4, 3.9, 8.7])

### Parameterization

In [34]:
DiagMat([1,2,3]) # implicit conversion to Vector{Float64}

DiagMat([1.0, 2.0, 3.0])

In [35]:
DiagMat([1+3im, 4-2im, im])

LoadError: InexactError: Float64(1 + 3im)

In [36]:
DiagMat(["Why", "not", "support", "strings?"])

LoadError: MethodError: [0mCannot `convert` an object of type [92mString[39m[0m to an object of type [91mFloat64[39m
[0mClosest candidates are:
[0m  convert(::Type{T}, [91m::T[39m) where T<:Number at /opt/julia-1.7.3/share/julia/base/number.jl:6
[0m  convert(::Type{T}, [91m::Number[39m) where T<:Number at /opt/julia-1.7.3/share/julia/base/number.jl:7
[0m  convert(::Type{T}, [91m::Base.TwicePrecision[39m) where T<:Number at /opt/julia-1.7.3/share/julia/base/twiceprecision.jl:262
[0m  ...

We can easily relax our type definition to allow all sorts of internal value types.

In [37]:
struct DiagMatParam{T, V<:AbstractVector{T}}
    diag::V
end

Let's (again) define some arithmetics.

In [38]:
# Essentially copied from above
import Base: +, -, *, /
+(Da::DiagMatParam, Db::DiagMatParam) = DiagMatParam(Da.diag .+ Db.diag)
-(Da::DiagMatParam, Db::DiagMatParam) = DiagMatParam(Da.diag .- Db.diag)
*(Da::DiagMatParam, Db::DiagMatParam) = DiagMatParam(Da.diag .* Db.diag)
/(Da::DiagMatParam, Db::DiagMatParam) = DiagMatParam(Da.diag ./ Db.diag)
# Number
*(x::Number, D::DiagMatParam) = DiagMatParam(x * D.diag)
*(D::DiagMatParam, x::Number) = DiagMatParam(D.diag * x)
/(D::DiagMatParam, x::Number) = DiagMatParam(D.diag / x)
# Vector
*(D::DiagMatParam, V::AbstractVector) = D.diag .* V

* (generic function with 372 methods)

In [39]:
DiagMatParam([1+3im, 4-2im, im])

DiagMatParam{Complex{Int64}, Vector{Complex{Int64}}}(Complex{Int64}[1 + 3im, 4 - 2im, 0 + 1im])

In [40]:
DiagMatParam(["This ", "just "]) * DiagMatParam(["should", "work!"])

DiagMatParam{String, Vector{String}}(["This should", "just work!"])

### Duck typing: `AbstractArray`

Let's **integrate our diagonal matrix into Julia's type hierarchy** by subtyping `AbstractArray{T,2}` aka `AbstractMatrix{T}`.

Of course, our diagonal matrix type should then better behave as an array! A minimal list of methods that we should define are specified by the [`AbstractArray` interface](https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array-1)!

In [41]:
struct DiagonalMatrix{T, V<:AbstractVector{T}} <: AbstractMatrix{T}
    diag::V
end

In [42]:
# implement AbstractArray interface
Base.size(D::DiagonalMatrix) = (length(D.diag), length(D.diag))

function Base.getindex(D::DiagonalMatrix{T,V}, i::Int, j::Int) where {T,V}
    if i == j
        r = D.diag[i]
    else
        r = zero(T)
    end
    return r
end

function setindex!(D::DiagonalMatrix, v, i::Int, j::Int)
    if i == j
        D.diag[i] = v
    else
        throw(ArgumentError("cannot set off-diagonal entry ($i, $j)"))
    end
    return v
end

setindex! (generic function with 1 method)

In [43]:
D = DiagonalMatrix([1,2,3])

3×3 DiagonalMatrix{Int64, Vector{Int64}}:
 1  0  0
 0  2  0
 0  0  3

Note how it's automagically pretty printed!

In [44]:
D * D

3×3 Matrix{Int64}:
 1  0  0
 0  4  0
 0  0  9

In [45]:
D + D

3×3 Matrix{Int64}:
 2  0  0
 0  4  0
 0  0  6

In [46]:
D - D

3×3 Matrix{Int64}:
 0  0  0
 0  0  0
 0  0  0

In [47]:
D / D

3×3 Matrix{Float64}:
 1.0  0.0  0.0
 0.0  1.0  0.0
 0.0  0.0  1.0

Basic arithmetics **just works!** What about broadcasting ("element-wise dot") and more complicated functions?

In [48]:
sin.(D)

3×3 Matrix{Float64}:
 0.841471  0.0       0.0
 0.0       0.909297  0.0
 0.0       0.0       0.14112

In [49]:
sum([D, D, D])

3×3 Matrix{Int64}:
 3  0  0
 0  6  0
 0  0  9

In [50]:
using LinearAlgebra
eigen(D)

Eigen{Float64, Float64, Matrix{Float64}, Vector{Float64}}
values:
3-element Vector{Float64}:
 1.0
 2.0
 3.0
vectors:
3×3 Matrix{Float64}:
 1.0  0.0  0.0
 0.0  1.0  0.0
 0.0  0.0  1.0

It is, of course, still advantageous to define fast versions that utilize the special diagonal structure:

In [51]:
@which D + D

In [52]:
@which 3 * D

In [53]:
import Base: +, *

+(Da::DiagonalMatrix, Db::DiagonalMatrix) = DiagonalMatrix(Da.diag + Db.diag)
*(x::Number, D::DiagonalMatrix) = DiagonalMatrix(x * D.diag)

* (generic function with 373 methods)

In [54]:
@which D + D

In [55]:
@which 3 * D

In [57]:
import Pkg;
Pkg.add("BenchmarkTools")

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m   Installed[22m[39m BenchmarkTools ─ v1.3.1
[32m[1m    Updating[22m[39m `~/.julia/environments/v1.7/Project.toml`
 [90m [6e4b80f9] [39m[92m+ BenchmarkTools v1.3.1[39m
[32m[1m    Updating[22m[39m `~/.julia/environments/v1.7/Manifest.toml`
 [90m [6e4b80f9] [39m[92m+ BenchmarkTools v1.3.1[39m
 [90m [37e2e46d] [39m[92m+ LinearAlgebra[39m
 [90m [9abbd945] [39m[92m+ Profile[39m
 [90m [2f01184e] [39m[92m+ SparseArrays[39m
 [90m [10745b16] [39m[92m+ Statistics[39m
 [90m [e66e0078] [39m[92m+ CompilerSupportLibraries_jll[39m
 [90m [4536629a] [39m[92m+ OpenBLAS_jll[39m
 [90m [8e850b90] [39m[92m+ libblastrampoline_jll[39m
[32m[1mPrecompiling[22m[39m project...
[32m  ✓ [39m[90mCompilerSupportLibraries_jll[39m
[32m  ✓ [39m[90mOpenBLAS_jll[39m
[32m  ✓ [39m[90mlibblastrampoline_jll[39m
[32m  ✓ [39mBe

An important thing to note is that **user defined types are just as good as built-in types**!

There is nothing special about built-in types. In fact, [they are implemented in precisely the same way](https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/diagonal.jl#L5)!

Let us quickly confirm that our `DiagonalMatrix` type does not come with any performance overhead by benchmarking it in a simple function.

# Benchmarking with `BenchmarkTools.jl`

In [58]:
using BenchmarkTools

In [59]:
operation(x) = x + 2*x

operation (generic function with 1 method)

In [60]:
x = rand(2,2)
@time operation.(x)

  0.065552 seconds (194.47 k allocations: 10.245 MiB, 99.59% compilation time)


2×2 Matrix{Float64}:
 2.33756  2.97858
 1.93508  1.4521

In [61]:
function f()
    x = rand(2,2)
    @time operation.(x)
end

f (generic function with 1 method)

In [62]:
f()

  0.000001 seconds (1 allocation: 96 bytes)


2×2 Matrix{Float64}:
 1.00282     2.26088
 0.00149101  1.55283

We should wrap benchmarks into functions!

Fortunately, there are tools that do this for us. In addition, they also collect some statistics by running the benchmark multiple times.

In [63]:
@benchmark operation.(x)

BenchmarkTools.Trial: 10000 samples with 205 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m374.005 ns[22m[39m … [35m 17.800 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 96.58%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m379.810 ns               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m401.040 ns[22m[39m ± [32m416.589 ns[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.72% ±  2.57%

  [39m▄[39m█[34m▇[39m[39m▅[39m▄[39m▃[39m▁[39m [39m▁[39m▂[32m▁[39m[39m▃[39m▃[39m▃[39m▂[39m▃[39m▂[39m▂[39m▂[39m▂[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[34

Typically we don't need all this information. Just use `@btime` instead of `@time`!

In [66]:
@btime operation.(x);

  381.564 ns (4 allocations: 160 bytes)


However, we still have to take some care to avoid accessing global variables.

In [67]:
@btime operation.($x); # interpolate the value of x into the expression to avoid overhead of globals

  35.437 ns (1 allocation: 96 bytes)


This is similar to string interpolation:

In [68]:
x = 42
s = "The answer to the ultimate question of life, the universe, and everything is $x !"

"The answer to the ultimate question of life, the universe, and everything is 42 !"

More information: [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl/blob/master/doc/manual.md).

Finally, we can check the performance of our custom volume type.

In [69]:
using LinearAlgebra
x = rand(100);
Djl = Diagonal(x)
D = DiagonalMatrix(x)
@btime operation($Djl);
@btime operation($D);

  111.798 ns (2 allocations: 1.75 KiB)
  223.540 ns (2 allocations: 1.75 KiB)


# Core messages of this Notebook

* **User defined types are as good as built-in types.**
* There are `mutable struct`s and immutable `struct`s.
* We can easily **extend `Base` functions** for our types to implement arithmetics and such.
* **Subtyping an existing interface** can give lots of functionality for free.
* We should always benchmark our code with **BenchmarkTools.jl's @btime and @benchmark**.