diff --git a/simd.go b/simd.go index 94a35d9..2002d51 100644 --- a/simd.go +++ b/simd.go @@ -14,9 +14,38 @@ var ( ) type number interface { - ~int8 | ~int16 | ~int32 | ~int64 | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 + ~int | ~int8 | ~int16 | ~int32 | ~int64 | uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 } +// Sum sums up all of the elements of the slice and returns the value +func Sum[T number](input []T) T { + switch v := any(input).(type) { + case []int8: + return T(SumInt8s(v)) + case []int16: + return T(SumInt16s(v)) + case []int32: + return T(SumInt32s(v)) + case []int64: + return T(SumInt64s(v)) + case []uint8: + return T(SumUint8s(v)) + case []uint16: + return T(SumUint16s(v)) + case []uint32: + return T(SumUint32s(v)) + case []uint64: + return T(SumUint64s(v)) + case []float32: + return T(SumFloat32s(v)) + case []float64: + return T(SumFloat64s(v)) + default: + return sum(input) + } +} + +// Sum sums up all of the elements of the slice and returns the value func sum[T number](input []T) (sum T) { for _, v := range input { sum += v @@ -24,16 +53,35 @@ func sum[T number](input []T) (sum T) { return } -func max[T number](input []T) T { - max := input[0] - for _, v := range input[1:] { - if v > max { - max = v - } +// Min returns the smallest element value in the slice +func Min[T number](input []T) T { + switch v := any(input).(type) { + case []int8: + return T(MinInt8s(v)) + case []int16: + return T(MinInt16s(v)) + case []int32: + return T(MinInt32s(v)) + case []int64: + return T(MinInt64s(v)) + case []uint8: + return T(MinUint8s(v)) + case []uint16: + return T(MinUint16s(v)) + case []uint32: + return T(MinUint32s(v)) + case []uint64: + return T(MinUint64s(v)) + case []float32: + return T(MinFloat32s(v)) + case []float64: + return T(MinFloat64s(v)) + default: + return min(input) } - return max } +// Min returns the smallest element value in the slice func min[T number](input []T) T { min := input[0] for _, v := range input[1:] { @@ -44,6 +92,46 @@ func min[T number](input []T) T { return min } +// Max returns the largest element value in the slice +func Max[T number](input []T) T { + switch v := any(input).(type) { + case []int8: + return T(MaxInt8s(v)) + case []int16: + return T(MaxInt16s(v)) + case []int32: + return T(MaxInt32s(v)) + case []int64: + return T(MaxInt64s(v)) + case []uint8: + return T(MaxUint8s(v)) + case []uint16: + return T(MaxUint16s(v)) + case []uint32: + return T(MaxUint32s(v)) + case []uint64: + return T(MaxUint64s(v)) + case []float32: + return T(MaxFloat32s(v)) + case []float64: + return T(MaxFloat64s(v)) + default: + return max(input) + } +} + +// Max returns the largest element value in the slice +func max[T number](input []T) T { + max := input[0] + for _, v := range input[1:] { + if v > max { + max = v + } + } + return max +} + +// Add adds input1 to input2 and writes back the result into dst slice func add[T number](dst, input1, input2 []T) []T { for i, v := range input1 { dst[i] = v + input2[i] @@ -51,6 +139,7 @@ func add[T number](dst, input1, input2 []T) []T { return dst } +// Sub subtracts input2 from input1 and writes back the result into dst slice func sub[T number](dst, input1, input2 []T) []T { for i, v := range input1 { dst[i] = v - input2[i] @@ -58,6 +147,7 @@ func sub[T number](dst, input1, input2 []T) []T { return dst } +// Mul multiplies input1 by input2 and writes back the result into dst slice func mul[T number](dst, input1, input2 []T) []T { for i, v := range input1 { dst[i] = v * input2[i] @@ -65,6 +155,7 @@ func mul[T number](dst, input1, input2 []T) []T { return dst } +// Div divides input1 by input2 and writes back the result into dst slice func div[T number](dst, input1, input2 []T) []T { for i, v := range input1 { dst[i] = v / input2[i] diff --git a/simd_test.go b/simd_test.go index de2f425..0065944 100644 --- a/simd_test.go +++ b/simd_test.go @@ -7,6 +7,8 @@ import ( "fmt" "testing" "time" + + "github.com/stretchr/testify/assert" ) // Result represents a result of a benchmark @@ -74,3 +76,45 @@ func setMode(mode string) { avx2 = false } } + +func TestSum(t *testing.T) { + assert.Equal(t, 3, int(Sum([]int8{1, 2}))) + assert.Equal(t, 3, int(Sum([]int16{1, 2}))) + assert.Equal(t, 3, int(Sum([]int32{1, 2}))) + assert.Equal(t, 3, int(Sum([]int64{1, 2}))) + assert.Equal(t, 3, int(Sum([]uint8{1, 2}))) + assert.Equal(t, 3, int(Sum([]uint16{1, 2}))) + assert.Equal(t, 3, int(Sum([]uint32{1, 2}))) + assert.Equal(t, 3, int(Sum([]uint64{1, 2}))) + assert.Equal(t, 3, int(Sum([]float32{1, 2}))) + assert.Equal(t, 3, int(Sum([]float64{1, 2}))) + assert.Equal(t, 3, int(Sum([]int{1, 2}))) +} + +func TestMin(t *testing.T) { + assert.Equal(t, 1, int(Min([]int8{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]int16{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]int32{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]int64{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]uint8{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]uint16{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]uint32{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]uint64{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]float32{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]float64{3, 1, 2}))) + assert.Equal(t, 1, int(Min([]int{3, 1, 2}))) +} + +func TestMax(t *testing.T) { + assert.Equal(t, 2, int(Max([]int8{1, 2}))) + assert.Equal(t, 2, int(Max([]int16{1, 2}))) + assert.Equal(t, 2, int(Max([]int32{1, 2}))) + assert.Equal(t, 2, int(Max([]int64{1, 2}))) + assert.Equal(t, 2, int(Max([]uint8{1, 2}))) + assert.Equal(t, 2, int(Max([]uint16{1, 2}))) + assert.Equal(t, 2, int(Max([]uint32{1, 2}))) + assert.Equal(t, 2, int(Max([]uint64{1, 2}))) + assert.Equal(t, 2, int(Max([]float32{1, 2}))) + assert.Equal(t, 2, int(Max([]float64{1, 2}))) + assert.Equal(t, 2, int(Max([]int{1, 2}))) +}