Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[GoldDocFidAgent] clear mapping at the end of each eval_step rather t…
Browse files Browse the repository at this point in the history
…han retrieve step (#4503)

* clear mapping at the end

* apply to train step as well

* dyn

* comment

* black
  • Loading branch information
jxmsML committed Apr 20, 2022
1 parent ae2e12a commit 71c4d44
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
17 changes: 16 additions & 1 deletion parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,13 @@ def __init__(self, opt: Opt, shared: TShared = None):
'GoldDocRetrieverFiDAgent only works with `rag_retriever_query` being `"full_history"`. '
f'Changing opt value for `rag_retriever_query`: `"{prev_sel}"` -> `"full_history"`'
)

if not (
opt['dynamic_batching'] in [None, 'off']
and opt.get('eval_dynamic_batching') in [None, 'off']
):
raise RuntimeError(
"For now dynamic batching doesn't work with ObservationEchoRetriever as it cleans up _saved_docs mapping after each batch act."
)
super().__init__(opt, shared=shared)

@abstractmethod
Expand Down Expand Up @@ -376,6 +382,15 @@ def _set_query_vec(self, observation: Message) -> Message:
self.show_observation_to_echo_retriever(observation)
super()._set_query_vec(observation)

def batch_act(self, observations):
"""
Clear the _saved_docs and _query_ids mappings in ObservationEchoRetriever.
"""
batch_reply = super().batch_act(observations)
if hasattr(self.model_api.retriever, 'clear_mapping'):
self.model_api.retriever.clear_mapping()
return batch_reply


class WizIntGoldDocRetrieverFiDAgent(GoldDocRetrieverFiDAgent):
"""
Expand Down
7 changes: 2 additions & 5 deletions parlai/agents/rag/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def tokenize_query(self, query: str) -> List[int]:
def get_delimiter(self) -> str:
return self._delimiter

def _clear_mapping(self):
def clear_mapping(self):
self._query_ids = dict()
self._saved_docs = dict()
self._largest_seen_idx = -1
Expand All @@ -1354,9 +1354,6 @@ def retrieve_and_score(
query.device
)

# empty the 2 mappings after each retrieval
self._clear_mapping()

return retrieved_docs, retrieved_doc_scores


Expand Down Expand Up @@ -1416,7 +1413,7 @@ def get_top_chunks(
doc_url: str,
):
"""
Return chunks according to the woi_chunk_retrieved_docs_mutator
Return chunks according to the woi_chunk_retrieved_docs_mutator.
"""
if isinstance(doc_chunks, list):
docs = ''.join(doc_chunks)
Expand Down

0 comments on commit 71c4d44

Please sign in to comment.