Skip to content

Commit

Permalink
cleaner message deserialzation
Browse files Browse the repository at this point in the history
  • Loading branch information
davebshow committed Jan 23, 2018
1 parent 50ed2bd commit 0bd35d3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 21 deletions.
27 changes: 6 additions & 21 deletions aiogremlin/driver/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,15 @@ async def write(self, request_id, request_message):
await func

async def data_received(self, data, results_dict):
serializer_version = self._message_serializer.version
data = data.decode('utf-8')
message = json.loads(data)
message = self._message_serializer.deserialize_message(json.loads(data))
request_id = message['requestId']
status_code = message['status']['code']
data = message['result']['data']
msg = message['status']['message']
if request_id in results_dict:
result_set = results_dict[request_id]
if serializer_version == b"application/vnd.gremlin-v2.0+json":
aggregate_to = data['result']['meta'].get('aggregateTo', 'list')
else:
meta_aggregate_to = message['result']['meta']['@value']
if len(meta_aggregate_to) > 1:
aggregate_to = meta_aggregate_to[1]
else:
aggregate_to = 'list'
aggregate_to = message['result']['meta'].get('aggregateTo', 'list')
result_set.aggregate_to = aggregate_to

if status_code == 407:
Expand All @@ -72,18 +64,11 @@ async def data_received(self, data, results_dict):
result_set.queue_result(None)
else:
if data:
if serializer_version == b"application/vnd.gremlin-v2.0+json":
for result in data:
result = self._message_serializer.deserialize_message(result)
message = Message(status_code, result, msg)
result_set.queue_result(message)
else:
results = self._message_serializer.deserialize_message(data['@value'])
for result in results:
message = Message(status_code, result, msg)
result_set.queue_result(message)
for result in data:
result = self._message_serializer.deserialize_message(result)
message = Message(status_code, result, msg)
result_set.queue_result(message)
else:
data = self._message_serializer.deserialize_message(data)
message = Message(status_code, data, msg)
result_set.queue_result(message)
if status_code != 206:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def test_label(self, remote_connection):
statics.load_statics(globals())
g = Graph().traversal().withRemote(remote_connection)
result = await g.V().limit(1).toList()
await remote_connection.close()

@pytest.mark.asyncio
async def test_traversals(self, remote_connection):
Expand Down

0 comments on commit 0bd35d3

Please sign in to comment.