In [1]:
#| include: false
using Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()
cd(@__DIR__)

[32m[1m  Activating[22m[39m project at `~/gitrepos/kdheepak.github.io/blog/effect-of-type-inference-on-performance-in-julia`


In [6]:
versioninfo(verbose=false)

Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 10 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)
Environment:
  JULIA_PROJECT = @.


In Julia, to ensure that the code you write executes fast and efficiently, it is important to benchmark frequently. There's lots of really useful tips in the [Performance Tips] section in the official documentation.

[Performance Tips](https://docs.julialang.org/en/v1/manual/performance-tips/)

In this blog post, I want to touch on one specific performance tip: containers with abstract types and type inference.

# Toy problem

Let's define a toy problem to work with.

In [9]:
abstract type Shape end
area(::Shape) = 0.0

@kwdef struct Square <: Shape
    side::Float64 = rand()
end
area(s::Square) = s.side * s.side
    
@kwdef struct Rectangle <: Shape
    width::Float64 = rand()
    height::Float64 = rand()
end
area(r::Rectangle) = r.width * r.height
    
@kwdef struct Triangle <: Shape
    base::Float64 = rand()
    height::Float64 = rand()
end
area(t::Triangle) = 1.0/2.0 * t.base * t.height

@kwdef struct Circle <: Shape
    radius::Float64 = rand()
end
area(c::Circle) = π * c.radius^2

area (generic function with 5 methods)

We can use the builtin `Test` module to check that the code we wrote is correct.

In [10]:
using Test
@testset "Areas" begin
    @test area(Square(2)) == 4
    @test area(Rectangle(2,3)) == 6
    @test area(Triangle(2,3)) == 3
    @test area(Circle(2)) ≈ 4π
end;

[0m[1mTest Summary: | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
Areas         | [32m   4  [39m[36m    4  [39m[0m0.1s


Let's also build 1 million random shapes.

In [11]:
using Random

Random.seed!(42)

function shape_builder(choice::Integer)
    if choice == 1
        Square(rand())
    elseif choice == 2
        Rectangle(rand(), rand())
    elseif choice == 3
        Triangle(rand(), rand())
    elseif choice == 4
        Circle(rand())
    end
end

count = 1_000_000
shapes = [shape_builder(rand((1,2,3,4))) for _ in 1:count];

In [12]:
#| echo: false
using Format
using Markdown
l = cfmt("%'d", length(shapes))
Markdown.md"The total number of shapes we have is $l."

The total number of shapes we have is 1,000,000.


# Type inference

We can use the `typeof` function to see what the type of the data in the `shapes` variable is:

In [23]:
typeof(shapes)

Vector{Shape}[90m (alias for [39m[90mArray{Shape, 1}[39m[90m)[39m

By default, Julia will infer the type at the bottom of the tree that fits all the data in the container.
For example, if we just built a vector with the same elements (e.g. `Square`), Julia will infer the container to be `Vector{Square}`.

In [27]:
typeof([shape_builder(rand((1,))) for _ in 1:10])

Vector{Square}[90m (alias for [39m[90mArray{Square, 1}[39m[90m)[39m

In [28]:
println(join(string.(supertypes(Square)), " <: "))

Square <: Shape <: Any


Let's define a function that calculates the `area` for all the shapes and adds them all up

In [29]:
main1(shapes) = sum(area, shapes)

main1 (generic function with 1 method)

We can test this function and precompile it by running it once.

In [30]:
@time main1(shapes)

  0.119760 seconds (2.00 M allocations: 30.552 MiB, 49.17% gc time, 1.84% compilation time)


439078.977716569

## Example of type inference failing

Unfortunately, it can be easy to accidentally construct a container with an abstract type for the type parameter of a generic type.

In [42]:
bad_shapes_by_type(::Type{T}, shapes) where T = filter(s -> isa(s, T), shapes)

shape_arr1 = bad_shapes_by_type(Square, shapes)
shape_arr2 = bad_shapes_by_type(Rectangle, shapes)
shape_arr3 = bad_shapes_by_type(Triangle, shapes)
shape_arr4 = bad_shapes_by_type(Circle, shapes)

@show typeof(shape_arr1)
@show typeof(shape_arr2)
@show typeof(shape_arr3)
@show typeof(shape_arr4)
nothing

typeof(shape_arr1) = Vector{Shape}
typeof(shape_arr2) = Vector{Shape}
typeof(shape_arr3) = Vector{Shape}
typeof(shape_arr4) = Vector{Shape}


For better performance, it helps to have concrete types in the generic parameters for a container.

In [40]:
good_shapes_by_type(::Type{T}, shapes) where T = [shape for shape in shapes if isa(shape, T)]

square_arr = good_shapes_by_type(Square, shapes)
rectangle_arr = good_shapes_by_type(Rectangle, shapes)
triangle_arr = good_shapes_by_type(Triangle, shapes)
circle_arr = good_shapes_by_type(Circle, shapes)

@show typeof(square_arr)
@show typeof(rectangle_arr)
@show typeof(triangle_arr)
@show typeof(circle_arr)
nothing

typeof(square_arr) = Vector{Square}
typeof(rectangle_arr) = Vector{Rectangle}
typeof(triangle_arr) = Vector{Triangle}
typeof(circle_arr) = Vector{Circle}


In [60]:
sorted_shapes_shape = vcat(square_arr, rectangle_arr, triangle_arr, circle_arr);
sorted_shapes_any = Any[s for s in sorted_shapes_shape];
sorted_shapes_union = Union{Square, Rectangle, Triangle, Circle}[s for s in sorted_shapes_shape];

@show typeof(sorted_shapes_shape)
@show typeof(sorted_shapes_any)
@show typeof(sorted_shapes_union);

typeof(sorted_shapes_shape) = Vector{Shape}
typeof(sorted_shapes_any) = Vector{Any}
typeof(sorted_shapes_union) = Vector{Union{Circle, Rectangle, Square, Triangle}}


# Benchmarks

We can benchmark the performance of these different types using `BenchmarkTools`:

In [33]:
using BenchmarkTools

@benchmark main1(shapes)

BenchmarkTools.Trial: 127 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m38.655 ms[22m[39m … [35m 40.566 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.75%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m39.363 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.70%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m39.403 ms[22m[39m ± [32m404.816 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.50% ± 0.34%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m█[39m [39m▂[39m▄[39m▂[39m▄[39m [39m [39m [39m [39m▂[39m [34m▆[39m[39m [32m▄[39m[39m [39m▂[39m [39m [39m [39m [39m [39m [39m [39m▄[39m [39m [39m [39m [39m▂[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▆[39m▆[39m▆[39m▄[

Both benchmarks for `Vector{Shape}` and `Vector{Any}` can be inefficient. 

The Julia manual has the following to say:

> If you cannot avoid containers with abstract value types, it is sometimes better to parametrize with `Any` to avoid runtime type checking. E.g. IdDict{Any, Any} performs better than IdDict{Type, Vector}

In [61]:
@benchmark main1(sorted_shapes_shape)

BenchmarkTools.Trial: 150 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m30.877 ms[22m[39m … [35m38.375 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 12.04%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m31.742 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m33.335 ms[22m[39m ± [32m 2.679 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m4.81% ±  5.99%

  [39m█[39m▆[39m▂[39m [39m▁[39m▁[39m [34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[39m█[39m█[39m▆[39m█[39m█[3

In [64]:
@benchmark main1(sorted_shapes_any)

BenchmarkTools.Trial: 148 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m31.392 ms[22m[39m … [35m39.047 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 13.01%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m32.419 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m33.877 ms[22m[39m ± [32m 2.629 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m4.55% ±  5.71%

  [39m█[39m▄[39m▃[39m▃[39m▃[39m [39m [39m [34m [39m[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[39m█[39m█[39m█[39m█[39m█[3

However, what is interesting is that `Vector{Union{Circle, Rectangle, Square, Triangle}}` has a concrete type parameter for the `Vector` container.

In [67]:
@show isconcretetype(Shape)
@show isconcretetype(Union{Circle, Rectangle, Square, Triangle});

isconcretetype(Shape) = false
isconcretetype(Union{Circle, Rectangle, Square, Triangle}) = false


You can see difference show up clearly in the performance benchmarks.

In [77]:
#| echo: false
using Format
f = cfmt("%'d", 33.877 * 1e3 / 939)
Markdown.md"`Union{Circle, Rectangle, Square, Triangle}` is faster than `Shape` by a factor of roughly $f times."

`Union{Circle, Rectangle, Square, Triangle}` is faster than `Shape` by a factor of roughly 36 times.


In [72]:
@benchmark main1(sorted_shapes_union)

BenchmarkTools.Trial: 5319 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m927.708 μs[22m[39m … [35m 1.203 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m934.917 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m938.103 μs[22m[39m ± [32m11.475 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m [39m [39m [39m▅[39m█[39m▅[34m▆[39m[39m▆[39m▂[32m▃[39m[39m▁[39m▂[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁
  [39m▅[39m▆[39m▇[39m█[39m█

It's possible to get even better performance by calculating the `sum`s for the individual arrays and summing them up together

In [79]:
main2(arrs...) = sum(main1, arrs)

main2 (generic function with 1 method)

In [80]:
@time main2(square_arr, rectangle_arr, triangle_arr, circle_arr);

  0.087025 seconds (271.86 k allocations: 18.823 MiB, 99.55% compilation time)


In [81]:
@benchmark main2(square_arr, rectangle_arr, triangle_arr, circle_arr)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m276.000 μs[22m[39m … [35m383.500 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m278.959 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m281.081 μs[22m[39m ± [32m  6.649 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m [39m [39m▃[39m▇[34m█[39m[39m▅[39m▁[39m▂[32m▅[39m[39m▅[39m▂[39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁
  [39m▂[39m▆[39m█[3

# Conclusion

The key takeaway is that if you care about performance in Julia, you have to be mindful of types! Keeping types as concrete as possible is important because when type inference fails, it can propogate through your program. Even small changes to your code can improve performance significantly.

Many thanks to the helpful [Julia community on Discourse](https://discourse.julialang.org/t/unusual-non-deterministic-benchmark-results/113273/) for always offering insightful comments and advise.