Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[RAG/FiD] Support Left Padded Inputs #4361

Merged
merged 3 commits into from Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 9 additions & 2 deletions parlai/agents/fid/fid.py
Expand Up @@ -447,6 +447,7 @@ def concat_enc_outs(
mask: torch.BoolTensor,
embedding_size: int,
padding_idx: int,
right_padded: bool = True,
) -> Tuple[torch.Tensor, torch.BoolTensor]:
"""
Concatenate Encoder Outputs.
Expand All @@ -464,6 +465,8 @@ def concat_enc_outs(
emb/hidden size of the enc representations
:param padding_idx:
pad token index; used for mask purposes.
:param right_padded:
whether the input is right padded (true) or left padded (false)

:return (new_out, new_mask):
return the encoder output and encoder mask, appropriately concatenated.
Expand All @@ -486,7 +489,11 @@ def concat_enc_outs(
new_mask.fill_(False)

for i, (out_i, length_i) in enumerate(zip(concat_outs, concat_lengths)):
new_out[i, :length_i] = out_i
new_mask[i, :length_i] = True
if right_padded:
new_out[i, :length_i] = out_i
new_mask[i, :length_i] = True
else:
new_out[i, new_out.size(1) - length_i :] = out_i
new_mask[i, new_out.size(1) - length_i :] = True

return new_out, new_mask
14 changes: 11 additions & 3 deletions parlai/agents/rag/modules.py
Expand Up @@ -329,6 +329,7 @@ def concat_docs_and_input(
input_lengths: torch.LongTensor,
top_docs: List[List[Document]],
max_num_docs: int,
right_padded: bool = True,
) -> torch.LongTensor:
"""
Add document tokens to input tokens.
Expand All @@ -341,6 +342,8 @@ def concat_docs_and_input(
list of n_docs top documents for each input sequence
:param max_num_docs:
maximum number of docs out of all examples
:param right_padded:
whether the input is right padded.

:return (tokens, lengths):
return expanded token vectors & corresponding lengths
Expand Down Expand Up @@ -368,7 +371,11 @@ def concat_docs_and_input(
self.expanded_input_truncate - self.min_doc_token_length,
input_i_len,
)
input_i = input_i[input_i_len - new_input_length : input_i_len]
if right_padded:
input_i = input_i[input_i_len - new_input_length : input_i_len]
else:
input_i = input_i[input_i.size(0) - new_input_length :]

doc_max_len = max(max_len - len(input_i), 0)
sample_doc_tokens = sample_doc_tokens[:doc_max_len]
expanded_input.append(
Expand All @@ -380,17 +387,18 @@ def concat_docs_and_input(
input_i_new = input_i.new(
self.n_positions - self.n_extra_positions
).fill_(self.pad_idx)
input_i_new[: input_i.size(0)] = input_i
input_i_new[input_i_new.size(0) - input_i.size(0) :] = input_i
expanded_input.append(torch.cat([input_i_new, sample_doc_tokens]))
# append extra null inputs if there are diff # of docs per input
expanded_input += [
input[i, :].new(input[i, :].size()).fill_(self.pad_idx)
] * (max_num_docs - len(docs))
expanded_input, _ = padded_tensor(
expanded_input,
fp16friendly=self.fp16,
fp16friendly=self.fp16 and right_padded,
max_len=max_len if self.n_extra_positions <= 0 else None,
pad_idx=self.pad_idx,
left_padded=not right_padded,
)
expanded_input = expanded_input.to(input.device)
return expanded_input # type: ignore
Expand Down
3 changes: 1 addition & 2 deletions parlai/agents/rag/rag.py
Expand Up @@ -300,8 +300,7 @@ def eval_step(self, batch: Batch) -> Optional[Output]:
output = super().eval_step(batch)
if output is None or not hasattr(self.model, 'retriever'):
return output
assert isinstance(self.model, RagModel)
if hasattr(self.model.retriever, 'top_docs'):
if hasattr(self.model.retriever, 'top_docs'): # type: ignore
output.top_docs = self.model.retriever.top_docs # type: ignore
return output

Expand Down
128 changes: 127 additions & 1 deletion tests/nightly/gpu/test_rag.py
Expand Up @@ -18,7 +18,8 @@

try:
from parlai.agents.rag.dpr import DprQueryEncoder
from parlai.agents.rag.retrievers import RetrievedChunkRanker
from parlai.agents.rag.retrievers import RetrievedChunkRanker, Document
from parlai.agents.fid.fid import concat_enc_outs
except ImportError:
pass

Expand Down Expand Up @@ -556,5 +557,130 @@ def test_chunker(self):
)


class TestLeftPadding(unittest.TestCase):
"""
Test whether left-padding functionality works.
"""

bsz = 4
seqlen = 32
n_docs = 5
esz = 16
batch_lens = [4, 8, 16, 32]
pad_idx = 0

def _create_input_and_mask(self, right_padded=True):
enc_input = torch.LongTensor(self.bsz, self.seqlen).fill_(0)
mask = torch.BoolTensor(self.bsz, self.seqlen).fill_(False)
for i, input_len in enumerate(self.batch_lens):
if right_padded:
enc_input[i, :input_len] = torch.arange(1, input_len + 1)
mask[i, :input_len] = True
else:
enc_input[i, -input_len:] = torch.arange(1, input_len + 1)
mask[i, -input_len:] = True
return enc_input, mask

def test_concat_enc_outs(self):
enc_output = torch.rand(self.bsz * self.n_docs, self.seqlen, self.esz)
enc_input, mask = self._create_input_and_mask()
# Right padded
mask = mask.repeat_interleave(self.n_docs, dim=0)
_, new_mask = concat_enc_outs(
enc_input, enc_output, mask, self.esz, self.pad_idx
)
########################################################################
# Assertion: new mask has `True` elements in first (n_docs * seqlen_i) #
# tokens in concatenated output #
########################################################################
assert all(
new_mask[i, : self.batch_lens[i] * self.n_docs].sum()
== self.n_docs * self.batch_lens[i]
for i in range(self.bsz)
)
# Left padded
enc_input, mask = self._create_input_and_mask(right_padded=False)
mask = mask.repeat_interleave(self.n_docs, dim=0)
_, new_mask = concat_enc_outs(
enc_input, enc_output, mask, self.esz, self.pad_idx, right_padded=False
)
#######################################################################
# Assertion: new mask has `True` elements in last (n_docs * seqlen_i) #
# tokens in concatenated output #
#######################################################################
assert all(
new_mask[i, -(self.batch_lens[i] * self.n_docs) :].sum()
== self.n_docs * self.batch_lens[i]
for i in range(self.bsz)
)

def test_concat_docs_and_input(self):
rag = create_agent(Opt({**test_opt, 'n_docs': self.n_docs}))
enc_input, _ = self._create_input_and_mask()
docs = [
[Document("title", "I am a document!", i) for i in range(self.n_docs)]
for _ in range(self.bsz)
]
doc_len = len(rag.dict.txt2vec(docs[0][0].get_passage_str()))
# right padded
expanded_output = rag.model.concat_docs_and_input(
enc_input, torch.LongTensor(self.batch_lens), docs, self.n_docs
)
############################################################
# Assertion: expanded output has non-pad elements in first #
# (doc_len + seq_len_i) tokens #
############################################################
assert all(
expanded_output[i, : doc_len + self.batch_lens[i // self.n_docs]]
.eq(0)
.sum()
== 0
for i in range(self.n_docs * self.bsz)
)
#######################################################
# Assertion: expanded output has pad elements in last #
# total_len - (doc_len + seq_len_i) tokens #
#######################################################
assert all(
expanded_output[i, doc_len + self.batch_lens[i // self.n_docs] :]
.eq(0)
.sum()
== expanded_output.size(1) - (doc_len + self.batch_lens[i // self.n_docs])
for i in range(self.n_docs * self.bsz)
)

# Left padded
enc_input, _ = self._create_input_and_mask(right_padded=False)
expanded_output = rag.model.concat_docs_and_input(
enc_input,
torch.LongTensor(self.batch_lens),
docs,
self.n_docs,
right_padded=False,
)
###########################################################
# Assertion: expanded output has non-pad elements in last #
# (doc_len + seq_len_i) tokens #
###########################################################
assert all(
expanded_output[i, -(doc_len + self.batch_lens[i // self.n_docs]) :]
.eq(0)
.sum()
== 0
for i in range(self.n_docs * self.bsz)
)
########################################################
# Assertion: expanded output has pad elements in first #
# total_len - (doc_len + seq_len_i) tokens #
########################################################
assert all(
expanded_output[i, : -(doc_len + self.batch_lens[i // self.n_docs])]
.eq(0)
.sum()
== expanded_output.size(1) - (doc_len + self.batch_lens[i // self.n_docs])
for i in range(self.n_docs * self.bsz)
)


if __name__ == '__main__':
unittest.main()