Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[Reranker] Set decoding method (#4473)
Browse files Browse the repository at this point in the history
* set decoding strategy

* cands to include more

* type

* add back inference opt key
  • Loading branch information
jxmsML committed Apr 5, 2022
1 parent 732c112 commit 2e7f917
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions parlai/agents/reranker/reranker.py
Expand Up @@ -62,7 +62,7 @@ def add_cmdline_args(cls, parser: ParlaiParser, partial_opt: Optional[Opt] = Non
'--reranker-delimiter',
type=str,
default=None,
help='delimiter for the retriever',
help='delimiter for the reranker',
)
return parser

Expand Down Expand Up @@ -505,11 +505,6 @@ def set_rerank_strategy(self, strategy: str):
assert strategy in RERANKER_STRATEGIES
self.reranker.reranker_strategy = strategy

def get_observations_for_reranker(
self, observations: List[Message]
) -> List[Message]:
return observations

def share(self):
"""
Share model parameters.
Expand All @@ -520,6 +515,14 @@ def share(self):
shared['reranker'] = self.reranker.share()
return shared

def set_decoding_method(self, strategy):
self.opt[self.inference_opt_key] = strategy

def get_observations_for_reranker(
self, observations: List[Message], batch_reply: List[Message]
) -> List[Message]:
return observations

def batch_act(self, observations: List[Message]) -> List[Message]:
"""
Batch process a list of observations.
Expand All @@ -530,15 +533,17 @@ def batch_act(self, observations: List[Message]) -> List[Message]:
batch_reply = [Message() for _ in range(len(observations))]
# 1. get all beam texts to consider
for strategy in self.inference_strategies:
self.opt[self.inference_opt_key] = strategy
self.set_decoding_method(strategy)
inference_batch_reply = super().batch_act(observations)
for i, resp in enumerate(inference_batch_reply):
beam_texts = batch_reply[i].get('beam_texts', [])
batch_reply[i] = resp # add metrics, other response items
new_beam_texts = [(*b, strategy) for b in resp.get('beam_texts', [])]
batch_reply[i].force_set('beam_texts', beam_texts + new_beam_texts)
# 2. Rerank
observations_for_reranker = self.get_observations_for_reranker(observations)
observations_for_reranker = self.get_observations_for_reranker(
observations, batch_reply
)
for observation, generator_response in zip(
observations_for_reranker, batch_reply
):
Expand Down

0 comments on commit 2e7f917

Please sign in to comment.