Skip to content

Commit

Permalink
Add actionable item to PatternMatchExhaustivity diagnostic (scala#18314)
Browse files Browse the repository at this point in the history
The purpose of this PR is to add an actionable item for non-exhaustive
pattern match diagnostic, so that people can auto insert missing cases.

Relates to
scalameta/metals-feature-requests#350
  • Loading branch information
dwijnand committed Aug 9, 2023
2 parents 5fc691a + 1a592d3 commit c5adafc
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 11 deletions.
35 changes: 33 additions & 2 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -844,10 +844,13 @@ extends Message(LossyWideningConstantConversionID):
|Write `.to$targetType` instead."""
def explain(using Context) = ""

class PatternMatchExhaustivity(uncoveredFn: => String, hasMore: Boolean)(using Context)
class PatternMatchExhaustivity(uncoveredCases: Seq[String], tree: untpd.Match)(using Context)
extends Message(PatternMatchExhaustivityID) {
def kind = MessageKind.PatternMatchExhaustivity
lazy val uncovered = uncoveredFn

private val hasMore = uncoveredCases.lengthCompare(6) > 0
val uncovered = uncoveredCases.take(6).mkString(", ")

def msg(using Context) =
val addendum = if hasMore then "(More unmatched cases are elided)" else ""
i"""|${hl("match")} may not be exhaustive.
Expand All @@ -862,6 +865,34 @@ extends Message(PatternMatchExhaustivityID) {
| - If an extractor always return ${hl("Some(...)")}, write ${hl("Some[X]")} for its return type
| - Add a ${hl("case _ => ...")} at the end to match all remaining cases
|"""

override def actions(using Context) =
import scala.language.unsafeNulls
val endPos = tree.cases.lastOption.map(_.endPos)
.getOrElse(tree.selector.endPos)
val startColumn = tree.cases.lastOption
.map(_.startPos.startColumn)
.getOrElse(tree.selector.startPos.startColumn + 2)

val pathes = List(
ActionPatch(
srcPos = endPos,
replacement = uncoveredCases.map(c => indent(s"case $c => ???", startColumn))
.mkString("\n", "\n", "")
),
)
List(
CodeAction(title = s"Insert missing cases (${uncoveredCases.size})",
description = None,
patches = pathes
)
)


private def indent(text:String, margin: Int): String = {
import scala.language.unsafeNulls
" " * margin + text
}
}

class UncheckedTypePattern(msgFn: => String)(using Context)
Expand Down
9 changes: 4 additions & 5 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ object SpaceEngine {
checkConstraint(genConstraint(sp))(using ctx.fresh.setNewTyperState())
}

def showSpaces(ss: Seq[Space])(using Context): String = ss.map(show).mkString(", ")
def showSpaces(ss: Seq[Space])(using Context): Seq[String] = ss.map(show)

/** Display spaces */
def show(s: Space)(using Context): String = {
Expand All @@ -784,7 +784,7 @@ object SpaceEngine {

def doShow(s: Space, flattenList: Boolean = false): String = s match {
case Empty => "empty"
case Typ(c: ConstantType, _) => "" + c.value.value
case Typ(c: ConstantType, _) => c.value.show
case Typ(tp: TermRef, _) =>
if (flattenList && tp <:< defn.NilType) ""
else tp.symbol.showName
Expand Down Expand Up @@ -894,9 +894,8 @@ object SpaceEngine {


if uncovered.nonEmpty then
val hasMore = uncovered.lengthCompare(6) > 0
val deduped = dedup(uncovered.take(6))
report.warning(PatternMatchExhaustivity(showSpaces(deduped), hasMore), m.selector)
val deduped = dedup(uncovered)
report.warning(PatternMatchExhaustivity(showSpaces(deduped), m), m.selector)
}

private def redundancyCheckable(sel: Tree)(using Context): Boolean =
Expand Down
89 changes: 85 additions & 4 deletions compiler/test/dotty/tools/dotc/reporting/CodeActionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,88 @@ class CodeActionTest extends DottyTest:
// TODO look into trying to remove the extra space that is left behind
"""|final class Test
|""".stripMargin
)

@Test def insertMissingCases =
checkCodeAction(
code =
"""|enum Tree:
| case Node(l: Tree, r: Tree)
| case Leaf(v: String)
|
|object Test:
| def foo(tree: Tree) = tree match {
| case Tree.Node(_, _) => ???
| }
|""".stripMargin,
title = "Insert missing cases (1)",
expected =
"""|enum Tree:
| case Node(l: Tree, r: Tree)
| case Leaf(v: String)
|
|object Test:
| def foo(tree: Tree) = tree match {
| case Tree.Node(_, _) => ???
| case Tree.Leaf(_) => ???
| }
|""".stripMargin,
afterPhase = "patternMatcher"
)

@Test def insertMissingCasesForUnionStringType =
checkCodeAction(
code =
"""object Test:
| def foo(text: "Alice" | "Bob") = text match {
| case "Alice" => ???
| }
|""".stripMargin,
title = "Insert missing cases (1)",
expected =
"""object Test:
| def foo(text: "Alice" | "Bob") = text match {
| case "Alice" => ???
| case "Bob" => ???
| }
|""".stripMargin,
afterPhase = "patternMatcher"
)

@Test def insertMissingCasesForUnionIntType =
checkCodeAction(
code =
"""object Test:
| def foo(text: 1 | 2) = text match {
| case 2 => ???
| }
|""".stripMargin,
title = "Insert missing cases (1)",
expected =
"""object Test:
| def foo(text: 1 | 2) = text match {
| case 2 => ???
| case 1 => ???
| }
|""".stripMargin,
afterPhase = "patternMatcher"
)

@Test def insertMissingCasesUsingBracelessSyntax =
checkCodeAction(
code =
"""object Test:
| def foo(text: 1 | 2) = text match
| case 2 => ???
|""".stripMargin,
title = "Insert missing cases (1)",
expected =
"""object Test:
| def foo(text: 1 | 2) = text match
| case 2 => ???
| case 1 => ???
|""".stripMargin,
afterPhase = "patternMatcher"
)

// Make sure we're not using the default reporter, which is the ConsoleReporter,
Expand All @@ -61,16 +142,16 @@ class CodeActionTest extends DottyTest:
val rep = new StoreReporter(null) with UniqueMessagePositions with HideNonSensicalMessages
initialCtx.setReporter(rep).withoutColors

private def checkCodeAction(code: String, title: String, expected: String) =
private def checkCodeAction(code: String, title: String, expected: String, afterPhase: String = "typer") =
ctx = newContext
val source = SourceFile.virtual("test", code).content
val runCtx = checkCompile("typer", code) { (_, _) => () }
val runCtx = checkCompile(afterPhase, code) { (_, _) => () }
val diagnostics = runCtx.reporter.removeBufferedMessages
assertEquals(1, diagnostics.size)
assertEquals("Expected exactly one diagnostic", 1, diagnostics.size)

val diagnostic = diagnostics.head
val actions = diagnostic.msg.actions.toList
assertEquals(1, actions.size)
assertEquals("Expected exactly one action", 1, actions.size)

// TODO account for more than 1 action
val action = actions.head
Expand Down

0 comments on commit c5adafc

Please sign in to comment.