From 0dd91169910202e41c524962b4bc00dc16c67c2b Mon Sep 17 00:00:00 2001 From: k-hara Date: Fri, 17 Jul 2015 14:24:26 +0900 Subject: [PATCH] fix Issue 14802 - Template argument deduction depends on order of arguments --- src/dcast.d | 23 ++++++++++++++++++++++ src/dtemplate.d | 26 ++++++++++++++++++++++--- test/runnable/template9.d | 41 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 3 deletions(-) diff --git a/src/dcast.d b/src/dcast.d index 356ce4a8d771..08e3d85d77b0 100644 --- a/src/dcast.d +++ b/src/dcast.d @@ -2431,6 +2431,29 @@ extern (C++) bool isVoidArrayLiteral(Expression e, Type other) return (e.op == TOKarrayliteral && t.ty == Tarray && t.nextOf().ty == Tvoid && (cast(ArrayLiteralExp)e).elements.dim == 0); } +// used by deduceType() +extern (C++) Type rawTypeMerge(Type t1, Type t2) +{ + Type t1b = t1.toBasetype(); + Type t2b = t2.toBasetype(); + + if (t1.equals(t2)) + { + return t1; + } + else if (t1b.equals(t2b)) + { + return t1b; + } + else + { + TY ty = cast(TY)impcnvResult[t1b.ty][t2b.ty]; + if (ty != Terror) + return Type.basic[ty]; + } + return null; +} + /************************************** * Combine types. * Output: diff --git a/src/dtemplate.d b/src/dtemplate.d index 06f2c226323b..c525b9c01512 100644 --- a/src/dtemplate.d +++ b/src/dtemplate.d @@ -4052,9 +4052,7 @@ extern (C++) MATCH deduceType(RootObject o, Scope* sc, Type tparam, TemplatePara at = xt.tded; } // From previous matched expressions to current deduced type - MATCH match1 = MATCHnomatch; - if (xt) - match1 = xt.matchAll(tt); + MATCH match1 = xt ? xt.matchAll(tt) : MATCHnomatch; // From current expresssion to previous deduced type Type pt = at.addMod(tparam.mod); if (*wm) @@ -4086,6 +4084,11 @@ extern (C++) MATCH deduceType(RootObject o, Scope* sc, Type tparam, TemplatePara } //printf("tt = %s, at = %s\n", tt->toChars(), at->toChars()); } + else + { + match1 = MATCHnomatch; + match2 = MATCHnomatch; + } } if (match1 > MATCHnomatch) { @@ -4105,6 +4108,23 @@ extern (C++) MATCH deduceType(RootObject o, Scope* sc, Type tparam, TemplatePara result = match2; return; } + + /* Deduce common type + */ + if (Type t = rawTypeMerge(at, tt)) + { + if (xt) + xt.update(t, e, tparam); + else + (*dedtypes)[i] = t; + + pt = tt.addMod(tparam.mod); + if (*wm) + pt = pt.substWildTo(*wm); + result = e.implicitConvTo(pt); + return; + } + result = MATCHnomatch; } diff --git a/test/runnable/template9.d b/test/runnable/template9.d index 63024c296e27..032e7d93adc1 100644 --- a/test/runnable/template9.d +++ b/test/runnable/template9.d @@ -4629,6 +4629,46 @@ class A14743 auto func2(T)() {} } +/******************************************/ +// 14802 + +void test14802() +{ + auto func(T)(T x, T y) { return x; } + + struct S1 { double x; alias x this; } + struct S2 { double x; alias x this; } + S1 s1; + S2 s2; + + enum E1 : double { a = 1.0 } + enum E2 : double { a = 1.0 } + + static assert(is(typeof( func(1 , 1 ) ) == int)); + static assert(is(typeof( func(1u, 1u) ) == uint)); + static assert(is(typeof( func(1u, 1 ) ) == uint)); + static assert(is(typeof( func(1 , 1u) ) == uint)); + + static assert(is(typeof( func(1.0f, 1.0f) ) == float)); + static assert(is(typeof( func(1.0 , 1.0 ) ) == double)); + static assert(is(typeof( func(1.0 , 1.0f) ) == double)); + static assert(is(typeof( func(1.0f, 1.0 ) ) == double)); + + static assert(is(typeof( func(s1, s1) ) == S1)); + static assert(is(typeof( func(s2, s2) ) == S2)); + static assert(is(typeof( func(s1, s2) ) == double)); + static assert(is(typeof( func(s2, s1) ) == double)); + + static assert(is(typeof( func(E1.a, E1.a) ) == E1)); + static assert(is(typeof( func(E2.a, E2.a) ) == E2)); + static assert(is(typeof( func(E1.a, 1.0) ) == double)); + static assert(is(typeof( func(E2.a, 1.0) ) == double)); + static assert(is(typeof( func(1.0, E1.a) ) == double)); + static assert(is(typeof( func(1.0, E2.a) ) == double)); + static assert(is(typeof( func(E1.a, E2.a) ) == double)); + static assert(is(typeof( func(E2.a, E1.a) ) == double)); +} + /******************************************/ int main() @@ -4741,6 +4781,7 @@ int main() test13694(); test14836(); test14735(); + test14802(); printf("Success\n"); return 0;