diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java index 70696c081e..9de90f5ede 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java @@ -30,6 +30,8 @@ import static com.oracle.graal.python.builtins.objects.PNone.NO_VALUE; import static com.oracle.graal.python.builtins.objects.PNotImplemented.NOT_IMPLEMENTED; import static com.oracle.graal.python.nodes.BuiltinNames.ABS; +import static com.oracle.graal.python.nodes.BuiltinNames.ALL; +import static com.oracle.graal.python.nodes.BuiltinNames.ANY; import static com.oracle.graal.python.nodes.BuiltinNames.ASCII; import static com.oracle.graal.python.nodes.BuiltinNames.BIN; import static com.oracle.graal.python.nodes.BuiltinNames.BREAKPOINT; @@ -102,7 +104,9 @@ import com.oracle.graal.python.builtins.objects.code.PCode; import com.oracle.graal.python.builtins.objects.common.DynamicObjectStorage; import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes; +import com.oracle.graal.python.builtins.objects.common.HashingStorage; import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary; +import com.oracle.graal.python.builtins.objects.common.PHashingCollection; import com.oracle.graal.python.builtins.objects.common.SequenceNodes.GetObjectArrayNode; import com.oracle.graal.python.builtins.objects.common.SequenceNodesFactory.GetObjectArrayNodeGen; import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes; @@ -119,6 +123,7 @@ import com.oracle.graal.python.builtins.objects.module.PythonModule; import com.oracle.graal.python.builtins.objects.object.ObjectNodes; import com.oracle.graal.python.builtins.objects.object.PythonObject; +import com.oracle.graal.python.builtins.objects.set.PBaseSet; import com.oracle.graal.python.builtins.objects.str.PString; import com.oracle.graal.python.builtins.objects.tuple.PTuple; import com.oracle.graal.python.builtins.objects.type.SpecialMethodSlot; @@ -134,6 +139,7 @@ import com.oracle.graal.python.lib.PyObjectGetAttr; import com.oracle.graal.python.lib.PyObjectGetIter; import com.oracle.graal.python.lib.PyObjectHashNode; +import com.oracle.graal.python.lib.PyObjectIsTrueNode; import com.oracle.graal.python.lib.PyObjectLookupAttr; import com.oracle.graal.python.lib.PyObjectReprAsObjectNode; import com.oracle.graal.python.lib.PyObjectSizeNode; @@ -144,6 +150,7 @@ import com.oracle.graal.python.nodes.ErrorMessages; import com.oracle.graal.python.nodes.GraalPythonTranslationErrorNode; import com.oracle.graal.python.nodes.PGuards; +import com.oracle.graal.python.nodes.PNodeWithContext; import com.oracle.graal.python.nodes.PNodeWithRaise; import com.oracle.graal.python.nodes.PRaiseNode; import com.oracle.graal.python.nodes.PRootNode; @@ -204,6 +211,8 @@ import com.oracle.graal.python.runtime.exception.PException; import com.oracle.graal.python.runtime.exception.PythonErrorType; import com.oracle.graal.python.runtime.object.PythonObjectFactory; +import com.oracle.graal.python.runtime.sequence.storage.BoolSequenceStorage; +import com.oracle.graal.python.runtime.sequence.storage.IntSequenceStorage; import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage; import com.oracle.graal.python.util.CharsetMapping; import com.oracle.graal.python.util.PythonUtils; @@ -232,11 +241,13 @@ import com.oracle.truffle.api.library.CachedLibrary; import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.nodes.ExplodeLoop.LoopExplosionKind; +import com.oracle.truffle.api.nodes.LoopNode; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.RootNode; import com.oracle.truffle.api.nodes.UnexpectedResultException; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.LoopConditionProfile; import com.oracle.truffle.api.source.Source; import com.oracle.truffle.api.utilities.TriState; @@ -286,6 +297,217 @@ public Object absObject(VirtualFrame frame, Object object, } } + /** + * Common class for all() and any() operations, as their logic and behaviors are very similar. + */ + abstract static class AllOrAnyNode extends PNodeWithContext { + enum NodeType { + ALL, + ANY + } + + @Child private PyObjectIsTrueNode isTrueNode = PyObjectIsTrueNode.create(); + + final private LoopConditionProfile loopConditionProfile = LoopConditionProfile.create(); + + abstract boolean execute(Frame frame, Object storageObj, NodeType nodeType); + + @Specialization + boolean doBoolSequence(VirtualFrame frame, + BoolSequenceStorage sequenceStorage, + NodeType nodeType) { + boolean[] internalArray = sequenceStorage.getInternalBoolArray(); + int seqLength = sequenceStorage.length(); + + loopConditionProfile.profileCounted(seqLength); + for (int i = 0; loopConditionProfile.inject(i < seqLength); i++) { + if (nodeType == NodeType.ALL && !isTrueNode.execute(frame, internalArray[i])) { + return false; + } else if (nodeType == NodeType.ANY && isTrueNode.execute(frame, internalArray[i])) { + return true; + } + } + + return nodeType == NodeType.ALL; + } + + @Specialization + boolean doIntSequence(VirtualFrame frame, + IntSequenceStorage sequenceStorage, + NodeType nodeType) { + int[] internalArray = sequenceStorage.getInternalIntArray(); + int seqLength = sequenceStorage.length(); + + loopConditionProfile.profileCounted(seqLength); + for (int i = 0; loopConditionProfile.inject(i < seqLength); i++) { + if (nodeType == NodeType.ALL && !isTrueNode.execute(frame, internalArray[i])) { + return false; + } else if (nodeType == NodeType.ANY && isTrueNode.execute(frame, internalArray[i])) { + return true; + } + } + + return nodeType == NodeType.ALL; + } + + @Specialization + boolean doGenericSequence(VirtualFrame frame, + SequenceStorage sequenceStorage, + NodeType nodeType, + @Cached SequenceStorageNodes.LenNode lenNode) { + Object[] internalArray = sequenceStorage.getInternalArray(); + int seqLength = lenNode.execute(sequenceStorage); + + loopConditionProfile.profileCounted(seqLength); + for (int i = 0; loopConditionProfile.inject(i < seqLength); i++) { + if (nodeType == NodeType.ALL && !isTrueNode.execute(frame, internalArray[i])) { + return false; + } else if (nodeType == NodeType.ANY && isTrueNode.execute(frame, internalArray[i])) { + return true; + } + } + + return nodeType == NodeType.ALL; + } + + @Specialization(limit = "3") + protected boolean doHashStorage(VirtualFrame frame, + HashingStorage hashingStorage, + NodeType nodeType, + @CachedLibrary("hashingStorage") HashingStorageLibrary hlib) { + HashingStorageLibrary.HashingStorageIterator keysIter = hlib.keys(hashingStorage).iterator(); + int seqLength = hlib.length(hashingStorage); + + loopConditionProfile.profileCounted(seqLength); + for (int i = 0; loopConditionProfile.inject(i < seqLength); i++) { + Object key = keysIter.next(); + if (nodeType == NodeType.ALL) { + if (!isTrueNode.execute(frame, key)) { + return false; + } + } else if (nodeType == NodeType.ANY && isTrueNode.execute(frame, key)) { + return true; + } + } + + return nodeType == NodeType.ALL; + } + } + + @Builtin(name = ALL, minNumOfPositionalArgs = 1) + @GenerateNodeFactory + public abstract static class AllNode extends PythonUnaryBuiltinNode { + + @Specialization(guards = "cannotBeOverridden(object, getClassNode)", limit = "1") + static boolean doList(VirtualFrame frame, + PList object, + @SuppressWarnings("unused") @Shared("getClassNode") @Cached GetClassNode getClassNode, + @Shared("allOrAnyNode") @Cached AllOrAnyNode allOrAnyNode) { + return allOrAnyNode.execute(frame, object.getSequenceStorage(), AllOrAnyNode.NodeType.ALL); + } + + @Specialization(guards = "cannotBeOverridden(object, getClassNode)", limit = "1") + static boolean doTuple(VirtualFrame frame, + PTuple object, + @SuppressWarnings("unused") @Shared("getClassNode") @Cached GetClassNode getClassNode, + @Shared("allOrAnyNode") @Cached AllOrAnyNode allOrAnyNode) { + return allOrAnyNode.execute(frame, object.getSequenceStorage(), AllOrAnyNode.NodeType.ALL); + } + + @Specialization(guards = "cannotBeOverridden(object, getClassNode)", limit = "1") + static boolean doHashColl(VirtualFrame frame, + PHashingCollection object, + @SuppressWarnings("unused") @Shared("getClassNode") @Cached GetClassNode getClassNode, + @Shared("allOrAnyNode") @Cached AllOrAnyNode allOrAnyNode) { + return allOrAnyNode.execute(frame, object.getDictStorage(), AllOrAnyNode.NodeType.ALL); + } + + @Specialization + boolean doObject(VirtualFrame frame, + Object object, + @Cached PyObjectGetIter getIter, + @Cached GetNextNode nextNode, + @Cached IsBuiltinClassProfile errorProfile, + @Cached PyObjectIsTrueNode isTrueNode) { + Object iterator = getIter.execute(frame, object); + int nbrIter = 0; + + while (true) { + try { + Object next = nextNode.execute(frame, iterator); + nbrIter++; + if (!isTrueNode.execute(frame, next)) { + return false; + } + } catch (PException e) { + e.expectStopIteration(errorProfile); + break; + } finally { + LoopNode.reportLoopCount(this, nbrIter); + } + } + + return true; + } + } + + @Builtin(name = ANY, minNumOfPositionalArgs = 1) + @GenerateNodeFactory + public abstract static class AnyNode extends PythonUnaryBuiltinNode { + + @Specialization(guards = "cannotBeOverridden(object, getClassNode)", limit = "1") + static boolean doList(VirtualFrame frame, + PList object, + @SuppressWarnings("unused") @Shared("getClassNode") @Cached GetClassNode getClassNode, + @Shared("allOrAnyNode") @Cached AllOrAnyNode allOrAnyNode) { + return allOrAnyNode.execute(frame, object.getSequenceStorage(), AllOrAnyNode.NodeType.ANY); + } + + @Specialization(guards = "cannotBeOverridden(object, getClassNode)", limit = "1") + static boolean doTuple(VirtualFrame frame, + PTuple object, + @SuppressWarnings("unused") @Shared("getClassNode") @Cached GetClassNode getClassNode, + @Shared("allOrAnyNode") @Cached AllOrAnyNode allOrAnyNode) { + return allOrAnyNode.execute(frame, object.getSequenceStorage(), AllOrAnyNode.NodeType.ANY); + } + + @Specialization(guards = "cannotBeOverridden(object, getClassNode)", limit = "1") + static boolean doHashColl(VirtualFrame frame, + PHashingCollection object, + @SuppressWarnings("unused") @Shared("getClassNode") @Cached GetClassNode getClassNode, + @Shared("allOrAnyNode") @Cached AllOrAnyNode allOrAnyNode) { + return allOrAnyNode.execute(frame, object.getDictStorage(), AllOrAnyNode.NodeType.ANY); + } + + @Specialization + boolean doObject(VirtualFrame frame, + Object object, + @Cached PyObjectGetIter getIter, + @Cached GetNextNode nextNode, + @Cached IsBuiltinClassProfile errorProfile, + @Cached PyObjectIsTrueNode isTrueNode) { + Object iterator = getIter.execute(frame, object); + int nbrIter = 0; + + while (true) { + try { + Object next = nextNode.execute(frame, iterator); + nbrIter++; + if (isTrueNode.execute(frame, next)) { + return true; + } + } catch (PException e) { + e.expectStopIteration(errorProfile); + break; + } finally { + LoopNode.reportLoopCount(this, nbrIter); + } + } + + return false; + } + } + // bin(object) @Builtin(name = BIN, minNumOfPositionalArgs = 1) @TypeSystemReference(PythonArithmeticTypes.class) diff --git a/graalpython/lib-graalpython/functions.py b/graalpython/lib-graalpython/functions.py index 1a18a43564..6462f07368 100644 --- a/graalpython/lib-graalpython/functions.py +++ b/graalpython/lib-graalpython/functions.py @@ -43,22 +43,6 @@ def hasattr(obj, key): return getattr(obj, key, default) is not default -@__graalpython__.builtin -def any(iterable): - for i in iterable: - if i: - return True - return False - - -@__graalpython__.builtin -def all(iterable): - for i in iterable: - if not i: - return False - return True - - from sys import _getframe as __getframe__