# Simple prefix sum parallel (basic threads) CPU

With barebone threads no libs.

In [1]:
# Pre-req: set thread number by using JULIA_NUM_THREADS variable
# And test if it is set as your desired number by running the following line
# If this prints 1, you probablly want to set environment variables correctly BEFORE running jupyter lab
Threads.nthreads()

4

In [8]:
using Base.Threads

function cpu_scan_parallel(x::Vector{T}) where T
    nthreads = Threads.nthreads()
    n = length(x)
    out = similar(x)

    # 1. Each thread computes a partial scan
    chunk = ceil(Int, n / nthreads)
    partial_sums = zeros(T, nthreads)

    @threads for tid in 1:nthreads
        start = (tid-1)*chunk + 1
        stop  = min(tid*chunk, n)
        println("Chunk $tid start: $start, stop: $stop")
        
        s = zero(T)
        for i in start:stop
            s += x[i]
            out[i] = s
        end
        partial_sums[tid] = s
    end
    println("partial_sums after step 1: ", partial_sums)
    println()

    # 2. Compute prefix sum across thread totals
    for i in 2:nthreads
        partial_sums[i] += partial_sums[i-1]
    end
    println("partial_sums after step 2: ", partial_sums)
    println()

    # 3. Add offsets
    @threads for tid in 2:nthreads
        # println("Thread $(Threads.threadid()) running chunk $tid")
        offset = partial_sums[tid-1]
        start = (tid-1)*chunk + 1
        stop  = min(tid*chunk, n)
        for i in start:stop
            out[i] += offset
        end
    end

    return out
end


cpu_scan_parallel (generic function with 1 method)

### Super small sanity test example

In [9]:
x = [1, 2, 3, 4, 5, 6, 7, 8]

scan_x = cpu_scan_parallel(x)

println("input:      ", x)
println("scan:       ", scan_x)
println("reference:  ", cumsum(x))

Chunk 4 start: 7, stop: 8
Chunk 2 start: 3, stop: 4
Chunk 1 start: 1, stop: 2
Chunk 3 start: 5, stop: 6
partial_sums after step 1: [3, 7, 11, 15]

partial_sums after step 2: [3, 10, 21, 36]

input:      [1, 2, 3, 4, 5, 6, 7, 8]
scan:       [1, 3, 6, 10, 15, 21, 28, 36]
reference:  [1, 3, 6, 10, 15, 21, 28, 36]


### Larger example

In [6]:
using BenchmarkTools

x = rand(Float64, 10_000)

scan_x = @btime cpu_scan_parallel($x);

display(x)
display(scan_x)

  7.067 μs (89 allocations: 84.70 KiB)


10000-element Vector{Float64}:
 0.36525572772639936
 0.46015952417310635
 0.3891924496910021
 0.8568845834671545
 0.5188883653310511
 0.027969361393221925
 0.2747274379048975
 0.4644162604843287
 0.5900924476815493
 0.3443894668475367
 0.5612536505562862
 0.9853531845370834
 0.8372917706306173
 ⋮
 0.25857009665966846
 0.8562080063732498
 0.3749706374094278
 0.3128926083126875
 0.7833848418581378
 0.17115599231065304
 0.017302777967360572
 0.010869342523478376
 0.20949536422480686
 0.43407110802832893
 0.6146397455003578
 0.11737874076281796

10000-element Vector{Float64}:
    0.36525572772639936
    0.8254152518995057
    1.214607701590508
    2.0714922850576625
    2.5903806503887137
    2.618350011781936
    2.8930774496868334
    3.357493710171162
    3.9475861578527116
    4.291975624700248
    4.853229275256534
    5.838582459793617
    6.675874230424235
    ⋮
 5018.68834206586
 5019.544550072233
 5019.919520709642
 5020.2324133179545
 5021.015798159813
 5021.186954152124
 5021.204256930091
 5021.215126272615
 5021.424621636839
 5021.858692744868
 5022.473332490368
 5022.590711231131