Skip to content

Commit

Permalink
Add support for using specific field/type comparator in recursive com…
Browse files Browse the repository at this point in the history
…parison
  • Loading branch information
joel-costigliola committed Mar 1, 2016
1 parent 8a0bcf6 commit 33e3e6a
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 130 deletions.
6 changes: 5 additions & 1 deletion src/main/java/org/assertj/core/api/AbstractObjectAssert.java
Expand Up @@ -450,6 +450,7 @@ public AbstractObjectArrayAssert<?, Object> extracting(String... propertiesOrFie
* <pre><code class='java'> public class Person {
* public String name;
* public Home home = new Home();
* public Person bestFriend;
* }
*
* public class Home {
Expand All @@ -467,6 +468,9 @@ public AbstractObjectArrayAssert<?, Object> extracting(String... propertiesOrFie
* Person jackClone = new Person();
* jackClone.name = "Jack";
* jackClone.home.address.number = 1;
* // cycle are handled in comparison
* jack.home.bestFriend = jackClone;
* jackClone.home.bestFriend = jack;
*
* // will fail as equals compares object references
* assertThat(jack).isEqualsTo(jackClone);
Expand All @@ -480,7 +484,7 @@ public AbstractObjectArrayAssert<?, Object> extracting(String... propertiesOrFie
* @throws IntrospectionError if one property/field to compare can not be found.
*/
public S isEqualToComparingFieldByFieldRecursively(Object other) {
objects.assertIsEqualToComparingFieldByFieldRecursively(info, actual, other);
objects.assertIsEqualToComparingFieldByFieldRecursively(info, actual, other, comparatorByPropertyOrField, comparatorByType);
return myself;
}
}
Expand Up @@ -20,22 +20,28 @@

import org.assertj.core.internal.DeepDifference.Difference;

public class ShouldBeEqualByComparingFieldByFieldRecursive extends BasicErrorMessageFactory {
public class ShouldBeEqualByComparingFieldByFieldRecursively extends BasicErrorMessageFactory {

public static ErrorMessageFactory shouldBeEqualByComparingFieldByFieldRecursive(Object actual, Object other,
List<Difference> differences) {
List<String> descriptionOfDifferences = new ArrayList<>(differences.size());
for (Difference difference : differences) {
descriptionOfDifferences.add(format("%nPath to difference: <%s>%nexpected: <%s>%nbut was: <%s>",
descriptionOfDifferences.add(format("%nPath to difference: <%s>%n" +
"- expected: <%s>%n" +
"- actual : <%s>",
join(difference.getPath()).with("."), difference.getOther(),
difference.getActual()));
}
return new ShouldBeEqualByComparingFieldByFieldRecursive("Expecting: <%s>%nto be equal to: <%s>%nwhen recursively comparing field by field, but found the following difference(s):%n"
+ join(descriptionOfDifferences).with(format("%n")),
actual, other);
return new ShouldBeEqualByComparingFieldByFieldRecursively("Expecting:%n" +
" <%s>%n" +
"to be equal to:%n"+
" <%s>%n" +
"when recursively comparing field by field, but found the following difference(s):%n"
+ join(descriptionOfDifferences).with(format("%n")),
actual, other);
}

private ShouldBeEqualByComparingFieldByFieldRecursive(String message, Object actual, Object other) {
private ShouldBeEqualByComparingFieldByFieldRecursively(String message, Object actual, Object other) {
super(message, actual, other);
}
}
96 changes: 36 additions & 60 deletions src/main/java/org/assertj/core/internal/DeepDifference.java
Expand Up @@ -13,13 +13,16 @@
package org.assertj.core.internal;

import static org.assertj.core.internal.Objects.getDeclaredFieldsIncludingInherited;
import static org.assertj.core.util.introspection.PropertyOrFieldSupport.EXTRACTION;
import static org.assertj.core.internal.Objects.propertyOrFieldValuesAreEqual;
import static org.assertj.core.util.Strings.join;
import static org.assertj.core.util.introspection.PropertyOrFieldSupport.COMPARISON;

import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -44,8 +47,6 @@ public class DeepDifference {

private static final Map<Class<?>, Boolean> customEquals = new ConcurrentHashMap<>();
private static final Map<Class<?>, Boolean> customHash = new ConcurrentHashMap<>();
private static final double doubleEplison = 1e-15;
private static final double floatEplison = 1e-6;

private final static class DualKey {

Expand All @@ -62,7 +63,7 @@ private DualKey(List<String> path, Object key1, Object key2) {
private DualKey(Object key1, Object key2) {
this(new ArrayList<String>(), key1, key2);
}

@Override
public boolean equals(Object other) {
if (!(other instanceof DualKey)) {
Expand All @@ -88,6 +89,10 @@ public String toString() {
public List<String> getPath() {
return path;
}

public String getConcatenatedPath() {
return join(path).with(".");
}
}

public static class Difference {
Expand Down Expand Up @@ -140,7 +145,9 @@ public String toString() {
* either at the field level or via the respectively encountered overridden
* .equals() methods during traversal.
*/
public static List<Difference> determineDifferences(Object a, Object b) {
public static List<Difference> determineDifferences(Object a, Object b,
Map<String, Comparator<?>> comparatorByPropertyOrField,
Map<Class<?>, Comparator<?>> comparatorByType) {
final Set<DualKey> visited = new HashSet<>();
final Deque<DualKey> toCompare = initStack(a, b, visited);
final List<Difference> differences = new ArrayList<>();
Expand Down Expand Up @@ -202,16 +209,6 @@ public static List<Difference> determineDifferences(Object a, Object b) {
continue;
}

if (key1 instanceof Double) {
if (compareFloatingPointNumbers(key1, key2, doubleEplison))
continue;
}

if (key1 instanceof Float) {
if (compareFloatingPointNumbers(key1, key2, floatEplison))
continue;
}

// Handle all [] types. In order to be equal, the arrays must be the
// same length, be of the same type, be in the same order, and all
// elements within the array must be deeply equivalent.
Expand Down Expand Up @@ -275,6 +272,12 @@ public static List<Difference> determineDifferences(Object a, Object b) {
continue;
}

if (hasCustomComparator(dualKey, comparatorByPropertyOrField, comparatorByType)) {
if (propertyOrFieldValuesAreEqual(key1, key2, dualKey.getConcatenatedPath(),
comparatorByPropertyOrField, comparatorByType))
continue;
}

if (hasCustomEquals(key1.getClass())) {
if (!key1.equals(key2)) {
differences.add(new Difference(currentPath, key1, key2));
Expand All @@ -288,8 +291,8 @@ public static List<Difference> determineDifferences(Object a, Object b) {
String fieldName = field.getName();
path.add(fieldName);
DualKey dk = new DualKey(path,
EXTRACTION.getSimpleValue(fieldName, key1),
EXTRACTION.getSimpleValue(fieldName, key2));
COMPARISON.getSimpleValue(fieldName, key1),
COMPARISON.getSimpleValue(fieldName, key2));
if (!visited.contains(dk)) {
toCompare.addFirst(dk);
}
Expand All @@ -299,6 +302,17 @@ public static List<Difference> determineDifferences(Object a, Object b) {
return differences;
}

private static boolean hasCustomComparator(DualKey dualKey, Map<String, Comparator<?>> comparatorByPropertyOrField,
Map<Class<?>, Comparator<?>> comparatorByType) {
if (dualKey.key1.getClass() == dualKey.key2.getClass()) {
String fieldName = dualKey.getConcatenatedPath();
Comparator<?> fieldComparator = comparatorByPropertyOrField.containsKey(fieldName)
? comparatorByPropertyOrField.get(fieldName) : comparatorByType.get(dualKey.key1.getClass());
return fieldComparator != null;
}
return false;
}

private static Deque<DualKey> initStack(Object a, Object b, Set<DualKey> visited) {
Deque<DualKey> stack = new LinkedList<>();
if (a != null && !isContainerType(a)) {
Expand All @@ -308,8 +322,8 @@ private static Deque<DualKey> initStack(Object a, Object b, Set<DualKey> visited
for (Field field : fieldsOfRootObject) {
String fieldName = field.getName();
DualKey dk = new DualKey(Arrays.asList(fieldName),
EXTRACTION.getSimpleValue(fieldName, a),
EXTRACTION.getSimpleValue(fieldName, b));
COMPARISON.getSimpleValue(fieldName, a),
COMPARISON.getSimpleValue(fieldName, b));
if (!visited.contains(dk)) {
stack.addFirst(dk);
}
Expand Down Expand Up @@ -372,16 +386,12 @@ private static boolean compareArrays(Object array1, Object array2, List<String>
private static <K, V> boolean compareOrderedCollection(Collection<K> col1, Collection<V> col2,
List<String> path, Deque<DualKey> toCompare,
Set<DualKey> visited) {
if (col1.size() != col2.size()) {
return false;
}
if (col1.size() != col2.size()) return false;

Iterator<V> i2 = col2.iterator();
for (K k : col1) {
DualKey dk = new DualKey(path, k, i2.next());
if (!visited.contains(dk)) {
toCompare.addFirst(dk);
}
if (!visited.contains(dk)) toCompare.addFirst(dk);
}
return true;
}
Expand Down Expand Up @@ -517,40 +527,6 @@ private static <K1, V1, K2, V2> boolean compareUnorderedMap(Map<K1, V1> map1, Ma
return true;
}

/**
* Compare if two floating point numbers are within a given range
*/
private static boolean compareFloatingPointNumbers(Object a, Object b, double epsilon) {
double a1 = a instanceof Double ? (Double) a : (Float) a;
double b1 = b instanceof Double ? (Double) b : (Float) b;
return nearlyEqual(a1, b1, epsilon);
}

/**
* Correctly handles floating point comparison. (source: http://floating-point-gui.de/errors/comparison/)
*
* @param a first number
* @param b second number
* @param epsilon double tolerance value
* @return true if a and b are close enough
*/
private static boolean nearlyEqual(double a, double b, double epsilon) {
final double absA = Math.abs(a);
final double absB = Math.abs(b);
final double diff = Math.abs(a - b);

if (a == b) {
// shortcut, handles infinities
return true;
} else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) {
// a or b is zero or both are extremely close to it
// relative error is less meaningful here
return diff < (epsilon * Double.MIN_NORMAL);
} else { // use relative error
return diff / (absA + absB) < epsilon;
}
}

/**
* Determine if the passed in class has a non-Object.equals() method. This
* method caches its results in static ConcurrentHashMap to benefit
Expand Down Expand Up @@ -644,7 +620,7 @@ static int deepHashCode(Object obj) {

Collection<Field> fields = getDeclaredFieldsIncludingInherited(obj.getClass());
for (Field field : fields) {
stack.addFirst(EXTRACTION.getSimpleValue(field.getName(), obj));
stack.addFirst(COMPARISON.getSimpleValue(field.getName(), obj));
}
}
return hash;
Expand Down
10 changes: 6 additions & 4 deletions src/main/java/org/assertj/core/internal/Objects.java
Expand Up @@ -15,7 +15,7 @@
import static java.lang.String.format;
import static java.util.Arrays.asList;
import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual;
import static org.assertj.core.error.ShouldBeEqualByComparingFieldByFieldRecursive.shouldBeEqualByComparingFieldByFieldRecursive;
import static org.assertj.core.error.ShouldBeEqualByComparingFieldByFieldRecursively.shouldBeEqualByComparingFieldByFieldRecursive;
import static org.assertj.core.error.ShouldBeEqualByComparingOnlyGivenFields.shouldBeEqualComparingOnlyGivenFields;
import static org.assertj.core.error.ShouldBeEqualToIgnoringFields.shouldBeEqualToIgnoringGivenFields;
import static org.assertj.core.error.ShouldBeExactlyInstanceOf.shouldBeExactlyInstance;
Expand Down Expand Up @@ -651,7 +651,7 @@ private <A> ByFieldsComparison isEqualToIgnoringGivenFields(A actual, A other,
}

@SuppressWarnings({ "unchecked", "rawtypes" })
private boolean propertyOrFieldValuesAreEqual(Object actualFieldValue, Object otherFieldValue, String fieldName,
static boolean propertyOrFieldValuesAreEqual(Object actualFieldValue, Object otherFieldValue, String fieldName,
Map<String, Comparator<?>> comparatorByPropertyOrField,
Map<Class<?>, Comparator<?>> comparatorByType) {
if (actualFieldValue != null && otherFieldValue != null
Expand All @@ -676,9 +676,11 @@ private <A> boolean canReadFieldValue(Field field, A actual) {
* @throws AssertionError if actual is {@code null}.
* @throws AssertionError if the actual and the given object are not "deeply" equal.
*/
public <A> void assertIsEqualToComparingFieldByFieldRecursively(AssertionInfo info, A actual, A other) {
public <A> void assertIsEqualToComparingFieldByFieldRecursively(AssertionInfo info, A actual, A other,
Map<String, Comparator<?>> comparatorByPropertyOrField,
Map<Class<?>, Comparator<?>> comparatorByType) {
assertNotNull(info, actual);
List<Difference> differences = determineDifferences(actual, other);
List<Difference> differences = determineDifferences(actual, other, comparatorByPropertyOrField, comparatorByType);
if (!differences.isEmpty()) {
throw failures.failure(info, shouldBeEqualByComparingFieldByFieldRecursive(actual, other, differences));
}
Expand Down
33 changes: 33 additions & 0 deletions src/test/java/org/assertj/core/internal/AtPrecisionComparator.java
@@ -0,0 +1,33 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
* Copyright 2012-2015 the original author or authors.
*/
package org.assertj.core.internal;

import static java.lang.Math.abs;

import java.util.Comparator;

public class AtPrecisionComparator<NUMBER extends Number> implements Comparator<NUMBER> {

private NUMBER precision;

public AtPrecisionComparator(NUMBER precision) {
this.precision = precision;
}

@Override
public int compare(NUMBER i1, NUMBER i2) {
double diff = abs(i1.doubleValue() - i2.doubleValue());
if (diff <= precision.doubleValue()) return 0;
return diff < 0.0 ? -1 : 1;
}
}
Expand Up @@ -21,6 +21,8 @@
import static java.lang.Math.sin;
import static java.lang.Math.tan;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.internal.ObjectsBaseTest.noFieldComparators;
import static org.assertj.core.internal.ObjectsBaseTest.noTypeComparators;
import static org.assertj.core.util.Lists.newArrayList;
import static org.assertj.core.util.Sets.newLinkedHashSet;

Expand Down Expand Up @@ -254,11 +256,11 @@ public void testHasCustomMethod() {
}

private void assertHaveNoDifferences(Object x, Object y) {
assertThat(DeepDifference.determineDifferences(x, y)).isEmpty();
assertThat(DeepDifference.determineDifferences(x, y, noFieldComparators(), noTypeComparators())).isEmpty();
}

private void assertHaveDifferences(Object x, Object y) {
assertThat(DeepDifference.determineDifferences(x, y)).isNotEmpty();
assertThat(DeepDifference.determineDifferences(x, y, noFieldComparators(), noTypeComparators())).isNotEmpty();
}

private static class EmptyClass {
Expand Down
12 changes: 12 additions & 0 deletions src/test/java/org/assertj/core/internal/ObjectsBaseTest.java
Expand Up @@ -16,7 +16,9 @@

import static org.mockito.Mockito.spy;

import java.util.Collections;
import java.util.Comparator;
import java.util.Map;

import org.assertj.core.api.Assertions;
import org.assertj.core.test.ExpectedException;
Expand Down Expand Up @@ -59,4 +61,14 @@ protected Comparator<?> comparatorForCustomComparisonStrategy() {
return new CaseInsensitiveStringComparator();
}

@SuppressWarnings("unchecked")
protected static Map<String, Comparator<?>> noFieldComparators() {
return Collections.EMPTY_MAP;
}

@SuppressWarnings("unchecked")
protected static Map<Class<?>, Comparator<?>> noTypeComparators() {
return Collections.EMPTY_MAP;
}

}

0 comments on commit 33e3e6a

Please sign in to comment.