Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing slow pipeline tests #14260

Merged
merged 7 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,10 @@ def forward(
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)

if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
Comment on lines +651 to +654
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 IS that ok to remove at inference time ?

Copy link
Contributor

@stas00 stas00 Nov 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory yes. In practice, it depends on how the model was pre-trained.

The model weights don't change during inference, so we don't need to keep things in check all the time.

However if the pre-trained model's weights lead to an overflow in a single iteration during training, as is the case with some mt5 models under mixed-precision then this can occur just as well during inference.

This is primarily an issue with pre-trained on bf16 models fine-tuned/inferenced on fp16 (mixed or non-mixed precision).

If a model was pretrained with fp16/mixed precision it's pretty sure the clamping won't be needed.

To give you a more intelligent answer it'd require running some tests with the actual DETR models and checking their activations magnitudes at the point you're asking about, which should be pretty trivial, using https://huggingface.co/transformers/debugging.html#underflow-and-overflow-detection, which can be plugged into HF Trainer and the examples with just a single cl arg --debug underflow_overflow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I think this code was just badly copy pasted, so I'm more in favor of disabling this hack for training (as it is done now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, if everyone is favorable, then let's do this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I think this code was just badly copy pasted, so I'm more in favor of disabling this hack for training (as it is done now)

you must have meant for inference, right Patrick?


outputs = (hidden_states,)

Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/unispeech/modeling_unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,10 @@ def _conv_out_length(input_length, kernel_size, stride):
return input_lengths

def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]

attention_mask = torch.zeros(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,10 @@ def _conv_out_length(input_length, kernel_size, stride):
return input_lengths

def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]

attention_mask = torch.zeros(
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,10 @@ def _conv_out_length(input_length, kernel_size, stride):
return input_lengths

def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
batch_size = attention_mask.shape[0]

attention_mask = torch.zeros(
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/pipelines/image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:

return super().__call__(*args, **kwargs)

def get_inference_context(self):
return torch.no_grad

def preprocess(self, image):
image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])
Expand Down
110 changes: 54 additions & 56 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,76 +93,74 @@ def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, *
)

def batch_inference(self, **inputs):
with torch.no_grad():
return self.model(**inputs)
return self.model(**inputs)

def sequential_inference(self, **inputs):
"""
Inference used for models that need to process sequences in a sequential fashion, like the SQA models which
handle conversational query related to a table.
"""
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

all_logits = []
all_aggregations = []
prev_answers = None
batch_size = inputs["input_ids"].shape[0]

input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
token_type_ids = inputs["token_type_ids"].to(self.device)
token_type_ids_example = None

for index in range(batch_size):
# If sequences have already been processed, the token type IDs will be created according to the previous
# answer.
if prev_answers is not None:
prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,)

token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
for i in range(model_labels.shape[0]):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col_id = token_type_ids_example[:, 1].tolist()[i] - 1
row_id = token_type_ids_example[:, 2].tolist()[i] - 1

if row_id >= 0 and col_id >= 0 and segment_id == 1:
model_labels[i] = int(prev_answers[(col_id, row_id)])

token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device)

input_ids_example = input_ids[index]
attention_mask_example = attention_mask[index] # shape (seq_len,)
all_logits = []
all_aggregations = []
prev_answers = None
batch_size = inputs["input_ids"].shape[0]

input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
token_type_ids = inputs["token_type_ids"].to(self.device)
token_type_ids_example = None

for index in range(batch_size):
# If sequences have already been processed, the token type IDs will be created according to the previous
# answer.
if prev_answers is not None:
prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,)

token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
outputs = self.model(
input_ids=input_ids_example.unsqueeze(0),
attention_mask=attention_mask_example.unsqueeze(0),
token_type_ids=token_type_ids_example.unsqueeze(0),
)
logits = outputs.logits
for i in range(model_labels.shape[0]):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col_id = token_type_ids_example[:, 1].tolist()[i] - 1
row_id = token_type_ids_example[:, 2].tolist()[i] - 1

if self.aggregate:
all_aggregations.append(outputs.logits_aggregation)
if row_id >= 0 and col_id >= 0 and segment_id == 1:
model_labels[i] = int(prev_answers[(col_id, row_id)])

all_logits.append(logits)
token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device)

dist_per_token = torch.distributions.Bernoulli(logits=logits)
probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(
dist_per_token.probs.device
)
input_ids_example = input_ids[index]
attention_mask_example = attention_mask[index] # shape (seq_len,)
token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
outputs = self.model(
input_ids=input_ids_example.unsqueeze(0),
attention_mask=attention_mask_example.unsqueeze(0),
token_type_ids=token_type_ids_example.unsqueeze(0),
)
logits = outputs.logits

coords_to_probs = collections.defaultdict(list)
for i, p in enumerate(probabilities.squeeze().tolist()):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col = token_type_ids_example[:, 1].tolist()[i] - 1
row = token_type_ids_example[:, 2].tolist()[i] - 1
if col >= 0 and row >= 0 and segment_id == 1:
coords_to_probs[(col, row)].append(p)
if self.aggregate:
all_aggregations.append(outputs.logits_aggregation)

all_logits.append(logits)

dist_per_token = torch.distributions.Bernoulli(logits=logits)
probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(
dist_per_token.probs.device
)

coords_to_probs = collections.defaultdict(list)
for i, p in enumerate(probabilities.squeeze().tolist()):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col = token_type_ids_example[:, 1].tolist()[i] - 1
row = token_type_ids_example[:, 2].tolist()[i] - 1
if col >= 0 and row >= 0 and segment_id == 1:
coords_to_probs[(col, row)].append(p)

prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}
prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}

logits_batch = torch.cat(tuple(all_logits), 0)
logits_batch = torch.cat(tuple(all_logits), 0)

return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0))
return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0))

def __call__(self, *args, **kwargs):
r"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pipelines_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_large_model_pt(self):
self.assertEqual(
nested_simplify(output, decimals=4),
[
{"score": 0.9809, "label": "go"},
{"score": 0.981, "label": "go"},
{"score": 0.0073, "label": "up"},
{"score": 0.0064, "label": "_unknown_"},
{"score": 0.0015, "label": "down"},
Expand Down