diff --git a/docs/source/model_doc/realm.mdx b/docs/source/model_doc/realm.mdx index f96e322ebfa7..545b1e0a3bf8 100644 --- a/docs/source/model_doc/realm.mdx +++ b/docs/source/model_doc/realm.mdx @@ -81,4 +81,5 @@ This model was contributed by [qqaatw](https://huggingface.co/qqaatw). The origi ## RealmForOpenQA [[autodoc]] RealmForOpenQA + - block_embedding_to - forward \ No newline at end of file diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 0c953f1636bf..9284bfa70efc 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -48,6 +48,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict( [ ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), + ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)), ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 118916413863..cbbf96e7e4f8 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -836,13 +836,13 @@ def __init__(self, config): self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps) self.relu = nn.ReLU() - def forward(self, hidden_states, token_type_ids): + def forward(self, hidden_states, block_mask): def span_candidates(masks): """ Generate span candidates. Args: - masks: [num_retrievals, max_sequence_len] + masks: [num_retrievals, max_sequence_len] Returns: starts: [num_spans] ends: [num_spans] span_masks: [num_retrievals, num_spans] @@ -875,8 +875,7 @@ def mask_to_score(mask): hidden_states = self.dense_intermediate(hidden_states) # [reader_beam_size, max_sequence_len, span_hidden_size] start_projection, end_projection = hidden_states.chunk(2, dim=-1) - block_mask = token_type_ids.detach().clone() - block_mask[:, -1] = 0 + candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask) candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts) @@ -1543,6 +1542,7 @@ def forward( head_mask=None, inputs_embeds=None, relevance_score=None, + block_mask=None, start_positions=None, end_positions=None, has_answers=None, @@ -1552,12 +1552,15 @@ def forward( ): r""" relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*): - Relevance score, which must be specified if you want to compute the marginal log loss. - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Relevance score, which must be specified if you want to compute the logits and marginal log loss. + block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*): + The mask of the evidence block, which must be specified if you want to compute the logits and marginal log + loss. + start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. @@ -1570,8 +1573,8 @@ def forward( if relevance_score is None: raise ValueError("You have to specify `relevance_score` to calculate logits and loss.") - if token_type_ids is None: - raise ValueError("You have to specify `token_type_ids` to separate question block and evidence block.") + if block_mask is None: + raise ValueError("You have to specify `block_mask` to separate question block and evidence block.") if token_type_ids.size(1) < self.config.max_span_width: raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.") outputs = self.realm( @@ -1590,7 +1593,9 @@ def forward( sequence_output = outputs[0] # [reader_beam_size, num_candidates], [num_candidates], [num_candidates] - reader_logits, candidate_starts, candidate_ends = self.qa_outputs(sequence_output, token_type_ids) + reader_logits, candidate_starts, candidate_ends = self.qa_outputs( + sequence_output, block_mask[0 : self.config.reader_beam_size] + ) # [searcher_beam_size, 1] retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1) # [reader_beam_size, num_candidates] @@ -1737,11 +1742,21 @@ def __init__(self, config, retriever=None): self.post_init() @property - def beam_size(self): + def searcher_beam_size(self): if self.training: return self.config.searcher_beam_size return self.config.reader_beam_size + def block_embedding_to(self, device): + """Send `self.block_emb` to a specific device. + + Args: + device (`str` or `torch.device`): + The device to which `self.block_emb` will be sent. + """ + + self.block_emb = self.block_emb.to(device) + @add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length")) @replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1787,36 +1802,37 @@ def forward( question_outputs = self.embedder( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True ) - # [1, projection_size] question_projection = question_outputs[0] + + # CPU computation starts. # [1, block_emb_size] - batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection) + batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection.to(self.block_emb.device)) # [1, searcher_beam_size] - _, retrieved_block_ids = torch.topk(batch_scores, k=self.beam_size, dim=-1) + _, retrieved_block_ids = torch.topk(batch_scores, k=self.searcher_beam_size, dim=-1) # [searcher_beam_size] - # Must convert to cpu tensor for subsequent numpy operations - retrieved_block_ids = retrieved_block_ids.squeeze().cpu() + retrieved_block_ids = retrieved_block_ids.squeeze() + # [searcher_beam_size, projection_size] + retrieved_block_emb = torch.index_select(self.block_emb, dim=0, index=retrieved_block_ids) + # CPU computation ends. # Retrieve possible answers has_answers, start_pos, end_pos, concat_inputs = self.retriever( - retrieved_block_ids, input_ids, answer_ids, max_length=self.config.reader_seq_len + retrieved_block_ids.cpu(), input_ids, answer_ids, max_length=self.config.reader_seq_len ) + concat_inputs = concat_inputs.to(self.reader.device) + block_mask = concat_inputs.special_tokens_mask.type(torch.bool).to(device=self.reader.device) + block_mask.logical_not_().logical_and_(concat_inputs.token_type_ids.type(torch.bool)) + if has_answers is not None: has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device) start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device) end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device) - concat_inputs = concat_inputs.to(self.reader.device) - - # [searcher_beam_size, projection_size] - retrieved_block_emb = torch.index_select( - self.block_emb, dim=0, index=retrieved_block_ids.to(self.block_emb.device) - ) # [searcher_beam_size] retrieved_logits = torch.einsum( - "D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(question_projection.device) + "D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(self.reader.device) ) reader_output = self.reader( @@ -1824,6 +1840,7 @@ def forward( attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size], token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size], relevance_score=retrieved_logits, + block_mask=block_mask, has_answers=has_answers, start_positions=start_pos, end_positions=end_pos, diff --git a/src/transformers/models/realm/retrieval_realm.py b/src/transformers/models/realm/retrieval_realm.py index 20ae30861583..db6c8c7246be 100644 --- a/src/transformers/models/realm/retrieval_realm.py +++ b/src/transformers/models/realm/retrieval_realm.py @@ -20,9 +20,9 @@ import numpy as np from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer from ...utils import logging -from .tokenization_realm import RealmTokenizer _REALM_BLOCK_RECORDS_FILENAME = "block_records.npy" @@ -97,7 +97,9 @@ def __call__(self, retrieved_block_ids, question_input_ids, answer_ids, max_leng text.append(question) text_pair.append(retrieved_block.decode()) - concat_inputs = self.tokenizer(text, text_pair, padding=True, truncation=True, max_length=max_length) + concat_inputs = self.tokenizer( + text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length + ) concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors) if answer_ids is not None: @@ -115,7 +117,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) block_records = np.load(block_records_path, allow_pickle=True) - tokenizer = RealmTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) return cls(block_records, tokenizer) @@ -133,13 +135,15 @@ def block_has_answer(self, concat_inputs, answer_ids): max_answers = 0 for input_id in concat_inputs.input_ids: + input_id_list = input_id.tolist() + # Check answers between two [SEP] tokens + first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id) + second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id) + start_pos.append([]) end_pos.append([]) - input_id_list = input_id.tolist() - # Checking answers after the [SEP] token - sep_idx = input_id_list.index(self.tokenizer.sep_token_id) for answer in answer_ids: - for idx in range(sep_idx, len(input_id)): + for idx in range(first_sep_idx + 1, second_sep_idx): if answer[0] == input_id_list[idx]: if input_id_list[idx : idx + len(answer)] == answer: start_pos[-1].append(idx) @@ -158,5 +162,4 @@ def block_has_answer(self, concat_inputs, answer_ids): padded = [-1] * (max_answers - len(start_pos_)) start_pos_ += padded end_pos_ += padded - return has_answers, start_pos, end_pos diff --git a/tests/realm/test_modeling_realm.py b/tests/realm/test_modeling_realm.py index 99d09ac48f86..02eaa6556e9f 100644 --- a/tests/realm/test_modeling_realm.py +++ b/tests/realm/test_modeling_realm.py @@ -345,7 +345,7 @@ def test_model_various_embeddings(self): self.model_tester.create_and_check_embedder(*config_and_inputs) self.model_tester.create_and_check_encoder(*config_and_inputs) - def test_retriever(self): + def test_scorer(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_scorer(*config_and_inputs) @@ -408,6 +408,13 @@ def test_training(self): loss = model(**inputs).reader_output.loss loss.backward() + # Test model.block_embedding_to + device = torch.device("cpu") + model.block_embedding_to(device) + loss = model(**inputs).reader_output.loss + loss.backward() + self.assertEqual(model.block_emb.device.type, device.type) + @slow def test_embedder_from_pretrained(self): model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder") @@ -506,10 +513,15 @@ def test_inference_reader(self): concat_input_ids = torch.arange(10).view((2, 5)) concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64) + concat_block_mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 1, 1, 0]], dtype=torch.int64) relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32) output = model( - concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True + concat_input_ids, + token_type_ids=concat_token_type_ids, + relevance_score=relevance_score, + block_mask=concat_block_mask, + return_dict=True, ) block_idx_expected_shape = torch.Size(()) diff --git a/tests/realm/test_retrieval_realm.py b/tests/realm/test_retrieval_realm.py index 3ffefef16e25..939d98440049 100644 --- a/tests/realm/test_retrieval_realm.py +++ b/tests/realm/test_retrieval_realm.py @@ -98,6 +98,7 @@ def get_dummy_block_records(self): b"This is the third record", b"This is the fourth record", b"This is the fifth record", + b"This is a longer longer longer record", ], dtype=np.object, ) @@ -135,6 +136,7 @@ def test_retrieve(self): self.assertEqual(concat_inputs.input_ids.shape, (2, 10)) self.assertEqual(concat_inputs.attention_mask.shape, (2, 10)) self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10)) + self.assertEqual(concat_inputs.special_tokens_mask.shape, (2, 10)) self.assertEqual( tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]), ["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"], @@ -149,10 +151,10 @@ def test_block_has_answer(self): retriever = self.get_dummy_retriever() tokenizer = retriever.tokenizer - retrieved_block_ids = np.array([0, 3], dtype=np.long) + retrieved_block_ids = np.array([0, 3, 5], dtype=np.long) question_input_ids = tokenizer(["Test question"]).input_ids answer_ids = tokenizer( - ["the fourth"], + ["the fourth", "longer longer"], add_special_tokens=False, return_token_type_ids=False, return_attention_mask=False, @@ -163,9 +165,9 @@ def test_block_has_answer(self): retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np" ) - self.assertEqual([False, True], has_answers) - self.assertEqual([[-1], [6]], start_pos) - self.assertEqual([[-1], [7]], end_pos) + self.assertEqual([False, True, True], has_answers) + self.assertEqual([[-1, -1, -1], [6, -1, -1], [6, 7, 8]], start_pos) + self.assertEqual([[-1, -1, -1], [7, -1, -1], [7, 8, 9]], end_pos) def test_save_load_pretrained(self): retriever = self.get_dummy_retriever()