Skip to content

Commit

Permalink
✨ add support for get queries
Browse files Browse the repository at this point in the history
  • Loading branch information
Rezenders committed Apr 26, 2024
1 parent 597f7eb commit 6b33386
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 1 deletion.
27 changes: 27 additions & 0 deletions ros_typedb/ros_typedb/ros_typedb_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,29 @@ def fetch_query_result_to_ros_msg(
return response


def get_query_result_to_ros_msg(
query_result: int | float | None
) -> ros_typedb_msgs.srv.Query.Response:
"""
Convert get query result to :class:`ros_typedb_msgs.srv.Query`.
:param query_result: typedb get aggreate query result.
:return: converted query response.
"""
response = Query.Response()
for result in query_result:
_result = QueryResult()
_variables = result.variables()
for variable in _variables:
if result.get(variable).is_attribute():
_typedb_attr = result.get(variable).as_attribute()
_attr = Parameter()
_attr.name = variable
_attr.value = set_query_result_value(
_typedb_attr.get_value(),
str(_typedb_attr.get_type().get_value_type()))
_result.attributes.append(_attr)
response.results.append(_result)
return response


Expand Down Expand Up @@ -134,6 +157,8 @@ def query_result_to_ros_msg(
response = Query.Response()
if query_type == 'fetch':
response = fetch_query_result_to_ros_msg(query_result)
elif query_type == 'get':
response = get_query_result_to_ros_msg(query_result)
elif query_type == 'get_aggregate':
response = get_aggregate_query_result_to_ros_msg(query_result)
return response
Expand Down Expand Up @@ -273,6 +298,8 @@ def query_service_cb(
query_func = self.typedb_interface.delete_from_database
elif req.query_type == 'fetch':
query_func = self.typedb_interface.fetch_database
elif req.query_type == 'get':
query_func = self.typedb_interface.get_database
elif req.query_type == 'get_aggregate':
query_func = self.typedb_interface.get_aggregate_database
else:
Expand Down
26 changes: 25 additions & 1 deletion ros_typedb/ros_typedb/typedb_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def database_query(
session_type: SessionType,
transaction_type: TransactionType,
query_type: Literal[
'define', 'insert', 'delete', 'fetch', 'get_aggregate'],
'define', 'insert', 'delete', 'fetch', 'get', 'get_aggregate'],
query: str,
options: Optional[TypeDBOptions] = TypeDBOptions()
) -> Literal[True] | Iterator[ConceptMap] | \
Expand Down Expand Up @@ -212,6 +212,8 @@ def database_query(
for answer in query_answer:
answer_list.append(answer)
return answer_list
elif query_type == 'get':
return list(query_answer)
elif query_type == 'get_aggregate':
answer = query_answer.resolve()
if answer.is_long():
Expand Down Expand Up @@ -360,6 +362,28 @@ def fetch_database(
return []
return result

def get_database(self, query: str) -> int | float | None:
"""
Perform get query.
:param query: Query to be performed.
:return: Query result.
"""
result = None
try:
options = TypeDBOptions()
options.infer = self._infer
result = self.database_query(
SessionType.DATA,
TransactionType.READ,
'get',
query,
options)
except Exception as err:
print(
'Error with get query! Exception retrieved: ', err)
return result

def get_aggregate_database(self, query: str) -> int | float | None:
"""
Perform get aggregate query.
Expand Down
35 changes: 35 additions & 0 deletions ros_typedb/test/test_ros_typedb_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,41 @@ def test_ros_typedb_fetch_query_attribute(insert_query):
rclpy.shutdown()


@pytest.mark.launch(fixture=generate_test_description)
def test_ros_typedb_get_query(insert_query):
rclpy.init()
try:
node = MakeTestNode()
node.start_node()
node.activate_ros_typedb()

node.call_service(node.query_srv, insert_query)

query_req = Query.Request()
query_req.query_type = 'get'
query_req.query = """
match
$p isa person, has full-name $name, has email $email;
get $name, $email;
sort $name asc; limit 3;
"""
query_res = node.call_service(node.query_srv, query_req)

correct_name = False
correct_email = False
for r in query_res.results[0].attributes:
if r.name == 'name' and r.value.string_value == 'Ahmed Frazier':
correct_name = True
if r.name == 'email' and \
r.value.string_value == 'ahmed.frazier@gmail.com':
correct_email = True

assert query_res.success is True and \
len(query_res.results) == 3 and correct_name and correct_email
finally:
rclpy.shutdown()


@pytest.mark.launch(fixture=generate_test_description)
def test_ros_typedb_get_aggregate_query(insert_query):
rclpy.init()
Expand Down
12 changes: 12 additions & 0 deletions ros_typedb/test/test_typedb_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,15 @@ def test_update_attributes(typedb_interface, insert_dict, update_dict, r_dict):
match_result = typedb_interface.get_aggregate_database(
"match " + query + "get; count;")
assert r is not None and r is not False and match_result > 0


def test_get_query(typedb_interface):
query = """
match
$p isa person, has full-name $name, has email $email;
get $name, $email;
sort $name asc; limit 3;
"""
result = typedb_interface.get_database(query)
name = result[0].get("name").as_attribute().get_value()
assert len(result) == 3 and name == "Ahmed Frazier"

0 comments on commit 6b33386

Please sign in to comment.