-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Fix and improve REALM fine-tuning #15297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
67bdd86
0070268
ebd1d0d
c2734aa
3eef182
d6a38ce
a3c9916
f8d6b3c
acf7617
bfca8df
2c22cd3
823843b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -836,13 +836,13 @@ def __init__(self, config): | |
| self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps) | ||
| self.relu = nn.ReLU() | ||
|
|
||
| def forward(self, hidden_states, token_type_ids): | ||
| def forward(self, hidden_states, block_mask): | ||
| def span_candidates(masks): | ||
| """ | ||
| Generate span candidates. | ||
|
|
||
| Args: | ||
| masks: <int32> [num_retrievals, max_sequence_len] | ||
| masks: <bool> [num_retrievals, max_sequence_len] | ||
|
|
||
| Returns: | ||
| starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans] | ||
|
|
@@ -875,8 +875,7 @@ def mask_to_score(mask): | |
| hidden_states = self.dense_intermediate(hidden_states) | ||
| # [reader_beam_size, max_sequence_len, span_hidden_size] | ||
| start_projection, end_projection = hidden_states.chunk(2, dim=-1) | ||
| block_mask = token_type_ids.detach().clone() | ||
| block_mask[:, -1] = 0 | ||
|
|
||
| candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask) | ||
|
|
||
| candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts) | ||
|
|
@@ -1543,6 +1542,7 @@ def forward( | |
| head_mask=None, | ||
| inputs_embeds=None, | ||
| relevance_score=None, | ||
| block_mask=None, | ||
|
||
| start_positions=None, | ||
| end_positions=None, | ||
| has_answers=None, | ||
|
|
@@ -1552,12 +1552,15 @@ def forward( | |
| ): | ||
| r""" | ||
| relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*): | ||
| Relevance score, which must be specified if you want to compute the marginal log loss. | ||
| start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
| Relevance score, which must be specified if you want to compute the logits and marginal log loss. | ||
| block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*): | ||
| The mask of the evidence block, which must be specified if you want to compute the logits and marginal log | ||
| loss. | ||
| start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*): | ||
| Labels for position (index) of the start of the labelled span for computing the token classification loss. | ||
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence | ||
| are not taken into account for computing the loss. | ||
| end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
| end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*): | ||
| Labels for position (index) of the end of the labelled span for computing the token classification loss. | ||
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence | ||
| are not taken into account for computing the loss. | ||
|
|
@@ -1570,8 +1573,8 @@ def forward( | |
|
|
||
| if relevance_score is None: | ||
| raise ValueError("You have to specify `relevance_score` to calculate logits and loss.") | ||
| if token_type_ids is None: | ||
| raise ValueError("You have to specify `token_type_ids` to separate question block and evidence block.") | ||
| if block_mask is None: | ||
| raise ValueError("You have to specify `block_mask` to separate question block and evidence block.") | ||
| if token_type_ids.size(1) < self.config.max_span_width: | ||
| raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.") | ||
| outputs = self.realm( | ||
|
|
@@ -1590,7 +1593,9 @@ def forward( | |
| sequence_output = outputs[0] | ||
|
|
||
| # [reader_beam_size, num_candidates], [num_candidates], [num_candidates] | ||
| reader_logits, candidate_starts, candidate_ends = self.qa_outputs(sequence_output, token_type_ids) | ||
| reader_logits, candidate_starts, candidate_ends = self.qa_outputs( | ||
| sequence_output, block_mask[0 : self.config.reader_beam_size] | ||
| ) | ||
| # [searcher_beam_size, 1] | ||
| retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1) | ||
| # [reader_beam_size, num_candidates] | ||
|
|
@@ -1737,11 +1742,21 @@ def __init__(self, config, retriever=None): | |
| self.post_init() | ||
|
|
||
| @property | ||
| def beam_size(self): | ||
| def searcher_beam_size(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This breaks backward compatibility here -
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's the only thing which would be nice to revert - the rest looks good to me! |
||
| if self.training: | ||
| return self.config.searcher_beam_size | ||
| return self.config.reader_beam_size | ||
|
|
||
| def block_embedding_to(self, device): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok for me even though I don't think it's really necessary
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten - Do you think if it is better that I just add some tips in the model docs, e.g. telling users they can manually send We still need to prompt users how they can fine-tune the model on a 12Gb memory GPU as promised by the paper.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the best would be to add a comment under But it's also ok for me to leave as is now - your choice :-) |
||
| """Send `self.block_emb` to a specific device. | ||
|
|
||
| Args: | ||
| device (`str` or `torch.device`): | ||
| The device to which `self.block_emb` will be sent. | ||
| """ | ||
|
|
||
| self.block_emb = self.block_emb.to(device) | ||
|
|
||
| @add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length")) | ||
| @replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC) | ||
| def forward( | ||
|
|
@@ -1787,43 +1802,45 @@ def forward( | |
| question_outputs = self.embedder( | ||
| input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True | ||
| ) | ||
|
|
||
| # [1, projection_size] | ||
| question_projection = question_outputs[0] | ||
|
|
||
| # CPU computation starts. | ||
| # [1, block_emb_size] | ||
| batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection) | ||
| batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection.to(self.block_emb.device)) | ||
| # [1, searcher_beam_size] | ||
| _, retrieved_block_ids = torch.topk(batch_scores, k=self.beam_size, dim=-1) | ||
| _, retrieved_block_ids = torch.topk(batch_scores, k=self.searcher_beam_size, dim=-1) | ||
| # [searcher_beam_size] | ||
| # Must convert to cpu tensor for subsequent numpy operations | ||
| retrieved_block_ids = retrieved_block_ids.squeeze().cpu() | ||
| retrieved_block_ids = retrieved_block_ids.squeeze() | ||
| # [searcher_beam_size, projection_size] | ||
| retrieved_block_emb = torch.index_select(self.block_emb, dim=0, index=retrieved_block_ids) | ||
| # CPU computation ends. | ||
|
|
||
| # Retrieve possible answers | ||
| has_answers, start_pos, end_pos, concat_inputs = self.retriever( | ||
| retrieved_block_ids, input_ids, answer_ids, max_length=self.config.reader_seq_len | ||
| retrieved_block_ids.cpu(), input_ids, answer_ids, max_length=self.config.reader_seq_len | ||
| ) | ||
|
|
||
| concat_inputs = concat_inputs.to(self.reader.device) | ||
| block_mask = concat_inputs.special_tokens_mask.type(torch.bool).to(device=self.reader.device) | ||
| block_mask.logical_not_().logical_and_(concat_inputs.token_type_ids.type(torch.bool)) | ||
|
|
||
| if has_answers is not None: | ||
| has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device) | ||
| start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device) | ||
| end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device) | ||
|
|
||
| concat_inputs = concat_inputs.to(self.reader.device) | ||
|
|
||
| # [searcher_beam_size, projection_size] | ||
| retrieved_block_emb = torch.index_select( | ||
| self.block_emb, dim=0, index=retrieved_block_ids.to(self.block_emb.device) | ||
| ) | ||
| # [searcher_beam_size] | ||
| retrieved_logits = torch.einsum( | ||
| "D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(question_projection.device) | ||
| "D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(self.reader.device) | ||
| ) | ||
|
|
||
| reader_output = self.reader( | ||
| input_ids=concat_inputs.input_ids[0 : self.config.reader_beam_size], | ||
| attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size], | ||
| token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size], | ||
| relevance_score=retrieved_logits, | ||
| block_mask=block_mask, | ||
| has_answers=has_answers, | ||
| start_positions=start_pos, | ||
| end_positions=end_pos, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We previously used
token_type_idsas theblock_mask, and now theblock_maskis computed inRealmForOpenQA.forward().