Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

wizint knowledge prediction teacher #3996

Merged
merged 5 commits into from Sep 8, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 36 additions & 0 deletions parlai/tasks/wizard_of_internet/agents.py
Expand Up @@ -571,3 +571,39 @@ def _knowledge_piece(self):
class GoldDocTitlesTeacher(BaseKnowledgeTeacher):
def _knowledge_piece(self):
return CONST.SELECTED_DOCS_TITLES


class PredictKnowledgeGivenLabelTeacher(WizardOfInternetBaseTeacher):
jaseweston marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, opt, shared=None):
super().__init__(opt, shared=shared)
self.id = 'PredictKnowledgeGivenLabelTeacher'

def _teacher_action_type(self) -> str:
return CONST.ACTION_WIZARD_DOC_SELECTION

def _knowledge_piece(self):
return CONST.SELECTED_SENTENCES

def additional_message_content(self, parlai_message: Message, action: Dict):
for item_key in (
CONST.SELECTED_DOCS,
CONST.SELECTED_DOCS_TITLES,
CONST.SELECTED_SENTENCES,
):
parlai_message[item_key] = action[item_key]

def create_parlai_message(self, dict_message: Dict):
parlai_msg = Message(
{
CONST.SPEAKER_ID: dict_message[CONST.SPEAKER_ID],
# CONST.MESSAGE_TEXT: dict_message[CONST.MESSAGE_TEXT] + "\n _label_ " +
CONST.LABELS: [' '.join(dict_message[CONST.SELECTED_SENTENCES])],
}
)
prv_msg = dict_message.get(CONST.PARTNER_PREVIOUS_MESSAGE)
label = '\n_label_ ' + dict_message[CONST.MESSAGE_TEXT]
jaseweston marked this conversation as resolved.
Show resolved Hide resolved
if prv_msg:
jaseweston marked this conversation as resolved.
Show resolved Hide resolved
parlai_msg[CONST.MESSAGE_TEXT] = prv_msg[1][CONST.MESSAGE_TEXT] + label
else:
parlai_msg[CONST.MESSAGE_TEXT] = label
return parlai_msg