Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeeBERT (entropy-based early exiting for *BERT) #5477

Merged
merged 7 commits into from
Jul 8, 2020

Conversation

ji-xin
Copy link
Contributor

@ji-xin ji-xin commented Jul 2, 2020

Add DeeBERT (entropy-based early exiting for *BERT).
Paper: https://www.aclweb.org/anthology/2020.acl-main.204/
Based on its original repository: https://github.com/castorini/DeeBERT

@codecov
Copy link

codecov bot commented Jul 3, 2020

Codecov Report

Merging #5477 into master will increase coverage by 0.41%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #5477      +/-   ##
==========================================
+ Coverage   77.83%   78.25%   +0.41%     
==========================================
  Files         141      141              
  Lines       24634    24634              
==========================================
+ Hits        19175    19278     +103     
+ Misses       5459     5356     -103     
Impacted Files Coverage Δ
src/transformers/modeling_tf_roberta.py 43.47% <0.00%> (-49.57%) ⬇️
src/transformers/modeling_openai.py 81.09% <0.00%> (+1.37%) ⬆️
src/transformers/generation_tf_utils.py 86.71% <0.00%> (+1.50%) ⬆️
src/transformers/modeling_tf_openai.py 94.98% <0.00%> (+74.19%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 58cca47...f44de41. Read the comment docs.

@JetRunner JetRunner self-requested a review July 3, 2020 00:27
Copy link
Contributor

@JetRunner JetRunner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

some high-level comments here:

  1. how many iterations do your test run? You may want to reduce the max_iteration a little to make it faster.
  2. /examples/deebert/src/ seems to be better than /examples/deebert/deebert/


## Installation

First, install [pytorch](https://pytorch.org/) and the [transformer library](https://github.com/huggingface/transformers/blob/master/README.md)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary.

publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.acl-main.204",
pages = "2246--2251",
abstract = "Large-scale pre-trained language models such as BERT have brought significant improvements to NLP applications. However, they are also notorious for being slow in inference, which makes them difficult to deploy in real-time applications. We propose a simple but effective method, DeeBERT, to accelerate BERT inference. Our approach allows samples to exit earlier without passing through the entire model. Experiments show that DeeBERT is able to save up to {\textasciitilde}40{\%} inference time with minimal degradation in model quality. Further analyses show different behaviors in the BERT transformer layers and also reveal their redundancy. Our work provides new ideas to efficiently apply deep transformer-based models to downstream tasks. Code is available at https://github.com/castorini/DeeBERT.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to exclude the abstract.

@@ -0,0 +1,5 @@
boto3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need boto3?

Comment on lines 2 to 5
tensorboard
tensorboardX
scikit-learn
seqeval
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,5 @@
boto3
tensorboard
tensorboardX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorboardX can be replaced with pytorch built-in tensorboard. So no need for this dependency!

Copy link
Collaborator

@stefan-it stefan-it Jul 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of these are already specified in the top-level requirements from examples folder:

https://github.com/huggingface/transformers/blob/master/examples/requirements.txt

So there's no need to add it here (exept boto3, but it may not be needed here anyway)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model = model_class.from_pretrained(checkpoint)
if args.model_type == "bert":
model.bert.encoder.set_early_exit_entropy(args.early_exit_entropy)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Comment on lines 12 to 16
There are a few other packages to install. Assuming we are in `transformers/examples/deebert`, simply run

```
pip install -r requirements.txt
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment for requirements.txt


# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

we now have a util function. you can refactor these codes

# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

we now have a util function. you can refactor these codes

Comment on lines +32 to +29
#### train_deebert.sh

This is for fine-tuning DeeBERT models.

#### eval_deebert.sh

This is for evaluating each exit layer for fine-tuned DeeBERT models.

#### entropy_eval.sh

This is for evaluating fine-tuned DeeBERT models, given a number of different early exit entropy thresholds.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to elaborate a little on what variable is what in these scripts?

@stefan-it
Copy link
Collaborator

Btw: would be awesome so see a token classification example 😅

@ji-xin
Copy link
Contributor Author

ji-xin commented Jul 3, 2020

Hi @JetRunner, thanks for the review! I have updated according to your suggestions.

@ji-xin
Copy link
Contributor Author

ji-xin commented Jul 3, 2020

2 checks fail, however they don't seem relevant to my commits.

@ji-xin ji-xin requested a review from JetRunner July 3, 2020 19:48
Copy link
Contributor

@JetRunner JetRunner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ji-xin ! It looks much better now!

Please wait for the final approval of @LysandreJik

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, very clean!


class DeeBertEncoder(nn.Module):
def __init__(self, config):
super(DeeBertEncoder, self).__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) we don't need to specify that since we're not python 2 compatible

Suggested change
super(DeeBertEncoder, self).__init__()
super().__init__()

Comment on lines 27 to 31
def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those should already be defined in the DeeBertModel?

@ji-xin
Copy link
Contributor Author

ji-xin commented Jul 7, 2020

@LysandreJik Thanks for the comments and I've updated accordingly!

@JetRunner JetRunner merged commit cfbb982 into huggingface:master Jul 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants