From 493944857cd9f4af7e0846481f7ebcc027c17663 Mon Sep 17 00:00:00 2001 From: Kurt Alfred Kluever Date: Fri, 5 Jun 2026 11:16:01 -0700 Subject: [PATCH] Only hoist the last statement into the `assertThrows()` lambda. PiperOrigin-RevId: 927391038 --- .../bugpatterns/AssertThrowsUtils.java | 27 ++++++++----------- .../bugpatterns/MissingFailTest.java | 9 ++----- .../bugpatterns/TryFailRefactoringTest.java | 24 +++++++---------- 3 files changed, 22 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsUtils.java b/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsUtils.java index 869d26c2b06..2fd1a524797 100644 --- a/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsUtils.java +++ b/core/src/main/java/com/google/errorprone/bugpatterns/AssertThrowsUtils.java @@ -18,7 +18,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getLast; -import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.errorprone.fixes.SuggestedFixes.renameVariableUsages; import static com.google.errorprone.util.ASTHelpers.getStartPosition; import static java.util.stream.Collectors.joining; @@ -112,6 +111,10 @@ public static Optional tryFailToAssertThrows( resources.stream().map(state::getSourceForNode).collect(joining("\n", "try (", ") {\n"))); fixSuffix = "\n}"; } + // Hoist all but the last statement out of the lambda to narrow the exception scope. + for (StatementTree statement : throwingStatements.subList(0, throwingStatements.size() - 1)) { + fixPrefix.append(state.getSourceForNode(statement)).append("\n"); + } if (!catchStatements.isEmpty()) { String name = catchTree.getParameter().getName().toString(); String newName = namer.avoidShadowing(name); @@ -133,30 +136,22 @@ public static Optional tryFailToAssertThrows( .map(t -> state.getSourceForNode(t) + ", ") .orElse(""), state.getSourceForNode(catchTree.getParameter().getType()))); - boolean useExpressionLambda = - throwingStatements.size() == 1 - && getOnlyElement(throwingStatements) instanceof ExpressionStatementTree; + StatementTree lastStatement = getLast(throwingStatements); + boolean useExpressionLambda = lastStatement instanceof ExpressionStatementTree; if (!useExpressionLambda) { fixPrefix.append("{"); } - fix.replace( - getStartPosition(tryTree), - getStartPosition(throwingStatements.iterator().next()), - fixPrefix.toString()); + fix.replace(getStartPosition(tryTree), getStartPosition(lastStatement), fixPrefix.toString()); if (useExpressionLambda) { - fix.postfixWith( - ((ExpressionStatementTree) throwingStatements.iterator().next()).getExpression(), ")"); + fix.postfixWith(((ExpressionStatementTree) lastStatement).getExpression(), ")"); } else { - fix.postfixWith(getLast(throwingStatements), "});"); + fix.postfixWith(lastStatement, "});"); } if (catchStatements.isEmpty()) { - fix.replace( - state.getEndPosition(getLast(throwingStatements)), - state.getEndPosition(tryTree), - fixSuffix); + fix.replace(state.getEndPosition(lastStatement), state.getEndPosition(tryTree), fixSuffix); } else { fix.replace( - /* startPos= */ state.getEndPosition(getLast(throwingStatements)), + /* startPos= */ state.getEndPosition(lastStatement), /* endPos= */ getStartPosition(catchStatements.getFirst()), "\n"); fix.replace( diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/MissingFailTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/MissingFailTest.java index 7d14fd8fdfa..17b4c37a779 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/MissingFailTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/MissingFailTest.java @@ -948,13 +948,8 @@ class ExceptionTest { @Test public void f() throws Exception { Path p = Paths.get("NOSUCH"); - IOException e = - assertThrows( - IOException.class, - () -> { - Files.readAllBytes(p); - Files.readAllBytes(p); - }); + Files.readAllBytes(p); + IOException e = assertThrows(IOException.class, () -> Files.readAllBytes(p)); assertThat(e).hasMessageThat().contains("NOSUCH"); } diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/TryFailRefactoringTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/TryFailRefactoringTest.java index 7b69fcc94fb..fc386285c28 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/TryFailRefactoringTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/TryFailRefactoringTest.java @@ -45,12 +45,12 @@ public void catchBlock() { class ExceptionTest { @Test - public void f(String message) throws Exception { + public void f(String msg) throws Exception { Path p = Paths.get("NOSUCH"); try { Files.readAllBytes(p); Files.readAllBytes(p); - fail(message); + fail(msg); } catch (IOException e) { assertThat(e).hasMessageThat().contains("NOSUCH"); } @@ -81,16 +81,10 @@ public void g() throws Exception { class ExceptionTest { @Test - public void f(String message) throws Exception { + public void f(String msg) throws Exception { Path p = Paths.get("NOSUCH"); - IOException e = - assertThrows( - message, - IOException.class, - () -> { - Files.readAllBytes(p); - Files.readAllBytes(p); - }); + Files.readAllBytes(p); + IOException e = assertThrows(msg, IOException.class, () -> Files.readAllBytes(p)); assertThat(e).hasMessageThat().contains("NOSUCH"); } @@ -169,11 +163,11 @@ public void tryWithResources() { class ExceptionTest { @Test - public void f(String message, CharSource cs) throws IOException { + public void f(String msg, CharSource cs) throws IOException { try (BufferedReader buf = cs.openBufferedStream(); PushbackReader pbr = new PushbackReader(buf)) { pbr.read(); - fail(message); + fail(msg); } catch (IOException e) { assertThat(e).hasMessageThat().contains("NOSUCH"); } @@ -195,10 +189,10 @@ public void f(String message, CharSource cs) throws IOException { class ExceptionTest { @Test - public void f(String message, CharSource cs) throws IOException { + public void f(String msg, CharSource cs) throws IOException { try (BufferedReader buf = cs.openBufferedStream(); PushbackReader pbr = new PushbackReader(buf)) { - IOException e = assertThrows(message, IOException.class, () -> pbr.read()); + IOException e = assertThrows(msg, IOException.class, () -> pbr.read()); assertThat(e).hasMessageThat().contains("NOSUCH"); } }