-
Notifications
You must be signed in to change notification settings - Fork 0
/
nlu_output.py
42 lines (31 loc) · 1.56 KB
/
nlu_output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import cast
from loguru import logger
from emma_experience_hub.datamodels.simbot import (
SimBotIntent,
SimBotIntentType,
SimBotNLUIntentType,
)
from emma_experience_hub.parsers.parser import NeuralParser
class SimBotNLUOutputParser(NeuralParser[SimBotIntent[SimBotNLUIntentType]]):
"""Convert the output from the SimBot NLU module to a SimBot intent."""
def __init__(self, intent_type_delimiter: str) -> None:
self._intent_type_delimiter = intent_type_delimiter
def __call__(self, output_text: str) -> SimBotIntent[SimBotNLUIntentType]:
"""Parses the intent generated by the NLU component.
The model is trained with the following templates:
- <act><one_match>
- <act><no_match> object_name
- <act><too_many_matches> object_name
- <act><missing_inventory> object_name
- <search>
"""
logger.debug(f"NLU output text: `{output_text}`")
# Split the raw output text by the given delimiter. We assume it's a " " separating the
# special tokens and the object_name.
split_parts = output_text.split(self._intent_type_delimiter)
# Get the intent type from the left-side of the template.
intent_type = SimBotIntentType(split_parts[0])
intent_type = cast(SimBotNLUIntentType, intent_type)
# If it exists, get the object name from the right-side of the template
object_name = " ".join(split_parts[1:]) if len(split_parts) > 1 else None
return SimBotIntent(type=intent_type, entity=object_name)