# Defining data types

We can define types (i.e. data structures) ourselves using the `struct` keyword.

It is a convention that type names are capitalized and [camel cased](https://en.wikipedia.org/wiki/Camel_case).

(Note that types can not be redefined - you have to restart your Julia session to change a type definiton.)

In [1]:
struct MyType end

To create an object of type `MyType` we have to call a [constructor](https://docs.julialang.org/en/latest/manual/constructors/). Loosely speaking, a constructor is a function that create new objects.

Julia automatically creates a trivial constructors for us, which has the same name as the type.

In [2]:
methods(MyType)

In [3]:
m = MyType()

MyType()

In [4]:
typeof(m)

MyType

In [5]:
m isa MyType

true

Since no data is contained in our `MyType`  - it is a so-called *singleton type* - we can basically only use it for dispatch.

Most of the time, we'll want a self-defined type to hold some data. For this, we need *fields*.

In [6]:
struct A
    x::Int64
end

In [7]:
A()

MethodError: MethodError: no method matching A()
Closest candidates are:
  A(!Matched::Int64) at In[6]:2
  A(!Matched::Any) at In[6]:2

The default constructor always expects values for all fields.

In [8]:
A(3)

A(3)

In [9]:
a = A(3)

A(3)

In [10]:
# a.<TAB>
a.x

3

Note that types defined with `struct` are **immutable**, that is the values of it's fields cannot be changed.

In [11]:
a.x = 2

ErrorException: setfield! immutable struct of type A cannot be changed

In [12]:
mutable struct B
    x::Int64
end

In [13]:
b = B(3)

B(3)

In [14]:
b.x

3

In [15]:
b.x = 4

4

In [16]:
b.x

4

Note, however, that **immutability is not recursive**.

In [17]:
struct C
    x::Vector{Int64}
end

In [18]:
c = C([1, 2, 3])

C([1, 2, 3])

In [19]:
c.x

3-element Array{Int64,1}:
 1
 2
 3

In [20]:
c.x = [3,4,5]

ErrorException: setfield! immutable struct of type C cannot be changed

In [22]:
c.x[1] = 3

3

In [23]:
c.x

3-element Array{Int64,1}:
 3
 2
 3

In [24]:
c.x .= [3,4,5] # dot to perform the assignment element-wise

3-element Array{Int64,1}:
 3
 4
 5

Abstract types are just as easy to define using the keyword `abstract type`.

In [25]:
abstract type MyAbstractType end

Since abstract types don't have fields, they only (informally) define interfaces and can be used for dispatch.

In [26]:
struct MyConcreteType <: MyAbstractType # subtype
    somefield::String
end

In [27]:
c = MyConcreteType("test")

MyConcreteType("test")

In [28]:
c isa MyAbstractType

true

In [29]:
supertype(MyConcreteType)

MyAbstractType

In [30]:
subtypes(MyAbstractType)

1-element Array{Any,1}:
 MyConcreteType

# Custom constructor

In [31]:
struct VolNaive
    value::Float64
end

In [32]:
VolNaive(3.0)

VolNaive(3.0)

In [33]:
VolNaive(-3.0)

VolNaive(-3.0)

In [34]:
struct VolSimple
    value::Float64
    
    function VolSimple(x) # inner constructor. function name must match the type name.
        if !(x isa Real)
            throw(ArgumentError("Must be real"))
        end
        if x < 0
            throw(ArgumentError("Negative volume not allowed."))
        end
        new(x) # within an inner constructor, the `new` function can be used to create an object.
    end
end

---

**Side note:**

```julia
if !(x isa Real)
    throw(ArgumentError("Must be real"))
end
if x < 0
    throw(ArgumentError("Negative volume not allowed."))
end
```

This can be written more compactly as
```julia
x isa Real || throw(ArgumentError("Must be real"))
x < 0 && throw(ArgumentError("Negative volume not allowed."))
```

See ["short-circuit evaluation"](https://docs.julialang.org/en/latest/manual/control-flow/#Short-Circuit-Evaluation-1) for more information.

---

In [35]:
VolSimple(3.0)

VolSimple(3.0)

In [36]:
VolSimple(-3.0)

ArgumentError: ArgumentError: Negative volume not allowed.

In [37]:
VolSimple("test")

ArgumentError: ArgumentError: Must be real

In [38]:
VolSimple(3) # implicit conversion from Int64 -> Float64

VolSimple(3.0)

# Parametric types

Volumes don't have to be `Float64` values. We can easily relax our type definition to allow all sorts of internal value types.

In [39]:
struct VolParam{T}
    value::T
    
    function VolParam(x::T) where T # x can be of any type T
        if !(x isa Real)
            throw(ArgumentError("Must be real"))
        end
        if x < 0
            throw(ArgumentError("Negative volume not allowed."))
        end
        new{T}(x) # Note that we need an extra {T} here
    end
end

In [40]:
VolParam(3.0)

VolParam{Float64}(3.0)

In [41]:
VolParam(3)

VolParam{Int64}(3)

Instead of checking the realness of the input `x` explicitly in the inner constructor, we can impose type constraints in the type and function signatures.

In [42]:
struct Vol{T<:Real} <: Real # the last <: Real tells Julia that a Vol is a subtype of Real, i.e. basically a real number
    value::T
    
    function Vol(x::T) where T<:Real # x can be of any type T<:Real
        x < 0 && throw(ArgumentError("Negative volume not allowed."))
        new{T}(x)
    end
end

In [43]:
Vol(3)

Vol{Int64}(3)

In [44]:
Vol(3.0)

Vol{Float64}(3.0)

In [45]:
Vol("1.23")

MethodError: MethodError: no method matching Vol(::String)
Closest candidates are:
  Vol(!Matched::Complex) where T<:Real at complex.jl:37
  Vol(!Matched::T<:Real) where T<:Real at In[42]:5
  Vol(!Matched::T<:Number) where T<:Number at boot.jl:718
  ...

In [46]:
Vol(-2)

ArgumentError: ArgumentError: Negative volume not allowed.

# Arithmetic

In [47]:
Vol(3) + Vol(4)

ErrorException: + not defined for Vol{Int64}

In [48]:
+(x::Vol, y::Vol) = Vol(x.value + y.value)

ErrorException: error in method definition: function Base.+ must be explicitly imported to be extended

If we want to extend or override functions that already exit, we need to `import` them first.

In [49]:
import Base: +

+(x::Vol, y::Vol) = Vol(x.value + y.value)

+ (generic function with 162 methods)

In [50]:
Vol(3) + Vol(4)

Vol{Int64}(7)

In [51]:
Vol(2) + Vol(8.3) # implicit conversion!

Vol{Float64}(10.3)

In [52]:
methodswith(Vol)

In [53]:
import Base: -, *

-(x::Vol, y::Vol) = Vol(x.value - y.value)
*(x::Vol, y::Vol) = Vol(x.value * y.value)

* (generic function with 355 methods)

Now that we have addition defined for our volume type, some functions already **just work**.

In [54]:
sum([Vol(3), Vol(4.8), Vol(1)])

Vol{Float64}(8.8)

In [55]:
M = Vol.(rand(3,3))

3×3 Array{Vol{Float64},2}:
 Vol{Float64}(0.453224)  Vol{Float64}(0.39036)   Vol{Float64}(0.442347)  
 Vol{Float64}(0.749864)  Vol{Float64}(0.611579)  Vol{Float64}(0.00645594)
 Vol{Float64}(0.639554)  Vol{Float64}(0.599829)  Vol{Float64}(0.93067)   

In [56]:
N = Vol.(rand(3,3))

3×3 Array{Vol{Float64},2}:
 Vol{Float64}(0.675248)  Vol{Float64}(0.338185)  Vol{Float64}(0.203866)
 Vol{Float64}(0.777313)  Vol{Float64}(0.372231)  Vol{Float64}(0.56093) 
 Vol{Float64}(0.509858)  Vol{Float64}(0.369495)  Vol{Float64}(0.483703)

In [57]:
M + N

3×3 Array{Vol{Float64},2}:
 Vol{Float64}(1.12847)  Vol{Float64}(0.728545)  Vol{Float64}(0.646213)
 Vol{Float64}(1.52718)  Vol{Float64}(0.98381)   Vol{Float64}(0.567386)
 Vol{Float64}(1.14941)  Vol{Float64}(0.969325)  Vol{Float64}(1.41437) 

Whenever something doesn't work, we implement the necessary functions.

In [58]:
sin(Vol(3))

MethodError: MethodError: no method matching AbstractFloat(::Vol{Int64})
Closest candidates are:
  AbstractFloat(::Real, !Matched::RoundingMode) where T<:AbstractFloat at rounding.jl:194
  AbstractFloat(::T<:Number) where T<:Number at boot.jl:718
  AbstractFloat(!Matched::Bool) at float.jl:252
  ...

In [59]:
import Base: AbstractFloat
AbstractFloat(x::Vol{T}) where T = AbstractFloat(x.value)

AbstractFloat

In [60]:
sin(Vol(3))

0.1411200080598672

In [61]:
sqrt(Vol(3))

1.7320508075688772

If we really wanted to have `Vol{T}` objects behave like real numbers in all operations, we'd have to do a bit more work like specifying [promotion and conversion rules](https://docs.julialang.org/en/latest/manual/conversion-and-promotion/).

An important thing to note is that **user defined types are just as good as built-in types**!

There is nothing special about built-in types. In fact, [they are implemented in the same way](https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/diagonal.jl#L5)!

Let us quickly confirm that our volume "wrapper" type does not come with any performance overhead by benchmarking it in a simple function.

# Benchmarking with `BenchmarkTools.jl`

In [62]:
using BenchmarkTools

In [63]:
operation(x) = x^2 + sqrt(x)

operation (generic function with 1 method)

In [64]:
x = rand(2,2)
@time operation.(x)

  0.061957 seconds (209.45 k allocations: 10.382 MiB)


2×2 Array{Float64,2}:
 1.55068   0.291338
 0.442158  0.254889

In [65]:
function f()
    x = rand(2,2)
    @time operation.(x)
end

f (generic function with 1 method)

In [66]:
f()

  0.000000 seconds (1 allocation: 112 bytes)


2×2 Array{Float64,2}:
 0.430906  0.977564
 0.209756  0.608237

We should wrap benchmarks into functions!

Fortunately, there are tools that do this for us. In addition, they also collect some statistics by running the benchmark multiple times.

In [67]:
@benchmark operation.(x)

BenchmarkTools.Trial: 
  memory estimate:  144 bytes
  allocs estimate:  3
  --------------
  minimum time:     286.960 ns (0.00% GC)
  median time:      307.973 ns (0.00% GC)
  mean time:        377.350 ns (7.01% GC)
  maximum time:     159.375 μs (99.78% GC)
  --------------
  samples:          10000
  evals/sample:     276

Typically we don't need all this information. Just use `@btime` instead of `@time`!

In [69]:
@btime operation.(x);

  287.500 ns (3 allocations: 144 bytes)


However, we still have to take some care to avoid accessing global variables.

In [70]:
@btime operation.($x); # interpolate the value of x into the expression to avoid overhead of globals

  31.187 ns (1 allocation: 112 bytes)


More information: [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl/blob/master/doc/manual.md).

Finally, we can check the performance of our custom volume type.

In [75]:
@btime sqrt(Vol(3));
@btime sqrt(3);

  1.299 ns (0 allocations: 0 bytes)
  1.299 ns (0 allocations: 0 bytes)


# Core messages of this Notebook

* There are `mutable struct`s and immutable `struct`s and immutability is not recursive.
* **Contructors** are functions that create objects. In an inner constructor we can use the function `new` to generate objects.
* We can easily **extend `Base` functions** for our types to implement arithmetics and such.
* We should always benchmark our code with **BenchmarkTools.jl's @btime and @benchmark**.

# Exercise: One-hot vector

[One-hot encoding](https://en.wikipedia.org/wiki/One-hot) is useful in machine learning, as we'll see later.

It simply means that among a group of bits (all either 0 or 1) only one is hot (1) while all others are cold (0),

`v = [0, 0, 0, 0, 0, 1, 0, 0, 0]`

### Task

1. Think about what information an implementation of a one-hot vector actually has to store.
2. Define a `OneHot` type which represents a vector with only a single hot (i.e. `== 1`) bit.
3. Extend all the necessary `Base` functions such that the following computation works for a matrix `A` and a vector of `OneHot` vectors `vs` (i.e. `vs isa Vector{OneHot}`).

    ```julia
    function innersum(A, vs)
        t = zero(eltype(A)) # generic!
        for v in vs
            y = A*v
            for i in 1:length(vs[1])
                t += v[i] * y[i]
            end
        end
        return t
    end

    A = rand(3,3)
    vs = [rand(3) for _ = 1:10] # This should be replaced by a `Vector{OneHot}`

    innersum(A, vs)

    ```

4. Benchmark the speed of `innersum` when called with a `OneHot` vector or with a `Vector{Float64}`, respectively.
 * Do you observe a speed up?


5. Now, define a `OneHotVector` type which is identical to `OneHot` but is declared to be a subtype of `AbstractVector{Bool}` and extend only the functions `Base.getindex(v::OneHotVector, i::Int)` and `Base.size(v::OneHotVector)`.
 * Here, the function `size` should return a `Tuple{Int64}` indicating the length of the vector, i.e. `(3,)` for a one-hot vector of length 3 (see the [AbstractArray interface](https://docs.julialang.org/en/latest/manual/interfaces/#man-interface-array-1) for more information)
 

6. Try to create a single `OneHotVector` and try to run the `innersum` function using the new `OneHotVector` type.
 * What changes do you observe?
 * Do you have to implement any further methods?