Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
add back inference opt key
Browse files Browse the repository at this point in the history
  • Loading branch information
jxmsML committed Apr 4, 2022
1 parent b26ac8a commit dd869cd
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions parlai/agents/reranker/reranker.py
Expand Up @@ -12,7 +12,7 @@
import logging
import torch
from abc import ABC, abstractmethod, abstractclassmethod
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from parlai.core.agents import create_agent_from_model_file, Agent
from parlai.core.build_data import modelzoo_path
Expand Down Expand Up @@ -283,7 +283,7 @@ def _rerank_candidates(
def rerank(
self,
observation: Message,
response_cands: Union[List[str], List[Message]],
response_cands: List[str],
response_cand_scores: torch.Tensor,
) -> Tuple[List[str], List[int]]:
"""
Expand Down Expand Up @@ -467,6 +467,12 @@ def add_cmdline_args(
default=False,
help='specify to enable certain debugging procedures.',
)
gen_agent.add_argument(
'--inference-opt-key',
type=str,
default='inference',
help='specify inference opt key for dialogue response model',
)

return parser

Expand All @@ -476,8 +482,9 @@ def __init__(self, opt: Opt, shared=None):
"""
super().__init__(opt, shared)
reranker_class = self.get_reranker_class()
self.inference_opt_key = opt.get('inference_opt_key', 'inference')
self.inference_strategies = (
opt['inference_strategies'] or opt['inference']
opt['inference_strategies'] or opt[self.inference_opt_key]
).split(',')
self.debug_mode = opt.get('debug_mode', False)
if not shared:
Expand All @@ -498,11 +505,6 @@ def set_rerank_strategy(self, strategy: str):
assert strategy in RERANKER_STRATEGIES
self.reranker.reranker_strategy = strategy

def get_observations_for_reranker(
self, observations: List[Message]
) -> List[Message]:
return observations

def share(self):
"""
Share model parameters.
Expand All @@ -514,10 +516,12 @@ def share(self):
return shared

def set_decoding_method(self, strategy):
self.opt['inference'] = strategy
self.opt[self.inference_opt_key] = strategy

def get_response_cands(self, generator_response):
return [b[0] for b in generator_response['beam_texts']]
def get_observations_for_reranker(
self, observations: List[Message], batch_reply: List[Message]
) -> List[Message]:
return observations

def batch_act(self, observations: List[Message]) -> List[Message]:
"""
Expand All @@ -537,7 +541,9 @@ def batch_act(self, observations: List[Message]) -> List[Message]:
new_beam_texts = [(*b, strategy) for b in resp.get('beam_texts', [])]
batch_reply[i].force_set('beam_texts', beam_texts + new_beam_texts)
# 2. Rerank
observations_for_reranker = self.get_observations_for_reranker(observations)
observations_for_reranker = self.get_observations_for_reranker(
observations, batch_reply
)
for observation, generator_response in zip(
observations_for_reranker, batch_reply
):
Expand All @@ -551,7 +557,7 @@ def batch_act(self, observations: List[Message]) -> List[Message]:
continue
reranked_candidates, indices = self.reranker.rerank(
observation,
self.get_response_cands(generator_response), # text
[b[0] for b in generator_response['beam_texts']], # text
torch.tensor([b[1] for b in generator_response['beam_texts']]), # score
)
if self.debug_mode:
Expand Down

0 comments on commit dd869cd

Please sign in to comment.