Skip to content

Commit

Permalink
Try tree iterators, streams, and iterables
Browse files Browse the repository at this point in the history
  • Loading branch information
matozoid committed Oct 19, 2017
1 parent b1bf722 commit 285803b
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 54 deletions.
80 changes: 63 additions & 17 deletions javaparser-core/src/main/java/com/github/javaparser/ast/Node.java
Expand Up @@ -43,10 +43,15 @@
import javax.annotation.Generated; import javax.annotation.Generated;
import java.util.*; import java.util.*;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;


import static com.github.javaparser.ast.Node.Parsedness.PARSED; import static com.github.javaparser.ast.Node.Parsedness.PARSED;
import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableList;
import static java.util.Spliterator.DISTINCT;
import static java.util.Spliterator.NONNULL;


/** /**
* Base class for all nodes of the abstract syntax tree. * Base class for all nodes of the abstract syntax tree.
Expand Down Expand Up @@ -679,49 +684,90 @@ protected SymbolResolver getSymbolResolver() {
public static final DataKey<SymbolResolver> SYMBOL_RESOLVER_KEY = new DataKey<SymbolResolver>() { public static final DataKey<SymbolResolver> SYMBOL_RESOLVER_KEY = new DataKey<SymbolResolver>() {
}; };



public enum TreeTraversal {
PREORDER, BREADTHFIRST, POSTORDER
}

public Iterator<Node> treeIterator(TreeTraversal traversal) {
switch (traversal) {
case BREADTHFIRST:
return new TreeVisitor.BreadthFirstIterator(this);
case POSTORDER:
return new TreeVisitor.PostOrderIterator(this);
case PREORDER:
return new TreeVisitor.PreOrderIterator(this);
default:
throw new IllegalArgumentException("Unknown traversal choice.");
}
}

public Iterable<Node> treeIterable(TreeTraversal traversal) {
return () -> treeIterator(traversal);
}

public Stream<Node> treeStream(TreeTraversal traversal) {
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(treeIterator(traversal), NONNULL | DISTINCT), false);
}

/** /**
* Walks the AST, calling the consumer for every node. * Walks the AST, calling the consumer for every node.
*/ */
public void walk(Consumer<Node> consumer) { public void walk(Consumer<Node> consumer) {
new TreeVisitor() { treeIterable(TreeTraversal.PREORDER).forEach(consumer);
@Override
public void process(Node node) {
consumer.accept(node);
}
}.visitPreOrder(this);
} }


/** /**
* Walks the AST, calling the consumer for every node of type "nodeType". * Walks the AST, calling the consumer for every node of type "nodeType".
*/ */
public <T extends Node> void walk(Class<T> nodeType, Consumer<T> consumer) { public <T extends Node> void walk(Class<T> nodeType, Consumer<T> consumer) {
new TreeVisitor() { for (Node node : treeIterable(TreeTraversal.PREORDER)) {
@Override if (nodeType.isInstance(node)) {
public void process(Node node) { consumer.accept(nodeType.cast(node));
if (nodeType.isInstance(node)) {
consumer.accept(nodeType.cast(node));
}
} }
}.visitPreOrder(this); }
} }


/** /**
* Walks the AST, returning the all nodes of type "nodeType". * Walks the AST, returning all nodes of type "nodeType".
*/ */
public <T extends Node> List<T> find(Class<T> nodeType) { public <T extends Node> List<T> findAll(Class<T> nodeType) {
final List<T> found = new ArrayList<>(); final List<T> found = new ArrayList<>();
walk(nodeType, found::add); walk(nodeType, found::add);
return found; return found;
} }


/** /**
* Walks the AST, returning the all nodes of type "nodeType" that match the predicate. * Walks the AST, returning all nodes of type "nodeType" that match the predicate.
*/ */
public <T extends Node> List<T> find(Class<T> nodeType, Predicate<T> predicate) { public <T extends Node> List<T> findAll(Class<T> nodeType, Predicate<T> predicate) {
final List<T> found = new ArrayList<>(); final List<T> found = new ArrayList<>();
walk(nodeType, n -> { walk(nodeType, n -> {
if (predicate.test(n)) found.add(n); if (predicate.test(n)) found.add(n);
}); });
return found; return found;
} }

