diff --git a/simd.go b/simd.go index 2002d51..1494c18 100644 --- a/simd.go +++ b/simd.go @@ -13,12 +13,13 @@ var ( sve = cpuid.CPU.Supports(cpuid.SVE) ) -type number interface { +// Number represents a number constraint for SIMD operations +type Number interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 } // Sum sums up all of the elements of the slice and returns the value -func Sum[T number](input []T) T { +func Sum[T Number](input []T) T { switch v := any(input).(type) { case []int8: return T(SumInt8s(v)) @@ -46,7 +47,7 @@ func Sum[T number](input []T) T { } // Sum sums up all of the elements of the slice and returns the value -func sum[T number](input []T) (sum T) { +func sum[T Number](input []T) (sum T) { for _, v := range input { sum += v } @@ -54,7 +55,7 @@ func sum[T number](input []T) (sum T) { } // Min returns the smallest element value in the slice -func Min[T number](input []T) T { +func Min[T Number](input []T) T { switch v := any(input).(type) { case []int8: return T(MinInt8s(v)) @@ -82,7 +83,7 @@ func Min[T number](input []T) T { } // Min returns the smallest element value in the slice -func min[T number](input []T) T { +func min[T Number](input []T) T { min := input[0] for _, v := range input[1:] { if v < min { @@ -93,7 +94,7 @@ func min[T number](input []T) T { } // Max returns the largest element value in the slice -func Max[T number](input []T) T { +func Max[T Number](input []T) T { switch v := any(input).(type) { case []int8: return T(MaxInt8s(v)) @@ -121,7 +122,7 @@ func Max[T number](input []T) T { } // Max returns the largest element value in the slice -func max[T number](input []T) T { +func max[T Number](input []T) T { max := input[0] for _, v := range input[1:] { if v > max { @@ -132,7 +133,7 @@ func max[T number](input []T) T { } // Add adds input1 to input2 and writes back the result into dst slice -func add[T number](dst, input1, input2 []T) []T { +func add[T Number](dst, input1, input2 []T) []T { for i, v := range input1 { dst[i] = v + input2[i] } @@ -140,7 +141,7 @@ func add[T number](dst, input1, input2 []T) []T { } // Sub subtracts input2 from input1 and writes back the result into dst slice -func sub[T number](dst, input1, input2 []T) []T { +func sub[T Number](dst, input1, input2 []T) []T { for i, v := range input1 { dst[i] = v - input2[i] } @@ -148,7 +149,7 @@ func sub[T number](dst, input1, input2 []T) []T { } // Mul multiplies input1 by input2 and writes back the result into dst slice -func mul[T number](dst, input1, input2 []T) []T { +func mul[T Number](dst, input1, input2 []T) []T { for i, v := range input1 { dst[i] = v * input2[i] } @@ -156,7 +157,7 @@ func mul[T number](dst, input1, input2 []T) []T { } // Div divides input1 by input2 and writes back the result into dst slice -func div[T number](dst, input1, input2 []T) []T { +func div[T Number](dst, input1, input2 []T) []T { for i, v := range input1 { dst[i] = v / input2[i] } diff --git a/simd_test.go b/simd_test.go index faf14e5..d0ae144 100644 --- a/simd_test.go +++ b/simd_test.go @@ -21,7 +21,7 @@ type Result struct { } // makeVector generates a test vector -func makeVector[T number](count int) []T { +func makeVector[T Number](count int) []T { arr := make([]T, count) for i := 0; i < count; i++ { arr[i] = T((i % 100) + 1)