From 0214cd1c5dd5090af8152481f599a346cb2cd1db Mon Sep 17 00:00:00 2001 From: Alec Theriault Date: Mon, 29 Mar 2021 05:30:36 -0700 Subject: [PATCH] Emit efficient code for switch over strings The pattern matcher will now emit `Match` with `String` scrutinee as well as the existing `Int` scrutinee. The JVM backend handles this case by emitting bytecode that switches on the String's `hashCode` (this matches what Java does). The SJS already handles `String` matches. The approach is similar to scala/scala#8451 (see scala/bug#11740 too), except that instead of doing a transformation on the AST, we just emit the right bytecode straight away. This is desirable since it means that Scala.js (and any other backend) can choose their own optimised strategy for compiling a match on strings. --- .../tools/backend/jvm/BCodeBodyBuilder.scala | 202 ++++++++++++++---- .../tools/dotc/transform/PatternMatcher.scala | 41 ++-- tests/run/string-switch-defaults-null.check | 2 + tests/run/string-switch-defaults-null.scala | 16 ++ tests/run/string-switch.check | 29 +++ tests/run/string-switch.scala | 69 ++++++ 6 files changed, 295 insertions(+), 64 deletions(-) create mode 100644 tests/run/string-switch-defaults-null.check create mode 100644 tests/run/string-switch-defaults-null.scala create mode 100644 tests/run/string-switch.check create mode 100644 tests/run/string-switch.scala diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala index 156551519cb9..bd4bdd0ad0eb 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala @@ -3,6 +3,7 @@ package backend package jvm import scala.annotation.switch +import scala.collection.mutable.SortedMap import scala.tools.asm import scala.tools.asm.{Handle, Label, Opcodes} @@ -826,61 +827,170 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { generatedType } - /* - * A Match node contains one or more case clauses, - * each case clause lists one or more Int values to use as keys, and a code block. - * Except the "default" case clause which (if it exists) doesn't list any Int key. - * - * On a first pass over the case clauses, we flatten the keys and their targets (the latter represented with asm.Labels). - * That representation allows JCodeMethodV to emit a lookupswitch or a tableswitch. - * - * On a second pass, we emit the switch blocks, one for each different target. + /* A Match node contains one or more case clauses, each case clause lists one or more + * Int/String values to use as keys, and a code block. The exception is the "default" case + * clause which doesn't list any key (there is exactly one of these per match). */ private def genMatch(tree: Match): BType = tree match { case Match(selector, cases) => lineNumber(tree) - genLoad(selector, INT) val generatedType = tpeTK(tree) + val postMatch = new asm.Label - var flatKeys: List[Int] = Nil - var targets: List[asm.Label] = Nil - var default: asm.Label = null - var switchBlocks: List[(asm.Label, Tree)] = Nil - - // collect switch blocks and their keys, but don't emit yet any switch-block. - for (caze @ CaseDef(pat, guard, body) <- cases) { - assert(guard == tpd.EmptyTree, guard) - val switchBlockPoint = new asm.Label - switchBlocks ::= (switchBlockPoint, body) - pat match { - case Literal(value) => - flatKeys ::= value.intValue - targets ::= switchBlockPoint - case Ident(nme.WILDCARD) => - assert(default == null, s"multiple default targets in a Match node, at ${tree.span}") - default = switchBlockPoint - case Alternative(alts) => - alts foreach { - case Literal(value) => - flatKeys ::= value.intValue - targets ::= switchBlockPoint - case _ => - abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}") - } - case _ => - abort(s"Invalid pattern in Match node: $tree at: ${tree.span}") + // Only two possible selector types exist in `Match` trees at this point: Int and String + if (tpeTK(selector) == INT) { + + /* On a first pass over the case clauses, we flatten the keys and their + * targets (the latter represented with asm.Labels). That representation + * allows JCodeMethodV to emit a lookupswitch or a tableswitch. + * + * On a second pass, we emit the switch blocks, one for each different target. + */ + + var flatKeys: List[Int] = Nil + var targets: List[asm.Label] = Nil + var default: asm.Label = null + var switchBlocks: List[(asm.Label, Tree)] = Nil + + genLoad(selector, INT) + + // collect switch blocks and their keys, but don't emit yet any switch-block. + for (caze @ CaseDef(pat, guard, body) <- cases) { + assert(guard == tpd.EmptyTree, guard) + val switchBlockPoint = new asm.Label + switchBlocks ::= (switchBlockPoint, body) + pat match { + case Literal(value) => + flatKeys ::= value.intValue + targets ::= switchBlockPoint + case Ident(nme.WILDCARD) => + assert(default == null, s"multiple default targets in a Match node, at ${tree.span}") + default = switchBlockPoint + case Alternative(alts) => + alts foreach { + case Literal(value) => + flatKeys ::= value.intValue + targets ::= switchBlockPoint + case _ => + abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}") + } + case _ => + abort(s"Invalid pattern in Match node: $tree at: ${tree.span}") + } } - } - bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY) + bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY) - // emit switch-blocks. - val postMatch = new asm.Label - for (sb <- switchBlocks.reverse) { - val (caseLabel, caseBody) = sb - markProgramPoint(caseLabel) - genLoad(caseBody, generatedType) - bc goTo postMatch + // emit switch-blocks. + for (sb <- switchBlocks.reverse) { + val (caseLabel, caseBody) = sb + markProgramPoint(caseLabel) + genLoad(caseBody, generatedType) + bc goTo postMatch + } + } else { + + /* Since the JVM doesn't have a way to switch on a string, we switch + * on the `hashCode` of the string then do an `equals` check (with a + * possible second set of jumps if blocks can be reach from multiple + * string alternatives). + * + * This mirrors the way that Java compiles `switch` on Strings. + */ + + var default: asm.Label = null + var indirectBlocks: List[(asm.Label, Tree)] = Nil + + import scala.collection.mutable + + // Cases grouped by their hashCode + val casesByHash = SortedMap.empty[Int, List[(String, Either[asm.Label, Tree])]] + var caseFallback: Tree = null + + for (caze @ CaseDef(pat, guard, body) <- cases) { + assert(guard == tpd.EmptyTree, guard) + pat match { + case Literal(value) => + val strValue = value.stringValue + casesByHash.updateWith(strValue.##) { existingCasesOpt => + val newCase = (strValue, Right(body)) + Some(newCase :: existingCasesOpt.getOrElse(Nil)) + } + case Ident(nme.WILDCARD) => + assert(default == null, s"multiple default targets in a Match node, at ${tree.span}") + default = new asm.Label + indirectBlocks ::= (default, body) + case Alternative(alts) => + // We need an extra basic block since multiple strings can lead to this code + val indirectCaseGroupLabel = new asm.Label + indirectBlocks ::= (indirectCaseGroupLabel, body) + alts foreach { + case Literal(value) => + val strValue = value.stringValue + casesByHash.updateWith(strValue.##) { existingCasesOpt => + val newCase = (strValue, Left(indirectCaseGroupLabel)) + Some(newCase :: existingCasesOpt.getOrElse(Nil)) + } + case _ => + abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}") + } + + case _ => + abort(s"Invalid pattern in Match node: $tree at: ${tree.span}") + } + } + + // Organize the hashCode options into switch cases + var flatKeys: List[Int] = Nil + var targets: List[asm.Label] = Nil + var hashBlocks: List[(asm.Label, List[(String, Either[asm.Label, Tree])])] = Nil + for ((hashValue, hashCases) <- casesByHash) { + val switchBlockPoint = new asm.Label + hashBlocks ::= (switchBlockPoint, hashCases) + flatKeys ::= hashValue + targets ::= switchBlockPoint + } + + // Push the hashCode of the string (or `0` it is `null`) onto the stack and switch on it + genLoadIf( + If( + tree.selector.select(defn.Any_==).appliedTo(nullLiteral), + Literal(Constant(0)), + tree.selector.select(defn.Any_hashCode).appliedToNone + ), + INT + ) + bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY) + + // emit blocks for each hash case + for ((hashLabel, caseAlternatives) <- hashBlocks.reverse) { + markProgramPoint(hashLabel) + for ((caseString, indirectLblOrBody) <- caseAlternatives) { + val comparison = if (caseString == null) defn.Any_== else defn.Any_equals + val condp = Literal(Constant(caseString)).select(defn.Any_==).appliedTo(tree.selector) + val keepGoing = new asm.Label + indirectLblOrBody match { + case Left(jump) => + genCond(condp, jump, keepGoing, targetIfNoJump = keepGoing) + + case Right(caseBody) => + val thisCaseMatches = new asm.Label + genCond(condp, thisCaseMatches, keepGoing, targetIfNoJump = thisCaseMatches) + markProgramPoint(thisCaseMatches) + genLoad(caseBody, generatedType) + bc goTo postMatch + } + markProgramPoint(keepGoing) + } + bc goTo default + } + + // emit blocks for common patterns + for ((caseLabel, caseBody) <- indirectBlocks.reverse) { + markProgramPoint(caseLabel) + genLoad(caseBody, generatedType) + bc goTo postMatch + } } markProgramPoint(postMatch) diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 047b2587299a..0463bfbb22e2 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -20,7 +20,7 @@ import util.Property._ /** The pattern matching transform. * After this phase, the only Match nodes remaining in the code are simple switches - * where every pattern is an integer constant + * where every pattern is an integer or string constant */ class PatternMatcher extends MiniPhase { import ast.tpd._ @@ -768,13 +768,15 @@ object PatternMatcher { (tpe isRef defn.IntClass) || (tpe isRef defn.ByteClass) || (tpe isRef defn.ShortClass) || - (tpe isRef defn.CharClass) + (tpe isRef defn.CharClass) || + (tpe isRef defn.StringClass) - val seen = mutable.Set[Int]() + val seen = mutable.Set[Any]() - def isNewIntConst(tree: Tree) = tree match { - case Literal(const) if const.isIntRange && !seen.contains(const.intValue) => - seen += const.intValue + def isNewSwitchableConst(tree: Tree) = tree match { + case Literal(const) + if (const.isIntRange || const.tag == Constants.StringTag) && !seen.contains(const.value) => + seen += const.value true case _ => false @@ -789,7 +791,7 @@ object PatternMatcher { val alts = List.newBuilder[Tree] def rec(innerPlan: Plan): Boolean = innerPlan match { case SeqPlan(TestPlan(EqualTest(tree), scrut, _, ReturnPlan(`innerLabel`)), tail) - if scrut === scrutinee && isNewIntConst(tree) => + if scrut === scrutinee && isNewSwitchableConst(tree) => alts += tree rec(tail) case ReturnPlan(`outerLabel`) => @@ -809,7 +811,7 @@ object PatternMatcher { def recur(plan: Plan): List[(List[Tree], Plan)] = plan match { case SeqPlan(testPlan @ TestPlan(EqualTest(tree), scrut, _, ons), tail) - if scrut === scrutinee && !canFallThrough(ons) && isNewIntConst(tree) => + if scrut === scrutinee && !canFallThrough(ons) && isNewSwitchableConst(tree) => (tree :: Nil, ons) :: recur(tail) case SeqPlan(AlternativesPlan(alts, ons), tail) => (alts, ons) :: recur(tail) @@ -832,29 +834,32 @@ object PatternMatcher { /** Emit a switch-match */ private def emitSwitchMatch(scrutinee: Tree, cases: List[(List[Tree], Plan)]): Match = { - /* Make sure to adapt the scrutinee to Int, as well as all the alternatives - * of all cases, so that only Matches on pritimive Ints survive this phase. + /* Make sure to adapt the scrutinee to Int or String, as well as all the + * alternatives, so that only Matches on pritimive Ints or Strings survive + * this phase. */ - val intScrutinee = - if (scrutinee.tpe.widen.isRef(defn.IntClass)) scrutinee - else scrutinee.select(nme.toInt) + val (primScrutinee, scrutineeTpe) = + if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType) + else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType) + else (scrutinee.select(nme.toInt), defn.IntType) - def intLiteral(lit: Tree): Tree = + def primLiteral(lit: Tree): Tree = val Literal(constant) = lit if (constant.tag == Constants.IntTag) lit + else if (constant.tag == Constants.StringTag) lit else cpy.Literal(lit)(Constant(constant.intValue)) val caseDefs = cases.map { (alts, ons) => val pat = alts match { - case alt :: Nil => intLiteral(alt) - case Nil => Underscore(defn.IntType) // default case - case _ => Alternative(alts.map(intLiteral)) + case alt :: Nil => primLiteral(alt) + case Nil => Underscore(scrutineeTpe) // default case + case _ => Alternative(alts.map(primLiteral)) } CaseDef(pat, EmptyTree, emit(ons)) } - Match(intScrutinee, caseDefs) + Match(primScrutinee, caseDefs) } /** If selfCheck is `true`, used to check whether a tree gets generated twice */ diff --git a/tests/run/string-switch-defaults-null.check b/tests/run/string-switch-defaults-null.check new file mode 100644 index 000000000000..4bbcfcf56827 --- /dev/null +++ b/tests/run/string-switch-defaults-null.check @@ -0,0 +1,2 @@ +2 +-1 diff --git a/tests/run/string-switch-defaults-null.scala b/tests/run/string-switch-defaults-null.scala new file mode 100644 index 000000000000..9fc4ce235a2d --- /dev/null +++ b/tests/run/string-switch-defaults-null.scala @@ -0,0 +1,16 @@ +import annotation.switch + +object Test { + def test(s: String): Int = { + (s : @switch) match { + case "1" => 0 + case null => -1 + case _ => s.toInt + } + } + + def main(args: Array[String]): Unit = { + println(test("2")) + println(test(null)) + } +} diff --git a/tests/run/string-switch.check b/tests/run/string-switch.check new file mode 100644 index 000000000000..7ab6b33ec0ae --- /dev/null +++ b/tests/run/string-switch.check @@ -0,0 +1,29 @@ +fido Success(dog) +garfield Success(cat) +wanda Success(fish) +henry Success(horse) +felix Failure(scala.MatchError: felix (of class java.lang.String)) +deuteronomy Success(cat) +===== +AaAa 2031744 Success(1) +BBBB 2031744 Success(2) +BBAa 2031744 Failure(scala.MatchError: BBAa (of class java.lang.String)) +cCCc 3015872 Success(3) +ddDd 3077408 Success(4) +EEee 2125120 Failure(scala.MatchError: EEee (of class java.lang.String)) +===== +A Success(()) +X Failure(scala.MatchError: X (of class java.lang.String)) +===== + Success(3) +null Success(2) +7 Failure(scala.MatchError: 7 (of class java.lang.String)) +===== +pig Success(1) +dog Success(2) +===== +Ea 2236 Success(1) +FB 2236 Success(2) +cC 3136 Success(3) +xx 3840 Success(4) +null 0 Success(4) diff --git a/tests/run/string-switch.scala b/tests/run/string-switch.scala new file mode 100644 index 000000000000..6a1522b416d9 --- /dev/null +++ b/tests/run/string-switch.scala @@ -0,0 +1,69 @@ +// scalac: -Werror +import annotation.switch +import util.Try + +object Test extends App { + + def species(name: String) = (name.toLowerCase : @switch) match { + case "fido" => "dog" + case "garfield" | "deuteronomy" => "cat" + case "wanda" => "fish" + case "henry" => "horse" + } + List("fido", "garfield", "wanda", "henry", "felix", "deuteronomy").foreach { n => println(s"$n ${Try(species(n))}") } + + println("=====") + + def collide(in: String) = (in : @switch) match { + case "AaAa" => 1 + case "BBBB" => 2 + case "cCCc" => 3 + case x if x == "ddDd" => 4 + } + List("AaAa", "BBBB", "BBAa", "cCCc", "ddDd", "EEee").foreach { s => + println(s"$s ${s.##} ${Try(collide(s))}") + } + + println("=====") + + def unitary(in: String) = (in : @switch) match { + case "A" => + case x => throw new MatchError(x) + } + List("A","X").foreach { s => + println(s"$s ${Try(unitary(s))}") + } + + println("=====") + + def nullFun(in: String) = (in : @switch) match { + case "1" => 1 + case null => 2 + case "" => 3 + } + List("", null, "7").foreach { s => + println(s"$s ${Try(nullFun(s))}") + } + + println("=====") + + def default(in: String) = (in : @switch) match { + case "pig" => 1 + case _ => 2 + } + List("pig","dog").foreach { s => + println(s"$s ${Try(default(s))}") + } + + println("=====") + + def onceOnly(in: Iterator[String]) = (in.next() : @switch) match { + case "Ea" => 1 + case "FB" => 2 //collision with above + case "cC" => 3 + case _ => 4 + } + List("Ea", "FB", "cC", "xx", null).foreach { s => + println(s"$s ${s.##} ${Try(onceOnly(Iterator(s)))}") + } +}