Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-873: PyTorch-Transformers update #941

Merged
merged 6 commits into from
Aug 7, 2019
Merged

Conversation

stefan-it
Copy link
Member

@stefan-it stefan-it commented Aug 1, 2019

Hi,

this PR updates the old pytorch-pretrained-BERT library to the latest version of pytorch-transformers to support various new Transformer-based architectures for embeddings.

A total of 7 (new/updated) embeddings can be used in Flair now:

from flair.embeddings import (
    BertEmbeddings,
    OpenAIGPTEmbeddings,
    OpenAIGPT2Embeddings,
    TransformerXLEmbeddings,
    XLNetEmbeddings,
    XLMEmbeddings,
    RoBERTaEmbeddings,
)

bert_embeddings = BertEmbeddings()
gpt1_embeddings = OpenAIGPTEmbeddings()
gpt2_embeddings = OpenAIGPT2Embeddings()
txl_embeddings = TransformerXLEmbeddings()
xlnet_embeddings = XLNetEmbeddings()
xlm_embeddings = XLMEmbeddings()
roberta_embeddings = RoBERTaEmbeddings()

Detailed benchmarks on the downsampled CoNLL-2003 NER dataset for English can be found in #873 . This PR is the first working attempt to include various new Transformer-based embeddings.

Unit tests can be executed with pytest --runslow tests. These unit tests for Transformer embeddings will take ~ 4 minutes using GPU.

flair/embeddings.py Outdated Show resolved Hide resolved
@stefan-it stefan-it changed the title GH-873: PyTorch-Transformers update WIP: GH-873: PyTorch-Transformers update Aug 2, 2019
The following Transformer-based architectures are now supported
via pytorch-transformers:

- BertEmbeddings (Updated API)
- OpenAIGPTEmbeddings (Updated API, various fixes)
- OpenAIGPT2Embeddings (New)
- TransformerXLEmbeddings (Updated API, tokenization fixes)
- XLNetEmbeddings (New)
- XLMEmbeddings (New)
- RoBERTaEmbeddings (New, via torch.hub module)

It also possible to use a scalar mix of specified layers from the
Transformer-based models. Scalar mix is proposed by Liu et al. (2019).
The scalar mix implementation is copied and slightly modified from
the allennlp repo (Apache 2.0 license).
flair/embeddings.py Outdated Show resolved Hide resolved
@stefan-it stefan-it changed the title WIP: GH-873: PyTorch-Transformers update GH-873: PyTorch-Transformers update Aug 4, 2019
try:
self.model = torch.hub.load("pytorch/fairseq", model)
except:
log_line(log)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_line needs to be imported, otherwise this fails.

from flair.training_utils import log_line

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :) I think I have to change my PyCharm default theme...

# method to avoid modifying the original state.
state = self.__dict__.copy()
# Remove the unpicklable entries.
state["model"] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line does nothing, since "model" is part of "_modules". However, the saved model is still huge, which is strange because in __setstate__ the RoBERTa model is re-loaded from torch.hub.

Copy link
Member Author

@stefan-it stefan-it Aug 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alanakbik I compared the model sizes for BERT and RoBERTa:

BERT (base): 429MB
RoBERTa (base): 487M

I will check the state["model"] now.

However, there's an upcoming PR in the PyTorch-Transformers repo that adds RoBERTa 🔥 So in near future it won't be necessary to use the torch.hub wrapper here :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, in this case don't spend too much time on this. It works already, so no need to fix something that will be fixed upstream :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I'll just leave it as it is now and will update the RoBERTaEmbeddings implementation whenever it is available in pytorch-transformers!

@alanakbik
Copy link
Collaborator

👍

1 similar comment
@yosipk
Copy link
Collaborator

yosipk commented Aug 7, 2019

👍

@yosipk yosipk merged commit 4db491d into master Aug 7, 2019
@yosipk yosipk deleted the GH-873-pytorch-transformers branch August 7, 2019 09:35
@alanakbik
Copy link
Collaborator

Awesome - thank you @stefan-it!!

@Hellisotherpeople
Copy link

Yes I love this!!!!!

@MarcioPorto MarcioPorto mentioned this pull request Aug 9, 2019
@berfubuyukoz
Copy link

You are great! @stefan-it Thank you for your generosity!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants