diff --git a/liquibase-standard/src/main/java/liquibase/snapshot/SnapshotGeneratorChain.java b/liquibase-standard/src/main/java/liquibase/snapshot/SnapshotGeneratorChain.java index 196d9490cc5..df4c049ab66 100644 --- a/liquibase-standard/src/main/java/liquibase/snapshot/SnapshotGeneratorChain.java +++ b/liquibase-standard/src/main/java/liquibase/snapshot/SnapshotGeneratorChain.java @@ -3,7 +3,11 @@ import liquibase.exception.DatabaseException; import liquibase.structure.DatabaseObject; -import java.util.*; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; +import java.util.SortedSet; public class SnapshotGeneratorChain { private Iterator snapshotGenerators; @@ -51,34 +55,43 @@ public T snapshot(T example, DatabaseSnapshot snapsho return null; } - SnapshotGenerator next = getNextValidGenerator(); - - if (next == null) { - return null; - } - - T obj = next.snapshot(example, snapshot, this); - if ((obj != null) && (obj.getSnapshotId() == null)) { - obj.setSnapshotId(snapshotIdService.generateId()); - } - return obj; - } - - public SnapshotGenerator getNextValidGenerator() { if (snapshotGenerators == null) { return null; } - if (!snapshotGenerators.hasNext()) { - return null; + boolean firstRoundDone = false; + SnapshotGenerator lastGenerator = null; + T lastObject = example; + while (snapshotGenerators.hasNext()) { + SnapshotGenerator generator = snapshotGenerators.next(); + if (replacedGenerators.contains(generator.getClass())) { + continue; + } + T object = generator.snapshot(lastObject, snapshot, this); + if ((object != null) && (object.getSnapshotId() == null)) { + object.setSnapshotId(snapshotIdService.generateId()); + } + // only first generator in the chain is allowed to create new instances - subsequent ones are not + if (firstRoundDone && object != lastObject) { + throw new DatabaseException(String.format("Snapshot generator %s has returned a different reference from the previous generator %s.\n" + + "\tSnapshot object was: %s, it is now: %s.\n" + + "\tConsider declaring %1$s as being replaced by one the generator in the chain via liquibase.snapshot.SnapshotGenerator#replaces.", + generator.getClass().getName(), + lastGenerator.getClass().getName(), + identity(lastObject), + identity(object))); + } + lastObject = object; + lastGenerator = generator; + firstRoundDone = true; } + return lastObject; + } - SnapshotGenerator next = snapshotGenerators.next(); - for (Class removedGenerator : replacedGenerators) { - if (removedGenerator.equals(next.getClass())) { - return getNextValidGenerator(); - } + private static String identity(Object object) { + if (object == null) { + return "null"; } - return next; + return String.format("%s@%s", object.getClass(), System.identityHashCode(object)); } } diff --git a/liquibase-standard/src/test/groovy/liquibase/snapshot/SnapshotGeneratorChainTest.groovy b/liquibase-standard/src/test/groovy/liquibase/snapshot/SnapshotGeneratorChainTest.groovy index f461891c9d6..f3bec63abfd 100644 --- a/liquibase-standard/src/test/groovy/liquibase/snapshot/SnapshotGeneratorChainTest.groovy +++ b/liquibase-standard/src/test/groovy/liquibase/snapshot/SnapshotGeneratorChainTest.groovy @@ -90,6 +90,23 @@ class SnapshotGeneratorChainTest extends Specification { snapshot.getAttribute("visited", Boolean.class) == expectedTable.getAttribute("visited", Boolean.class) } + def "snapshotting fails if subsequent generator returns a different instance"() { + given: + def chain = new SnapshotGeneratorChain(sortedSetOf(visitingGenerator, badGenerator)) + def object = new Table() + database.isSystemObject(object) >> false + snapshotControl.shouldInclude(object.class) >> true + def expectedTable = new Table() + expectedTable.setAttribute("visited", true) + + + when: + chain.snapshot(object, snapshotContext) + + then: + def exception = thrown(DatabaseException) + exception.message.startsWith("Snapshot generator liquibase.snapshot.BadSnapshotGenerator has returned a different reference from the previous generator liquibase.snapshot.VisitedSnapshotGenerator.") + } def "snapshotting works even if first generator returns a different instance"() { given: @@ -108,6 +125,25 @@ class SnapshotGeneratorChainTest extends Specification { result != null } + def "snapshotting works if bad generator is replaced in the chain"() { + given: + def chain = new SnapshotGeneratorChain(sortedSetOf(visitingGenerator, badGenerator, replacementForBadGenerator)) + def object = new Table() + database.isSystemObject(object) >> false + snapshotControl.shouldInclude(object.class) >> true + def expectedTable = new Table() + expectedTable.setAttribute("visited", true) + expectedTable.setAttribute("replacement", "done") + + + when: + def snapshot = chain.snapshot(object, snapshotContext) + + then: + snapshot.getAttribute("visited", Boolean.class) == expectedTable.getAttribute("visited", Boolean.class) + snapshot.getAttribute("replacement", String.class) == expectedTable.getAttribute("replacement", String.class) + } + private static SortedSet sortedSetOf(SnapshotGenerator... generators) { def result = new TreeSet() result.addAll(generators) @@ -132,7 +168,7 @@ class VisitedSnapshotGenerator implements SnapshotGenerator, Comparable T snapshot(T example, DatabaseSnapshot snapshot, SnapshotGeneratorChain chain) throws DatabaseException, InvalidExampleException { + T snapshot(T example, DatabaseSnapshot snapshot, SnapshotGeneratorChain chain) throws DatabaseException, InvalidExampleException { example.setAttribute("visited", true) return example } @@ -150,9 +186,9 @@ class VisitedSnapshotGenerator implements SnapshotGenerator, Comparable T snapshot(T example, DatabaseSnapshot snapshot, SnapshotGeneratorChain chain) throws DatabaseException, InvalidExampleException { + T snapshot(T example, DatabaseSnapshot snapshot, SnapshotGeneratorChain chain) throws DatabaseException, InvalidExampleException { example.setAttribute("replacement", "done") return example } @@ -195,9 +231,9 @@ class ReplacingSnapshotGenerator implements SnapshotGenerator, Comparable T snapshot(T example, DatabaseSnapshot snapshot, SnapshotGeneratorChain chain) throws DatabaseException, InvalidExampleException { + T snapshot(T example, DatabaseSnapshot snapshot, SnapshotGeneratorChain chain) throws DatabaseException, InvalidExampleException { // generators are expected to add nested attributes, ... to the provided example or delegate if the example type does not match // they are NOT expected to create new instances of the same type return acceptedType.newInstance() as T @@ -237,8 +273,8 @@ class BadSnapshotGenerator implements SnapshotGenerator, Comparable