diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 701f212d9fef..e75e4c9719ad 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -18,7 +18,7 @@ from collections.abc import Sequence import torch -from torch import _softmax_backward_data, nn +from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN @@ -31,12 +31,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import softmax_backward_data from ...utils import logging from .configuration_deberta import DebertaConfig logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "DebertaConfig" _TOKENIZER_FOR_DOC = "DebertaTokenizer" _CHECKPOINT_FOR_DOC = "microsoft/deberta-base" @@ -115,7 +115,7 @@ def forward(self, input, mask, dim): @staticmethod def backward(self, grad_output): (output,) = self.saved_tensors - inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) return inputGrad, None, None @staticmethod diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 903b153111f3..108f08e4704a 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -19,7 +19,7 @@ import numpy as np import torch -from torch import _softmax_backward_data, nn +from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from ...activations import ACT2FN @@ -32,6 +32,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import softmax_backward_data from ...utils import logging from .configuration_deberta_v2 import DebertaV2Config @@ -116,7 +117,7 @@ def forward(self, input, mask, dim): @staticmethod def backward(self, grad_output): (output,) = self.saved_tensors - inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) return inputGrad, None, None @staticmethod diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index af7dcba4b9a5..7443a67bcc8c 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -22,7 +22,7 @@ import numpy as np import torch import torch.utils.checkpoint -from torch import _softmax_backward_data, nn +from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from transformers.deepspeed import is_deepspeed_zero3_enabled @@ -31,14 +31,13 @@ from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import torch_int_div +from ...pytorch_utils import softmax_backward_data, torch_int_div from ...utils import logging from .configuration_sew_d import SEWDConfig logger = logging.get_logger(__name__) - _HIDDEN_STATES_START_POSITION = 1 @@ -545,7 +544,7 @@ def forward(self, input, mask, dim): @staticmethod def backward(self, grad_output): (output,) = self.saved_tensors - inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) return inputGrad, None, None @staticmethod diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index b41f438d9c3a..ee0c94bd9c70 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -14,18 +14,34 @@ import torch from packaging import version +from torch import _softmax_backward_data from .utils import logging logger = logging.get_logger(__name__) +is_torch_less_than_1_8 = version.parse(torch.__version__) < version.parse("1.8.0") +is_torch_less_than_1_11 = version.parse(torch.__version__) < version.parse("1.11") + def torch_int_div(tensor1, tensor2): """ A function that performs integer division across different versions of PyTorch. """ - if version.parse(torch.__version__) < version.parse("1.8.0"): + if is_torch_less_than_1_8: return tensor1 // tensor2 else: return torch.div(tensor1, tensor2, rounding_mode="floor") + + +def softmax_backward_data(parent, grad_output, output, dim, self): + """ + A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according + to the torch version detected. + """ + + if is_torch_less_than_1_11: + return _softmax_backward_data(grad_output, output, parent.dim, self) + else: + return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)