Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import Sequence

import torch
from torch import _softmax_backward_data, nn
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
Expand All @@ -31,12 +31,12 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import softmax_backward_data
from ...utils import logging
from .configuration_deberta import DebertaConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DebertaConfig"
_TOKENIZER_FOR_DOC = "DebertaTokenizer"
_CHECKPOINT_FOR_DOC = "microsoft/deberta-base"
Expand Down Expand Up @@ -115,7 +115,7 @@ def forward(self, input, mask, dim):
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np
import torch
from torch import _softmax_backward_data, nn
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss

from ...activations import ACT2FN
Expand All @@ -32,6 +32,7 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import softmax_backward_data
from ...utils import logging
from .configuration_deberta_v2 import DebertaV2Config

Expand Down Expand Up @@ -116,7 +117,7 @@ def forward(self, input, mask, dim):
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None

@staticmethod
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np
import torch
import torch.utils.checkpoint
from torch import _softmax_backward_data, nn
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm

from transformers.deepspeed import is_deepspeed_zero3_enabled
Expand All @@ -31,14 +31,13 @@
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...pytorch_utils import softmax_backward_data, torch_int_div
from ...utils import logging
from .configuration_sew_d import SEWDConfig


logger = logging.get_logger(__name__)


_HIDDEN_STATES_START_POSITION = 1


Expand Down Expand Up @@ -545,7 +544,7 @@ def forward(self, input, mask, dim):
@staticmethod
def backward(self, grad_output):
(output,) = self.saved_tensors
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
return inputGrad, None, None

@staticmethod
Expand Down
18 changes: 17 additions & 1 deletion src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,34 @@

import torch
from packaging import version
from torch import _softmax_backward_data

from .utils import logging


logger = logging.get_logger(__name__)

is_torch_less_than_1_8 = version.parse(torch.__version__) < version.parse("1.8.0")
is_torch_less_than_1_11 = version.parse(torch.__version__) < version.parse("1.11")
Comment on lines +24 to +25
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch version cannot change during runtime, so this is harmless



def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if version.parse(torch.__version__) < version.parse("1.8.0"):
if is_torch_less_than_1_8:
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")


def softmax_backward_data(parent, grad_output, output, dim, self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self comes from the signature of the PyTorch function which is identical

"""
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
to the torch version detected.
"""

if is_torch_less_than_1_11:
return _softmax_backward_data(grad_output, output, parent.dim, self)
else:
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)