Skip to content

Commit

Permalink
Implement one findFirst method
Browse files Browse the repository at this point in the history
  • Loading branch information
matozoid committed Oct 22, 2017
1 parent aa62a13 commit a863052
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 26 deletions.
74 changes: 63 additions & 11 deletions javaparser-core/src/main/java/com/github/javaparser/ast/Node.java
Expand Up @@ -31,7 +31,10 @@
import com.github.javaparser.ast.observer.AstObserver;
import com.github.javaparser.ast.observer.ObservableProperty;
import com.github.javaparser.ast.observer.PropagatingAstObserver;
import com.github.javaparser.ast.visitor.*;
import com.github.javaparser.ast.visitor.CloneVisitor;
import com.github.javaparser.ast.visitor.EqualsVisitor;
import com.github.javaparser.ast.visitor.HashCodeVisitor;
import com.github.javaparser.ast.visitor.Visitable;
import com.github.javaparser.metamodel.InternalProperty;
import com.github.javaparser.metamodel.JavaParserMetaModel;
import com.github.javaparser.metamodel.NodeMetaModel;
Expand All @@ -49,6 +52,7 @@
import java.util.stream.StreamSupport;

import static com.github.javaparser.ast.Node.Parsedness.PARSED;
import static com.github.javaparser.ast.Node.TreeTraversal.PREORDER;
import static java.util.Collections.unmodifiableList;
import static java.util.Spliterator.DISTINCT;
import static java.util.Spliterator.NONNULL;
Expand Down Expand Up @@ -709,31 +713,67 @@ private Iterator<Node> treeIterator(TreeTraversal traversal) {
private Iterable<Node> treeIterable(TreeTraversal traversal) {
return () -> treeIterator(traversal);
}

public Stream<Node> treeStream(TreeTraversal traversal) {

/**
* Make a stream of nodes using traversal algorithm "traversal".
*/
public Stream<Node> stream(TreeTraversal traversal) {
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(treeIterator(traversal), NONNULL | DISTINCT), false);
}

/**
* Make a stream of nodes using pre-order traversal.
*/
public Stream<Node> stream() {
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(treeIterator(PREORDER), NONNULL | DISTINCT), false);
}

/**
* Walks the AST, applying the function for every node, with traversal algorithm "traversal".
* If the function returns something else than null, the traversal is stopped and the function result is returned.
* <br/>This is the most general walk method. All other walk and find methods are based on this.
*/
public <T> Optional<T> walk(TreeTraversal traversal, Function<Node, T> consumer) {
for (Node node : treeIterable(traversal)) {
T result = consumer.apply(node);
if (result != null) {
return Optional.of(result);
}
}
return Optional.empty();
}


/**
* Walks the AST, calling the consumer for every node, with traversal algorithm "traversal".
*/
public void walk(TreeTraversal traversal, Consumer<Node> consumer) {
// Could be implemented as a call to the above walk method, but this is a little more efficient.
for (Node node : treeIterable(traversal)) {
consumer.accept(node);
}
}

/**
* Walks the AST, calling the consumer for every node.
* Walks the AST, calling the consumer for every node with pre-order traversal.
*/
public void walk(Consumer<Node> consumer) {
treeIterable(TreeTraversal.PREORDER).forEach(consumer);
walk(PREORDER, consumer);
}

