Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[RAG/FiD/BB2] Incremental Decoding (#4088)
Browse files Browse the repository at this point in the history
* incremental decoding

* update bb2

* handle gold docs correctly

* comment
  • Loading branch information
klshuster committed Oct 13, 2021
1 parent 0a2e683 commit ac0d5d3
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 16 deletions.
27 changes: 25 additions & 2 deletions parlai/agents/fid/fid.py
Expand Up @@ -19,7 +19,11 @@
from parlai.agents.rag.args import RetrieverType
from parlai.agents.rag.modules import RagModel, Document, T5RagModel
from parlai.agents.rag.rag import RagAgent
from parlai.agents.rag.model_types import RagToken, get_forced_decoder_inputs
from parlai.agents.rag.model_types import (
RagToken,
get_forced_decoder_inputs,
fix_incremental_state,
)
from parlai.utils.typing import TShared


Expand Down Expand Up @@ -66,7 +70,7 @@ class FidModel(RagModel):

def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None):
super().__init__(opt, dictionary, retriever_shared=retriever_shared)
self.rag_model_interface = Fid(opt, dictionary[dictionary.null_token])
self._rag_model_interface = Fid(opt, dictionary[dictionary.null_token])
self.embedding_size = opt['embedding_size']

def reorder_encoder_states(
Expand All @@ -84,6 +88,25 @@ def reorder_encoder_states(
self, (enc, mask), indices
)

def reorder_decoder_incremental_state(
self, incremental_state: Dict[int, dict], inds: torch.Tensor
) -> Dict[int, dict]:
"""
Override RagModel.reorder_decoder_incremental_state to resort back
to normal reordering.
See ``TorchGeneratorModel.reorder_decoder_incremental_state`` for a description.
"""
incremental_state = fix_incremental_state(
self.generation_model, incremental_state
)
if not incremental_state:
return incremental_state
return {
idx: layer.reorder_incremental_state(incremental_state[idx], inds)
for idx, layer in enumerate(self.seq2seq_decoder.layers)
}

def encoder(
self,
input: torch.LongTensor,
Expand Down
116 changes: 114 additions & 2 deletions parlai/agents/rag/model_types.py
Expand Up @@ -15,7 +15,7 @@
import torch.nn
import torch.nn.functional as F
import torch.cuda
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union


from parlai.core.message import Message
Expand All @@ -24,7 +24,7 @@
from parlai.core.torch_agent import Batch
from parlai.utils.torch import padded_tensor, FP16_PAD_SIZE

from parlai.agents.rag.modules import RagModel
from parlai.agents.rag.modules import RagDecoder, RagModel
from parlai.agents.rag.retrievers import Document


Expand Down Expand Up @@ -95,6 +95,33 @@ def get_forced_decoder_inputs(
return dec_inputs # type: ignore


def fix_incremental_state(
generation_model: str, incremental_state: Dict[int, Any]
) -> Dict[int, Any]:
"""
Fix incremental state. Essentially takes BART into account.
:param generation_model:
generation model
:param incremental_state:
incremental decoded state
"""
if generation_model == 'bart':
for incr_state_l in incremental_state.values():
assert 'self_attn' in incr_state_l
assert 'prev_mask' in incr_state_l['self_attn']
self_attn_mask = incr_state_l['self_attn']['prev_mask']
# check this is on the very first run with incremental state
if self_attn_mask.ndim == 3 and tuple(self_attn_mask.shape[1:]) == (2, 2):
# cut off the inappropriate incremental state
incr_state_l['self_attn']['prev_mask'] = self_attn_mask[:, -1:, :]
elif generation_model == 't5':
# No current solution for t5 exists.
incremental_state = {}

return incremental_state


class RagModelInterface(ABC):
"""
Define an interface for the RAG Model Types.
Expand Down Expand Up @@ -198,6 +225,21 @@ def reorder_encoder_states(
See ``TorchGeneratorModel.reorder_encoder_states`` for a description.
"""

@abstractmethod
def reorder_decoder_incremental_state(
self,
incremental_state: Dict[int, Any],
inds: Union[List[int], torch.LongTensor],
decoder: RagDecoder,
) -> Dict[int, dict]:
"""
Reorder the decoder incremental state, for incremental decoding.
See ``TorchGeneratorModel.reorder_decoder_incremental_state`` for a description.
Each RagModelType will require specialized reordering, depending on the method used.
"""

###########################################
# Loss Computation/Output Marginalization #
###########################################
Expand Down Expand Up @@ -340,6 +382,28 @@ def reorder_encoder_states(
mask = torch.index_select(mask, 0, indices)
return enc, mask, None, None, None # type: ignore

def reorder_decoder_incremental_state(
self,
incremental_state: Dict[int, Any],
inds: Union[List[int], torch.LongTensor],
decoder: RagDecoder,
) -> Dict[int, dict]:
"""
For RAG Sequence, each doc/context pair is it's own batch item.
So, we can simply reorder normally.
"""
assert incremental_state is not None
incremental_state = fix_incremental_state(
self.generation_model, incremental_state
)
if not incremental_state:
return incremental_state
return {
idx: layer.reorder_incremental_state(incremental_state[idx], inds)
for idx, layer in enumerate(decoder.layers)
}

def get_ctxt_index(self, batch: Batch, batch_idx: int) -> int:
"""
Map the batch_idx back to the appropriate batch item during generation.
Expand Down Expand Up @@ -761,6 +825,43 @@ def reorder_encoder_states(

return enc, mask, input_turns_cnt, docs, doc_probs # type: ignore

def reorder_decoder_incremental_state(
self,
incremental_state: Dict[int, Any],
inds: Union[List[int], torch.LongTensor],
decoder: RagDecoder,
) -> Dict[int, dict]:
"""
For RAG Token, we send each decoder input through n_docs times.
Similarly to reordering the encoder states, we need to reorder according
to the documents dimensions.
"""
assert incremental_state is not None
incremental_state = fix_incremental_state(
self.generation_model, incremental_state
)
if not incremental_state:
return incremental_state
for incr_state_l in incremental_state.values():
for key in incr_state_l:
for sub_key in incr_state_l[key]:
incr_state_l[key][sub_key] = _unstack_ctxt(
incr_state_l[key][sub_key], self.n_docs
)

new_incr_state = {
idx: layer.reorder_incremental_state(incremental_state[idx], inds)
for idx, layer in enumerate(decoder.layers)
}

for incr_state_l in new_incr_state.values():
for key in incr_state_l:
for sub_key in incr_state_l[key]:
incr_state_l[key][sub_key] = _stack_ctxt(incr_state_l[key][sub_key])

return new_incr_state

def rerank_beams(
self,
model: RagModel,
Expand Down Expand Up @@ -1053,6 +1154,17 @@ def reorder_encoder_states(

return enc, mask, input_turns_cnt, docs, doc_probs # type: ignore

def reorder_decoder_incremental_state(
self,
incremental_state: Dict[int, Any],
inds: Union[List[int], torch.LongTensor],
decoder: RagDecoder,
) -> Dict[int, dict]:
"""
Unsupported for Rag Turn.
"""
return None # type: ignore

def rerank_beams(
self,
model: RagModel,
Expand Down
16 changes: 9 additions & 7 deletions parlai/agents/rag/modules.py
Expand Up @@ -65,7 +65,7 @@ def __init__(self, opt, dictionary, retriever_shared=None):
)
# attrs
self.rag_model_type = opt['rag_model_type']
self.rag_model_interface = RAG_MODELS[self.rag_model_type](opt, self.pad_idx)
self._rag_model_interface = RAG_MODELS[self.rag_model_type](opt, self.pad_idx)
self.generation_model = opt['generation_model']
self.n_extra_positions = opt['n_extra_positions']
self.n_positions = get_n_positions_from_options(opt) + opt['n_extra_positions']
Expand Down Expand Up @@ -226,7 +226,7 @@ def decoder(
) # [bsz * beam_size, n_docs, input_len, esz]

# 3. Marginalize
marginalized = self.rag_model_interface.marginalize(
marginalized = self._rag_model_interface.marginalize(
out_probs, F.log_softmax(doc_scores, dim=1), input_turns_cnt
)
else:
Expand Down Expand Up @@ -257,7 +257,7 @@ def seq2seq_forward_pass(
bsz = ys.size(0)
seqlen = ys.size(1)
inputs = ys.narrow(1, 0, seqlen - 1)
dec_inputs = self.rag_model_interface.get_initial_forced_decoder_input(
dec_inputs = self._rag_model_interface.get_initial_forced_decoder_input(
bsz,
inputs,
n_docs=1,
Expand Down Expand Up @@ -415,17 +415,19 @@ def reorder_encoder_states(
indices = torch.LongTensor(indices).to(
encoder_states[0].device
) # type: ignore
return self.rag_model_interface.reorder_encoder_states(encoder_states, indices)
return self._rag_model_interface.reorder_encoder_states(encoder_states, indices)

def reorder_decoder_incremental_state(
self,
incremental_state: Dict[str, Any],
inds: Union[List[int], torch.LongTensor],
) -> Optional[Dict[str, Any]]:
) -> Optional[Dict[int, dict]]:
"""
TODO: Determine how to do this
"""
return None
return self._rag_model_interface.reorder_decoder_incremental_state(
incremental_state, inds, self.seq2seq_decoder
)

def decode_forced(
self, encoder_states: Tuple[torch.Tensor, ...], ys: torch.LongTensor
Expand Down Expand Up @@ -458,7 +460,7 @@ def decode_forced(
)
doc_scores = encoder_states[-1]

inputs = self.rag_model_interface.get_initial_forced_decoder_input(
inputs = self._rag_model_interface.get_initial_forced_decoder_input(
bsz,
inputs,
n_docs=doc_scores.size(1) if doc_scores is not None else None,
Expand Down
4 changes: 3 additions & 1 deletion projects/blenderbot2/agents/blenderbot2.py
Expand Up @@ -36,6 +36,7 @@
SELECTED_DOCS,
SELECTED_DOCS_TITLES,
SELECTED_SENTENCES,
NO_SELECTED_DOCS_TOKEN,
)
from parlai.utils.torch import padded_3d

Expand Down Expand Up @@ -607,7 +608,8 @@ def _set_gold_doc_vec(self, observation: Message) -> Message:
:return observation:
return observation with gold doc vec.
"""
if not observation[self.opt['gold_document_key']]:
gold_docs = observation[self.opt['gold_document_key']]
if not gold_docs or gold_docs == [NO_SELECTED_DOCS_TOKEN]:
return observation
doc_vecs = None
doc_title_vecs = None
Expand Down
37 changes: 36 additions & 1 deletion projects/blenderbot2/agents/modules.py
Expand Up @@ -11,7 +11,7 @@
import torch.nn
from typing import List, Tuple, Dict, Optional

from parlai.agents.fid.fid import FidModel, T5FidModel, concat_enc_outs
from parlai.agents.fid.fid import FidModel, T5FidModel, concat_enc_outs, Fid
from parlai.agents.rag.args import RetrieverType
from parlai.agents.rag.rag import RagModel, T5RagModel
from parlai.agents.rag.dpr import DprQueryEncoder, DprDocumentEncoder
Expand Down Expand Up @@ -76,6 +76,8 @@ class BlenderBot2RagModel(RagModel):
"""

def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None):
from .blenderbot2 import RAG_MODELS

# TODO: Get rid of this hack
opt['converting'] = True
super().__init__(opt, dictionary, retriever_shared)
Expand All @@ -97,6 +99,7 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None)
self.memory_decoder = MemoryDecoder(opt)

# attrs
self._rag_model_interface = RAG_MODELS[self.rag_model_type](opt, self.pad_idx)
self.knowledge_access_method = KnowledgeAccessMethod(
opt['knowledge_access_method']
)
Expand Down Expand Up @@ -323,6 +326,14 @@ def retrieve_and_concat(
num_memories = num_memories.repeat_interleave(
input_turns_cnt, dim=0
) # type: ignore
if memory_decoder_vec is not None:
memory_decoder_vec = memory_decoder_vec.repeat_interleave(
input_turns_cnt, dim=0
) # type: ignore
if num_memory_decoder_vecs is not None:
num_memory_decoder_vecs = num_memory_decoder_vecs.repeat_interleave(
input_turns_cnt, dim=0
) # type: ignore
n_input = (
input_turns_cnt.sum().item()
if input_turns_cnt is not None
Expand Down Expand Up @@ -777,10 +788,34 @@ def retrieve_and_score(
return top_docs, torch.stack(top_doc_scores)


class BlenderBot2Fid(Fid):
"""
FiD Interface for BB2.
"""

def __init__(self, opt: Opt, null_idx: int):
super().__init__(opt, null_idx)
if (
KnowledgeAccessMethod(opt['knowledge_access_method'])
is KnowledgeAccessMethod.ALL
):
# Need to account for memories + search results
self.n_docs *= 2


class BlenderBot2FidModelMixin:
embedding_size: int
pad_idx: int

def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None):
super().__init__(
opt, dictionary, retriever_shared=retriever_shared
) # type: ignore
self._rag_model_interface = BlenderBot2Fid(
opt, dictionary[dictionary.null_token]
)
self.embedding_size = opt['embedding_size']

def encoder(
self,
input: torch.LongTensor,
Expand Down
5 changes: 2 additions & 3 deletions tests/nightly/gpu/test_bb2.py
Expand Up @@ -25,8 +25,8 @@
LOCAL = True

if TRANSFORMER_INSTALLED:
SEARCH_QUERY_MODEL = ZOO_MEMORY_DECODER
PERSONA_SUMMARY_MODEL = ZOO_QUERY_GENERATOR
SEARCH_QUERY_MODEL = ZOO_QUERY_GENERATOR
PERSONA_SUMMARY_MODEL = ZOO_MEMORY_DECODER
ZOO_BB2 = 'zoo:blenderbot2/blenderbot2_400M/model'
ZOO_BB2_3B = 'zoo:blenderbot2/blenderbot2_3B/model'
SEARCH_SERVER = '<SERVER_API>'
Expand All @@ -50,7 +50,6 @@ def _test_bb2_rag(retrieval_method: KnowledgeAccessMethod, **kwargs):
opt = copy.deepcopy(common_opt)
opt['knowledge_access_method'] = retrieval_method.value
opt.update(dict(kwargs))
print(' '.join([f'--{k} {v}' for k, v in opt.items()]))
testing_utils.eval_model(opt, skip_test=True)
torch.cuda.empty_cache()

Expand Down

0 comments on commit ac0d5d3

Please sign in to comment.