In [1]:
from autotune import autotune_fork, search
from math import min, max
from memory.unsafe import DTypePointer, Pointer
from time import time_function
from benchmark import keep
from memory import memset as stdlib_memset

alias type = UInt8
alias ptr_type = DTypePointer[DType.uint8]

alias fn_type = fn(ptr: ptr_type, value: type, count: Int, /) -> None

In [2]:
fn measure_time(func: fn_type, size: Int, iters: Int, samples: Int) -> Int:
    alias alloc_size = 1024 * 1024
    let ptr = ptr_type.alloc(alloc_size)

    var best = -1
    for sample in range(samples):

        @parameter
        fn runner():
            for iter in range(iters):
                # Offset pointer to shake up cache a bit
                let offset_ptr = ptr.offset((iter * 128) & 1024)

                # memset, change the value we're filling with
                let v = type(iter&255)

                # Actually call the memset function
                func(offset_ptr, v.value, size)

                # Avoid compiler optimizing things away
                keep(v)
                keep(size)
                keep(offset_ptr)

        let ns = time_function[runner]()
        if best < 0 or ns < best:
            best = ns

    ptr.free()
    return best

alias MULT = 2_000

fn visualize_result(size: Int, result: Int):
    print_no_newline("Size: ")
    if size < 10:
        print_no_newline(" ")
    print_no_newline(size, "  |")
    for _ in range(result // MULT):
        print_no_newline("*")
    print()


fn benchmark(func: fn_type, title: StringRef):
    print("\n===========================")
    print(title)
    print("---------------------------\n")

    alias benchmark_iterations = 30 * MULT
    alias warmup_samples = 10
    alias benchmark_samples = 1000

    # Warmup
    for size in range(35):
        _ = measure_time(
            func, size, benchmark_iterations, warmup_samples
        )

    # Actual run
    for size in range(35):
        let result = measure_time(
            func, size, benchmark_iterations, benchmark_samples
        )

        visualize_result(size, result)

In [3]:
@always_inline
fn overlapped_store[
    width: Int
](ptr: ptr_type, value: type, count: Int):
    let v = SIMD[DType.uint8, width].splat(value)
    ptr.simd_store[width](v)
    ptr.simd_store[width](count - width, v)


fn memset_manual(ptr: ptr_type, value: type, count: Int):
    if count < 32:
        if count < 5:
            if count == 0:
                return
            # 0 < count <=4
            ptr.store(0, value)
            ptr.store(count - 1, value)
            if count <= 2:
                return
            ptr.store(1, value)
            ptr.store(count - 2, value)
            return

        if count <= 16:
            if count >=8:
                # 8 <= count < 16
                overlapped_store[8](ptr, value, count)
                return
            # 4 < count < 8
            overlapped_store[4](ptr, value, count)
            return

        # 16 <= count < 32
        overlapped_store[16](ptr, value, count)
    else:
        # 32 < count
        memset_system(ptr, value, count)

fn memset_system(ptr: ptr_type, value: type, count: Int):
    stdlib_memset(ptr, value.value, count)

In [4]:
benchmark(memset_manual, "Manual memset")
benchmark(memset_system, "System memset")


Manual memset
---------------------------

Size:  0   |*******************************************************************
Size:  1   |***********************************************************************************************
Size:  2   |*******************************************************************************************
Size:  3   |************************************************************************************************************
Size:  4   |*****************************************************************************************************
Size:  5   |*********************************************************************************
Size:  6   |****************************************************************************
Size:  7   |*************************************************************************
Size:  8   |*************************************************************************
Size:  9   |**************************************************************

In [5]:
fn memset_manual_2(ptr: ptr_type, value: type, count: Int):
    if count < 32:
        if count >= 16:
            # 16 <= count < 32
            overlapped_store[16](ptr, value, count)
            return

        if count < 5:
            if count == 0:
                return
            # 0 < count <= 4
            ptr.store(0, value)
            ptr.store(count - 1, value)
            if count <= 2:
                return
            ptr.store(1, value)
            ptr.store(count - 2, value)
            return
            
        if count >= 8:
            # 8 <= count < 16
            overlapped_store[8](ptr, value, count)
            return
        # 4 < count < 8
        overlapped_store[4](ptr, value, count)
    
    else:
        # 32 < count
        memset_system(ptr, value, count)

In [6]:
benchmark(memset_manual_2, "Manual memset v2")
benchmark(memset_system, "Mojo system memset")


Manual memset v2
---------------------------

Size:  0   |**************************************************************
Size:  1   |************************************************************************************
Size:  2   |******************************************************************************
Size:  3   |*****************************************************************************************************
Size:  4   |*****************************************************************************************************
Size:  5   |*********************************************************************************
Size:  6   |******************************************************************************
Size:  7   |*************************************************************************
Size:  8   |***********************************************************************
Size:  9   |***********************************************************************
Size: 10   |***********

In [7]:
@adaptive
@always_inline
fn memset_impl_layer[
    lower: Int, upper: Int
](ptr: ptr_type, value: type, count: Int):
    @parameter
    if lower == -100 and upper == 0:
        pass
    elif lower == 0 and upper == 4:
        ptr.store(0, value)
        ptr.store(count - 1, value)
        if count <=2:
            return
        ptr.store(1, value)
        ptr.store(count - 2, value)
    elif lower == 4 and upper == 8:
        overlapped_store[4](ptr, value, count)
    elif lower == 8 and upper == 16:
        overlapped_store[8](ptr, value, count)
    elif lower == 16 and upper == 32:
        overlapped_store[16](ptr, value, count)
    elif lower == 32 and upper == 100:
        memset_system(ptr, value, count)
    else:
        constrained[False]()

In [9]:
@adaptive
@always_inline
fn memset_impl_layer[
    lower: Int, upper: Int
](ptr: ptr_type, value: type, count: Int):
    alias cur: Int
    autotune_fork[Int, 0, 4, 8, 16, 32 -> cur]()

    constrained[cur > lower]()
    constrained[cur < upper]()

    if count > cur:
        memset_impl_layer[max(cur, lower), upper](ptr, value, count)
    else:
        memset_impl_layer[lower, min(cur, upper)](ptr, value, count)

In [10]:
@adaptive
@always_inline
fn memset_impl_layer[
    lower: Int, upper: Int
](ptr: ptr_type, value: type, count: Int):
    alias cur: Int
    autotune_fork[Int, 0, 4, 8, 16, 32 -> cur]()

    constrained[cur > lower]()
    constrained[cur < upper]()

    if count <= cur:
        memset_impl_layer[lower, min(cur, upper)](ptr, value, count)
    else:
        memset_impl_layer[max(cur, lower), upper](ptr, value, count)


In [11]:
@adaptive
fn memset_autotune_impl(ptr: ptr_type, value: type, count: Int, /):
    memset_impl_layer[-100, 100](ptr, value, count)

In [12]:
fn memset_evaluator(funcs: Pointer[fn_type], size: Int) -> Int:
    # This size is picked at random, in real code we could use a real size
    # distribution
    let size_to_optimize_for = 17

    var best_idx: Int = -1
    var best_time: Int = -1

    alias eval_iterations = MULT
    alias eval_samples = 500

    # Find the function that's the fastest on the size we're optimizing for 
    for f_idx in range(size):
        let func = funcs.load(f_idx)
        let cur_time = measure_time(
            func, size_to_optimize_for, eval_iterations, eval_samples
        )
        if best_idx < 0:
            best_idx = f_idx
            best_time = cur_time
        if best_time < cur_time:
            best_idx = f_idx
            best_time = cur_time

    return best_idx

In [13]:
fn memset_autotune(ptr: ptr_type, value: type, count: Int):
    # Get the set of all candidates
    alias candidates = memset_autotune_impl.__adaptive_set

    # Use the evaluator to select the best candidate.
    alias best_impl: fn_type
    search[fn_type, VariadicList(candidates), memset_evaluator -> best_impl]()

    # Run the best candidate
    return best_impl(ptr, value, count)

In [14]:
benchmark(memset_manual, "Mojo manual memset")
benchmark(memset_manual_2, "Mojo manual memset 2")
benchmark(memset_system, "Mojo system memset")
benchmark(memset_autotune, "Mojo autotune memset")


Mojo manual memset
---------------------------

Size:  0   |**************************************************
Size:  1   |*************************************************************************
Size:  2   |*********************************************************************************
Size:  3   |************************************************************************************************************
Size:  4   |************************************************************************************************************
Size:  5   |************************************************************************************
Size:  6   |************************************************************************************
Size:  7   |*********************************************************************************
Size:  8   |*********************************************************************************
Size:  9   |************************************************************************