Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gel/_internal/_qbmodel/_abstract/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gel._internal._schemapath import (
TypeNameIntersection,
)
from gel._internal import _type_expression
from gel._internal._xmethod import classonlymethod

from ._base import AbstractGelModel
Expand Down Expand Up @@ -246,6 +247,7 @@ def __edgeql_qb_expr__(cls) -> _qb.Expr: # pyright: ignore [reportIncompatibleM

class BaseGelModelIntersection(
BaseGelModel,
_type_expression.Intersection,
Generic[_T_Lhs, _T_Rhs],
):
__gel_type_class__: ClassVar[type]
Expand Down
15 changes: 15 additions & 0 deletions gel/_internal/_type_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-PackageName: gel-python
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors.

import typing


class Intersection:
lhs: typing.ClassVar[type]
rhs: typing.ClassVar[type]


class Union:
lhs: typing.ClassVar[type]
rhs: typing.ClassVar[type]
5 changes: 5 additions & 0 deletions gel/_internal/_typing_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import typing

from gel._internal import _namespace
from gel._internal import _type_expression
from gel._internal import _typing_eval
from gel._internal import _typing_inspect
from gel._internal import _typing_parametric
Expand Down Expand Up @@ -66,6 +67,10 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:
# subtypes of the variable bounds.
# This lets us handle cases like:
# std.array[Object] <: std.array[_T_anytype].

if issubclass(lhs, _type_expression.Intersection):
return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))

if _typing_inspect.is_generic_alias(tp):
origin = typing.get_origin(tp)
args = typing.get_args(tp)
Expand Down
4 changes: 4 additions & 0 deletions tests/dbsetup/orm_qb.gel
Original file line number Diff line number Diff line change
Expand Up @@ -618,4 +618,8 @@ type Link_Inh_A {
};
};

function Read_Inh_A(x: Inh_A) -> int64 using (x.a ?? -1);
function Read_Inh_A_Overload(x: Inh_A) -> int64 using (x.a ?? -1);
function Read_Inh_A_Overload(x: Inh_AB) -> int64 using (x.ab ?? -1);

}
99 changes: 93 additions & 6 deletions tests/test_qb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,9 +1695,7 @@ def test_qb_is_type_basic_07(self):
# Link TypeIntersection
from models.orm_qb import default

result = self.client.query(
default.Link_Inh_A.l.is_(default.Inh_B)
)
result = self.client.query(default.Link_Inh_A.l.is_(default.Inh_B))

self._assertObjectsWithFields(
result,
Expand Down Expand Up @@ -1900,9 +1898,9 @@ def test_qb_is_type_for_01(self):
from models.orm_qb import default, std

result = self.client.query(
std.for_(
default.Inh_A.is_(default.Inh_B), lambda x: x
).select(a=True)
std.for_(default.Inh_A.is_(default.Inh_B), lambda x: x).select(
a=True
)
)

self._assertObjectsWithFields(
Expand Down Expand Up @@ -2014,6 +2012,95 @@ def test_qb_is_type_for_03(self):
excluded_fields={'b', 'c', 'ab', 'ac', 'bc', 'abc', 'ab_ac'},
)

def test_qb_is_type_as_function_arg_01(self):
# Test that type exprs produced by is_ can be passed as function args
from models.orm_qb import default, std

result = self.client.query(
std.distinct(default.Inh_A.is_(default.Inh_B)).select('*')
)

self._assertObjectsWithFields(
result,
"a",
[
(
default.Inh_AB,
{
"a": 4,
"b": 5,
},
),
(
default.Inh_ABC,
{
"a": 13,
"b": 14,
},
),
(
default.Inh_AB_AC,
{
"a": 17,
"b": 18,
},
),
],
excluded_fields={'c', 'ab', 'ac', 'bc', 'abc', 'ab_ac'},
)

def test_qb_is_type_as_function_arg_02(self):
# Test that complex type exprs produced by is_ can be passed as
# function args
from models.orm_qb import default, std

result = self.client.query(
std.distinct(
default.Inh_A.is_(default.Inh_B).is_(default.Inh_C)
).select('*')
)

self._assertObjectsWithFields(
result,
"a",
[
(
default.Inh_ABC,
{
"a": 13,
"b": 14,
"c": 15,
},
),
(
default.Inh_AB_AC,
{
"a": 17,
"b": 18,
"c": 19,
},
),
],
excluded_fields={'ab', 'ac', 'bc', 'abc', 'ab_ac'},
)

def test_qb_is_type_as_function_arg_03(self):
# Test that exprs produced by is_ can be passed as function args to
# user defined function
from models.orm_qb import default

# Note, we do Inh_A[is Inh_B] since is_ currently pretends its return
# type is its argument type.
result = self.client.query(
default.Read_Inh_A(default.Inh_B.is_(default.Inh_A))
)
self.assertEqual(sorted(result), [4, 13, 17])

result = self.client.query(
default.Read_Inh_A_Overload(default.Inh_B.is_(default.Inh_A))
)
self.assertEqual(sorted(result), [6, 13, 20])


class TestQueryBuilderModify(tb.ModelTestCase):
"""This test suite is for data manipulation using QB."""
Expand Down