-
Notifications
You must be signed in to change notification settings - Fork 0
/
feedback.py
390 lines (322 loc) · 15.7 KB
/
feedback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import itertools
import string
from collections import Counter
from typing import Any, Optional
import orjson
from pydantic import BaseModel, Field, validator
from rule_engine import Context, Rule
from rule_engine.ast import ExpressionBase, StringExpression, SymbolExpression
from emma_experience_hub.datamodels.simbot.actions import SimBotAction
from emma_experience_hub.datamodels.simbot.agent_memory import get_area_from_compressed_mask
from emma_experience_hub.datamodels.simbot.enums import (
SimBotActionType,
SimBotAnyUserIntentType,
SimBotEnvironmentIntentType,
SimBotIntentType,
SimBotPhysicalInteractionIntentType,
SimBotVerbalInteractionIntentType,
)
from emma_experience_hub.datamodels.simbot.intents import SimBotIntent
from emma_experience_hub.datamodels.simbot.payloads import SimBotObjectInteractionPayload
def get_score_for_rule_expression(
expression: ExpressionBase, left_attr_name: str = "left", right_attr_name: str = "right"
) -> int:
"""Get the score for expression within the rule.
The conditions for each rule are converted into a binary tree. Therefore, we can recursively
iterate over all the nodes within the tree to get the score as the number of conditions in the
rule. Therefore, the more specific the rule the higher its score.
"""
# Try to get the left and right nodes within the expression
left_node: Optional[ExpressionBase] = getattr(expression, left_attr_name, None)
right_node: Optional[ExpressionBase] = getattr(expression, right_attr_name, None)
# If the left node is a SymBolExpression, we are at the slot name and can go no lower
if isinstance(left_node, SymbolExpression):
return 1
# Otherwise, both nodes are not None, we can dig into them further to try and get the score
if left_node is not None and right_node is not None:
left_expression_score = get_score_for_rule_expression(
left_node, left_attr_name, right_attr_name
)
right_expression_score = get_score_for_rule_expression(
right_node, left_attr_name, right_attr_name
)
return left_expression_score + right_expression_score
# If none of the above conditions suit, return 0 for this expression
return 0
def should_rule_be_mandatory(expression: ExpressionBase) -> bool: # noqa: WPS231
"""Determine if the rule is mandatory.
Mandatory rules include responses that we want actually to include in the selection process no
matter what. These are generally responses where we want to communicate back to the user
something important aka, confirm_before_plan, ask_about_the_game.
"""
# Try to get the left and right nodes within the expression
left_node: Optional[ExpressionBase] = getattr(expression, "left", None)
right_node: Optional[ExpressionBase] = getattr(expression, "right", None)
# If the left node is a SymbolExpression, we are at the slot name and can go no lower
if isinstance(left_node, SymbolExpression):
# Check the slot name for the verbal interaction intent type
if left_node.name == "verbal_interaction_intent_type":
# Check to see if the slot value is a mandatory one
if isinstance(right_node, StringExpression):
return "confirm_before_plan" in right_node.value or "ask_about" in right_node.value
# Otherwise, return False since we can go no lower
return False
# Otherwise, ensure both nodes are not None, and we can dig into them further
if left_node is not None and right_node is not None:
return should_rule_be_mandatory(left_node) or should_rule_be_mandatory(right_node)
# Otherwise, return False since that's the end.
return False
class SimBotFeedbackRule(BaseModel):
"""Rule for response generation."""
id: int = Field(..., description="Unique rule id")
rule: Rule = Field(..., description="Logical expression of the rule")
response: str = Field(
...,
description="Response template that can include slots. Slot values are derived from the `SimBotFeedbackState`",
)
is_lightweight_dialog: bool = Field(
..., description="Should the response be a lightweight dialog action"
)
score: int = Field(default=0, description="Determined by the number of conditions in the rule")
is_mandatory: bool = Field(
default=False,
description="Mandatory rules are always included in the candidate pool even if they have already been used",
)
class Config:
"""Updated config."""
arbitrary_types_allowed = True
@property
def slot_names(self) -> list[str]:
"""Get the necessary slot names."""
slot_names: list[str] = [
name for _, name, _, _ in string.Formatter().parse(self.response) if name
]
return slot_names
def prepare_response(self, slots: Optional[dict[str, str]] = None) -> str:
"""Build the response for the given rule.
If the rule is a template that requires slots to be filled, then slots need to be provided
and this method will raise an exception if all slots are not filled.
"""
if self.slot_names and not slots:
raise AssertionError(
"We should be providing slot-value pairs for the response template."
)
if self.slot_names and slots:
return self.response.format(**slots)
return self.response
@classmethod
def from_raw(cls, raw_dict: dict[str, str]) -> "SimBotFeedbackRule":
"""Parse a dictionary into a SimBotFeedbackRule."""
# If the rule is unable to resolve symbols, it defaults to None and returns False
engine_context = Context(default_value=None)
rule = Rule(raw_dict["conditions"].lower(), context=engine_context)
rule_id = int(raw_dict["id"])
if not rule.is_valid(rule.text):
raise AssertionError(f"Invalid rule: ID {rule_id} - {rule.text}")
is_mandatory = should_rule_be_mandatory(rule.statement.expression)
return cls(
id=rule_id,
rule=rule,
response=raw_dict["response"],
is_lightweight_dialog=raw_dict["is_lightweight"] == "True",
score=len(rule.context.symbols),
is_mandatory=is_mandatory,
)
@validator("score", always=True)
@classmethod
def calculate_rule_score(cls, score: int, values: dict[str, Any]) -> int: # noqa: WPS110
"""Calculate the score for the rule."""
# If the score is not 0, then just return it
if score > 1:
return score
# Get the rule and make sure it exists
rule: Optional[Rule] = values.get("rule")
if not rule:
raise AssertionError("There should be a rule for this model?")
# The score for a rule is the number of different criterion within it
score = get_score_for_rule_expression(rule.statement.expression)
if score < 1:
raise AssertionError("Score should not be less than 1.")
return score
def is_query_suitable(self, query: dict[str, Any]) -> bool:
"""Evaluate the rule given the query and ensure it is suitable."""
try:
return self.rule.matches(query) and all(name in query for name in self.slot_names)
except Exception:
return False
def turn_requires_lightweight_dialog(
verbal_interaction_intent: Optional[SimBotIntent[SimBotVerbalInteractionIntentType]],
utterance_queue_not_empty: bool,
find_queue_not_empty: bool,
interaction_action: Optional[SimBotAction],
) -> bool:
"""Does this turn require a lightweight dialog?"""
# If the verbal interaction intent triggers a question, ignore if the utterance queue is empty
triggers_question = (
verbal_interaction_intent is not None
and verbal_interaction_intent.type.triggers_question_to_user
)
utterance_queue_not_empty = utterance_queue_not_empty and not triggers_question
require_lightweight_dialog = (
(interaction_action and not interaction_action.is_end_of_trajectory)
or utterance_queue_not_empty
or (find_queue_not_empty)
)
return require_lightweight_dialog
class SimBotFeedbackState(BaseModel):
"""Flattened representation of the session state for feedback generation."""
# Force query for a lightweight dialog
require_lightweight_dialog: bool = False
# Session statistics
num_turns: int
# Location
current_room: str
# Inventory
inventory_entity: Optional[str] = None
inventory_turn: int
# Count all of the rooms visited
visited_room_counter: Counter[str]
# User intent
user_intent_type: Optional[SimBotAnyUserIntentType] = None
# Environment intent
environment_intent_type: Optional[SimBotEnvironmentIntentType] = None
environment_intent_action_type: Optional[SimBotActionType] = None
environment_intent_entity: Optional[str] = None
# Interaction intent
physical_interaction_intent_type: Optional[SimBotPhysicalInteractionIntentType] = None
physical_interaction_intent_entity: Optional[str] = None
# Language Condition intent
verbal_interaction_intent_type: Optional[SimBotVerbalInteractionIntentType] = None
verbal_interaction_intent_entity: Optional[str] = None
# Current interaction action
interaction_action_type: Optional[SimBotActionType] = None
interaction_action_entity: Optional[str] = None
# History of actions taken in the session
interaction_action_per_turn: list[SimBotAction]
# History of interacted entities in the session
interacted_entities_counter: Counter[str]
# Counter of how many times each action was taken
action_type_counter: Counter[str]
# History of all the intents per turn in the session
intent_types_per_turn: list[list[SimBotIntentType]]
# Counter of how many times each AGENT intent was held
intent_type_counter: Counter[str]
# There are more instructions to execute from the latest user utterance
utterance_queue_not_empty: bool
# There are more instructions to execute from the find routine
find_queue_not_empty: bool = False
previous_find_queue_not_empty: bool = False
# History of used rule ids
used_rule_ids: list[int] = Field(default_factory=list)
# History of agent responses within local window
# This allows us to ensure words are not being repeated across responses, which gives us
# more control and allows for more natural responses.
agent_responses_since_last_user_utterance: str = ""
current_turn_has_user_utterance: bool = False
object_area: Optional[float] = None
class Config:
"""Config for the model."""
json_encoders = {
# Use the action type name when converting to the JSON response
SimBotActionType: lambda action_type: action_type.name,
# Use the intent type name when converting to the JSON response
SimBotIntentType: lambda intent_type: intent_type.name,
}
@classmethod
def from_all_information(
cls,
num_turns: int,
current_room: str,
user_intent_type: Optional[SimBotAnyUserIntentType],
environment_intent: Optional[SimBotIntent[SimBotEnvironmentIntentType]],
physical_interaction_intent: Optional[SimBotIntent[SimBotPhysicalInteractionIntentType]],
verbal_interaction_intent: Optional[SimBotIntent[SimBotVerbalInteractionIntentType]],
interaction_action: Optional[SimBotAction],
current_room_per_turn: list[str],
interaction_action_per_turn: list[SimBotAction],
intent_types_per_turn: list[list[SimBotIntentType]],
utterance_queue_not_empty: bool,
find_queue_not_empty: bool,
previous_find_queue_not_empty: bool,
used_rule_ids: list[int],
inventory_turn: int,
inventory_entity: Optional[str],
agent_responses_since_last_user_utterance: list[str],
current_turn_has_user_utterance: bool,
) -> "SimBotFeedbackState":
"""Create the state in a simple way."""
# Conditions under which we should try to find a lightweight dialog action
require_lightweight_dialog = turn_requires_lightweight_dialog(
interaction_action=interaction_action,
utterance_queue_not_empty=utterance_queue_not_empty,
find_queue_not_empty=find_queue_not_empty,
verbal_interaction_intent=verbal_interaction_intent,
)
object_area = None
interaction_action_has_bbox = interaction_action is not None and isinstance(
interaction_action.payload, SimBotObjectInteractionPayload
)
if interaction_action_has_bbox:
object_area = get_area_from_compressed_mask(interaction_action.payload.object.mask) # type: ignore[union-attr]
return cls(
# Require a lightweight dialog action when the model does not decode a <stop token
require_lightweight_dialog=require_lightweight_dialog,
num_turns=num_turns,
current_room=current_room,
user_intent_type=user_intent_type,
environment_intent_type=environment_intent.type if environment_intent else None,
environment_intent_action_type=environment_intent.action
if environment_intent
else None,
environment_intent_entity=environment_intent.entity if environment_intent else None,
physical_interaction_intent_type=physical_interaction_intent.type
if physical_interaction_intent
else None,
physical_interaction_intent_entity=physical_interaction_intent.entity
if physical_interaction_intent
else None,
verbal_interaction_intent_type=verbal_interaction_intent.type
if verbal_interaction_intent
else None,
verbal_interaction_intent_entity=verbal_interaction_intent.entity
if verbal_interaction_intent
else None,
interaction_action_type=interaction_action.type if interaction_action else None,
interaction_action_entity=interaction_action.payload.entity_name
if interaction_action
else None,
visited_room_counter=Counter[str](current_room_per_turn),
interaction_action_per_turn=interaction_action_per_turn,
interacted_entities_counter=Counter[str](
[
action.payload.entity_name
for action in interaction_action_per_turn
if action.payload.entity_name is not None
]
),
action_type_counter=Counter[str](
action.type.name for action in interaction_action_per_turn
),
intent_types_per_turn=intent_types_per_turn,
intent_type_counter=Counter[str](
intent.name for intent in itertools.chain.from_iterable(intent_types_per_turn)
),
utterance_queue_not_empty=utterance_queue_not_empty,
find_queue_not_empty=find_queue_not_empty,
previous_find_queue_not_empty=previous_find_queue_not_empty,
used_rule_ids=used_rule_ids,
inventory_turn=inventory_turn,
inventory_entity=inventory_entity,
agent_responses_since_last_user_utterance=" ".join(
agent_responses_since_last_user_utterance
),
current_turn_has_user_utterance=current_turn_has_user_utterance,
object_area=object_area,
)
def to_query(self) -> dict[str, Any]:
"""Convert the state to a dictionary for the feedback engine."""
model_as_json = self.json(
exclude_unset=True, exclude_defaults=True, exclude_none=True
).lower()
model_as_dict = orjson.loads(model_as_json)
return model_as_dict