From e00a3d6ee08cb754d907ce8ab8cc14b5339f4e8c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Nov 2021 15:06:57 +0100 Subject: [PATCH 1/7] Fiixng slow pipeline tests --- tests/test_pipelines_audio_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipelines_audio_classification.py b/tests/test_pipelines_audio_classification.py index a1cfaafe6dde8e..1b0ad5d2cbe647 100644 --- a/tests/test_pipelines_audio_classification.py +++ b/tests/test_pipelines_audio_classification.py @@ -117,7 +117,7 @@ def test_large_model_pt(self): self.assertEqual( nested_simplify(output, decimals=4), [ - {"score": 0.9809, "label": "go"}, + {"score": 0.981, "label": "go"}, {"score": 0.0073, "label": "up"}, {"score": 0.0064, "label": "_unknown_"}, {"score": 0.0015, "label": "down"}, From c7ead4a7fef79694a3603068af4e56df0287d3c8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Nov 2021 16:20:37 +0100 Subject: [PATCH 2/7] Remove the image-segmentaiton override. --- src/transformers/pipelines/image_segmentation.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index 84a3e67ef6b671..fac8cddc6731ce 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -91,9 +91,6 @@ def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]: return super().__call__(*args, **kwargs) - def get_inference_context(self): - return torch.no_grad - def preprocess(self, image): image = load_image(image) target_size = torch.IntTensor([[image.height, image.width]]) From 443fcbaee6b9017aa7e35f96d5e07c40c9be0fc4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Nov 2021 16:24:17 +0100 Subject: [PATCH 3/7] Fixing clamping only in training. --- src/transformers/models/detr/modeling_detr.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index af650e75e1a6cc..5e95cc3f324e2d 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -648,9 +648,10 @@ def forward( hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) From f307ef69da657891c371743ef14377d8ca9da169 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Nov 2021 16:26:40 +0100 Subject: [PATCH 4/7] Wav2vec2. --- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 485983acd5ed6c..3b35991a0a3e22 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -989,7 +989,10 @@ def _conv_out_length(input_length, kernel_size, stride): return input_lengths def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + mask = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(mask).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( From dca3788560016fe34b30fa98680aeef5bf251225 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Nov 2021 16:32:03 +0100 Subject: [PATCH 5/7] Remove last mention of `no_grad`. --- .../pipelines/table_question_answering.py | 110 +++++++++--------- 1 file changed, 54 insertions(+), 56 deletions(-) diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py index a58e3aacbe9ac4..7697752b2bc899 100644 --- a/src/transformers/pipelines/table_question_answering.py +++ b/src/transformers/pipelines/table_question_answering.py @@ -93,76 +93,74 @@ def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, * ) def batch_inference(self, **inputs): - with torch.no_grad(): - return self.model(**inputs) + return self.model(**inputs) def sequential_inference(self, **inputs): """ Inference used for models that need to process sequences in a sequential fashion, like the SQA models which handle conversational query related to a table. """ - with torch.no_grad(): - all_logits = [] - all_aggregations = [] - prev_answers = None - batch_size = inputs["input_ids"].shape[0] - - input_ids = inputs["input_ids"].to(self.device) - attention_mask = inputs["attention_mask"].to(self.device) - token_type_ids = inputs["token_type_ids"].to(self.device) - token_type_ids_example = None - - for index in range(batch_size): - # If sequences have already been processed, the token type IDs will be created according to the previous - # answer. - if prev_answers is not None: - prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,) - model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,) - - token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) - for i in range(model_labels.shape[0]): - segment_id = token_type_ids_example[:, 0].tolist()[i] - col_id = token_type_ids_example[:, 1].tolist()[i] - 1 - row_id = token_type_ids_example[:, 2].tolist()[i] - 1 - - if row_id >= 0 and col_id >= 0 and segment_id == 1: - model_labels[i] = int(prev_answers[(col_id, row_id)]) - - token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device) - - input_ids_example = input_ids[index] - attention_mask_example = attention_mask[index] # shape (seq_len,) + all_logits = [] + all_aggregations = [] + prev_answers = None + batch_size = inputs["input_ids"].shape[0] + + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs["attention_mask"].to(self.device) + token_type_ids = inputs["token_type_ids"].to(self.device) + token_type_ids_example = None + + for index in range(batch_size): + # If sequences have already been processed, the token type IDs will be created according to the previous + # answer. + if prev_answers is not None: + prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,) + model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,) + token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) - outputs = self.model( - input_ids=input_ids_example.unsqueeze(0), - attention_mask=attention_mask_example.unsqueeze(0), - token_type_ids=token_type_ids_example.unsqueeze(0), - ) - logits = outputs.logits + for i in range(model_labels.shape[0]): + segment_id = token_type_ids_example[:, 0].tolist()[i] + col_id = token_type_ids_example[:, 1].tolist()[i] - 1 + row_id = token_type_ids_example[:, 2].tolist()[i] - 1 - if self.aggregate: - all_aggregations.append(outputs.logits_aggregation) + if row_id >= 0 and col_id >= 0 and segment_id == 1: + model_labels[i] = int(prev_answers[(col_id, row_id)]) - all_logits.append(logits) + token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device) - dist_per_token = torch.distributions.Bernoulli(logits=logits) - probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to( - dist_per_token.probs.device - ) + input_ids_example = input_ids[index] + attention_mask_example = attention_mask[index] # shape (seq_len,) + token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) + outputs = self.model( + input_ids=input_ids_example.unsqueeze(0), + attention_mask=attention_mask_example.unsqueeze(0), + token_type_ids=token_type_ids_example.unsqueeze(0), + ) + logits = outputs.logits - coords_to_probs = collections.defaultdict(list) - for i, p in enumerate(probabilities.squeeze().tolist()): - segment_id = token_type_ids_example[:, 0].tolist()[i] - col = token_type_ids_example[:, 1].tolist()[i] - 1 - row = token_type_ids_example[:, 2].tolist()[i] - 1 - if col >= 0 and row >= 0 and segment_id == 1: - coords_to_probs[(col, row)].append(p) + if self.aggregate: + all_aggregations.append(outputs.logits_aggregation) + + all_logits.append(logits) + + dist_per_token = torch.distributions.Bernoulli(logits=logits) + probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to( + dist_per_token.probs.device + ) + + coords_to_probs = collections.defaultdict(list) + for i, p in enumerate(probabilities.squeeze().tolist()): + segment_id = token_type_ids_example[:, 0].tolist()[i] + col = token_type_ids_example[:, 1].tolist()[i] - 1 + row = token_type_ids_example[:, 2].tolist()[i] - 1 + if col >= 0 and row >= 0 and segment_id == 1: + coords_to_probs[(col, row)].append(p) - prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs} + prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs} - logits_batch = torch.cat(tuple(all_logits), 0) + logits_batch = torch.cat(tuple(all_logits), 0) - return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0)) + return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0)) def __call__(self, *args, **kwargs): r""" From 0809fe30047e648cf11e18f13960228bddfeb5ad Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Nov 2021 16:53:43 +0100 Subject: [PATCH 6/7] Fixing copies. --- src/transformers/models/unispeech/modeling_unispeech.py | 5 ++++- .../models/unispeech_sat/modeling_unispeech_sat.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index b4a3423516275e..3592e9036d5696 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -947,7 +947,10 @@ def _conv_out_length(input_length, kernel_size, stride): return input_lengths def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + mask = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(mask).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index bd27b53edb991a..2361c248c6c768 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -948,7 +948,10 @@ def _conv_out_length(input_length, kernel_size, stride): return input_lengths def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + mask = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(mask).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( From bbb9b277c0967adbaa5e6c0a4870dbeec5d32d3d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Nov 2021 19:29:16 +0100 Subject: [PATCH 7/7] Rename. --- src/transformers/models/unispeech/modeling_unispeech.py | 4 ++-- .../models/unispeech_sat/modeling_unispeech_sat.py | 4 ++-- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 3592e9036d5696..a8a89c302b75d1 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -949,8 +949,8 @@ def _conv_out_length(input_length, kernel_size, stride): def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): # Effectively attention_mask.sum(-1), but not inplace to be able to run # on inference mode. - mask = attention_mask.cumsum(dim=-1)[:, -1] - output_lengths = self._get_feat_extract_output_lengths(mask).to(torch.long) + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 2361c248c6c768..c5f8243bf11524 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -950,8 +950,8 @@ def _conv_out_length(input_length, kernel_size, stride): def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): # Effectively attention_mask.sum(-1), but not inplace to be able to run # on inference mode. - mask = attention_mask.cumsum(dim=-1)[:, -1] - output_lengths = self._get_feat_extract_output_lengths(mask).to(torch.long) + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 3b35991a0a3e22..6548f245f0e842 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -991,8 +991,8 @@ def _conv_out_length(input_length, kernel_size, stride): def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): # Effectively attention_mask.sum(-1), but not inplace to be able to run # on inference mode. - mask = attention_mask.cumsum(dim=-1)[:, -1] - output_lengths = self._get_feat_extract_output_lengths(mask).to(torch.long) + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros(