Skip to content

Add new model RoFormer (use rotary position embedding )#11684

Merged
LysandreJik merged 56 commits intohuggingface:masterfrom
JunnYu:master
May 20, 2021
Merged

Add new model RoFormer (use rotary position embedding )#11684
LysandreJik merged 56 commits intohuggingface:masterfrom
JunnYu:master

Conversation

@JunnYu
Copy link
Copy Markdown
Contributor

@JunnYu JunnYu commented May 11, 2021

What does this PR do?

Add new model RoFormer

RoFormer: Enhanced Transformer with Rotary Position Embedding by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
The original code can be found here.

The abstract from the paper is the following:

Position encoding in transformer architecture provides supervision for dependency modeling between elements at
different positions in the sequence. We investigate various methods to encode positional information in
transformer-based language models and propose a novel implementation named Rotary Position Embedding(RoPE). The
proposed RoPE encodes absolute positional information with rotation matrix and naturally incorporates explicit relative
position dependency in self-attention formulation. Notably, RoPE comes with valuable properties such as flexibility of
being expand to any sequence lengths, decaying inter-token dependency with increasing relative distances, and
capability of equipping the linear self-attention with relative position encoding. As a result, the enhanced
transformer with rotary position embedding, or RoFormer, achieves superior performance in tasks with long texts. We
release the theoretical analysis along with some preliminary experiment results on Chinese data. The undergoing
experiment for English benchmark will soon be updated.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Copy Markdown
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition, thanks a lot for adding this!

I left a few comments below, specifically,

  • please add as many copied from statements as possible.
  • would be nice to refactor sinusoidal embeds in it's own module

JunnYu and others added 4 commits May 12, 2021 16:33
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
@JunnYu
Copy link
Copy Markdown
Contributor Author

JunnYu commented May 12, 2021

@patil-suraj I have updated some codes, please review again. Thanks~

@JunnYu JunnYu requested a review from patil-suraj May 12, 2021 15:40
Copy link
Copy Markdown
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this, the PR is in a good shape!

I left a few more comments below, specifically

  • the docstrings format should be fixed, this will make the build_doc tests pass
  • resolve the merge conflicts
  • and fix the style issues, we could do this by running make style and make quality

Let me know if you need any help with this :)

TFRobertaModel,
)

# Add modeling imports here
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be here



# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer
class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines +274 to +290
# sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
# sin [θ0,θ1,θ2......θd/2-1]-> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
sin_pos = torch.repeat_interleave(sin, 2, dim=-1)
# cos [θ0,θ1,θ2......θd/2-1]-> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = torch.repeat_interleave(cos, 2, dim=-1)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query_layer = torch.stack(
[-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1
).reshape_as(query_layer)
query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(
key_layer
)
key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool!

@JunnYu
Copy link
Copy Markdown
Contributor Author

JunnYu commented May 13, 2021

@patil-suraj

  • I fixed the docstrings format and the build_doc tests pass
  • I have resolved the merge conflicts
  • I have run make style and make quality

Thank you for reviewing on this PR. ∩▂∩

@JunnYu JunnYu requested a review from patil-suraj May 13, 2021 09:29
Copy link
Copy Markdown
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool, that was really quick @JunnYu !

Would be nice to add the fast tokenizer as well. Other than that looks good to me!

@LysandreJik the run_tests_torch seems to be timing out, not sure why, could you please take a look ?

The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be set to 2048 ? Would the model work with that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have set it to 1536.
The pre-training was done in multiple stages with changing batch size and maximum input sequence length in
order to adapt the model with various scenarios.

stage max_len batch_size train_step loss acc
1 512 256 200k 1.73 65.0%
2 1536 256 12.5k 1.61 66.8%
3 256 256 120k 1.75 64.6%
4 128 512 80k 1.83 63.4%
5 1536 256 10k 1.58 67.4%
6 512 512 30k 1.66 66.2%

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And roformer_chinese_base model is a word level model not char level model.

BERT WoBERT NEZHA RoFormer
tokenization level char word char word

return "".join(output)


