From 58c7ca54011bba1208de298d7b2e6877570d5b47 Mon Sep 17 00:00:00 2001 From: Pavel Marek Date: Tue, 21 Feb 2023 01:56:11 +0100 Subject: [PATCH] Performance improvements for Comparators (#5687) Critical performance improvements after #4067 # Important Notes - Replace if-then-else expressions in `Any.==` with case expressions. - Fix caching in `EqualsNode`. - This includes fixing specializations, along with fallback guard. --- .../lib/Standard/Base/0.0.0-dev/src/Any.enso | 22 +- .../Base/0.0.0-dev/src/Data/Numbers.enso | 6 +- .../Base/0.0.0-dev/src/Data/Ordering.enso | 13 +- .../Base/0.0.0-dev/src/Runtime/State.enso | 2 +- .../expression/builtin/meta/EqualsNode.java | 421 +++++++++++++----- .../expression/builtin/meta/HashCodeNode.java | 65 ++- .../enso/interpreter/runtime/data/Array.java | 12 +- .../org/enso/interpreter/test/EqualsTest.java | 24 +- .../enso/interpreter/test/HashCodeTest.java | 25 +- .../interpreter/test/ValuesGenerator.java | 34 +- 10 files changed, 462 insertions(+), 162 deletions(-) diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Any.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Any.enso index dfb75583cea2..2988b1ae422f 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Any.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Any.enso @@ -110,16 +110,20 @@ type Any # host or polyglot values, so we just compare them with the default comparator. eq_self = Panic.catch No_Such_Conversion (Comparable.from self) _-> Default_Unordered_Comparator eq_that = Panic.catch No_Such_Conversion (Comparable.from that) _-> Default_Unordered_Comparator - if Meta.is_same_object eq_self Incomparable then False else - similar_type = Meta.is_same_object eq_self eq_that - if similar_type.not then False else - case eq_self.is_ordered of + case Meta.is_same_object eq_self Incomparable of + True -> False + False -> + similar_type = Meta.is_same_object eq_self eq_that + case similar_type of + False -> False True -> - # Comparable.equals_builtin is a hack how to directly access EqualsNode from the - # engine, so that we don't end up in an infinite recursion here (which would happen - # if we would compare with `eq_self == eq_that`). - Comparable.equals_builtin (eq_self.compare self that) Ordering.Equal - False -> eq_self.equals self that + case eq_self.is_ordered of + True -> + # Comparable.equals_builtin is a hack how to directly access EqualsNode from the + # engine, so that we don't end up in an infinite recursion here (which would happen + # if we would compare with `eq_self == eq_that`). + Comparable.equals_builtin (eq_self.compare self that) Ordering.Equal + False -> eq_self.equals self that ## ALIAS Inequality diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso index ccf7887e3a3a..d6e8bc29c860 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Numbers.enso @@ -1,5 +1,6 @@ import project.Data.Ordering.Ordering import project.Data.Ordering.Comparable +import project.Data.Ordering.Incomparable import project.Data.Ordering.Default_Ordered_Comparator import project.Data.Text.Text import project.Data.Locale.Locale @@ -940,7 +941,10 @@ type Integer parse_builtin text radix = @Builtin_Method "Integer.parse" -Comparable.from (_:Number) = Default_Ordered_Comparator +Comparable.from (that:Number) = + case that.is_nan of + True -> Incomparable + False -> Default_Ordered_Comparator ## UNSTABLE diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Ordering.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Ordering.enso index 3686f7071493..433d6b03dd2b 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Ordering.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Ordering.enso @@ -9,6 +9,7 @@ import project.Error.Unimplemented.Unimplemented import project.Nothing import project.Meta import project.Meta.Atom +import project.Panic.Panic from project.Data.Boolean import all ## Provides custom ordering, equality check and hash code for types that need it. @@ -165,9 +166,15 @@ type Default_Ordered_Comparator ## Handles only primitive types, not atoms or vectors. compare : Any -> Any -> Ordering compare x y = - if Comparable.less_than_builtin x y then Ordering.Less else - if Comparable.equals_builtin x y then Ordering.Equal else - if Comparable.less_than_builtin y x then Ordering.Greater + case Comparable.less_than_builtin x y of + True -> Ordering.Less + False -> + case Comparable.equals_builtin x y of + True -> Ordering.Equal + False -> + case Comparable.less_than_builtin y x of + True -> Ordering.Greater + False -> Panic.throw "Unreachable" hash : Number -> Integer hash x = Comparable.hash_builtin x diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Runtime/State.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Runtime/State.enso index 63b30f20b1d6..f62ea1dfdb6a 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Runtime/State.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Runtime/State.enso @@ -47,7 +47,7 @@ get key = @Builtin_Method "State.get" - key: The key with which to associate the new state. - new_state: The new state to store. - Returns an uninitialized state error if the user tries to read from an + Returns an uninitialized state error if the user tries to put into an uninitialized slot. > Example diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/EqualsNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/EqualsNode.java index c3be97e8d4b6..d9edd092d060 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/EqualsNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/EqualsNode.java @@ -1,9 +1,9 @@ package org.enso.interpreter.node.expression.builtin.meta; import com.ibm.icu.text.Normalizer; +import com.oracle.truffle.api.CompilerAsserts; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; -import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.GenerateUncached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.ArityException; @@ -15,18 +15,19 @@ import com.oracle.truffle.api.interop.UnsupportedMessageException; import com.oracle.truffle.api.interop.UnsupportedTypeException; import com.oracle.truffle.api.library.CachedLibrary; +import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.profiles.ConditionProfile; -import com.oracle.truffle.api.profiles.LoopConditionProfile; +import java.math.BigInteger; import java.time.LocalDateTime; import java.time.ZonedDateTime; import java.util.Arrays; -import java.util.Map; import org.enso.interpreter.dsl.AcceptsError; import org.enso.interpreter.dsl.BuiltinMethod; import org.enso.interpreter.node.callable.InvokeCallableNode.ArgumentsExecutionMode; import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode; import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode; +import org.enso.interpreter.node.expression.builtin.number.utils.BigIntegerOps; import org.enso.interpreter.node.expression.builtin.ordering.HasCustomComparatorNode; import org.enso.interpreter.runtime.EnsoContext; import org.enso.interpreter.runtime.Module; @@ -78,48 +79,77 @@ public static EqualsNode build() { @Specialization - boolean equalsBoolean(boolean self, boolean other) { + boolean equalsBoolBool(boolean self, boolean other) { return self == other; } @Specialization - boolean equalsBytes(byte self, byte other) { - return self == other; + boolean equalsBoolDouble(boolean self, double other) { + return false; + } + + @Specialization + boolean equalsBoolLong(boolean self, long other) { + return false; + } + + @Specialization + boolean equalsBoolBigInt(boolean self, EnsoBigInteger other) { + return false; + } + + @Specialization + boolean equalsBoolText(boolean self, Text other) { + return false; } @Specialization - boolean equalsLong(long self, long other) { + boolean equalsByteByte(byte self, byte other) { return self == other; } @Specialization - boolean equalsDouble(double self, double other) { + boolean equalsLongLong(long self, long other) { return self == other; } + @Specialization + boolean equalsLongBool(long self, boolean other) { + return false; + } + + @Specialization + boolean equalsLongInt(long self, int other) { + return self == (long) other; + } + @Specialization boolean equalsLongDouble(long self, double other) { return (double) self == other; } @Specialization - boolean equalsDoubleLong(double self, long other) { - return self == (double) other; + boolean equalsLongText(long self, Text other) { + return false; } @Specialization - boolean equalsIntLong(int self, long other) { - return (long) self == other; + boolean equalsDoubleDouble(double self, double other) { + if (Double.isNaN(self) || Double.isNaN(other)) { + return false; + } else { + return self == other; + } } @Specialization - boolean equalsLongInt(long self, int other) { - return self == (long) other; + boolean equalsDoubleLong(double self, long other) { + return self == (double) other; } @Specialization - boolean equalsIntDouble(int self, double other) { - return (double) self == other; + boolean equalsDoubleBool(double self, boolean other) { + return false; } @Specialization @@ -129,7 +159,33 @@ boolean equalsDoubleInt(double self, int other) { @Specialization @TruffleBoundary - boolean equalsBigInt(EnsoBigInteger self, EnsoBigInteger otherBigInt) { + boolean equalsDoubleBigInt(double self, EnsoBigInteger other) { + return self == other.doubleValue(); + } + + @Specialization + boolean equalsDoubleText(double self, Text other) { + return false; + } + + @Specialization + boolean equalsIntInt(int self, int other) { + return self == other; + } + + @Specialization + boolean equalsIntLong(int self, long other) { + return (long) self == other; + } + + @Specialization + boolean equalsIntDouble(int self, double other) { + return (double) self == other; + } + + @Specialization + @TruffleBoundary + boolean equalsBigIntBigInt(EnsoBigInteger self, EnsoBigInteger otherBigInt) { return self.equals(otherBigInt); } @@ -141,8 +197,63 @@ boolean equalsBitIntDouble(EnsoBigInteger self, double other) { @Specialization @TruffleBoundary - boolean equalsDoubleBigInt(double self, EnsoBigInteger other) { - return self == other.doubleValue(); + boolean equalsBigIntLong(EnsoBigInteger self, long other) { + if (BigIntegerOps.fitsInLong(self.getValue())) { + return self.getValue().compareTo(BigInteger.valueOf(other)) == 0; + } else { + return false; + } + } + + @Specialization + boolean equalsBigIntBool(EnsoBigInteger self, boolean other) { + return false; + } + + @Specialization + boolean equalsBigIntText(EnsoBigInteger self, Text other) { + return false; + } + + @Specialization + @TruffleBoundary + boolean equalsLongBigInt(long self, EnsoBigInteger other) { + if (BigIntegerOps.fitsInLong(other.getValue())) { + return BigInteger.valueOf(self).compareTo(other.getValue()) == 0; + } else { + return false; + } + } + + @Specialization(limit = "3") + boolean equalsTextText(Text selfText, Text otherText, + @CachedLibrary("selfText") InteropLibrary selfInterop, + @CachedLibrary("otherText") InteropLibrary otherInterop) { + if (selfText.is_normalized() && otherText.is_normalized()) { + return selfText.toString().compareTo(otherText.toString()) == 0; + } else { + return equalsStrings(selfText, otherText, selfInterop, otherInterop); + } + } + + @Specialization + boolean equalsTextBool(Text self, boolean other) { + return false; + } + + @Specialization + boolean equalsTextLong(Text selfText, long otherLong) { + return false; + } + + @Specialization + boolean equalsTextDouble(Text selfText, double otherDouble) { + return false; + } + + @Specialization + boolean equalsTextBigInt(Text self, EnsoBigInteger other) { + return false; } /** @@ -225,17 +336,6 @@ boolean equalsWithWarnings(Object selfWithWarnings, Object otherWithWarnings, } } - @Specialization(limit = "3") - boolean equalsTexts(Text selfText, Text otherText, - @CachedLibrary("selfText") InteropLibrary selfInterop, - @CachedLibrary("otherText") InteropLibrary otherInterop) { - if (selfText.is_normalized() && otherText.is_normalized()) { - return selfText.toString().compareTo(otherText.toString()) == 0; - } else { - return equalsStrings(selfText, otherText, selfInterop, otherInterop); - } - } - /** Interop libraries **/ @Specialization(guards = { @@ -249,7 +349,6 @@ boolean equalsNull( return selfInterop.isNull(selfNull) && otherInterop.isNull(otherNull); } - @Specialization(guards = { "selfInterop.isBoolean(selfBoolean)", "otherInterop.isBoolean(otherBoolean)" @@ -268,12 +367,8 @@ boolean equalsBooleanInterop( } @Specialization(guards = { - "!selfInterop.isDate(selfTimeZone)", - "!selfInterop.isTime(selfTimeZone)", - "selfInterop.isTimeZone(selfTimeZone)", - "!otherInterop.isDate(otherTimeZone)", - "!otherInterop.isTime(otherTimeZone)", - "otherInterop.isTimeZone(otherTimeZone)" + "isTimeZone(selfTimeZone, selfInterop)", + "isTimeZone(otherTimeZone, otherInterop)", }, limit = "3") boolean equalsTimeZones(Object selfTimeZone, Object otherTimeZone, @CachedLibrary("selfTimeZone") InteropLibrary selfInterop, @@ -289,12 +384,8 @@ boolean equalsTimeZones(Object selfTimeZone, Object otherTimeZone, @TruffleBoundary @Specialization(guards = { - "selfInterop.isDate(selfZonedDateTime)", - "selfInterop.isTime(selfZonedDateTime)", - "selfInterop.isTimeZone(selfZonedDateTime)", - "otherInterop.isDate(otherZonedDateTime)", - "otherInterop.isTime(otherZonedDateTime)", - "otherInterop.isTimeZone(otherZonedDateTime)" + "isZonedDateTime(selfZonedDateTime, selfInterop)", + "isZonedDateTime(otherZonedDateTime, otherInterop)", }, limit = "3") boolean equalsZonedDateTimes(Object selfZonedDateTime, Object otherZonedDateTime, @CachedLibrary("selfZonedDateTime") InteropLibrary selfInterop, @@ -318,12 +409,8 @@ boolean equalsZonedDateTimes(Object selfZonedDateTime, Object otherZonedDateTime } @Specialization(guards = { - "selfInterop.isDate(selfDateTime)", - "selfInterop.isTime(selfDateTime)", - "!selfInterop.isTimeZone(selfDateTime)", - "otherInterop.isDate(otherDateTime)", - "otherInterop.isTime(otherDateTime)", - "!otherInterop.isTimeZone(otherDateTime)" + "isDateTime(selfDateTime, selfInterop)", + "isDateTime(otherDateTime, otherInterop)", }, limit = "3") boolean equalsDateTimes(Object selfDateTime, Object otherDateTime, @CachedLibrary("selfDateTime") InteropLibrary selfInterop, @@ -344,12 +431,8 @@ boolean equalsDateTimes(Object selfDateTime, Object otherDateTime, } @Specialization(guards = { - "selfInterop.isDate(selfDate)", - "!selfInterop.isTime(selfDate)", - "!selfInterop.isTimeZone(selfDate)", - "otherInterop.isDate(otherDate)", - "!otherInterop.isTime(otherDate)", - "!otherInterop.isTimeZone(otherDate)" + "isDate(selfDate, selfInterop)", + "isDate(otherDate, otherInterop)", }, limit = "3") boolean equalsDates(Object selfDate, Object otherDate, @CachedLibrary("selfDate") InteropLibrary selfInterop, @@ -364,12 +447,8 @@ boolean equalsDates(Object selfDate, Object otherDate, } @Specialization(guards = { - "!selfInterop.isDate(selfTime)", - "selfInterop.isTime(selfTime)", - "!selfInterop.isTimeZone(selfTime)", - "!otherInterop.isDate(otherTime)", - "otherInterop.isTime(otherTime)", - "!otherInterop.isTimeZone(otherTime)" + "isTime(selfTime, selfInterop)", + "isTime(otherTime, otherInterop)", }, limit = "3") boolean equalsTimes(Object selfTime, Object otherTime, @CachedLibrary("selfTime") InteropLibrary selfInterop, @@ -505,21 +584,8 @@ boolean equalsHashMaps(Object selfHashMap, Object otherHashMap, } @Specialization(guards = { - "!isAtom(selfObject)", - "!isAtom(otherObject)", - "!isHostObject(selfObject)", - "!isHostObject(otherObject)", - "interop.hasMembers(selfObject)", - "interop.hasMembers(otherObject)", - "!interop.isDate(selfObject)", - "!interop.isDate(otherObject)", - "!interop.isTime(selfObject)", - "!interop.isTime(otherObject)", - // Objects with types are handled in `equalsTypes` specialization, so we have to - // negate the guards of that specialization here - to make the specializations - // disjunctive. - "!typesLib.hasType(selfObject)", - "!typesLib.hasType(otherObject)", + "isObjectWithMembers(selfObject, interop)", + "isObjectWithMembers(otherObject, interop)", }) boolean equalsInteropObjectWithMembers(Object selfObject, Object otherObject, @CachedLibrary(limit = "10") InteropLibrary interop, @@ -584,57 +650,65 @@ static EqualsNode[] createEqualsNodes(int size) { return nodes; } - @Specialization + @Specialization(guards = { + "selfCtorCached == self.getConstructor()" + }, limit = "10") + @ExplodeLoop boolean equalsAtoms( Atom self, Atom other, - @Cached LoopConditionProfile loopProfile, - @Cached(value = "createEqualsNodes(equalsNodeCountForFields)", allowUncached = true) EqualsNode[] fieldEqualsNodes, - @Cached ConditionProfile enoughEqualNodesForFieldsProfile, + @Cached("self.getConstructor()") AtomConstructor selfCtorCached, + @Cached(value = "selfCtorCached.getFields().length", allowUncached = true) int fieldsLenCached, + @Cached(value = "createEqualsNodes(fieldsLenCached)", allowUncached = true) EqualsNode[] fieldEqualsNodes, @Cached ConditionProfile constructorsNotEqualProfile, - @CachedLibrary(limit = "3") StructsLibrary selfStructs, - @CachedLibrary(limit = "3") StructsLibrary otherStructs, @Cached HasCustomComparatorNode hasCustomComparatorNode, - @Cached InvokeAnyEqualsNode invokeAnyEqualsNode + @Cached InvokeAnyEqualsNode invokeAnyEqualsNode, + @CachedLibrary(limit = "5") StructsLibrary structsLib ) { if (constructorsNotEqualProfile.profile( self.getConstructor() != other.getConstructor() )) { return false; } - var selfFields = selfStructs.getFields(self); - var otherFields = otherStructs.getFields(other); - assert selfFields.length == otherFields.length; - - int fieldsSize = selfFields.length; - if (enoughEqualNodesForFieldsProfile.profile(fieldsSize <= equalsNodeCountForFields)) { - loopProfile.profileCounted(fieldsSize); - for (int i = 0; loopProfile.inject(i < fieldsSize); i++) { - boolean fieldsAreEqual; - // We don't check whether `other` has the same type of comparator, that is checked in - // `Any.==` that we invoke here anyway. - if (selfFields[i] instanceof Atom selfAtomField - && otherFields[i] instanceof Atom otherAtomField - && hasCustomComparatorNode.execute(selfAtomField)) { - // If selfFields[i] has a custom comparator, we delegate to `Any.==` that deals with - // custom comparators. EqualsNode cannot deal with custom comparators. - fieldsAreEqual = invokeAnyEqualsNode.execute(selfAtomField, otherAtomField); - } else { - fieldsAreEqual = fieldEqualsNodes[i].execute(selfFields[i], otherFields[i]); - } - if (!fieldsAreEqual) { - return false; - } + var selfFields = structsLib.getFields(self); + var otherFields = structsLib.getFields(other); + assert selfFields.length == otherFields.length : "Constructors are same, atoms should have the same number of fields"; + + CompilerAsserts.partialEvaluationConstant(fieldsLenCached); + for (int i = 0; i < fieldsLenCached; i++) { + boolean fieldsAreEqual; + // We don't check whether `other` has the same type of comparator, that is checked in + // `Any.==` that we invoke here anyway. + if (selfFields[i] instanceof Atom selfAtomField + && otherFields[i] instanceof Atom otherAtomField + && hasCustomComparatorNode.execute(selfAtomField)) { + // If selfFields[i] has a custom comparator, we delegate to `Any.==` that deals with + // custom comparators. EqualsNode cannot deal with custom comparators. + fieldsAreEqual = invokeAnyEqualsNode.execute(selfAtomField, otherAtomField); + } else { + fieldsAreEqual = fieldEqualsNodes[i].execute( + selfFields[i], + otherFields[i] + ); + } + if (!fieldsAreEqual) { + return false; } - } else { - return equalsAtomsFieldsUncached(selfFields, otherFields); } return true; } @TruffleBoundary - private static boolean equalsAtomsFieldsUncached(Object[] selfFields, Object[] otherFields) { - assert selfFields.length == otherFields.length; + @Specialization(replaces = "equalsAtoms") + boolean equalsAtomsUncached(Atom self, Atom other) { + if (!equalsAtomConstructors(self.getConstructor(), other.getConstructor())) { + return false; + } + Object[] selfFields = StructsLibrary.getUncached().getFields(self); + Object[] otherFields = StructsLibrary.getUncached().getFields(other); + if (selfFields.length != otherFields.length) { + return false; + } for (int i = 0; i < selfFields.length; i++) { boolean areFieldsSame; if (selfFields[i] instanceof Atom selfFieldAtom @@ -683,17 +757,136 @@ boolean equalsHostFunctions(Object selfHostFunc, Object otherHostFunc, return equalsNode.execute(selfFuncStrRepr, otherFuncStrRepr); } - @Fallback + @Specialization(guards = "fallbackGuard(left, right, interop)") @TruffleBoundary boolean equalsGeneric(Object left, Object right, - @CachedLibrary(limit = "5") InteropLibrary interop, - @CachedLibrary(limit = "5") TypesLibrary typesLib) { + @CachedLibrary(limit = "10") InteropLibrary interop, + @CachedLibrary(limit = "10") TypesLibrary typesLib) { return left == right || interop.isIdentical(left, right, interop) || left.equals(right) || (isNullOrNothing(left, typesLib, interop) && isNullOrNothing(right, typesLib, interop)); } + // We have to manually specify negation of guards of other specializations, because + // we cannot use @Fallback here. Note that this guard is not precisely the negation of + // all the other guards on purpose. + boolean fallbackGuard(Object left, Object right, InteropLibrary interop) { + if (isPrimitive(left) && isPrimitive(right)) { + return false; + } + if (isHostObject(left) && isHostObject(right)) { + return false; + } + if (isHostFunction(left) && isHostFunction(right)) { + return false; + } + if (left instanceof Atom && right instanceof Atom) { + return false; + } + if (interop.isNull(left) && interop.isNull(right)) { + return false; + } + if (interop.isString(left) && interop.isString(right)) { + return false; + } + if (interop.hasArrayElements(left) && interop.hasArrayElements(right)) { + return false; + } + if (interop.hasHashEntries(left) && interop.hasHashEntries(right)) { + return false; + } + if (isObjectWithMembers(left, interop) && isObjectWithMembers(right, interop)) { + return false; + } + if (isTimeZone(left, interop) && isTimeZone(right, interop)) { + return false; + } + if (isZonedDateTime(left, interop) && isZonedDateTime(right, interop)) { + return false; + } + if (isDateTime(left, interop) && isDateTime(right, interop)) { + return false; + } + if (isDate(left, interop) && isDate(right, interop)) { + return false; + } + if (isTime(left, interop) && isTime(right, interop)) { + return false; + } + if (interop.isDuration(left) && interop.isDuration(right)) { + return false; + } + // For all other cases, fall through to the generic specialization + return true; + } + + /** + * Return true iff object is a primitive value used in some of the specializations + * guard. By primitive value we mean any value that can be present in Enso, so, + * for example, not Integer, as that cannot be present in Enso. + * All the primitive types should be handled in their corresponding specializations. + * See {@link org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode}. + */ + private static boolean isPrimitive(Object object) { + return object instanceof Boolean || + object instanceof Long || + object instanceof Double || + object instanceof EnsoBigInteger || + object instanceof Text; + } + + boolean isTimeZone(Object object, InteropLibrary interop) { + return + !interop.isTime(object) && + !interop.isDate(object) && + interop.isTimeZone(object); + } + + boolean isZonedDateTime(Object object, InteropLibrary interop) { + return + interop.isTime(object) && + interop.isDate(object) && + interop.isTimeZone(object); + } + + boolean isDateTime(Object object, InteropLibrary interop) { + return + interop.isTime(object) && + interop.isDate(object) && + !interop.isTimeZone(object); + } + + boolean isDate(Object object, InteropLibrary interop) { + return + !interop.isTime(object) && + interop.isDate(object) && + !interop.isTimeZone(object); + } + + boolean isTime(Object object, InteropLibrary interop) { + return + interop.isTime(object) && + !interop.isDate(object) && + !interop.isTimeZone(object); + } + + boolean isObjectWithMembers(Object object, InteropLibrary interop) { + if (object instanceof Atom) { + return false; + } + if (isHostObject(object)) { + return false; + } + if (interop.isDate(object)) { + return false; + } + if (interop.isTime(object)) { + return false; + } + return interop.hasMembers(object); + } + private boolean isNullOrNothing(Object object, TypesLibrary typesLib, InteropLibrary interop) { if (typesLib.hasType(object)) { return typesLib.getType(object) == EnsoContext.get(this).getNothing(); @@ -734,11 +927,11 @@ boolean invokeEqualsCachedAtomCtor(Atom selfAtom, Atom thatAtom, @Cached(value = "getAnyEqualsMethod()", allowUncached = true) Function anyEqualsFunc, @Cached(value = "buildInvokeFuncNodeForAnyEquals()", allowUncached = true) InvokeFunctionNode invokeAnyEqualsNode, @CachedLibrary(limit = "3") InteropLibrary interop) { - // TODO: Shouldn't Comparable type be the very first argument? (synthetic self)? Object ret = invokeAnyEqualsNode.execute( anyEqualsFunc, null, State.create(EnsoContext.get(this)), + // TODO: Shouldn't Any type be the very first argument? (synthetic self)? new Object[]{selfAtom, thatAtom} ); try { diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/HashCodeNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/HashCodeNode.java index ac6dac78c4ab..7fc7b75fe8f8 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/HashCodeNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/HashCodeNode.java @@ -2,6 +2,7 @@ import com.google.common.base.Objects; import com.ibm.icu.text.Normalizer2; +import com.oracle.truffle.api.CompilerAsserts; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.GenerateUncached; @@ -14,6 +15,7 @@ import com.oracle.truffle.api.interop.UnsupportedMessageException; import com.oracle.truffle.api.interop.UnsupportedTypeException; import com.oracle.truffle.api.library.CachedLibrary; +import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.LoopConditionProfile; @@ -103,7 +105,10 @@ long hashCodeForFloat(float f) { @Specialization long hashCodeForDouble(double d) { - if (d % 1.0 != 0 || BigIntegerOps.fitsInLong(d)) { + if (Double.isNaN(d)) { + // NaN is Incomparable, just return a "random" constant + return 456879; + } else if (d % 1.0 != 0 || BigIntegerOps.fitsInLong(d)) { return Double.hashCode(d); } else { return bigDoubleHash(d); @@ -127,6 +132,7 @@ long hashCodeForBigInteger(EnsoBigInteger bigInteger) { @Specialization long hashCodeForAtomConstructor(AtomConstructor atomConstructor) { + // AtomConstructors are singletons, we take system hash code explicitly. return System.identityHashCode(atomConstructor); } @@ -181,23 +187,23 @@ long hashCodeForType(Type type, } } - /** How many {@link HashCodeNode} nodes should be created for fields in atoms. */ - static final int hashCodeNodeCountForFields = 10; - static HashCodeNode[] createHashCodeNodes(int size) { HashCodeNode[] nodes = new HashCodeNode[size]; Arrays.fill(nodes, HashCodeNode.build()); return nodes; } - @Specialization + @Specialization(guards = { + "atomCtorCached == atom.getConstructor()" + }, limit = "5") + @ExplodeLoop long hashCodeForAtom( Atom atom, - @Cached(value = "createHashCodeNodes(hashCodeNodeCountForFields)", allowUncached = true) + @Cached("atom.getConstructor()") AtomConstructor atomCtorCached, + @Cached("atomCtorCached.getFields().length") int fieldsLenCached, + @Cached(value = "createHashCodeNodes(fieldsLenCached)", allowUncached = true) HashCodeNode[] fieldHashCodeNodes, @Cached ConditionProfile isHashCodeCached, - @Cached ConditionProfile enoughHashCodeNodesForFields, - @Cached LoopConditionProfile loopProfile, @CachedLibrary(limit = "10") StructsLibrary structs, @Cached HasCustomComparatorNode hasCustomComparatorNode, @Cached HashCallbackNode hashCallbackNode) { @@ -208,22 +214,18 @@ long hashCodeForAtom( Object[] fields = structs.getFields(atom); int fieldsCount = fields.length; + CompilerAsserts.partialEvaluationConstant(fieldsLenCached); // hashes stores hash codes for all fields, and for constructor. int[] hashes = new int[fieldsCount + 1]; - if (enoughHashCodeNodesForFields.profile(fieldsCount <= hashCodeNodeCountForFields)) { - loopProfile.profileCounted(fieldsCount); - for (int i = 0; loopProfile.inject(i < fieldsCount); i++) { - if (fields[i] instanceof Atom atomField && hasCustomComparatorNode.execute(atomField)) { - hashes[i] = (int) hashCallbackNode.execute(atomField); - } else { - hashes[i] = (int) fieldHashCodeNodes[i].execute(fields[i]); - } + for (int i = 0; i < fieldsLenCached; i++) { + if (fields[i] instanceof Atom atomField && hasCustomComparatorNode.execute(atomField)) { + hashes[i] = (int) hashCallbackNode.execute(atomField); + } else { + hashes[i] = (int) fieldHashCodeNodes[i].execute(fields[i]); } - } else { - hashCodeForAtomFieldsUncached(fields, hashes); } - int ctorHashCode = System.identityHashCode(atom.getConstructor()); + int ctorHashCode = (int) hashCodeForAtomConstructor(atom.getConstructor()); hashes[hashes.length - 1] = ctorHashCode; int atomHashCode = Arrays.hashCode(hashes); @@ -232,15 +234,29 @@ long hashCodeForAtom( } @TruffleBoundary - private void hashCodeForAtomFieldsUncached(Object[] fields, int[] fieldHashes) { + @Specialization(replaces = "hashCodeForAtom") + long hashCodeForAtomUncached(Atom atom) { + if (atom.getHashCode() != null) { + return atom.getHashCode(); + } + + Object[] fields = StructsLibrary.getUncached().getFields(atom); + int[] hashes = new int[fields.length + 1]; for (int i = 0; i < fields.length; i++) { if (fields[i] instanceof Atom atomField && HasCustomComparatorNode.getUncached().execute(atomField)) { - fieldHashes[i] = (int) HashCallbackNode.getUncached().execute(atomField); + hashes[i] = (int) HashCallbackNode.getUncached().execute(atomField); } else { - fieldHashes[i] = (int) HashCodeNodeGen.getUncached().execute(fields[i]); + hashes[i] = (int) HashCodeNodeGen.getUncached().execute(fields[i]); } } + + int ctorHashCode = (int) hashCodeForAtomConstructor(atom.getConstructor()); + hashes[hashes.length - 1] = ctorHashCode; + + int atomHashCode = Arrays.hashCode(hashes); + atom.setHashCode(atomHashCode); + return atomHashCode; } @Specialization( @@ -434,7 +450,10 @@ long hashCodeForArray( * Two maps are considered equal, if they have the same entries. Note that we do not care about * ordering. */ - @Specialization(guards = "interop.hasHashEntries(selfMap)") + @Specialization(guards = { + "interop.hasHashEntries(selfMap)", + "!interop.hasArrayElements(selfMap)", + }) long hashCodeForMap( Object selfMap, @CachedLibrary(limit = "5") InteropLibrary interop, diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/Array.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/Array.java index afd1fa603f43..4cd3887b48b2 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/Array.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/Array.java @@ -1,6 +1,7 @@ package org.enso.interpreter.runtime.data; import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.InvalidArrayIndexException; import com.oracle.truffle.api.interop.TruffleObject; @@ -9,6 +10,7 @@ import com.oracle.truffle.api.library.ExportLibrary; import com.oracle.truffle.api.library.ExportMessage; import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.BranchProfile; import org.enso.interpreter.dsl.Builtin; import org.enso.interpreter.runtime.EnsoContext; import org.enso.interpreter.runtime.error.Warning; @@ -25,7 +27,7 @@ @Builtin(pkg = "mutable", stdlibName = "Standard.Base.Data.Array.Array") public final class Array implements TruffleObject { private final Object[] items; - private @CompilerDirectives.CompilationFinal Boolean withWarnings; + private Boolean withWarnings; /** * Creates a new array @@ -75,14 +77,20 @@ public boolean hasArrayElements() { * @throws InvalidArrayIndexException when the index is out of bounds. */ @ExportMessage - public Object readArrayElement(long index, @CachedLibrary(limit = "3") WarningsLibrary warnings) + public Object readArrayElement( + long index, + @CachedLibrary(limit = "3") WarningsLibrary warnings, + @Cached BranchProfile errProfile, + @Cached BranchProfile hasWarningsProfile) throws InvalidArrayIndexException, UnsupportedMessageException { if (index >= items.length || index < 0) { + errProfile.enter(); throw InvalidArrayIndexException.create(index); } var v = items[(int) index]; if (this.hasWarnings(warnings)) { + hasWarningsProfile.enter(); Warning[] extracted = this.getWarnings(null, warnings); if (warnings.hasWarnings(v)) { v = warnings.removeWarnings(v); diff --git a/engine/runtime/src/test/java/org/enso/interpreter/test/EqualsTest.java b/engine/runtime/src/test/java/org/enso/interpreter/test/EqualsTest.java index 464f746f3e22..82b271f28025 100644 --- a/engine/runtime/src/test/java/org/enso/interpreter/test/EqualsTest.java +++ b/engine/runtime/src/test/java/org/enso/interpreter/test/EqualsTest.java @@ -11,7 +11,9 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode; import org.enso.interpreter.node.expression.builtin.meta.EqualsNode; +import org.enso.interpreter.node.expression.builtin.meta.EqualsNodeGen; import org.graalvm.polyglot.Context; import org.graalvm.polyglot.Value; import org.junit.AfterClass; @@ -28,19 +30,21 @@ public class EqualsTest extends TestBase { private static Context context; private static EqualsNode equalsNode; private static TestRootNode testRootNode; + private static HostValueToEnsoNode hostValueToEnsoNode; @BeforeClass public static void initContextAndData() { context = createDefaultContext(); - unwrappedValues = fetchAllUnwrappedValues(); executeInContext( context, () -> { testRootNode = new TestRootNode(); equalsNode = EqualsNode.build(); - testRootNode.insertChildren(equalsNode); + hostValueToEnsoNode = HostValueToEnsoNode.build(); + testRootNode.insertChildren(equalsNode, hostValueToEnsoNode); return null; }); + unwrappedValues = fetchAllUnwrappedValues(); } @AfterClass @@ -74,6 +78,7 @@ private static Object[] fetchAllUnwrappedValues() { try { return values.stream() .map(value -> unwrapValue(context, value)) + .map(unwrappedValue -> hostValueToEnsoNode.execute(unwrappedValue)) .collect(Collectors.toList()) .toArray(new Object[] {}); } catch (Exception e) { @@ -105,6 +110,21 @@ public void equalsOperatorShouldBeConsistent(Object value) { }); } + @Theory + public void equalsNodeCachedIsConsistentWithUncached(Object firstVal, Object secondVal) { + executeInContext( + context, + () -> { + boolean uncachedRes = EqualsNodeGen.getUncached().execute(firstVal, secondVal); + boolean cachedRes = equalsNode.execute(firstVal, secondVal); + assertEquals( + "Result from uncached EqualsNode should be the same as result from its cached variant", + uncachedRes, + cachedRes); + return null; + }); + } + /** Test for some specific values, for which we know that they are equal. */ @Test public void testDateEquality() { diff --git a/engine/runtime/src/test/java/org/enso/interpreter/test/HashCodeTest.java b/engine/runtime/src/test/java/org/enso/interpreter/test/HashCodeTest.java index 09ed164afad6..cb8ac43bfc85 100644 --- a/engine/runtime/src/test/java/org/enso/interpreter/test/HashCodeTest.java +++ b/engine/runtime/src/test/java/org/enso/interpreter/test/HashCodeTest.java @@ -7,8 +7,10 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode; import org.enso.interpreter.node.expression.builtin.meta.EqualsNode; import org.enso.interpreter.node.expression.builtin.meta.HashCodeNode; +import org.enso.interpreter.node.expression.builtin.meta.HashCodeNodeGen; import org.graalvm.polyglot.Context; import org.graalvm.polyglot.Value; import org.junit.AfterClass; @@ -26,20 +28,22 @@ public class HashCodeTest extends TestBase { private static HashCodeNode hashCodeNode; private static EqualsNode equalsNode; + private static HostValueToEnsoNode hostValueToEnsoNode; private static TestRootNode testRootNode; @BeforeClass public static void initContextAndData() { context = createDefaultContext(); - // Initialize datapoints here, to make sure that it is initialized just once. - unwrappedValues = fetchAllUnwrappedValues(); executeInContext(context, () -> { hashCodeNode = HashCodeNode.build(); equalsNode = EqualsNode.build(); + hostValueToEnsoNode = HostValueToEnsoNode.build(); testRootNode = new TestRootNode(); - testRootNode.insertChildren(hashCodeNode, equalsNode); + testRootNode.insertChildren(hashCodeNode, equalsNode, hostValueToEnsoNode); return null; }); + // Initialize datapoints here, to make sure that it is initialized just once. + unwrappedValues = fetchAllUnwrappedValues(); } @AfterClass @@ -79,6 +83,7 @@ private static Object[] fetchAllUnwrappedValues() { return values .stream() .map(value -> unwrapValue(context, value)) + .map(unwrappedValue -> hostValueToEnsoNode.execute(unwrappedValue)) .collect(Collectors.toList()) .toArray(new Object[]{}); } catch (Exception e) { @@ -132,4 +137,18 @@ public void hashCodeIsConsistent(Object value) { return null; }); } + + @Theory + public void hashCodeCachedNodeIsConsistentWithUncached(Object value) { + executeInContext(context, () -> { + long uncachedRes = HashCodeNodeGen.getUncached().execute(value); + long cachedRes = hashCodeNode.execute(value); + assertEquals( + "Result from cached HashCodeNode should be the same as from its uncached variant", + uncachedRes, + cachedRes + ); + return null; + }); + } } diff --git a/engine/runtime/src/test/java/org/enso/interpreter/test/ValuesGenerator.java b/engine/runtime/src/test/java/org/enso/interpreter/test/ValuesGenerator.java index e8ca263b8ad3..b8493361eb45 100644 --- a/engine/runtime/src/test/java/org/enso/interpreter/test/ValuesGenerator.java +++ b/engine/runtime/src/test/java/org/enso/interpreter/test/ValuesGenerator.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.Set; import java.util.TimeZone; +import org.enso.polyglot.MethodNames.Module; import org.graalvm.polyglot.Context; import org.graalvm.polyglot.PolyglotException; import org.graalvm.polyglot.Value; @@ -74,6 +75,32 @@ private ValueInfo v(String k, String t, String s) { return v; } + /** + * Converts expressions into values of type described by {@code typeDefs} by concatenating + * everything into a single source. + * + * This method exists so that there are no multiple definitions of a single type. + * + * @param typeDefs Type definitions. + * @param expressions List of expressions - every expression will be converted to a {@link Value}. + * @return List of values converted from the given expressions. + */ + private List createValuesOfCustomType(String typeDefs, List expressions) { + var sb = new StringBuilder(); + sb.append(typeDefs); + sb.append("\n"); + for (int i = 0; i < expressions.size(); i++) { + sb.append("var_").append(i).append(" = ").append(expressions.get(i)).append("\n"); + } + Value module = ctx.eval("enso", sb.toString()); + List values = new ArrayList<>(expressions.size()); + for (int i = 0; i < expressions.size(); i++) { + Value val = module.invokeMember(Module.EVAL_EXPRESSION, "var_" + i); + values.add(val); + } + return values; + } + public Value typeAny() { return v("typeAny", """ import Standard.Base.Any.Any @@ -521,7 +548,7 @@ public List multiLevelAtoms() { Nil Value value """; - for (var expr : List.of( + var exprs = List.of( "Node.C2 Node.Nil (Node.Value 42)", "Node.C2 (Node.Value 42) Node.Nil", "Node.Nil", @@ -536,9 +563,8 @@ public List multiLevelAtoms() { "Node.C2 (Node.C2 (Node.C1 Node.Nil) (Node.C1 (Node.C1 Node.Nil))) (Node.C2 (Node.C3 (Node.Nil) (Node.Value 22) (Node.Nil)) (Node.C2 (Node.Value 22) (Node.Nil)))", "Node.C2 (Node.C2 (Node.C1 Node.Nil) (Node.C1 Node.Nil)) (Node.C2 (Node.C3 (Node.Nil) (Node.Value 22) (Node.Nil)) (Node.C2 (Node.Value 22) (Node.Nil)))", "Node.C2 (Node.C2 (Node.C1 Node.Nil) (Node.C1 Node.Nil)) (Node.C2 (Node.C3 (Node.Nil) (Node.Nil) (Node.Value 22)) (Node.C2 (Node.Value 22) (Node.Nil)))" - )) { - collect.add(v(null, nodeTypeDef, expr).type()); - } + ); + collect.addAll(createValuesOfCustomType(nodeTypeDef, exprs)); } return collect; }