Skip to content

Commit

Permalink
Fix NPE when returning double-brace-initialised things.
Browse files Browse the repository at this point in the history
The previous code assumes that unmodifiableFoo(new Foo<> {{ }}) must be assigned to a variable, rather than returned by a function (or, who knows, it could just be a statement).

Fixes #1040

RELNOTES: Fix NPE in DoubleBraceInitialization.

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=211413826
  • Loading branch information
graememorgan authored and ronshapiro committed Sep 4, 2018
1 parent 2783e13 commit ca3356d
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 21 deletions.
Expand Up @@ -45,10 +45,13 @@
import com.sun.source.tree.MethodTree; import com.sun.source.tree.MethodTree;
import com.sun.source.tree.NewClassTree; import com.sun.source.tree.NewClassTree;
import com.sun.source.tree.ParameterizedTypeTree; import com.sun.source.tree.ParameterizedTypeTree;
import com.sun.source.tree.ParenthesizedTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.tree.StatementTree; import com.sun.source.tree.StatementTree;
import com.sun.source.tree.Tree; import com.sun.source.tree.Tree;
import com.sun.source.tree.Tree.Kind; import com.sun.source.tree.Tree.Kind;
import com.sun.source.tree.VariableTree; import com.sun.source.tree.VariableTree;
import com.sun.source.util.TreePath;
import com.sun.tools.javac.code.Symbol.VarSymbol; import com.sun.tools.javac.code.Symbol.VarSymbol;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
Expand Down Expand Up @@ -128,22 +131,36 @@ Optional<Fix> maybeFix(NewClassTree tree, VisitorState state, BlockTree block) {
// check the enclosing context: calls to Collections.unmodifiable* are now redundant, and // check the enclosing context: calls to Collections.unmodifiable* are now redundant, and
// if there's an enclosing constant variable declaration we can rewrite its type to Immutable* // if there's an enclosing constant variable declaration we can rewrite its type to Immutable*
Tree unmodifiable = null; Tree unmodifiable = null;
VariableTree enclosingVariable = null;
boolean constant = false; boolean constant = false;
Tree typeTree = null;
Tree toReplace = null;


for (Tree enclosing : state.getPath().getParentPath()) { for (TreePath path = state.getPath().getParentPath();
path != null;
path = path.getParentPath()) {
Tree enclosing = path.getLeaf();
if (unmodifiableMatcher.matches(enclosing, state)) { if (unmodifiableMatcher.matches(enclosing, state)) {
unmodifiable = enclosing; unmodifiable = enclosing;
continue; continue;
} }
if (enclosing instanceof ParenthesizedTree) {
continue;
}
if (enclosing instanceof VariableTree) { if (enclosing instanceof VariableTree) {
enclosingVariable = (VariableTree) enclosing; VariableTree enclosingVariable = (VariableTree) enclosing;
toReplace = enclosingVariable.getInitializer();
typeTree = enclosingVariable.getType();
VarSymbol symbol = ASTHelpers.getSymbol(enclosingVariable); VarSymbol symbol = ASTHelpers.getSymbol(enclosingVariable);
constant = constant =
symbol.isStatic() symbol.isStatic()
&& symbol.getModifiers().contains(Modifier.FINAL) && symbol.getModifiers().contains(Modifier.FINAL)
&& symbol.getKind() == ElementKind.FIELD; && symbol.getKind() == ElementKind.FIELD;
} }
if (enclosing instanceof ReturnTree) {
toReplace = ((ReturnTree) enclosing).getExpression();
MethodTree enclosingMethod = ASTHelpers.findEnclosingNode(path, MethodTree.class);
typeTree = enclosingMethod == null ? null : enclosingMethod.getReturnType();
}
break; break;
} }
SuggestedFix.Builder fix = SuggestedFix.builder(); SuggestedFix.Builder fix = SuggestedFix.builder();
Expand Down Expand Up @@ -172,14 +189,13 @@ Optional<Fix> maybeFix(NewClassTree tree, VisitorState state, BlockTree block) {
if (unmodifiable != null || constant) { if (unmodifiable != null || constant) {
// there's an enclosing unmodifiable* call, or we're in the initializer of a constant, // there's an enclosing unmodifiable* call, or we're in the initializer of a constant,
// so rewrite the variable's type to be immutable and drop the unmodifiable* method // so rewrite the variable's type to be immutable and drop the unmodifiable* method
Tree typeType = enclosingVariable.getType(); if (typeTree instanceof ParameterizedTypeTree) {
if (typeType instanceof ParameterizedTypeTree) { typeTree = ((ParameterizedTypeTree) typeTree).getType();
typeType = ((ParameterizedTypeTree) typeType).getType(); }
if (typeTree != null) {
fix.replace(typeTree, immutableType);
} }
fix.replace(typeType, immutableType) fix.replace(unmodifiable == null ? toReplace : unmodifiable, replacement);
.replace(
unmodifiable != null ? unmodifiable : enclosingVariable.getInitializer(),
replacement);
} else { } else {
// the result may need to be mutable, so rewrite e.g. // the result may need to be mutable, so rewrite e.g.
// `new ArrayList<>() {{ add(1); }}` -> `new ArrayList<>(ImmutableList.of(1));` // `new ArrayList<>() {{ add(1); }}` -> `new ArrayList<>(ImmutableList.of(1));`
Expand Down
Expand Up @@ -32,7 +32,7 @@ public class DoubleBraceInitializationTest {
public void negative() { public void negative() {
CompilationTestHelper.newInstance(DoubleBraceInitialization.class, getClass()) CompilationTestHelper.newInstance(DoubleBraceInitialization.class, getClass())
.addSourceLines( .addSourceLines(
"Test.java", // "Test.java",
"import java.util.ArrayList;", "import java.util.ArrayList;",
"import java.util.List;", "import java.util.List;",
"class Test {", "class Test {",
Expand Down Expand Up @@ -61,7 +61,7 @@ public void negative() {
public void positiveNoFix() { public void positiveNoFix() {
testHelper testHelper
.addInputLines( .addInputLines(
"in/Test.java", // "in/Test.java",
"import java.util.ArrayList;", "import java.util.ArrayList;",
"import java.util.List;", "import java.util.List;",
"// BUG: Diagnostic contains:", "// BUG: Diagnostic contains:",
Expand All @@ -78,7 +78,7 @@ public void positiveNoFix() {
public void list() { public void list() {
testHelper testHelper
.addInputLines( .addInputLines(
"in/Test.java", // "in/Test.java",
"import java.util.ArrayList;", "import java.util.ArrayList;",
"import java.util.Collections;", "import java.util.Collections;",
"import java.util.List;", "import java.util.List;",
Expand All @@ -89,7 +89,7 @@ public void list() {
" List<Integer> c = new ArrayList<Integer>() {{ add(1); add(2); }};", " List<Integer> c = new ArrayList<Integer>() {{ add(1); add(2); }};",
"}") "}")
.addOutputLines( .addOutputLines(
"out/Test.java", // "out/Test.java",
"import com.google.common.collect.ImmutableList;", "import com.google.common.collect.ImmutableList;",
"import java.util.ArrayList;", "import java.util.ArrayList;",
"import java.util.Collections;", "import java.util.Collections;",
Expand All @@ -106,7 +106,7 @@ public void list() {
public void set() { public void set() {
testHelper testHelper
.addInputLines( .addInputLines(
"in/Test.java", // "in/Test.java",
"import java.util.Collections;", "import java.util.Collections;",
"import java.util.HashSet;", "import java.util.HashSet;",
"import java.util.Set;", "import java.util.Set;",
Expand All @@ -117,7 +117,7 @@ public void set() {
" Set<Integer> c = new HashSet<Integer>() {{ add(1); add(2); }};", " Set<Integer> c = new HashSet<Integer>() {{ add(1); add(2); }};",
"}") "}")
.addOutputLines( .addOutputLines(
"out/Test.java", // "out/Test.java",
"import com.google.common.collect.ImmutableSet;", "import com.google.common.collect.ImmutableSet;",
"import java.util.Collections;", "import java.util.Collections;",
"import java.util.HashSet;", "import java.util.HashSet;",
Expand All @@ -134,7 +134,7 @@ public void set() {
public void collection() { public void collection() {
testHelper testHelper
.addInputLines( .addInputLines(
"in/Test.java", // "in/Test.java",
"import java.util.ArrayDeque;", "import java.util.ArrayDeque;",
"import java.util.Collection;", "import java.util.Collection;",
"import java.util.Collections;", "import java.util.Collections;",
Expand All @@ -147,7 +147,7 @@ public void collection() {
" Deque<Integer> c = new ArrayDeque<Integer>() {{ add(1); add(2); }};", " Deque<Integer> c = new ArrayDeque<Integer>() {{ add(1); add(2); }};",
"}") "}")
.addOutputLines( .addOutputLines(
"out/Test.java", // "out/Test.java",
"import com.google.common.collect.ImmutableCollection;", "import com.google.common.collect.ImmutableCollection;",
"import com.google.common.collect.ImmutableList;", "import com.google.common.collect.ImmutableList;",
"import java.util.ArrayDeque;", "import java.util.ArrayDeque;",
Expand All @@ -166,7 +166,7 @@ public void collection() {
public void map() { public void map() {
testHelper testHelper
.addInputLines( .addInputLines(
"in/Test.java", // "in/Test.java",
"import java.util.Collections;", "import java.util.Collections;",
"import java.util.HashMap;", "import java.util.HashMap;",
"import java.util.Map;", "import java.util.Map;",
Expand All @@ -191,7 +191,7 @@ public void map() {
" }};", " }};",
"}") "}")
.addOutputLines( .addOutputLines(
"out/Test.java", // "out/Test.java",
"import com.google.common.collect.ImmutableMap;", "import com.google.common.collect.ImmutableMap;",
"import java.util.Collections;", "import java.util.Collections;",
"import java.util.HashMap;", "import java.util.HashMap;",
Expand Down Expand Up @@ -220,7 +220,7 @@ public void map() {
public void nulls() { public void nulls() {
testHelper testHelper
.addInputLines( .addInputLines(
"in/Test.java", // "in/Test.java",
"import java.util.*;", "import java.util.*;",
"// BUG: Diagnostic contains:", "// BUG: Diagnostic contains:",
"class Test {", "class Test {",
Expand All @@ -232,4 +232,93 @@ public void nulls() {
.expectUnchanged() .expectUnchanged()
.doTest(); .doTest();
} }

@Test
public void returned() {
testHelper
.addInputLines(
"Test.java",
"import java.util.Collections;",
"import java.util.HashMap;",
"import java.util.Map;",
"class Test {",
" private Map<String, Object> test() {",
" return Collections.unmodifiableMap(new HashMap<String, Object>() {",
" {}",
" });",
" }",
"}")
.addOutputLines(
"Test.java",
"import com.google.common.collect.ImmutableMap;",
"import java.util.Collections;",
"import java.util.HashMap;",
"import java.util.Map;",
"class Test {",
" private ImmutableMap<String, Object> test() {",
" return ImmutableMap.of();",
" }",
"}")
.doTest();
}

@Test
public void lambda() {
testHelper
.addInputLines(
"Test.java",
"import java.util.Collections;",
"import java.util.HashMap;",
"import java.util.Map;",
"import java.util.function.Supplier;",
"class Test {",
" private Supplier<Map<String, Object>> test() {",
" return () -> Collections.unmodifiableMap(new HashMap<String, Object>() {",
" {}",
" });",
" }",
"}")
.addOutputLines(
"Test.java",
"import com.google.common.collect.ImmutableMap;",
"import java.util.Collections;",
"import java.util.HashMap;",
"import java.util.Map;",
"import java.util.function.Supplier;",
"class Test {",
" private Supplier<Map<String, Object>> test() {",
" return () -> ImmutableMap.of();",
" }",
"}")
.doTest();
}

@Test
public void statement() {
testHelper
.addInputLines(
"Test.java",
"import java.util.Collections;",
"import java.util.HashMap;",
"import java.util.Map;",
"class Test {",
" private void test() {",
" Collections.unmodifiableMap(new HashMap<String, Object>() {",
" {}",
" });",
" }",
"}")
.addOutputLines(
"Test.java",
"import com.google.common.collect.ImmutableMap;",
"import java.util.Collections;",
"import java.util.HashMap;",
"import java.util.Map;",
"class Test {",
" private void test() {",
" ImmutableMap.of();",
" }",
"}")
.doTest();
}
} }

0 comments on commit ca3356d

Please sign in to comment.