From 6a946fed61c9b662d25bbca8ca2698bc75cce8c3 Mon Sep 17 00:00:00 2001 From: jmh530 Date: Tue, 2 Jun 2020 11:20:11 -0400 Subject: [PATCH] Add overloads for sum (#264) * Add variadic overloads to sum * Add sum overloads * Remove sum variadic template * Remove prod/mean variadic templates --- source/mir/math/numeric.d | 26 ++++------------ source/mir/math/stat.d | 11 +++---- source/mir/math/sum.d | 64 ++++++++++++++++++++++++++++++++++++--- 3 files changed, 70 insertions(+), 31 deletions(-) diff --git a/source/mir/math/numeric.d b/source/mir/math/numeric.d index d4e3ec94..767eed0f 100644 --- a/source/mir/math/numeric.d +++ b/source/mir/math/numeric.d @@ -241,7 +241,7 @@ prodType!Range prod(Range)(Range r) if (isIterable!Range) { import core.lifetime: move; - + alias F = typeof(return); return .prod!(F, Range)(r.move); } @@ -264,32 +264,18 @@ prodType!Range prod(Range)(Range r, ref long exp) /++ Params: - val = values + ar = values Returns: - The prduct of all the elements in `val` + The prduct of all the elements in `ar` +/ -F prod(F)(scope const F[] val...) - if (isFloatingPoint!F) +prodType!T prod(T)(scope const T[] ar...) { + alias F = typeof(return); ProdAccumulator!F prod; - prod.put(val); + prod.put(ar); return prod.prod; } -/++ -Params: - val = values -Returns: - The prduct of all the elements in `val` -+/ -prodType!(CommonType!T) prod(T...)(T val) - if (T.length > 0 && - !is(CommonType!T == void)) -{ - alias F = typeof(return); - return .prod!(F)(val); -} - /// Product of arbitrary inputs version(mir_test) @safe pure @nogc nothrow diff --git a/source/mir/math/stat.d b/source/mir/math/stat.d index f85d9a9c..d3ed7ca5 100644 --- a/source/mir/math/stat.d +++ b/source/mir/math/stat.d @@ -137,8 +137,6 @@ template mean(F, Summation summation = Summation.appropriate) /// ditto template mean(Summation summation = Summation.appropriate) { - import std.traits: CommonType; - /++ Params: r = range, must be finite iterable @@ -151,13 +149,12 @@ template mean(Summation summation = Summation.appropriate) /++ Params: - val = values + ar = values +/ - @fmamath CommonType!T mean(T...)(T val) - if (T.length > 0 && - !is(CommonType!T == void)) + @fmamath sumType!T mean(T)(scope const T[] ar...) { - return .mean!(CommonType!T, summation)(val); + alias F = typeof(return); + return .mean!(F, summation)(ar); } } diff --git a/source/mir/math/sum.d b/source/mir/math/sum.d index 9e9050ff..7d3ade31 100644 --- a/source/mir/math/sum.d +++ b/source/mir/math/sum.d @@ -235,6 +235,17 @@ unittest assert(ma.avg == (1010 * 1009 / 2 - 10 * 9 / 2) / 1000.0); } +/// Arbitrary sum +version(mir_test) +@safe pure nothrow +unittest +{ + assert(sum(1, 2, 3, 4) == 10); + assert(sum!float(1, 2, 3, 4) == 10f); + assert(sum(1f, 2, 3, 4) == 10f); + assert(sum(1.0 + 2i, 2 + 3i, 3 + 4i, 4 + 5i) == (10 + 14i)); +} + version(X86) version = X86_Any; version(X86_64) @@ -1701,6 +1712,21 @@ template sum(F, Summation summation = Summation.appropriate) } } } + + /// + F sum(scope const F[] r...) + { + static if (isComplex!F && summation == Summation.precise) + { + return sum(r, summationInitValue!F); + } + else + { + Summator!(F, ResolveSummationType!(summation, const(F)[], F)) sum; + sum.put(r); + return sum.sum; + } + } } ///ditto @@ -1722,6 +1748,13 @@ template sum(Summation summation = Summation.appropriate) import core.lifetime: move; return .sum!(F, ResolveSummationType!(summation, Range, F))(r.move, seed); } + + /// + sumType!T sum(T)(scope const T[] ar...) + { + alias F = typeof(return); + return .sum!(F, ResolveSummationType!(summation, F[], F))(ar); + } } ///ditto @@ -1840,6 +1873,24 @@ unittest } } +version(mir_test) +unittest +{ + assert(sum(1) == 1); + assert(sum(1, 2, 3) == 6); + assert(sum(1.0, 2.0, 3.0) == 6); + assert(sum(1.0 + 1i, 2.0 + 2i, 3.0 + 3i) == (6 + 6i)); +} + +version(mir_test) +unittest +{ + assert(sum!float(1) == 1f); + assert(sum!float(1, 2, 3) == 6f); + assert(sum!float(1.0, 2.0, 3.0) == 6f); + assert(sum!cfloat(1.0 + 1i, 2.0 + 2i, 3.0 + 3i) == (6f + 6i)); +} + version(LDC) version(X86_Any) version(mir_test) @@ -1935,10 +1986,15 @@ private T summationInitValue(T)() package template sumType(Range) { import mir.ndslice.slice: isSlice, DeepElementType; - static if (isSlice!Range) - alias T = Unqual!(DeepElementType!(Range.This)); - else - alias T = Unqual!(ForeachType!Range); + + static if (isIterable!Range) { + static if (isSlice!Range) + alias T = Unqual!(DeepElementType!(Range.This)); + else + alias T = Unqual!(ForeachType!Range); + } else { + alias T = Unqual!Range; + } static if (__traits(compiles, { auto a = T.init + T.init; a += T.init;