Skip to content

Commit

Permalink
Revert "Added support for other features for already supported models (
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun committed Dec 8, 2021
1 parent 0c70f14 commit 0f4e39c
Show file tree
Hide file tree
Showing 17 changed files with 384 additions and 931 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/albert/configuration_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
184 changes: 19 additions & 165 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
""" BART model configuration """
import warnings
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Mapping

from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...onnx import OnnxConfigWithPast
from ...utils import logging


Expand Down Expand Up @@ -182,174 +180,30 @@ def __init__(
)


class BartOnnxConfig(OnnxSeq2SeqConfigWithPast):
class BartOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)

if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
# TODO: figure this case out.
common_inputs = OrderedDict(
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_inputs = OrderedDict(
return OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)

return common_inputs

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs

def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:

if self.task in ["default", "seq2seq-lm"]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)

if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)

common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)

common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"

for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)

# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))

elif self.task == "causal-lm":
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch

batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)

common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
else:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

return common_inputs

def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
4 changes: 4 additions & 0 deletions src/transformers/models/bert/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
66 changes: 22 additions & 44 deletions src/transformers/models/gpt2/configuration_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# limitations under the License.
""" OpenAI GPT-2 configuration """
from collections import OrderedDict
from typing import Any, List, Mapping, Optional
from typing import Any, Mapping, Optional

from transformers import PreTrainedTokenizer, TensorType, is_torch_available

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...onnx import OnnxConfigWithPast
from ...utils import logging


Expand Down Expand Up @@ -194,36 +194,29 @@ def __init__(


class GPT2OnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
common_inputs = OrderedDict({"input_ids": {0: "batch"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
for i in range(self._config.n_layer * 2):
common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"}

common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}

return common_inputs

@property
def num_layers(self) -> int:
return self._config.n_layer
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._config.n_layer * 2):
common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"}

@property
def num_attention_heads(self) -> int:
return self._config.n_head
return common_outputs

return common_outputs

def generate_dummy_inputs(
self,
Expand All @@ -233,9 +226,7 @@ def generate_dummy_inputs(
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)

# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
Expand All @@ -247,27 +238,14 @@ def generate_dummy_inputs(
else:
import torch

batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
batch = common_inputs["input_ids"].shape[0]
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
(
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
)
for _ in range(self._config.n_layer)
]

ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)

return ordered_inputs

@property
def default_onnx_opset(self) -> int:
return 13
Loading

0 comments on commit 0f4e39c

Please sign in to comment.