-
Notifications
You must be signed in to change notification settings - Fork 0
/
session_db.py
125 lines (101 loc) · 4.66 KB
/
session_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import boto3
from boto3.dynamodb.conditions import Key
from botocore.exceptions import ClientError
from loguru import logger
from emma_experience_hub.api.clients.dynamo_db import DynamoDbClient
from emma_experience_hub.datamodels.simbot import SimBotSessionTurn
class SimBotSessionDbClient(DynamoDbClient):
"""Client for storing SimBot session data."""
primary_key = "session_id"
sort_key = "idx"
data_key = "turn"
def healthcheck(self) -> bool:
"""Verify that the DB can be accessed and that it is ready."""
dynamodb_client = boto3.client(
"dynamodb", region_name=self._resource_region # pyright: ignore
)
try:
dynamodb_client.describe_table(TableName=self._table_name)
except dynamodb_client.exceptions.ResourceNotFoundException:
logger.exception("Cannot find DynamoDB table")
return False
return True
def add_session_turn(self, session_turn: SimBotSessionTurn) -> None:
"""Add a session turn to the table."""
try:
response = self._table.put_item(
Item={
self.primary_key: session_turn.session_id,
self.sort_key: session_turn.idx,
self.data_key: session_turn.json(by_alias=True),
},
ConditionExpression="attribute_not_exists(#sort_key)",
ExpressionAttributeNames={"#sort_key": self.sort_key},
)
logger.debug(response)
except ClientError as err:
logger.exception("Could not add turn to table.")
error_code = err.response["Error"]["Code"] # pyright: ignore
if error_code != "ConditionalCheckFailedException":
raise err
def put_session_turn(self, session_turn: SimBotSessionTurn) -> None:
"""Put a session turn to the table.
If the turn already exists, it WILL overwrite it.
"""
try:
self._table.put_item(
Item={
self.primary_key: session_turn.session_id,
self.sort_key: session_turn.idx,
self.data_key: session_turn.json(by_alias=True),
},
)
except ClientError as err:
logger.exception("Could not add turn to table.")
raise err
def get_session_turn(self, session_id: str, idx: int) -> SimBotSessionTurn:
"""Get the session turn from the table."""
try:
response = self._table.get_item(Key={self.primary_key: session_id, self.sort_key: idx})
except ClientError as err:
logger.exception("Could not get session turn from table")
raise err
return SimBotSessionTurn.parse_obj(response["Item"][self.data_key])
def get_all_session_turns(self, session_id: str) -> list[SimBotSessionTurn]:
"""Get all the turns for a given session."""
try:
all_raw_turns = self._get_all_session_turns(session_id)
except ClientError as query_err:
logger.exception("Could not query for session turns")
raise query_err
with ThreadPoolExecutor() as thread_pool:
# Try parse everything and hope it doesn't crash
try:
parsed_responses = list(
thread_pool.map(
SimBotSessionTurn.parse_raw,
(response_item[self.data_key] for response_item in all_raw_turns),
)
)
except Exception:
logger.exception(
"Could not parse session turns from response. Returning an empty list."
)
return []
logger.debug(f"Successfully got previous `{len(parsed_responses)}` turns")
# Sort the responses by the sort key before returning
sorted_responses = sorted(parsed_responses, key=lambda turn: turn.idx)
return sorted_responses
def _get_all_session_turns(self, session_id: str) -> list[dict[str, Any]]:
response = self._table.query(KeyConditionExpression=Key(self.primary_key).eq(session_id))
all_response_items = response["Items"]
# If not all the instances have been returned, get the next set
while "LastEvaluatedKey" in response:
response = self._table.query(
KeyConditionExpression=Key(self.primary_key).eq(session_id),
ExclusiveStartKey=response["LastEvaluatedKey"],
)
all_response_items.extend(response["Items"])
return all_response_items # type: ignore[unreachable]