private <T> Optional<T> visitPreOrder(Function<Node, Optional<T>> visitor) {
Optional<T> result = visitor.apply(this);

if(result.isPresent()){
return result;
}

for(Node n: new ArrayList<>(getChildNodes())) {
result = n.visitPreOrder(visitor);

if(result.isPresent()){
return result;
}
}
return Optional.empty();
}

private void visitPreOrder(Consumer<Node> visitor) {
visitor.accept(this);

new ArrayList<>(getChildNodes()).forEach(n -> n.visitPreOrder(visitor));
}
} }
Expand Up @@ -23,57 +23,145 @@


import com.github.javaparser.ast.Node; import com.github.javaparser.ast.Node;


import java.util.ArrayList; import java.util.*;
import java.util.LinkedList;
import java.util.Queue;


/** /**
* Iterate over all the nodes in (a part of) the AST. * Iterate over all the nodes in (a part of) the AST.
*/ */
public abstract class TreeVisitor { public abstract class TreeVisitor {


public void visitLeavesFirst(Node node) { public void visitLeavesFirst(Node node) {
for (Node child : node.getChildNodes()) { visitPostOrder(node);
visitLeavesFirst(child);
}
process(node);
} }


/** /**
* Performs a pre-order node traversal starting with a given node. When each node is visited, * Performs a pre-order node traversal starting with a given node. When each node is visited, {@link #process(Node)}
* {@link #process(Node)} is called for further processing. * is called for further processing.
* *
* @param node The node at which the traversal begins. * @param node The node at which the traversal begins.
*
* @see <a href="https://en.wikipedia.org/wiki/Pre-order">Pre-order traversal</a> * @see <a href="https://en.wikipedia.org/wiki/Pre-order">Pre-order traversal</a>
*/ */
public void visitPreOrder(Node node) { public void visitPreOrder(Node node) {
process(node); new PreOrderIterator(node).forEachRemaining(this::process);
new ArrayList<>(node.getChildNodes()).forEach(this::visitPreOrder); }

public static class BreadthFirstIterator implements Iterator<Node> {
private final Queue<Node> queue = new LinkedList<>();

public BreadthFirstIterator(Node node) {
queue.add(node);
}

@Override
public boolean hasNext() {
return !queue.isEmpty();
}

@Override
public Node next() {
Node next = queue.remove();
queue.addAll(next.getChildNodes());
return next;
}
}

public static class PreOrderIterator implements Iterator<Node> {
private final Stack<Node> stack = new Stack<>();

public PreOrderIterator(Node node) {
stack.add(node);
}

@Override
public boolean hasNext() {
return !stack.isEmpty();
}

@Override
public Node next() {
Node next = stack.pop();
List<Node> children = next.getChildNodes();
for (int i = children.size() - 1; i >= 0; i--) {
stack.add(children.get(i));
}
return next;
}
}

public static class PostOrderIterator implements Iterator<Node> {
private final Stack<List<Node>> nodesStack = new Stack<>();
private final Stack<Integer> cursorStack = new Stack<>();
private final Node root;
private boolean hasNext = true;

public PostOrderIterator(Node root) {
this.root = root;
fillStackToLeaf(root);
}

private void fillStackToLeaf(Node node) {
while (true) {
List<Node> childNodes = new ArrayList<>(node.getChildNodes());
if (childNodes.isEmpty()) {
break;
}
nodesStack.push(childNodes);
cursorStack.push(0);
node = childNodes.get(0);
}
}

@Override
public boolean hasNext() {
return hasNext;
}

@Override
public Node next() {
final List<Node> nodes = nodesStack.peek();
final int cursor = cursorStack.peek();
final boolean levelHasNext = cursor < nodes.size();
if (levelHasNext) {
Node node = nodes.get(cursor);
fillStackToLeaf(node);
return nextFromLevel();
} else {
nodesStack.pop();
cursorStack.pop();
hasNext = !nodesStack.empty();
if (hasNext) {
return nextFromLevel();
}
return root;
}
}

private Node nextFromLevel() {
final List<Node> nodes = nodesStack.peek();
final int cursor = cursorStack.pop();
cursorStack.push(cursor + 1);
return nodes.get(cursor);
}
} }


/** /**
* Performs a post-order node traversal starting with a given node. When each node is visited, * Performs a post-order node traversal starting with a given node. When each node is visited, {@link
* {@link #process(Node)} is called for further processing. * #process(Node)} is called for further processing.
* *
* @param node The node at which the traversal begins. * @param node The node at which the traversal begins.
*
* @see <a href="https://en.wikipedia.org/wiki/Post-order">Post-order traversal</a> * @see <a href="https://en.wikipedia.org/wiki/Post-order">Post-order traversal</a>
*/ */
public void visitPostOrder(Node node) { public void visitPostOrder(Node node) {
new ArrayList<>(node.getChildNodes()).forEach(this::visitPostOrder); new PostOrderIterator(node).forEachRemaining(this::process);
process(node);
} }


/** /**
* Performs a pre-order node traversal starting with a given node. When each node is visited, * Performs a pre-order node traversal starting with a given node. When each node is visited, {@link #process(Node)}
* {@link #process(Node)} is called for further processing. * is called for further processing.
*
* @deprecated As of release 3.1.0, replaced by {@link #visitPreOrder(Node)}
* *
* @param node The node at which the traversal begins. * @param node The node at which the traversal begins.
*
* @see <a href="https://en.wikipedia.org/wiki/Pre-order">Pre-order traversal</a> * @see <a href="https://en.wikipedia.org/wiki/Pre-order">Pre-order traversal</a>
* @deprecated As of release 3.1.0, replaced by {@link #visitPreOrder(Node)}
*/ */
@Deprecated @Deprecated
public void visitDepthFirst(Node node) { public void visitDepthFirst(Node node) {
Expand All @@ -86,15 +174,7 @@ public void visitDepthFirst(Node node) {
* @param node the start node, and the first one that is passed to process(node). * @param node the start node, and the first one that is passed to process(node).
*/ */
public void visitBreadthFirst(Node node) { public void visitBreadthFirst(Node node) {
final Queue<Node> queue = new LinkedList<>(); new BreadthFirstIterator(node).forEachRemaining(this::process);
queue.offer(node);
while (queue.size() > 0) {
final Node head = queue.peek();
for (Node child : head.getChildNodes()) {
queue.offer(child);
}
process(queue.poll());
}
} }


/** /**
Expand Down
Expand Up @@ -344,32 +344,32 @@ public void cantFindCompilationUnit() {
} }


@Test @Test
public void walk1() { public void genericWalk() {
Expression e = parseExpression("1+1"); Expression e = parseExpression("1+1");
StringBuilder b = new StringBuilder(); StringBuilder b = new StringBuilder();
e.walk(n -> b.append(n.toString())); e.walk(n -> b.append(n.toString()));
assertEquals("1 + 111", b.toString()); assertEquals("1 + 111", b.toString());
} }


@Test @Test
public void walk2() { public void classSpecificWalk() {
Expression e = parseExpression("1+1"); Expression e = parseExpression("1+1");
StringBuilder b = new StringBuilder(); StringBuilder b = new StringBuilder();
e.walk(IntegerLiteralExpr.class, n -> b.append(n.toString())); e.walk(IntegerLiteralExpr.class, n -> b.append(n.toString()));
assertEquals("11", b.toString()); assertEquals("11", b.toString());
} }


@Test @Test
public void find1() { public void conditionalFindAll() {
Expression e = parseExpression("1+2+3"); Expression e = parseExpression("1+2+3");
List<IntegerLiteralExpr> ints = e.find(IntegerLiteralExpr.class, n -> n.asInt() > 1); List<IntegerLiteralExpr> ints = e.findAll(IntegerLiteralExpr.class, n -> n.asInt() > 1);
assertEquals("[2, 3]", ints.toString()); assertEquals("[2, 3]", ints.toString());
} }


@Test @Test
public void find2() { public void typeOnlyFindAll() {
Expression e = parseExpression("1+2+3"); Expression e = parseExpression("1+2+3");
List<IntegerLiteralExpr> ints = e.find(IntegerLiteralExpr.class); List<IntegerLiteralExpr> ints = e.findAll(IntegerLiteralExpr.class);
assertEquals("[1, 2, 3]", ints.toString()); assertEquals("[1, 2, 3]", ints.toString());
} }
} }

0 comments on commit 285803b

Please sign in to comment.