Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support various BERT relative position embeddings (2nd) #8276

Merged
merged 21 commits into from
Nov 24, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions examples/question-answering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,81 @@ Larger batch size may improve the performance while costing more memory.
}
```

#### Fine-tuning BERT on SQuAD1.0 with relative position embeddings

The following examples show how to fine-tune BERT models with different relative position embeddings. The BERT model
`bert-base-uncased` was pre-trained with default absolute position embeddings. We provide the following pre-trained
models which were pre-trained on the same training data (BooksCorpus and English Wikipedia) as in the BERT model
training, but with different relative position embeddings.

* `zhiheng-huang/bert-base-uncased-embedding-relative-key`, trained from scratch with relative embedding proposed by
Shaw et al., [Self-Attention with Relative Position Representations](https://arxiv.org/abs/1803.02155)
* `zhiheng-huang/bert-base-uncased-embedding-relative-key-query`, trained from scratch with relative embedding method 4
in Huang et al. [Improve Transformer Models with Better Relative Position Embeddings](https://arxiv.org/abs/2009.13658)
* `zhiheng-huang/bert-large-uncased-whole-word-masking-embedding-relative-key-query`, fine-tuned from model
`bert-large-uncased-whole-word-masking` with 3 additional epochs with relative embedding method 4 in Huang et al.
[Improve Transformer Models with Better Relative Position Embeddings](https://arxiv.org/abs/2009.13658)


##### Base models fine-tuning

```bash
export SQUAD_DIR=/path/to/SQUAD
output_dir=relative_squad
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py \
--model_type bert \
--model_name_or_path zhiheng-huang/bert-base-uncased-embedding-relative-key-query \
--do_train \
--do_eval \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 512 \
--doc_stride 128 \
--output_dir ${output_dir} \
--per_gpu_eval_batch_size=60 \
--per_gpu_train_batch_size=6
```
Training with the above command leads to the following results. It boosts the BERT default from f1 score of 88.52 to 90.54.

```bash
'exact': 83.6802270577105, 'f1': 90.54772098174814
```

The change of `max_seq_length` from 512 to 384 in the above command leads to the f1 score of 90.34. Replacing the above
model `zhiheng-huang/bert-base-uncased-embedding-relative-key-query` with
`zhiheng-huang/bert-base-uncased-embedding-relative-key` leads to the f1 score of 89.51. The changing of 8 gpus to one
gpu training leads to the f1 score of 90.71.

##### Large models fine-tuning

```bash
export SQUAD_DIR=/path/to/SQUAD
output_dir=relative_squad
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_squad.py \
--model_type bert \
--model_name_or_path zhiheng-huang/bert-large-uncased-whole-word-masking-embedding-relative-key-query \
--do_train \
--do_eval \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 512 \
--doc_stride 128 \
--output_dir ${output_dir} \
--per_gpu_eval_batch_size=6 \
--per_gpu_train_batch_size=2 \
--gradient_accumulation_steps 3
```
Training with the above command leads to the f1 score of 93.52, which is slightly better than the f1 score of 93.15 for
`bert-large-uncased-whole-word-masking`.

## SQuAD with the Tensorflow Trainer

```bash
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/bert/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ class BertConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.

Examples::

Expand Down Expand Up @@ -123,6 +130,7 @@ def __init__(
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -140,3 +148,4 @@ def __init__(
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
28 changes: 26 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type

def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
Expand All @@ -195,10 +196,12 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Expand All @@ -222,6 +225,10 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
Expand Down Expand Up @@ -256,6 +263,23 @@ def forward(

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ class BertGenerationConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.

Examples::

Expand Down Expand Up @@ -87,6 +94,7 @@ def __init__(
bos_token_id=2,
eos_token_id=1,
gradient_checkpointing=False,
position_embedding_type="absolute",
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -103,3 +111,4 @@ def __init__(
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
7 changes: 7 additions & 0 deletions src/transformers/models/electra/configuration_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class ElectraConfig(PretrainedConfig):
Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.

The dropout ratio to be used after the projection and activation.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Position embedding type. :obj:`"absolute"`, BERT default absolute position embedding,
:obj:`"relative_key"`, position embedding in Shaw et al. Self-Attention with Relative Position
Representations, https://arxiv.org/abs/1803.02155, :obj:`"relative_key_query"`: Method 4 in Huang et al.
Improve Transformer Models with Better Relative Position Embeddings, https://arxiv.org/abs/2009.13658

Examples::

Expand Down Expand Up @@ -133,6 +138,7 @@ def __init__(
summary_activation="gelu",
summary_last_dropout=0.1,
pad_token_id=0,
position_embedding_type="absolute",
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -155,3 +161,4 @@ def __init__(
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout
self.position_embedding_type = position_embedding_type
28 changes: 26 additions & 2 deletions src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type

# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
Expand All @@ -183,10 +184,12 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Expand All @@ -211,6 +214,10 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
Expand Down Expand Up @@ -245,6 +252,23 @@ def forward(

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
Expand Down Expand Up @@ -180,6 +184,23 @@ def forward(

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/longformer/modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ class LongformerEmbeddings(nn.Module):
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""

# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
Expand All @@ -461,7 +460,6 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

# End copy
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
Expand All @@ -475,7 +473,6 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)

# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
if input_ids is not None:
input_shape = input_ids.size()
else:
Expand Down
Loading