diff --git a/graalpython/com.oracle.graal.python.test/src/tests/test_super.py b/graalpython/com.oracle.graal.python.test/src/tests/test_super.py index d518468b5e..7a441031e8 100644 --- a/graalpython/com.oracle.graal.python.test/src/tests/test_super.py +++ b/graalpython/com.oracle.graal.python.test/src/tests/test_super.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, 2024, Oracle and/or its affiliates. All rights reserved. +# Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved. # DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. # # The Universal Permissive License (UPL), Version 1.0 @@ -47,3 +47,44 @@ def m(self): return ["B"] + self.my_super.m() B.my_super = B.my_super assert B().m() == ["B", "A"] + + +def test_super_subclass_descr_get_invokes_subclass_type(): + class MySuper(super): + news = [] + calls = [] + + def __new__(cls, *args): + cls.news.append(args) + return super().__new__(cls) + + def __init__(self, *args): + type(self).calls.append(args) + super().__init__(*args) + + class A: + def f(self): + return "A.f" + + class B(A): + pass + + raw = MySuper(B) + MySuper.news.clear() + MySuper.calls.clear() + obj = B() + bound = raw.__get__(obj, B) + + assert type(bound) is MySuper + assert MySuper.news == [(B, obj)] + assert MySuper.calls == [(B, obj)] + assert bound.f() == "A.f" + + raw = MySuper.__new__(MySuper) + MySuper.news.clear() + MySuper.calls.clear() + bound = raw.__get__(obj, B) + + assert type(bound) is MySuper + assert MySuper.news == [()] + assert MySuper.calls == [()] diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/superobject/SuperBuiltins.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/superobject/SuperBuiltins.java index d56938b421..9e8fd679a5 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/superobject/SuperBuiltins.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/superobject/SuperBuiltins.java @@ -90,6 +90,7 @@ import com.oracle.graal.python.nodes.bytecode.FrameInfo; import com.oracle.graal.python.nodes.bytecode.PBytecodeRootNode; import com.oracle.graal.python.nodes.bytecode_dsl.PBytecodeDSLRootNode; +import com.oracle.graal.python.nodes.call.CallNode; import com.oracle.graal.python.nodes.classes.IsSubtypeNode; import com.oracle.graal.python.nodes.frame.ReadFrameNode; import com.oracle.graal.python.nodes.function.BuiltinFunctionRootNode; @@ -97,7 +98,9 @@ import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode; import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode; import com.oracle.graal.python.nodes.function.builtins.PythonVarargsBuiltinNode; +import com.oracle.graal.python.nodes.object.BuiltinClassProfiles.IsBuiltinClassExactProfile; import com.oracle.graal.python.nodes.object.GetClassNode; +import com.oracle.graal.python.nodes.object.GetClassNode.GetPythonObjectClassNode; import com.oracle.graal.python.nodes.object.IsForeignObjectNode; import com.oracle.graal.python.runtime.CallerFlags; import com.oracle.graal.python.runtime.PythonOptions; @@ -105,6 +108,7 @@ import com.oracle.graal.python.runtime.exception.PythonErrorType; import com.oracle.graal.python.runtime.object.PFactory; import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.HostCompilerDirectives.InliningCutoff; import com.oracle.truffle.api.bytecode.BytecodeFrame; import com.oracle.truffle.api.bytecode.BytecodeNode; import com.oracle.truffle.api.dsl.Bind; @@ -459,19 +463,37 @@ private Object supercheck(VirtualFrame frame, Node inliningTarget, Object cls, O @GenerateNodeFactory public abstract static class GetNode extends DescrGetBuiltinNode { @Specialization - static Object doNoneOrBound(SuperObject self, Object obj, @SuppressWarnings("unused") Object type, + static Object doNoneOrBound(VirtualFrame frame, SuperObject self, Object obj, @SuppressWarnings("unused") Object type, @Bind Node inliningTarget, @Cached InlinedConditionProfile objIsNoneProfile, @Cached InlinedConditionProfile selfObjIsNullProfile, + @Cached IsBuiltinClassExactProfile isBuiltinSuperProfile, @Cached GetObjectNode getObject, + @Cached GetTypeNode getType, + @Cached GetPythonObjectClassNode getClass, + @Cached CallNode callNode, + @Cached InlinedConditionProfile superTypeIsNullProfile, @Cached DoGetNode doGetNode) { - // TODO: (GR-53092) doesn't seem to handle super subclasses like CPython if (objIsNoneProfile.profile(inliningTarget, PGuards.isPNone(obj)) || // selfObjIsNullProfile.profile(inliningTarget, getObject.execute(inliningTarget, self) != null)) { return self; } + Object cls = getClass.execute(inliningTarget, self); + if (!isBuiltinSuperProfile.profileClass(inliningTarget, cls, PythonBuiltinClassType.Super)) { + return doSuperSubclass(frame, inliningTarget, self, obj, cls, getType, callNode, superTypeIsNullProfile); + } return doGetNode.execute(inliningTarget, self, obj); } + + @InliningCutoff + private static Object doSuperSubclass(VirtualFrame frame, Node inliningTarget, SuperObject self, Object obj, Object cls, GetTypeNode getType, CallNode callNode, + InlinedConditionProfile superTypeIsNullProfile) { + Object superType = getType.execute(inliningTarget, self); + if (superTypeIsNullProfile.profile(inliningTarget, superType == null)) { + return callNode.execute(frame, cls); + } + return callNode.execute(frame, cls, superType, obj); + } } @GenerateInline