diff --git a/parlai/agents/fid/fid.py b/parlai/agents/fid/fid.py index 28ef2638360..b0c5c1bf1a3 100644 --- a/parlai/agents/fid/fid.py +++ b/parlai/agents/fid/fid.py @@ -348,7 +348,13 @@ def __init__(self, opt: Opt, shared: TShared = None): 'GoldDocRetrieverFiDAgent only works with `rag_retriever_query` being `"full_history"`. ' f'Changing opt value for `rag_retriever_query`: `"{prev_sel}"` -> `"full_history"`' ) - + if not ( + opt['dynamic_batching'] in [None, 'off'] + and opt.get('eval_dynamic_batching') in [None, 'off'] + ): + raise RuntimeError( + "For now dynamic batching doesn't work with ObservationEchoRetriever as it cleans up _saved_docs mapping after each batch act." + ) super().__init__(opt, shared=shared) @abstractmethod @@ -376,6 +382,15 @@ def _set_query_vec(self, observation: Message) -> Message: self.show_observation_to_echo_retriever(observation) super()._set_query_vec(observation) + def batch_act(self, observations): + """ + Clear the _saved_docs and _query_ids mappings in ObservationEchoRetriever. + """ + batch_reply = super().batch_act(observations) + if hasattr(self.model_api.retriever, 'clear_mapping'): + self.model_api.retriever.clear_mapping() + return batch_reply + class WizIntGoldDocRetrieverFiDAgent(GoldDocRetrieverFiDAgent): """ diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index f7d472612d1..5d15059847b 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -1332,7 +1332,7 @@ def tokenize_query(self, query: str) -> List[int]: def get_delimiter(self) -> str: return self._delimiter - def _clear_mapping(self): + def clear_mapping(self): self._query_ids = dict() self._saved_docs = dict() self._largest_seen_idx = -1 @@ -1354,9 +1354,6 @@ def retrieve_and_score( query.device ) - # empty the 2 mappings after each retrieval - self._clear_mapping() - return retrieved_docs, retrieved_doc_scores @@ -1416,7 +1413,7 @@ def get_top_chunks( doc_url: str, ): """ - Return chunks according to the woi_chunk_retrieved_docs_mutator + Return chunks according to the woi_chunk_retrieved_docs_mutator. """ if isinstance(doc_chunks, list): docs = ''.join(doc_chunks)