Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[RAG/FiD] Support Left Padded Inputs #4361

Merged
merged 3 commits into from Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 9 additions & 2 deletions parlai/agents/fid/fid.py
Expand Up @@ -447,6 +447,7 @@ def concat_enc_outs(
mask: torch.BoolTensor,
embedding_size: int,
padding_idx: int,
right_padded: bool = True,
) -> Tuple[torch.Tensor, torch.BoolTensor]:
"""
Concatenate Encoder Outputs.
Expand All @@ -464,6 +465,8 @@ def concat_enc_outs(
emb/hidden size of the enc representations
:param padding_idx:
pad token index; used for mask purposes.
:param right_padded:
whether the input is right padded
klshuster marked this conversation as resolved.
Show resolved Hide resolved

:return (new_out, new_mask):
return the encoder output and encoder mask, appropriately concatenated.
Expand All @@ -486,7 +489,11 @@ def concat_enc_outs(
new_mask.fill_(False)

for i, (out_i, length_i) in enumerate(zip(concat_outs, concat_lengths)):
new_out[i, :length_i] = out_i
new_mask[i, :length_i] = True
if right_padded:
new_out[i, :length_i] = out_i
new_mask[i, :length_i] = True
else:
new_out[i, new_out.size(1) - length_i :] = out_i
new_mask[i, new_out.size(1) - length_i :] = True

return new_out, new_mask
14 changes: 11 additions & 3 deletions parlai/agents/rag/modules.py
Expand Up @@ -329,6 +329,7 @@ def concat_docs_and_input(
input_lengths: torch.LongTensor,
top_docs: List[List[Document]],
max_num_docs: int,
right_padded: bool = True,
) -> torch.LongTensor:
"""
Add document tokens to input tokens.
Expand All @@ -341,6 +342,8 @@ def concat_docs_and_input(
list of n_docs top documents for each input sequence
:param max_num_docs:
maximum number of docs out of all examples
:param right_padded:
whether the input is right padded.

:return (tokens, lengths):
return expanded token vectors & corresponding lengths
Expand Down Expand Up @@ -368,7 +371,11 @@ def concat_docs_and_input(
self.expanded_input_truncate - self.min_doc_token_length,
input_i_len,
)
input_i = input_i[input_i_len - new_input_length : input_i_len]
if right_padded:
input_i = input_i[input_i_len - new_input_length : input_i_len]
else:
input_i = input_i[input_i.size(0) - new_input_length :]

doc_max_len = max(max_len - len(input_i), 0)
sample_doc_tokens = sample_doc_tokens[:doc_max_len]
expanded_input.append(
Expand All @@ -380,17 +387,18 @@ def concat_docs_and_input(
input_i_new = input_i.new(
self.n_positions - self.n_extra_positions
).fill_(self.pad_idx)
input_i_new[: input_i.size(0)] = input_i
input_i_new[input_i_new.size(0) - input_i.size(0) :] = input_i
expanded_input.append(torch.cat([input_i_new, sample_doc_tokens]))
# append extra null inputs if there are diff # of docs per input
expanded_input += [
input[i, :].new(input[i, :].size()).fill_(self.pad_idx)
] * (max_num_docs - len(docs))
expanded_input, _ = padded_tensor(
expanded_input,
fp16friendly=self.fp16,
fp16friendly=self.fp16 and right_padded,
max_len=max_len if self.n_extra_positions <= 0 else None,
pad_idx=self.pad_idx,
left_padded=not right_padded,
)
expanded_input = expanded_input.to(input.device)
return expanded_input # type: ignore
Expand Down
3 changes: 1 addition & 2 deletions parlai/agents/rag/rag.py
Expand Up @@ -300,8 +300,7 @@ def eval_step(self, batch: Batch) -> Optional[Output]:
output = super().eval_step(batch)
if output is None or not hasattr(self.model, 'retriever'):
return output
assert isinstance(self.model, RagModel)
if hasattr(self.model.retriever, 'top_docs'):
if hasattr(self.model.retriever, 'top_docs'): # type: ignore
output.top_docs = self.model.retriever.top_docs # type: ignore
return output

Expand Down
128 changes: 127 additions & 1 deletion tests/nightly/gpu/test_rag.py
Expand Up @@ -18,7 +18,8 @@

try:
from parlai.agents.rag.dpr import DprQueryEncoder
from parlai.agents.rag.retrievers import RetrievedChunkRanker
from parlai.agents.rag.retrievers import RetrievedChunkRanker, Document
from parlai.agents.fid.fid import concat_enc_outs
except ImportError:
pass

Expand Down Expand Up @@ -556,5 +557,130 @@ def test_chunker(self):
)


class TestLeftPadding(unittest.TestCase):
"""
Test whether left-padding functionality works.
"""

bsz = 4
seqlen = 32
n_docs = 5
esz = 16
batch_lens = [4, 8, 16, 32]
pad_idx = 0

def _create_input_and_mask(self, right_padded=True):
enc_input = torch.LongTensor(self.bsz, self.seqlen).fill_(0)
mask = torch.BoolTensor(self.bsz, self.seqlen).fill_(False)
for i, input_len in enumerate(self.batch_lens):
if right_padded:
enc_input[i, :input_len] = torch.arange(1, input_len + 1)
mask[i, :input_len] = True
else:
enc_input[i, -input_len:] = torch.arange(1, input_len + 1)
mask[i, -input_len:] = True
return enc_input, mask

def test_concat_enc_outs(self):
enc_output = torch.rand(self.bsz * self.n_docs, self.seqlen, self.esz)
enc_input, mask = self._create_input_and_mask()
# Right padded
mask = mask.repeat_interleave(self.n_docs, dim=0)
_, new_mask = concat_enc_outs(
enc_input, enc_output, mask, self.esz, self.pad_idx
)
########################################################################
# Assertion: new mask has `True` elements in first (n_docs * seqlen_i) #
# tokens in concatenated output #
########################################################################
assert all(
new_mask[i, : self.batch_lens[i] * self.n_docs].sum()
== self.n_docs * self.batch_lens[i]
for i in range(self.bsz)
)
# Left padded
enc_input, mask = self._create_input_and_mask(right_padded=False)
mask = mask.repeat_interleave(self.n_docs, dim=0)
_, new_mask = concat_enc_outs(
enc_input, enc_output, mask, self.esz, self.pad_idx, right_padded=False
)
#######################################################################
# Assertion: new mask has `True` elements in last (n_docs * seqlen_i) #
# tokens in concatenated output #
#######################################################################
assert all(
new_mask[i, -(self.batch_lens[i] * self.n_docs) :].sum()
== self.n_docs * self.batch_lens[i]
for i in range(self.bsz)
)

def test_concat_docs_and_input(self):
rag = create_agent(Opt({**test_opt, 'n_docs': self.n_docs}))
enc_input, _ = self._create_input_and_mask()
docs = [
[Document("title", "I am a document!", i) for i in range(self.n_docs)]
for _ in range(self.bsz)
]
doc_len = len(rag.dict.txt2vec(docs[0][0].get_passage_str()))
# right padded
expanded_output = rag.model.concat_docs_and_input(
enc_input, torch.LongTensor(self.batch_lens), docs, self.n_docs
)
############################################################
# Assertion: expanded output has non-pad elements in first #
# (doc_len + seq_len_i) tokens #
############################################################
assert all(
expanded_output[i, : doc_len + self.batch_lens[i // self.n_docs]]
.eq(0)
.sum()
== 0
for i in range(self.n_docs * self.bsz)
)
#######################################################
# Assertion: expanded output has pad elements in last #
# total_len - (doc_len + seq_len_i) tokens #
#######################################################
assert all(
expanded_output[i, doc_len + self.batch_lens[i // self.n_docs] :]
.eq(0)
.sum()
== expanded_output.size(1) - (doc_len + self.batch_lens[i // self.n_docs])
for i in range(self.n_docs * self.bsz)
)

# Left padded
enc_input, _ = self._create_input_and_mask(right_padded=False)
expanded_output = rag.model.concat_docs_and_input(
enc_input,
torch.LongTensor(self.batch_lens),
docs,
self.n_docs,
right_padded=False,
)
###########################################################
# Assertion: expanded output has non-pad elements in last #
# (doc_len + seq_len_i) tokens #
###########################################################
assert all(
expanded_output[i, -(doc_len + self.batch_lens[i // self.n_docs]) :]
.eq(0)
.sum()
== 0
for i in range(self.n_docs * self.bsz)
)
########################################################
# Assertion: expanded output has pad elements in first #
# total_len - (doc_len + seq_len_i) tokens #
########################################################
assert all(
expanded_output[i, : -(doc_len + self.batch_lens[i // self.n_docs])]
.eq(0)
.sum()
== expanded_output.size(1) - (doc_len + self.batch_lens[i // self.n_docs])
for i in range(self.n_docs * self.bsz)
)


if __name__ == '__main__':
unittest.main()