Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[PACER] Adding flexibility + removing padding idx before v2t (#4488)
Browse files Browse the repository at this point in the history
* pacer class

* lint

* don't change self.context
  • Loading branch information
Jing authored Apr 13, 2022
1 parent 1cb9083 commit ae2e12a
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions projects/light_whoami/agents/pacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn.functional as F
from typing import Optional, Any, Dict, List
from parlai.agents.rag.retrievers import clean_vec

from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from parlai.core.opt import Opt
Expand All @@ -34,17 +35,27 @@
from projects.light_whoami.task.utils import extract_characters
from projects.msc.agents.long_tga import TransformerVariantAgent

from parlai.agents.reranker.reranker import AbstractReranker


class PacerAgentMixin:
"""
Override TGA to use a different tree search decoder.
"""

@classmethod
def get_partial_only_reranker_class(cls) -> AbstractReranker:
"""
Return class to instantiate classifier.
"""
return RPAReranker

@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
RPAReranker.add_cmdline_args(parser, partial_opt=partial_opt)
reranker_class = cls.get_partial_only_reranker_class() or AbstractReranker
reranker_class.add_cmdline_args(parser, partial_opt=partial_opt)
group = parser.add_argument_group('PACER Group')
group.add_argument(
'--pacer-n-tokens',
Expand All @@ -62,8 +73,9 @@ def add_cmdline_args(

def __init__(self, opt: Opt, shared=None):
super().__init__(opt, shared)
if not shared:
self.classifier = RPAReranker(opt)
reranker_class = self.get_partial_only_reranker_class()
if not (shared and 'classifier' in shared):
self.classifier = reranker_class(opt)
else:
self.classifier = shared['classifier']
assert opt[
Expand Down Expand Up @@ -181,15 +193,22 @@ def __init__(self, *args, **kwargs):
self.frequency = kwargs.pop('pacer_frequency_ratio')
super().__init__(*args, **kwargs)

def get_target_character(self):
return extract_characters(self.context_str)['_self_name']

def set_batch_context(
self: TSType, batch_context_list: List[List[int]], batch_idx: int
) -> TSType:
"""
Override to save de-tokenized version of context.
"""
# remove pad_idx from the batch vec
self.context = batch_context_list[batch_idx]
self.context_str = self.agent._v2t(self.context)
self.character = extract_characters(self.context_str)['_self_name']
clean_context = clean_vec(
batch_context_list[batch_idx], self.agent.END_IDX, [self.agent.NULL_IDX]
)
self.context_str = self.agent._v2t(clean_context)
self.character = self.get_target_character()
return self

def select_paths(
Expand Down Expand Up @@ -257,7 +276,7 @@ def modify_logprobs(self, logprobs: torch.Tensor) -> torch.Tensor:
torch.stack(
[
F.log_softmax(pred['sorted_scores'].float(), dim=0)[
int(pred['text'] == self.character) - 1
pred['text_candidates'].index(self.character)
]
for pred in predictor_outputs
]
Expand Down

0 comments on commit ae2e12a

Please sign in to comment.