diff --git a/std/functional.d b/std/functional.d index f63b79d209c..188eac95418 100644 --- a/std/functional.d +++ b/std/functional.d @@ -22,6 +22,7 @@ module std.functional; import std.traits, std.typetuple; + private template needOpCallAlias(alias fun) { /* Determine whether or not unaryFun and binaryFun need to alias to fun or @@ -62,8 +63,11 @@ template unaryFun(alias fun, string parmName = "a") { static if (is(typeof(fun) : string)) { - import std.traits, std.typecons, std.typetuple; - import std.algorithm, std.conv, std.exception, std.math, std.range, std.string; + static if (!fun._ctfeMatchUnary(parmName)) + { + import std.traits, std.typecons, std.typetuple; + import std.algorithm, std.conv, std.exception, std.math, std.range, std.string; + } auto unaryFun(ElementType)(auto ref ElementType __a) { mixin("alias " ~ parmName ~ " = __a ;"); @@ -151,8 +155,11 @@ template binaryFun(alias fun, string parm1Name = "a", { static if (is(typeof(fun) : string)) { - import std.traits, std.typecons, std.typetuple; - import std.algorithm, std.conv, std.exception, std.math, std.range, std.string; + static if (!fun._ctfeMatchBinary(parm1Name, parm2Name)) + { + import std.traits, std.typecons, std.typetuple; + import std.algorithm, std.conv, std.exception, std.math, std.range, std.string; + } auto binaryFun(ElementType1, ElementType2) (auto ref ElementType1 __a, auto ref ElementType2 __b) { @@ -215,8 +222,170 @@ unittest static assert(!is(typeof(binaryFun!FuncObj))); } -private template safeOp(string S) - if (is(typeof(mixin("0 "~S~" 0")) == bool)) +// skip all ASCII chars except a..z, A..Z, 0..9, '_' and '.'. +private uint _ctfeSkipOp(ref string op) +{ + if (!__ctfe) assert(false); + import std.ascii : isASCII, isAlphaNum; + immutable oldLength = op.length; + while (op.length) + { + immutable front = op[0]; + if(front.isASCII && !(front.isAlphaNum || front == '_' || front == '.')) + op = op[1..$]; + else + break; + } + return oldLength != op.length; +} + +// skip all digits +private uint _ctfeSkipInteger(ref string op) +{ + if (!__ctfe) assert(false); + import std.ascii : isDigit; + immutable oldLength = op.length; + while (op.length) + { + immutable front = op[0]; + if(front.isDigit) + op = op[1..$]; + else + break; + } + return oldLength != op.length; +} + +// skip name +private uint _ctfeSkipName(ref string op, string name) +{ + if (!__ctfe) assert(false); + if (op.length >= name.length && op[0..name.length] == name) + { + op = op[name.length..$]; + return 1; + } + return 0; +} + +// returns 1 if $(D fun) is trivial unary function +private uint _ctfeMatchUnary(string fun, string name) +{ + if (!__ctfe) assert(false); + import std.stdio; + fun._ctfeSkipOp; + for (;;) + { + immutable h = fun._ctfeSkipName(name) + fun._ctfeSkipInteger; + if (h == 0) + { + fun._ctfeSkipOp; + break; + } + else if (h == 1) + { + if(!fun._ctfeSkipOp) + break; + } + else + return 0; + } + return fun.length == 0; +} + +unittest +{ + static assert(!_ctfeMatchUnary("sqrt(ё)", "ё")); + static assert(!_ctfeMatchUnary("ё.sqrt", "ё")); + static assert(!_ctfeMatchUnary(".ё+ё", "ё")); + static assert(!_ctfeMatchUnary("_ё+ё", "ё")); + static assert(!_ctfeMatchUnary("ёё", "ё")); + static assert(_ctfeMatchUnary("a+a", "a")); + static assert(_ctfeMatchUnary("a + 10", "a")); + static assert(_ctfeMatchUnary("4 == a", "a")); + static assert(_ctfeMatchUnary("2==a", "a")); + static assert(_ctfeMatchUnary("1 != a", "a")); + static assert(_ctfeMatchUnary("a!=4", "a")); + static assert(_ctfeMatchUnary("a< 1", "a")); + static assert(_ctfeMatchUnary("434 < a", "a")); + static assert(_ctfeMatchUnary("132 > a", "a")); + static assert(_ctfeMatchUnary("123 >a", "a")); + static assert(_ctfeMatchUnary("a>82", "a")); + static assert(_ctfeMatchUnary("ё>82", "ё")); + static assert(_ctfeMatchUnary("ё[ё(ё)]", "ё")); + static assert(_ctfeMatchUnary("ё[21]", "ё")); +} + +// returns 1 if $(D fun) is trivial binary function +private uint _ctfeMatchBinary(string fun, string name1, string name2) +{ + if (!__ctfe) assert(false); + fun._ctfeSkipOp; + for (;;) + { + immutable h = fun._ctfeSkipName(name1) + fun._ctfeSkipName(name2) + fun._ctfeSkipInteger; + if (h == 0) + { + fun._ctfeSkipOp; + break; + } + else if (h == 1) + { + if(!fun._ctfeSkipOp) + break; + } + else + return 0; + } + return fun.length == 0; +} + +unittest { + + static assert(!_ctfeMatchBinary("sqrt(ё)", "ё", "b")); + static assert(!_ctfeMatchBinary("ё.sqrt", "ё", "b")); + static assert(!_ctfeMatchBinary(".ё+ё", "ё", "b")); + static assert(!_ctfeMatchBinary("_ё+ё", "ё", "b")); + static assert(!_ctfeMatchBinary("ёё", "ё", "b")); + static assert(_ctfeMatchBinary("a+a", "a", "b")); + static assert(_ctfeMatchBinary("a + 10", "a", "b")); + static assert(_ctfeMatchBinary("4 == a", "a", "b")); + static assert(_ctfeMatchBinary("2==a", "a", "b")); + static assert(_ctfeMatchBinary("1 != a", "a", "b")); + static assert(_ctfeMatchBinary("a!=4", "a", "b")); + static assert(_ctfeMatchBinary("a< 1", "a", "b")); + static assert(_ctfeMatchBinary("434 < a", "a", "b")); + static assert(_ctfeMatchBinary("132 > a", "a", "b")); + static assert(_ctfeMatchBinary("123 >a", "a", "b")); + static assert(_ctfeMatchBinary("a>82", "a", "b")); + static assert(_ctfeMatchBinary("ё>82", "ё", "q")); + static assert(_ctfeMatchBinary("ё[ё(10)]", "ё", "q")); + static assert(_ctfeMatchBinary("ё[21]", "ё", "q")); + + static assert(!_ctfeMatchBinary("sqrt(ё)+b", "b", "ё")); + static assert(!_ctfeMatchBinary("ё.sqrt-b", "b", "ё")); + static assert(!_ctfeMatchBinary(".ё+b", "b", "ё")); + static assert(!_ctfeMatchBinary("_b+ё", "b", "ё")); + static assert(!_ctfeMatchBinary("ba", "b", "a")); + static assert(_ctfeMatchBinary("a+b", "b", "a")); + static assert(_ctfeMatchBinary("a + b", "b", "a")); + static assert(_ctfeMatchBinary("b == a", "b", "a")); + static assert(_ctfeMatchBinary("b==a", "b", "a")); + static assert(_ctfeMatchBinary("b != a", "b", "a")); + static assert(_ctfeMatchBinary("a!=b", "b", "a")); + static assert(_ctfeMatchBinary("a< b", "b", "a")); + static assert(_ctfeMatchBinary("b < a", "b", "a")); + static assert(_ctfeMatchBinary("b > a", "b", "a")); + static assert(_ctfeMatchBinary("b >a", "b", "a")); + static assert(_ctfeMatchBinary("a>b", "b", "a")); + static assert(_ctfeMatchBinary("ё>b", "b", "ё")); + static assert(_ctfeMatchBinary("b[ё(-1)]", "b", "ё")); + static assert(_ctfeMatchBinary("ё[-21]", "b", "ё")); +} + +//undocumented +template safeOp(string S) + if (S=="<"||S==">"||S=="<="||S==">="||S=="=="||S=="!=") { private bool unsafeOp(ElementType1, ElementType2)(ElementType1 a, ElementType2 b) pure if (isIntegral!ElementType1 && isIntegral!ElementType2) @@ -225,7 +394,7 @@ private template safeOp(string S) return mixin("cast(T)a "~S~" cast(T)b"); } - private bool safeOp(T0, T1)(T0 a, T1 b) pure + bool safeOp(T0, T1)(auto ref T0 a, auto ref T1 b) { static if (isIntegral!T0 && isIntegral!T1 && (mostNegative!T0 < 0) != (mostNegative!T1 < 0)) @@ -256,16 +425,27 @@ private template safeOp(string S) } } +unittest //check user defined types +{ + import std.algorithm : equal; + struct Foo + { + int a; + auto opEquals(Foo foo) + { + return a == foo.a; + } + } + assert(safeOp!"!="(Foo(1), Foo(2))); +} + /** Predicate that returns $(D_PARAM a < b). Correctly compares signed and unsigned integers, ie. -1 < 2U. */ -bool lessThan(T0, T1)(T0 a, T1 b) -{ - return safeOp!"<"(a, b); -} +alias lessThan = safeOp!"<"; -unittest +pure @safe @nogc nothrow unittest { assert(lessThan(2, 3)); assert(lessThan(2U, 3U)); @@ -283,10 +463,7 @@ unittest Predicate that returns $(D_PARAM a > b). Correctly compares signed and unsigned integers, ie. 2U > -1. */ -bool greaterThan(T0, T1)(T0 a, T1 b) -{ - return safeOp!">"(a, b); -} +alias greaterThan = safeOp!">"; unittest { @@ -306,10 +483,7 @@ unittest Predicate that returns $(D_PARAM a == b). Correctly compares signed and unsigned integers, ie. !(-1 == ~0U). */ -bool equalTo(T0, T1)(T0 a, T1 b) -{ - return safeOp!"=="(a, b); -} +alias equalTo = safeOp!"=="; unittest { @@ -361,7 +535,7 @@ unittest template binaryReverseArgs(alias pred) { auto binaryReverseArgs(ElementType1, ElementType2) - (ElementType1 a, ElementType2 b) + (auto ref ElementType1 a, auto ref ElementType2 b) { return pred(b, a); } @@ -384,8 +558,7 @@ Negates predicate $(D pred). */ template not(alias pred) { - auto not(T...)(T args) - if (is(typeof(!pred(args))) || is(typeof(!unaryFun!pred(args))) || is(typeof(!binaryFun!pred(args)))) + auto not(T...)(auto ref T args) { static if (is(typeof(!pred(args)))) return !pred(args);