Skip to content

Commit

Permalink
Improved subquery support
Browse files Browse the repository at this point in the history
  • Loading branch information
huntfx committed May 21, 2022
1 parent 257cfeb commit 8691706
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 36 deletions.
32 changes: 17 additions & 15 deletions ftrack_query/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@


class AbstractQuery(object):
"""Class to use for inheritance checks."""
"""Base class to use mainly for inheritance checks."""

def __init__(self):
self._where = []


class AbstractComparison(object):
Expand Down Expand Up @@ -94,10 +97,7 @@ def parser(cls, *args, **kwargs):
args:
Query: An unexecuted query object.
This is not recommended, but an attempt will be made
to execute it for a single result.
It will raise an exception if multiple or none are
found.
This will be added as a subquery if supported.
dict: Like kargs, but with relationships allowed.
A relationship like "parent.name" is not compatible
Expand All @@ -110,24 +110,23 @@ def parser(cls, *args, **kwargs):
Anything else passed in will get converted to strings.
The comparison class has been designed to evaluate when
__str__ is called, but any custom class could be used.
to_str() is called, but any custom class could be used.
kwargs:
Search for attributes of an entity.
This is the recommended way to query if possible.
`(x=y)` is the equivelant of `(entity.x == y)`.
"""
for arg in args:
# The query has not been performed, attempt to execute
# This shouldn't really be used, so don't catch any errors
if isinstance(arg, AbstractQuery):
arg = arg.one()
for item in arg._where:
yield item

if isinstance(arg, dict):
elif isinstance(arg, dict):
for key, value in arg.items():
yield cls(key)==value
yield cls(key) == value

elif isinstance(arg, ftrack_api.entity.base.Entity):
raise TypeError("keyword required for {}".format(arg))
raise TypeError('keyword required for {}'.format(arg))

# The object is likely a comparison object, so convert to str
# If an actual string is input, then assume it's valid syntax
Expand All @@ -136,8 +135,11 @@ def parser(cls, *args, **kwargs):

for key, value in kwargs.items():
if isinstance(value, AbstractQuery):
value = value.one()
yield cls(key)==value
for item in value._where:
yield item

else:
yield cls(key) == value


class AbstractStatement(object):
Expand Down
3 changes: 1 addition & 2 deletions ftrack_query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ class Query(AbstractQuery):
)

def __init__(self, session, entity):
super(Query, self).__init__()
self._session = session
self._entity = entity
self._where = []
self._populate = []
self._sort = []
self._offset = 0
Expand Down Expand Up @@ -294,7 +294,6 @@ def __call__(self, *args, **kwargs):
return result
raise


raise TypeError("'Query' object is not callable, "
"perhaps you meant to use 'Query.where()'?")

Expand Down
3 changes: 1 addition & 2 deletions ftrack_query/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def parse_operators(func):
"""Parse the value when an operator is used."""
@wraps(func)
def wrapper(self, value):
# If the item is constructed query, assume it's a single object
if isinstance(value, AbstractQuery):
value = value.one()
raise NotImplementedError('query comparisons are not supported')

# If the item is an FTrack entity, use the ID
if isinstance(value, ftrack_api.entity.base.Entity):
Expand Down
63 changes: 46 additions & 17 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_any(self):

def test_sort(self):
self.assertEqual(str(entity.a.desc()), 'a descending')
self.assertNotEqual(str(entity.a.desc), 'a descending')
self.assertEqual(str(entity.a.desc), 'a.desc')
self.assertEqual(str(entity.a.b.asc()), 'a.b ascending')

def test_call(self):
Expand Down Expand Up @@ -138,30 +138,59 @@ def test_id_remap_in(self):
'Project where project_schema.id in ("{}")'.format(schema['id'])
)

def test_query_remap(self):
def test_id_remap_in_multiple(self):
schema = self.session.ProjectSchema.first()
query = self.session.ProjectSchema.where(id=schema['id'])
self.assertEqual(
str(self.session.Project.where(entity.project_schema == query)),
'Project where project_schema.id is "{}"'.format(schema['id']),
)
self.assertEqual(
str(self.session.Project.where(project_schema=query)),
'Project where project_schema.id is "{}"'.format(schema['id']),
str(self.session.Project.where(entity.project_schema.in_(schema, schema))),
'Project where project_schema.id in ("{s}", "{s}")'.format(s=schema['id'])
)

def test_subquery_in(self):
schema = self.session.ProjectSchema.first()
query = self.session.ProjectSchema.where(id=schema['id'])

class TestQueryComparison(unittest.TestCase):
def setUp(self):
self.session = FTrackQuery(debug=True)

def test_in(self):
query = self.session.ProjectSchema.where(name='My Schema')
self.assertEqual(
str(self.session.Project.where(entity.project_schema.in_(query))),
'Project where project_schema.id in (select id from ProjectSchema where id is "{}")'.format(schema['id']),
'Project where project_schema.id in (select id from ProjectSchema where name is "My Schema")',
)
with self.assertRaises(ValueError):
self.assertEqual(
str(self.session.Project.where(entity.project_schema.in_(query, query))),
'Project where project_schema.id in ("{id}", "{id}")'.format(id=schema['id']),
)
str(self.session.Project.where(entity.project_schema.in_(query, query)))

def test_has_simple(self):
query = self.session.ProjectSchema.where(name='My Schema')
self.assertEqual(
str(self.session.Project.where(entity.project_schema.has(query))),
'Project where project_schema has (name is "My Schema")',
)

def test_has_complex(self):
query = self.session.ProjectSchema.where(~entity.project.has(name='Invalid Project'), name='My Schema')
self.assertEqual(
str(self.session.Project.where(entity.project_schema.has(query))),
'Project where project_schema has (not project has (name is "Invalid Project") and name is "My Schema")',
)

def test_has_multiple(self):
query1 = self.session.ProjectSchema.where(~entity.project.has(name='Invalid Project'))
query2 = self.session.ProjectSchema.where(name='My Schema')
self.assertEqual(
str(self.session.Project.where(entity.project_schema.has(query1, query2))),
'Project where project_schema has (not project has (name is "Invalid Project") and name is "My Schema")',
)
self.assertEqual(
str(self.session.Project.where(entity.project_schema.any(query1, query2))),
'Project where project_schema any (not project has (name is "Invalid Project") and name is "My Schema")',
)

def test_equals(self):
with self.assertRaises(NotImplementedError):
entity.value == self.session.ProjectSchema.where(name='My Schema')
with self.assertRaises(NotImplementedError):
entity.value != self.session.ProjectSchema.where(name='My Schema')



if __name__ == '__main__':
Expand Down

0 comments on commit 8691706

Please sign in to comment.