/**
* Walks the AST, calling the consumer for every node of type "nodeType".
* Walks the AST with pre-order traversal, calling the consumer for every node of type "nodeType".
*/
public <T extends Node> void walk(Class<T> nodeType, Consumer<T> consumer) {
for (Node node : treeIterable(TreeTraversal.PREORDER)) {
if (nodeType.isInstance(node)) {
walk(TreeTraversal.PREORDER, node -> {
if (nodeType.isAssignableFrom(node.getClass())) {
consumer.accept(nodeType.cast(node));
}
}
});
}

/**
* Walks the AST, returning all nodes of type "nodeType".
* Walks the AST with pre-order traversal, returning all nodes of type "nodeType".
*/
public <T extends Node> List<T> findAll(Class<T> nodeType) {
final List<T> found = new ArrayList<>();
Expand All @@ -742,7 +782,7 @@ public <T extends Node> List<T> findAll(Class<T> nodeType) {
}

/**
* Walks the AST, returning all nodes of type "nodeType" that match the predicate.
* Walks the AST with pre-order traversal, returning all nodes of type "nodeType" that match the predicate.
*/
public <T extends Node> List<T> findAll(Class<T> nodeType, Predicate<T> predicate) {
final List<T> found = new ArrayList<>();
Expand All @@ -752,6 +792,18 @@ public <T extends Node> List<T> findAll(Class<T> nodeType, Predicate<T> predicat
return found;
}

/**
* Walks the AST with pre-order traversal, returning the first node of type "nodeType" or empty() if none is found.
*/
public <N extends Node> Optional<N> findFirst(Class<N> nodeType) {
return walk(TreeTraversal.PREORDER, node -> {
if(nodeType.isAssignableFrom(node.getClass())){
return nodeType.cast(node);
}
return null;
});
}

/**
* Performs a breadth-first node traversal starting with a given node.
*
Expand Down
Expand Up @@ -19,6 +19,6 @@ public void accept(Node node, ProblemReporter problemReporter) {
if (type.isInstance(node)) {
validator.accept(type.cast(node), problemReporter);
}
node.getChildNodesByType(type).forEach(n -> validator.accept(n, problemReporter));
node.findAll(type).forEach(n -> validator.accept(n, problemReporter));
}
}
Expand Up @@ -14,6 +14,6 @@ public void storeNoTokens() {
ParseResult<CompilationUnit> result = new JavaParser(new ParserConfiguration().setStoreTokens(false)).parse(ParseStart.COMPILATION_UNIT, provider("class X{}"));

assertFalse(result.getTokens().isPresent());
assertTrue(result.getResult().get().getChildNodesByType(Node.class).stream().noneMatch(node -> node.getTokenRange().isPresent()));
assertTrue(result.getResult().get().findAll(Node.class).stream().noneMatch(node -> node.getTokenRange().isPresent()));
}
}
Expand Up @@ -41,6 +41,7 @@
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.stream.Collectors;

import static com.github.javaparser.JavaParser.parse;
import static com.github.javaparser.JavaParser.parseExpression;
Expand Down Expand Up @@ -373,4 +374,22 @@ public void typeOnlyFindAll() {
List<IntegerLiteralExpr> ints = e.findAll(IntegerLiteralExpr.class);
assertEquals("[1, 2, 3]", ints.toString());
}

@Test
public void typeOnlyFindAllMatchesSubclasses() {
Expression e = parseExpression("1+2+3");
List<Node> ints = e.findAll(Node.class);
assertEquals("[1 + 2 + 3, 1 + 2, 1, 2, 3]", ints.toString());
}

@Test
public void stream() {
Expression e = parseExpression("1+2+3");
List<IntegerLiteralExpr> ints = e.stream()
.filter(n -> n instanceof IntegerLiteralExpr)
.map(IntegerLiteralExpr.class::cast)
.filter(i -> i.asInt() > 1)
.collect(Collectors.toList());
assertEquals("[2, 3]", ints.toString());
}
}
Expand Up @@ -51,7 +51,7 @@ public void topClass() {
@Test
public void localClass() {
MethodDeclaration method= (MethodDeclaration)JavaParser.parseBodyDeclaration("void x(){class X{};}");
ClassOrInterfaceDeclaration x = method.getChildNodesByType(ClassOrInterfaceDeclaration.class).get(0);
ClassOrInterfaceDeclaration x = method.findFirst(ClassOrInterfaceDeclaration.class).get();

assertFalse(x.isInnerClass());
assertFalse(x.isNestedType());
Expand Down
Expand Up @@ -203,17 +203,17 @@ public void thenLambdaInStatementInMethodInClassIsParentOfContainedParameter(int
public void thenMethodReferenceInStatementInMethodInClassIsScope(int statementPosition, int methodPosition,
int classPosition, String expectedName) {
ExpressionStmt statementUnderTest = getStatementInMethodInClass(statementPosition, methodPosition, classPosition).asExpressionStmt();
assertEquals(1, statementUnderTest.getChildNodesByType(MethodReferenceExpr.class).size());
MethodReferenceExpr methodReferenceUnderTest = statementUnderTest.getChildNodesByType(MethodReferenceExpr.class).get(0);
assertEquals(1, statementUnderTest.findAll(MethodReferenceExpr.class).size());
MethodReferenceExpr methodReferenceUnderTest = statementUnderTest.findFirst(MethodReferenceExpr.class).get();
assertThat(methodReferenceUnderTest.getScope().toString(), is(expectedName));
}

@Then("method reference in statement $statementPosition in method $methodPosition in class $classPosition identifier is $expectedName")
public void thenMethodReferenceInStatementInMethodInClassIdentifierIsCompareByAge(int statementPosition, int methodPosition,
int classPosition, String expectedName) {
Statement statementUnderTest = getStatementInMethodInClass(statementPosition, methodPosition, classPosition);
assertEquals(1, statementUnderTest.getChildNodesByType(MethodReferenceExpr.class).size());
MethodReferenceExpr methodReferenceUnderTest = statementUnderTest.getChildNodesByType(MethodReferenceExpr.class).get(0);
assertEquals(1, statementUnderTest.findAll(MethodReferenceExpr.class).size());
MethodReferenceExpr methodReferenceUnderTest = statementUnderTest.findFirst(MethodReferenceExpr.class).get();
assertThat(methodReferenceUnderTest.getIdentifier(), is(expectedName));
}

Expand Down Expand Up @@ -334,7 +334,7 @@ public void thenTheAssignExprProducedDoesntHaveANullTarget() {

private void setSelectedNodeFromCompilationUnit(Class<? extends Node> nodeType) {
CompilationUnit compilationUnit = (CompilationUnit) state.get("cu1");
List<? extends Node> nodes = compilationUnit.getChildNodesByType(nodeType);
List<? extends Node> nodes = compilationUnit.findAll(nodeType);
if (nodes.size() != 1) {
throw new RuntimeException(format("Exactly one %s expected", nodeType.getSimpleName()));
}
Expand Down
Expand Up @@ -26,24 +26,20 @@
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.FieldDeclaration;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import static org.junit.Assert.assertEquals;

import java.util.stream.IntStream;

public class PrettyPrinterTest {

private String prettyPrintField(String code) {
CompilationUnit cu = JavaParser.parse(code);
return new PrettyPrinter().print(cu.getChildNodesByType(FieldDeclaration.class).get(0));
return new PrettyPrinter().print(cu.findFirst(FieldDeclaration.class).get());
}

private String prettyPrintVar(String code) {
CompilationUnit cu = JavaParser.parse(code);
return new PrettyPrinter().print(cu.getChildNodesByType(VariableDeclarationExpr.class).get(0));
return new PrettyPrinter().print(cu.findAll(VariableDeclarationExpr.class).get(0));
}

@Test
Expand Down Expand Up @@ -87,7 +83,7 @@ public void printingArrayVariables() {
private String prettyPrintConfigurable(String code) {
CompilationUnit cu = JavaParser.parse(code);
PrettyPrinter printer = new PrettyPrinter(new PrettyPrinterConfiguration().setVisitorFactory(TestVisitor::new));
return printer.print(cu.getChildNodesByType(ClassOrInterfaceDeclaration.class).get(0));
return printer.print(cu.findFirst(ClassOrInterfaceDeclaration.class).get());
}

@Test
Expand Down

0 comments on commit a863052

Please sign in to comment.