class RoFormerTokenizer(BertTokenizer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try to not inherit from BertTokenizer & instead just copy & past all the functionality in here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
if sinusoidal_pos is not None:
# https://kexue.fm/archives/8265
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we maybe put all of this logic into a staticmethod function? that takes as in input sinusoidal_pos, query_layer, and key_layer? This would make it a bit more readable IMO and more importantly would allow us to easily test this layer

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work! It's in very good shape, thanks @JunnYu!

One aspect that is important is the ability to save/load tokenizers. Once this issue is solved, then it looks good to merge for me!

return "".join(output)


class RoFormerTokenizer(BertTokenizer):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
import jieba
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the second occurrence of an import jieba in the library, I guess the third time we'll refactor that with dummy objects @sgugger?

Comment on lines +86 to +99
def test_added_token_serializable(self):
pass

def test_save_pretrained(self):
pass

def test_pickle_tokenizer(self):
pass

def test_save_and_load_tokenizer(self):
pass

def test_encode_decode_with_spaces(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of an issue, serializing a tokenizer is important! I guess one option to enable this is to delete self.jieba from the tokenizer before serializing it, as it is stateless?

Comment on lines +392 to +394
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test supported? I don't think there's a position_embedding_type in the model

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik @sgugger RoFormerTokenizer is not identical to BertTokenizer when tokenizing Chinese chars.
So we need to create a new file for RoFormerTokenizer.

import jieba
from transformers import BertTokenizer,RoFormerTokenizer

jieba.lcut(zh_text,HMM=False)
# ['今天天气', '非常', '好', '!']

# difference
bert_tokenzier = BertTokenizer.from_pretrained("junnyu/roformer_chinese_base")
roformer_tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
zh_text = "今天天气非常好!"
bert_output = bert_tokenzier.tokenize(zh_text)
roformer_output = roformer_tokenizer.tokenize(zh_text)
print(bert_output) 
print(roformer_output) 
# ['今', '天', '天', '气', '非', '常', '好', '!']
# ['今', '天', '天', '气', '非常', '好', '!']

# same
en_text = "I love Beijing, Beijing is the capital of China!"
bert_output = bert_tokenzier.tokenize(en_text)
roformer_output = roformer_tokenizer.tokenize(en_text)
 # ['i', 'love', 'be', '##i', '##jing', ',', 'be', '##i', '##jing', 'is', 'the', 'capital', 'of', 'china', '!']
 # ['i', 'love', 'be', '##i', '##jing', ',', 'be', '##i', '##jing', 'is', 'the', 'capital', 'of', 'china', '!']
print(bert_output)
print(roformer_output)

In this case, 今天天气 and 非常 are meaningful words. We shouldn't split them into single char.
For 今天天气, this word do not exist in the vocabulary,we should use original BertTokenizer to tokenize them.
we got .
For 非常 , this word do exist in the vocabulary, we do not use original BertTokenizer to tokenize them.

@JunnYu JunnYu requested a review from LysandreJik May 15, 2021 06:34
Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minus the pickle issue, I think this looks great! Thanks for the efforts @JunnYu!

@patrickvonplaten and @sgugger could you take an additional look at this? Thank you

def test_alignement_methods(self):
pass

def test_pickle_tokenizer(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, you're doing something custom for this test because the PreTokenizer cannot be pickled.

This seems to put quite a large burden on the user were they to want to pickle their tokenizer. Wouldn't it be better to leverage to __get_state__ and __set_state__ similar to what we do for SentencePiece-based tokenizers? See an example here: https://github.com/huggingface/transformers/blob/master/src/transformers/models/albert/tokenization_albert.py#L182-L195

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have changed this. :)

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great apart from the issue with pickle! Thanks for adding the fast tokenzier!

from typing import List, Optional, Tuple

from ...utils import logging
from ..bert.tokenization_bert import BasicTokenizer, PreTrainedTokenizer, WordpieceTokenizer, load_vocab
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PreTrainedTokenizer should be imported from ...tokenization_utils

return outputs

@staticmethod
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to factor it out here!

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>


@require_torch
class RoFormerSelfAttentionRotaryPositionEmbeddingTest(unittest.TestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean now! Great job @JunnYu

You might have to run

make style

once to solve the check quality test ;-)

@JunnYu
Copy link
Copy Markdown
Contributor Author

JunnYu commented May 20, 2021

@patrickvonplaten i have done it,thanks;)

@patrickvonplaten
Copy link
Copy Markdown
Contributor

Tests are fine I think (PyTorch times out :-/).
Good to merge for me

@LysandreJik LysandreJik merged commit 206f06f into huggingface:master May 20, 2021
@LysandreJik
Copy link
Copy Markdown
Member

Thanks a lot @JunnYu, fantastic addition!

Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
…11684)

* add roformer

* Update docs/source/model_doc/roformer.rst

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update docs/source/model_doc/roformer.rst

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* update

* add TFRoFormerSinusoidalPositionalEmbedding and fix TFMarianSinusoidalPositionalEmbedding

* update docs

* make style and make quality

* roback

* unchanged

* rm copies from , this is a error in TFMarianSinusoidalPositionalEmbedding

* update Copyright year

* move # Add modeling imports here to the correct position

* max_position_embeddings can be set to 1536

* # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RoFormer

* # Copied from transformers.models.bert.modeling_bert.BertLayer.__init__ with Bert->RoFormer

* update tokenization_roformer

* make style

* add staticmethod apply_rotary_position_embeddings

* add TF staticmethod apply_rotary_position_embeddings

* update torch apply_rotary_position_embeddings

* fix tf apply_rotary_position_embeddings error

* make style

* add pytorch RoFormerSelfAttentionRotaryPositionEmbeddingTest

* add TF rotary_position_embeddings test

* update test_modeling_rofomer

* Update docs/source/model_doc/roformer.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/roformer/modeling_roformer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/roformer/modeling_roformer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/roformer/modeling_tf_roformer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refact roformer tokenizer

* add RoFormerTokenizerFast

* add RoFormerTokenizationTest

* add require_jieba

* update Copyright

* update tokenizer & add copy from

* add option rotary_value

* use rust jieba

* use rjieba

* use rust jieba

* fix test_alignement_methods

* slice normalized_string is too slow

* add config.embedding_size when embedding_size!=hidden_size

* fix pickle tokenizer

* Update docs/source/model_doc/roformer.rst

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style and make quality

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants