diff --git a/rewrite-core/src/main/java/org/openrewrite/config/CompositeRefactorVisitor.java b/rewrite-core/src/main/java/org/openrewrite/config/CompositeRefactorVisitor.java index 41ac28df956..84d2e1098b8 100644 --- a/rewrite-core/src/main/java/org/openrewrite/config/CompositeRefactorVisitor.java +++ b/rewrite-core/src/main/java/org/openrewrite/config/CompositeRefactorVisitor.java @@ -22,6 +22,7 @@ import org.openrewrite.SourceVisitor; import org.openrewrite.Tree; +import java.awt.*; import java.util.List; public class CompositeRefactorVisitor extends SourceVisitor { @@ -50,7 +51,9 @@ CompositeRefactorVisitor setName(String name) { public Class getVisitorType() { return delegates.stream().findAny() - .map(Object::getClass) + .map(d -> d instanceof CompositeRefactorVisitor ? + ((CompositeRefactorVisitor) d).getVisitorType() : + d.getClass()) .orElse(null); } @@ -60,7 +63,7 @@ public String getName() { @Override public Tree visitTree(Tree tree) { - if(tree instanceof SourceFile) { + if (tree instanceof SourceFile) { Refactor refactor = new Refactor<>(tree); return refactor.visit(delegates).fix().getFixed(); } @@ -68,8 +71,17 @@ public Tree visitTree(Tree tree) { return super.visitTree(tree); } + void extendsFrom(CompositeRefactorVisitor delegate) { + delegates.add(0, delegate); + andThen().add(0, delegate); + } + @Override public Tree defaultTo(Tree t) { - return delegates.stream().findAny().map(v -> v.defaultTo(t)).orElse(null); + return delegates.stream().findAny() + .map(d -> d instanceof CompositeRefactorVisitor ? + ((CompositeRefactorVisitor) d).defaultTo(t) : + d.defaultTo(t)) + .orElse(null); } } diff --git a/rewrite-core/src/main/java/org/openrewrite/config/YamlResourceLoader.java b/rewrite-core/src/main/java/org/openrewrite/config/YamlResourceLoader.java index af8c2225f42..6234fcd1cf3 100644 --- a/rewrite-core/src/main/java/org/openrewrite/config/YamlResourceLoader.java +++ b/rewrite-core/src/main/java/org/openrewrite/config/YamlResourceLoader.java @@ -39,31 +39,37 @@ public class YamlResourceLoader implements ProfileConfigurationLoader, SourceVis .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); private final Map profiles = new HashMap<>(); - private final Collection> visitors = new ArrayList<>(); + private final Collection visitors = new ArrayList<>(); + private final Map visitorExtensions = new HashMap<>(); public YamlResourceLoader(InputStream yamlInput) throws UncheckedIOException { try { - Yaml yaml = new Yaml(); - for (Object resource : yaml.loadAll(yamlInput)) { - if (resource instanceof Map) { - @SuppressWarnings("unchecked") Map resourceMap = (Map) resource; - String type = resourceMap.getOrDefault("type", "invalid").toString(); - switch(type) { - case "beta.openrewrite.org/v1/visitor": - mapVisitor(resourceMap); - break; - case "beta.openrewrite.org/v1/profile": - mapProfile(resourceMap); - break; + try { + Yaml yaml = new Yaml(); + for (Object resource : yaml.loadAll(yamlInput)) { + if (resource instanceof Map) { + @SuppressWarnings("unchecked") Map resourceMap = (Map) resource; + String type = resourceMap.getOrDefault("type", "invalid").toString(); + switch (type) { + case "beta.openrewrite.org/v1/visitor": + mapVisitor(resourceMap); + break; + case "beta.openrewrite.org/v1/profile": + mapProfile(resourceMap); + break; + } } } - } - } finally { - try { + + for (Map.Entry extendingVisitor : visitorExtensions.entrySet()) { + visitors.stream().filter(v -> v.getName().equals(extendingVisitor.getValue())).findAny() + .ifPresent(v -> extendingVisitor.getKey().extendsFrom(v)); + } + } finally { yamlInput.close(); - } catch (IOException e) { - throw new UncheckedIOException(e); } + } catch (IOException e) { + throw new UncheckedIOException(e); } } @@ -108,7 +114,13 @@ private void mapVisitor(Map visitorMap) { } } - this.visitors.add(new CompositeRefactorVisitor(visitorMap.get("name").toString(), subVisitors)); + CompositeRefactorVisitor visitor = new CompositeRefactorVisitor(visitorMap.get("name").toString(), subVisitors); + + if (visitorMap.containsKey("extends")) { + visitorExtensions.put(visitor, visitorMap.get("extends").toString()); + } + + this.visitors.add(visitor); } private Class visitorClass(String name) throws ClassNotFoundException { @@ -136,8 +148,9 @@ public Collection loadProfiles() { return profiles.values(); } + @SuppressWarnings("unchecked") @Override public Collection> loadVisitors() { - return visitors; + return (Collection>) (Collection) visitors; } }