Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feed forward chunking #6024

Merged
merged 14 commits into from
Aug 11, 2020
7 changes: 0 additions & 7 deletions src/transformers/configuration_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ class ReformerConfig(PretrainedConfig):
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
eos_token_id (:obj:`int`, optional, defaults to 2):
The token id for the <EOS> token.
feed_forward_size (:obj:`int`, optional, defaults to 512):
Expand Down Expand Up @@ -147,7 +142,6 @@ def __init__(
axial_pos_shape=[64, 64],
axial_pos_embds_dim=[64, 192],
chunk_size_lm_head=0,
chunk_size_feed_forward=0,
eos_token_id=2,
feed_forward_size=512,
hash_seed=None,
Expand Down Expand Up @@ -202,5 +196,4 @@ def __init__(
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.chunk_size_feed_forward = chunk_size_feed_forward
self.attn_layers = attn_layers
6 changes: 6 additions & 0 deletions src/transformers/configuration_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class PretrainedConfig(object):
2.
xla_device (:obj:`bool`, `optional`):
A flag to indicate if TPU are available or not.
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of :obj:`0` means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .

Parameters for sequence generation
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
Expand Down Expand Up @@ -160,6 +165,7 @@ def __init__(self, **kwargs):
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)

# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
Expand Down
20 changes: 17 additions & 3 deletions src/transformers/modeling_bert.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,6 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
"""
try:
import re

import numpy as np
import tensorflow as tf
except ImportError:
Expand Down Expand Up @@ -376,6 +382,8 @@ def forward(self, hidden_states, input_tensor):
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
if self.is_decoder:
Expand Down Expand Up @@ -410,11 +418,17 @@ def forward(
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights

intermediate_output = self.intermediate(attention_output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To solve the problem, my suggestions would be to wrap these two calls in a function forward_chunk which is part of this class (def forward_chunk(self, ....)) and call apply_chunking_to_forward(self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output,) here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think I quite follow what you mean here. Which two calls do you want to wrap?
Did you mean to have a forward_chunk function in the BertLayer class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I fixed it based on your input - looks ok to me now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, that's exactly what I meant :-)

layer_output = self.output(intermediate_output, attention_output)
layer_output = apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
)
Comment on lines +421 to +423
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very much a nitpick here, for future PRs probably, but this looks a lot like the gradient checkpointing method from PyTorch. This method takes the callable (the forward) method as first positional argument and I think it makes sense to have it this way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do this globally in the new PR where I add the chunking for other models. Let me know if you have concerns with that.

outputs = (layer_output,) + outputs
return outputs

def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output


class BertEncoder(nn.Module):
def __init__(self, config):
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_chunking = True

def setUp(self):
self.model_tester = BertModelTester(self)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ModelTesterMixin:
test_resize_embeddings = True
test_head_masking = True
test_missing_keys = True
test_chunking = False
is_encoder_decoder = False

def _prepare_for_class(self, inputs_dict, model_class):
Expand Down Expand Up @@ -519,6 +520,29 @@ def check_hidden_states_output(inputs_dict, config, model_class):

check_hidden_states_output(inputs_dict, config, model_class)

def test_feed_forward_chunking(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_chunking:
return

for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()

hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]

torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()

hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))

def test_resize_tokens_embeddings(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
Expand Down
24 changes: 2 additions & 22 deletions tests/test_modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,24 +291,6 @@ def create_and_check_reformer_layer_dropout_seed(
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
)

def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]

config.chunk_size_lm_head = 1
config.chunk_size_feed_forward = 1

torch.manual_seed(0)
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()

hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))

def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
if not self.is_training:
return
Expand Down Expand Up @@ -517,10 +499,6 @@ def test_reformer_layer_training_dropout(self):
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=True)
self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, is_decoder=False)

def test_reformer_chunking_forward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs)

def test_reformer_chunking_backward_equality(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs)
Expand Down Expand Up @@ -577,6 +555,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True

def prepare_kwargs(self):
return {
Expand Down Expand Up @@ -637,6 +616,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
test_pruning = False
test_headmasking = False
test_torchscript = False
test_chunking = True

def prepare_kwargs(self):
return {
Expand Down