diff --git a/benchmarks/FsMath.Benchmarks/Vector.fs b/benchmarks/FsMath.Benchmarks/Vector.fs index b5528cc..6bc5cf6 100644 --- a/benchmarks/FsMath.Benchmarks/Vector.fs +++ b/benchmarks/FsMath.Benchmarks/Vector.fs @@ -46,3 +46,23 @@ type VectorBenchmarks() = let result = Vector.norm vector1 GC.KeepAlive(result) // Prevents the result from being optimized away + [] + member _.Sum() = + let result = Vector.sum vector1 + GC.KeepAlive(result) // Prevents the result from being optimized away + + [] + member _.Product() = + let result = Vector.product vector1 + GC.KeepAlive(result) // Prevents the result from being optimized away + + [] + member _.Min() = + let result = Vector.min vector1 + GC.KeepAlive(result) // Prevents the result from being optimized away + + [] + member _.Max() = + let result = Vector.max vector1 + GC.KeepAlive(result) // Prevents the result from being optimized away + diff --git a/src/FsMath/SpanMath.fs b/src/FsMath/SpanMath.fs index 7285b48..05801cb 100644 --- a/src/FsMath/SpanMath.fs +++ b/src/FsMath/SpanMath.fs @@ -253,10 +253,36 @@ type SpanMath = static member inline sum<'T when 'T :> Numerics.INumber<'T> and 'T : (new: unit -> 'T) and 'T : struct - and 'T :> ValueType> + and 'T :> ValueType> (v:ReadOnlySpan<'T>) : 'T = - let zero = LanguagePrimitives.GenericZero<'T> - SpanINumberPrimitives.fold ( (+) , (+) , v , zero ) + if v.Length = 0 then + LanguagePrimitives.GenericZero<'T> + elif Numerics.Vector.IsHardwareAccelerated && v.Length >= Numerics.Vector<'T>.Count then + let simdWidth = Numerics.Vector<'T>.Count + let simdCount = v.Length / simdWidth + let ceiling = simdWidth * simdCount + + // SIMD accumulation + let mutable accVec = Numerics.Vector<'T>.Zero + + for i = 0 to simdCount - 1 do + let srcIndex = i * simdWidth + let vec = Numerics.Vector<'T>(v.Slice(srcIndex, simdWidth)) + accVec <- accVec + vec + + // Horizontal reduction using Vector.Sum for optimized performance + let mutable acc = Numerics.Vector.Sum(accVec) + + // Tail + for i = ceiling to v.Length - 1 do + acc <- acc + v.[i] + + acc + else + let mutable acc = LanguagePrimitives.GenericZero<'T> + for i = 0 to v.Length - 1 do + acc <- acc + v.[i] + acc /// Computes the product of all elements in the vector.