# Multithreading

## What are threads?
Threads are **execution units within a process** that can run simultaneously. While processes are separate, threads run in a **shared memory** space (heap).

<!-- <img src="./imgs/what-are-threads.png" width=500px> -->

<br>
<img src="imgs/stack_heap_threads.svg" width=450px>
<br>

## Starting Julia with multiple threads

By default, Julia starts with a single *user thread*. We must tell it explicitly to start multiple user threads. There are a couple of ways to do this:

* Environment variable: `export JULIA_NUM_THREADS=4`
* Command line argument: `julia --threads 4` or `julia -t 4`

**It is currently not (easily) possible to change the number of threads at runtime!**

For Jupyter, we create another kernel that starts Julia with multiple threads.

In [None]:
using IJulia
installkernel("Julia (4 threads)", "--project=@.", env=Dict("JULIA_NUM_THREADS"=>"4"))

Afterwards, we need to **refresh the page** and select the new `Julia (4 threads) 1.10` kernel in the top right corner. (Restart Jupyter if the kernel doesn't show up.)

We can readily check how many threads we are running:

In [None]:
using Base.Threads: nthreads
nthreads()

### User threads vs default threads

Technically, the Julia process is also spawning multiple threads already in "single-threaded" mode, like
* a thread for unix signal listening
* multiple OpenBLAS threads for BLAS/LAPACK operations
* GC threads

We call the threads that we can actually run computations on *user threads* or *Julia threads*.

In [None]:
using LinearAlgebra
BLAS.get_num_threads()

## Where are my threads running?

In [None]:
using ThreadPinning

In [None]:
threadinfo()

## Task-based multithreading

Julia implements **task-based** multithreading. In this paradigm, a task - e.g. a computational piece of a code - is marked for **parallel** execution on **any** of the Julia threads. Julia's **dynamic scheduler** will put the task on a thread and trigger the execution of the task.

<br>
<!-- <img src="imgs/task-based-parallelism.png" width=768px> -->
<img src="imgs/tasks_threads_cores.svg" width=650px>
</br>

Task-based multithreading: **The user should think about tasks and not threads**.
* By default, the user does not control on which thread a task will run (the task might even [migrate](https://docs.julialang.org/en/v1/manual/multi-threading/#man-task-migration) between threads!).

**Advantages:**
* high-level abstraction: one can spawn many tasks (>> number of threads)
* nestable multithreading

**Disadvantages:**
* dynamic scheduling overhead
* uncertainty and potentially suboptimal task → thread assignment
  * can get in the way when performance engineering

### Spawning tasks

In [None]:
using Base.Threads

In [None]:
@spawn 3+3

`@spawn` creates a `Task` and schedules it for execution on an available Julia thread (we don't control which one!).

Note that `Threads.@spawn` is **asynchronous** and **non-blocking**, that is, it doesn't wait for the task to actually run but immediately returns a `Task`.

We can fetch the result of a task with `fetch`.

In [None]:
t = @spawn 3+3
fetch(t)

While `@spawn` returns right away, `fetch` is **blocking** as it has to wait for the task to actually finish.

In [None]:
@time t = @spawn begin
    sleep(3)
    return 3+3
end
@time fetch(t)

We can use the macro `@sync` to synchronize all (lexically) encompassed asynchronous operations (`@spawn`).

In [None]:
@time @sync t = @spawn begin
    sleep(3)
    return 3+3
end
@time fetch(t)

#### Example: multithreaded `map`

`tmap`: *threaded map*

In [None]:
function tmap(f, collection)
    # for each x ∈ collection, spawn a task to compute f(x)
    tasks = map(collection) do x
        @spawn f(x)
    end
    # fetch and return all the results
    return fetch.(tasks)
end

In [None]:
M = [rand(200,200) for i in 1:8];

In [None]:
using LinearAlgebra: svdvals

In [None]:
tmap(svdvals, M)

In [None]:
using BenchmarkTools

In [None]:
@btime tmap($svdvals, $M) samples=10 evals=3;
@btime map($svdvals, $M) samples=10 evals=3;

If you use multithreading in Julia in combination with BLAS/LAPACK functions, it is important to carefully consider and configure the [interplay between Julia threads and BLAS threads](https://carstenbauer.github.io/ThreadPinning.jl/stable/explanations/blas/).

Easiest way out: turn of BLAS/LAPACK multithreading.

In [None]:
using LinearAlgebra: BLAS
BLAS.set_num_threads(1)

In [None]:
@btime tmap($svdvals, $M) samples=10 evals=3;

#### Example: multithreading for-loops

In [None]:
using OhMyThreads.Tools: taskid

In [None]:
@sync for i in 1:8
    @spawn println("Task ", taskid(), " is running iteration ", i, " on thread ", threadid())
end

#### Example: nestable multithreading

Recursive Fibonacci series

$$ F(n) = F(n-1) + F(n-2), \qquad F(1) = F(2) = 1$$

(Note: Algorithmically, this is a highly inefficient implementation of the Fibonacci series!)

In [None]:
function fib(n)
    n < 2 && return n
    t = @spawn fib(n-2)
    return fib(n-1) + fetch(t)
end

We are nesting `@spawn` calls recursively!

In [None]:
fib(20)

In [None]:
tmap(fib, 1:20) # multithreaded tmap applying a multithreaded fib

### Load-balancing and chunking

If there are many tasks (e.g. many more than available threads), Julia's scheduler balances the load of these tasks among threads.

In [None]:
using OhMyThreads: chunks, index_chunks

In [None]:
x = rand(10)
collect(chunks(data; n=3)) # chunks hold elements of x (views)

In [None]:
collect(index_chunks(x; n=3)) # chunks hold indices of elements of x

In [None]:
# this function is purely pedagogical
function tmap_tracking(f, collection; tracker = [UnitRange[] for _ in 1:nthreads()], ntasks=nthreads())
    result = zeros(Float64, length(collection))
    @sync for chunk_indices in index_chunks(collection; n=ntasks)   # chunk up collection into ntasks-many chunks
        @spawn begin                                                # spawn a task for each chunk
            for i in chunk_indices                                  # for each element of a that belongs to this chunk/task
                result[i] = f(collection[i])                        # apply f
            end
            push!(tracker[threadid()], chunk_indices)               # keep track of which thread ran the task
        end
    end
    return result, tracker
end

In [None]:
xs = 1:2^7
f(x) = sum(abs2, rand() for _ in 1:(2^14*x)) # computational cost is increasing as a function of x (non-uniform)

In [None]:
using StatsPlots
using Base.Threads: nthreads

result, tracker = tmap_tracking(f, xs; ntasks=length(xs))   # create a task for each element of `a`
# result, tracker = tmap_tracking(f, xs; ntasks=8*nthreads()) # create 8*nthreads() tasks, each handling a chunk of `a`
# result, tracker = tmap_tracking(f, xs; ntasks=4*nthreads()) # create 4*nthreads() tasks, each handling a chunk of `a`
# result, tracker = tmap_tracking(f, xs; ntasks=2*nthreads()) # create 2*nthreads() tasks, each handling a chunk of `a`
# result, tracker = tmap_tracking(f, xs; ntasks=1)            # create a single task, handling all of `a`
# result, tracker = tmap_tracking(f, xs; ntasks=nthreads())   # create nthreads() tasks, each handling a chunk of `a`

# plotting
thread_workloads = zeros(Int, nthreads(), maximum(length, tracker))
for th in eachindex(tracker)
    for (i, ws) in enumerate(tracker[th])
        thread_workloads[th, i] = sum(ws)
    end
end
b = groupedbar(thread_workloads, xlab="threadid", ylab="workload", title="@spawn", legend=false, bar_position=:stack)
display(b)

#### Multithreading for-loops (revisited): `OhMyThreads.@tasks`

In [None]:
using OhMyThreads: @tasks

In [None]:
@tasks for i in 1:8
    println("Task ", taskid(), " is running iteration ", i, " on thread ", threadid())
end

**The iteration space is divided into `nthreads()` contiguous chunks**, then creates a task for each chunks. $\quad \Rightarrow \quad $ **no load balancing!**

You can tune the number of tasks to spawn (== chunking granularity) for `@tasks` with `@set ntasks = value`.

In [None]:
using OhMyThreads: @set

In [None]:
@tasks for i in 1:8
    @set ntasks = 1
    println("Task ", taskid(), " is running iteration ", i, " on thread ", threadid())
end

In [None]:
@tasks for i in 1:8
    @set ntasks = 8   # same as @sync for .... @spawn ... end
    println("Task ", taskid(), " is running iteration ", i, " on thread ", threadid())
end

(Note that you can't tune the number of tasks for `Threads.@threads`! 🙁)

## Opting out of dynamic scheduling

For "traditional HPC", where you tell each thread what to do, you might want to opt out of dynamic scheduling and task migration. 

**Advantages:**

* guaranteed task-thread mapping ("task pinning")
* lower overhead

**Disadvantages:**

* often less portable code (e.g. hardcoded assumptions about the system)
* no (or at least bad) nestability

### Spawning a sticky task on a specific thread

In [None]:
using OhMyThreads: @spawnat

In [None]:
@spawnat 4 println("Task ", taskid(), " is running on thread ", threadid(), ", and always will be 😉");

### Static scheduling

* **Statically** map tasks to threads, specifically: task 1 → thread 1, task 2 → thread 2, and so on.

For `@tasks` there is `@set scheduler = :static`.

In [None]:
@tasks for i in 1:2*nthreads()
    @set scheduler = :static
    println("Task ", taskid(), " is running iteration ", i, " on thread ", threadid());
end

For `scheduler = :static`, every thread handles precisely two iterations and always the same iterations!

In [None]:
@tasks for i in 1:2*nthreads()
    @set scheduler = :dynamic # :dynamic is the default
    println("Task ", taskid(), " is running iteration ", i, " on thread ", threadid());
end

## Beware of Multithreading: Parallel Summation

In [None]:
data = rand(1_000_000 * nthreads());

sum(data) # we want to parallelize this

### How you should parallelize it

The real answer is: There is no need to roll your own parallel summation (or your own `tmap` 😉). 

In [None]:
using OhMyThreads: treduce

treduce(+, data)

In [None]:
treduce(+, data) ≈ sum(data)

But let's assume we want to write a parallel version ourselves.

### Task-focused parallel version

Key questions for task-based parallelisation:
* How to divide the computation into seperate **tasks**?
    * Answer: chunk up the data and perform partial sums.
* How many **tasks** should we create?
    * Answer: since the workload is uniform, `nthreads()` many tasks is a reasonable choice.

In [None]:
function sum_map_spawn(data; ntasks=nthreads())
    ts = map(chunks(data, n=ntasks)) do chunk_elements
        @spawn sum(chunk_elements)
    end
    return sum(fetch.(ts))
end

* Conceptually simple and task-focused
  * → We're **explicitly** spawning one task per chunk.
  * → No mention of threads, except in `ntasks=nthreads()`.
* In the latter form, we don't even need a manual pre-allocation (it is hidden in the map operation).

In [None]:
sum_map_spawn(data) ≈ sum(data)

In [None]:
@btime sum_map_spawn($data);

### Mistake 1: Race condition

In [None]:
function sum_threads_naive(data)
    s = zero(eltype(data))
    @tasks for i in eachindex(data)
        s += data[i]
    end
    return s
end

In [None]:
@show sum(data);
@show sum_threads_naive(data);

**Wrong** result! Even worse, it's **non-deterministic** and different every time!

There is a [race condition](https://en.wikipedia.org/wiki/Race_condition) which typically appear when multiple tasks are modifying shared state simultaneously.

→ If possible, **don't modify shared (i.e. non task-local) state!**

### Mistake 2: Thread-focused rather than task-focused

You might be inclined to write something similar to the following (intentionally written in a slightly more verbose form):

In [None]:
function sum_threads_unsafe(data)
    psums = zeros(eltype(data), nthreads())
    @threads for i in eachindex(data)    # spawn nthreads many tasks
        current_sum = psums[threadid()]  # read
        new_sum = current_sum + data[i]  # "work"
        psums[threadid()] = new_sum      # write
    end
    return sum(psums)
end

Such an approach is generally **unsafe** because Julia's scheduler may **migrate tasks between threads**!
  * For example, a task might start on thread 1, is then paused (say, after "work") and migrated to thread 3, where it finishes execution.
  * → The output of `threadid()` might change within a task! To be safe, [don't use `threadid()`](https://julialang.org/blog/2023/07/PSA-dont-use-threadid/) at all!
  
It also goes against the idea of task-based multithreading, as we're **thinking about threads rather than tasks**.

(Note that, in spite of the comments above, the `threadid()` pattern will often still work correctly. This is because as of Julia 1.10 task migrations are very rare. **You can't rely on it though!**)

In [None]:
sum_threads_unsafe(data) ≈ sum(data)

### (Performance) Mistake 3: False sharing

In [None]:
function sum_threads_chunks(data; nchunks=nthreads())
    psums = zeros(eltype(data), nchunks)
    @tasks for (c, chunk_elements) in enumerate(chunks(data; n=nchunks)) # spawn nchunks many tasks
        @simd for x in chunk_elements
            psums[c] += x
        end
    end
    return sum(psums)
end

In [None]:
sum_threads_chunks(data) ≈ sum(data)

In [None]:
@btime sum($data);
@btime sum_threads_chunks($data);

Safe, but slow?! Why?

##### Performance issue: [False sharing](https://en.wikipedia.org/wiki/False_sharing)

Why does `sum_threads_chunks` above have bad performance? Although argubaly subtle, this is because different tasks mutate shared data (`psums`) in parallel. There is no *logical* sharing: Tasks access different slots of `psums` and there is no data race. However, CPU cores work on the basis of **cache lines** instead of single elements leading to *implicit* sharing of cache lines.

**Despite its subtlety, false sharing can lead to dramatic slowdown!**

In [None]:
using CpuId

In [None]:
cachelinesize() ÷ sizeof(Float64)

<img src="imgs/false_sharing.svg" width=850px>

Different tasks modify the same cache line
* need for synchronization to ensure cache coherency
* performance decreases (dramatically).

Once agin: **The less you modify shared (i.e. non task-local) state, the better!**

##### "Fixed" version

In [None]:
function sum_threads_chunks_local(data; nchunks=nthreads())
    psums = zeros(eltype(data), nchunks)
    @tasks for (c, chunk_elements) in enumerate(chunks(data; n=nchunks))  # spawn nchunks many tasks
        psums[c] = sum(chunk_elements)
    end
    return sum(psums)
end

* each task/iteration computes a local sum independently
* no *frequent* non-local mutation

In [None]:
sum(data) ≈ sum_threads_chunks_local(data)

In [None]:
@btime sum($data);
@btime sum_threads_chunks_local($data);

## Additional comments

### Synchronization/communication of tasks

Low-level to high-level (roughly):

* [atomic operations](https://docs.julialang.org/en/v1/base/multi-threading/#Atomic-operations)
* [locks](https://docs.julialang.org/en/v1/base/parallel/#Base.ReentrantLock)
* [channels](https://docs.julialang.org/en/v1/base/parallel/#Channels)

Generally, one should try to minimize synchronization/communication as much as possible as it can lead to serialization (or worse).

In [None]:
d = Dict()
lck = ReentrantLock()
@tasks for i in 1:1000
    @lock lck d[i] = i
end
d

### Garbage collection

If it gets triggered, it stops the world (all threads) for clearing up memory.

Hence, when using multithreading, it is even more important to **avoid heap allocations!**

(If you can't avoid allocations, consider using multiprocessing instead.)

### Pinning Julia threads to CPU threads/cores

A compute node has a complex topology (two sockets, multiple memory channels/domains). Placing the Julia threads systematically on CPU-threads matters for

* the computation performance of your Julia codes
* fluctuations/noises in benchmarks
* hardware-level performance monitoring

#### ThreadPinning.jl

`pinthreads(strategy)`
* `:cputhreads` pin to CPU threads (incl. "hypterthreads") one after another
* `:cores:` pin to CPU cores one after another
* `:numa:` round-robin between NUMA domains
* `:sockets:` round-robin between sockets
* `:affinitymask`: according to an external affinity mask (e.g. set by SLURM)

In [None]:
threadinfo()

In [None]:
pinthreads(:cores)
threadinfo()

We'll explore the effect of thread pinning on performance in more detail later → **daxpy_cpu exercise**