Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feed forward chunking others #6365

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified src/transformers/configuration_reformer.py
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(self, **kwargs):
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forwar", 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, thanks for adding this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the docstring from Reformer to this file and delete the corresponding docstring / config variable from reformer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it's aleady done - never mind


# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)
Expand Down
19 changes: 14 additions & 5 deletions src/transformers/modeling_albert.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices


logger = logging.getLogger(__name__)
Expand All @@ -69,6 +69,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model."""
try:
import re

import numpy as np
import tensorflow as tf
except ImportError:
Expand Down Expand Up @@ -286,6 +287,8 @@ def __init__(self, config):
super().__init__()

self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = AlbertAttention(config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
Expand All @@ -297,14 +300,20 @@ def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
):
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = self.ffn(attention_output[0])
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
ffn_output = self.dropout(ffn_output)

ffn_output = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[0],
)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])

return (hidden_states,) + attention_output[1:] # add attentions if we output them

def ff_chunk(self, attention_output):
ffn_output = self.ffn(attention_output)
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
return ffn_output


class AlbertLayerGroup(nn.Module):
def __init__(self, config):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def forward(
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights

layer_output = apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
return outputs
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/modeling_distilbert.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -208,6 +213,8 @@ class FFN(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = nn.Dropout(p=config.dropout)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
Expand All @@ -216,6 +223,9 @@ def __init__(self, config):
self.activation = gelu if config.activation == "gelu" else nn.ReLU()

def forward(self, input):
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

def ff_chunk(self, input):
x = self.lin1(input)
x = self.activation(x)
x = self.lin2(x)
Expand Down
19 changes: 16 additions & 3 deletions src/transformers/modeling_longformer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
TokenClassifierOutput,
)
from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -685,6 +690,8 @@ def __init__(self, config, layer_id=0):
self.attention = LongformerAttention(config, layer_id)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1

def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
Expand All @@ -693,11 +700,17 @@ def forward(
attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights

intermediate_output = self.intermediate(attn_output)
layer_output = self.output(intermediate_output, attn_output)
layer_output = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attn_output
)
outputs = (layer_output,) + outputs
return outputs

def ff_chunk(self, attn_output):
intermediate_output = self.intermediate(attn_output)
layer_output = self.output(intermediate_output, attn_output)
return layer_output


class LongformerEncoder(nn.Module):
def __init__(self, config):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_reformer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ def __init__(self, config):

def forward(self, attention_output):
return apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output,
self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output,
)

def forward_chunk(self, hidden_states):
Expand Down Expand Up @@ -1730,7 +1730,7 @@ def __init__(self, config):
self.decoder.bias = self.bias

def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)

def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/modeling_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ def prune_layer(


def apply_chunking_to_forward(
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for changing that. @LysandreJik - as you said this is the better order of the arguments and should be fine in terms of breaking backward compatibility

) -> torch.Tensor:
"""
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
Expand All @@ -1457,12 +1457,12 @@ def apply_chunking_to_forward(
directly applying :obj:`forward_fn` to :obj:`input_tensors`.

Args:
forward_fn (:obj:`Callable[..., torch.Tensor]`):
The forward function of the model.
chunk_size (:obj:`int`):
The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (:obj:`int`):
The dimension over which the :obj:`input_tensors` should be chunked.
forward_fn (:obj:`Callable[..., torch.Tensor]`):
The forward function of the model.
input_tensors (:obj:`Tuple[torch.Tensor]`):
The input tensors of ``forward_fn`` which will be chunked.
Returns:
Expand All @@ -1478,7 +1478,7 @@ def forward_chunk(self, hidden_states):

# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
"""

assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/modeling_xlm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
PreTrainedModel,
SequenceSummary,
SQuADHead,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
Expand Down Expand Up @@ -212,8 +213,13 @@ def __init__(self, in_dim, dim_hidden, out_dim, config):
self.lin1 = nn.Linear(in_dim, dim_hidden)
self.lin2 = nn.Linear(dim_hidden, out_dim)
self.act = gelu if config.gelu_activation else F.relu
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1

def forward(self, input):
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

def ff_chunk(self, input):
x = self.lin1(input)
x = self.act(x)
x = self.lin2(x)
Expand Down
21 changes: 18 additions & 3 deletions src/transformers/modeling_xlnet.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
from .modeling_utils import (
PoolerAnswerClass,
PoolerEndLogits,
PoolerStartLogits,
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -495,6 +502,8 @@ def __init__(self, config):
self.rel_attn = XLNetRelativeAttention(config)
self.ff = XLNetFeedForward(config)
self.dropout = nn.Dropout(config.dropout)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1

def forward(
self,
Expand Down Expand Up @@ -524,12 +533,18 @@ def forward(
output_h, output_g = outputs[:2]

if output_g is not None:
output_g = self.ff(output_g)
output_h = self.ff(output_h)
output_g = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_g
)
output_h = apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_h)

outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
return outputs

def ff_chunk(self, output_x):
output_x = self.ff(output_x)
return output_x


class XLNetPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
Expand Down
7 changes: 3 additions & 4 deletions tests/test_modeling_bert.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
if is_torch_available():
from transformers import (
BertConfig,
BertModel,
BertLMHeadModel,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertForMultipleChoice,
BertLMHeadModel,
BertModel,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST

Expand Down Expand Up @@ -370,7 +370,6 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_chunking = True

def setUp(self):
self.model_tester = BertModelTester(self)
Expand Down
8 changes: 2 additions & 6 deletions tests/test_modeling_common.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@


if is_torch_available():
import torch
import numpy as np
import torch

from transformers import (
AdaptiveEmbedding,
PretrainedConfig,
PreTrainedModel,
BertModel,
BertConfig,
BertModel,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
Expand Down Expand Up @@ -65,7 +65,6 @@ class ModelTesterMixin:
test_resize_embeddings = True
test_head_masking = True
test_missing_keys = True
test_chunking = False
is_encoder_decoder = False

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
Expand Down Expand Up @@ -552,9 +551,6 @@ def check_hidden_states_output(inputs_dict, config, model_class):

def test_feed_forward_chunking(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_chunking:
return

for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,6 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True

def prepare_kwargs(self):
return {
Expand Down Expand Up @@ -616,7 +615,6 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True

def prepare_kwargs(self):
return {
Expand Down