-
Notifications
You must be signed in to change notification settings - Fork 6.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does BART support more than 1024 tokens in inference of summarization task? #1685
Comments
@JunhyunB Only inference with longer document won't work because the summarization model was finetuned on seqlen of What you can do, is finetune the model with longer seq_len on your custom training data. In fact, that is similar to what we do. We preatrained bart on For above, you would need to adjust positional embeddings by either:
I would recommend 2, but that might require slight code changes. (lemme know if you need some help with that) |
On the readme for CNN/DM, it says to use MAX_TOKENS=2048, but @ngoyal2707, you say it is 1024, and also here too #1474. Is the readme incorrect? |
You can reset the position embedding to new length (ex 2048) and copy 1024 from model (the second half will be random initialized, while the first half is trained)... this is a common trick in summarization. |
Thanks for the response yinhanliu. I wanted to know though, which hyperparameter setting was used to get the best results when fine-tuning on CNN/DM. Was it 1024 or 2048? |
@loganlebanoff we only used 1024. Never tried 2048. |
Change MAX_TOKENS=2048 --> 1024, as per yinhanliu in facebookresearch#1685
Thanks! I've created a pull request to fix the CNN/DM fine-tuning readme. |
max_tokens, max_sentences, tokens_per_sample are different args. max_sentences is bsz, max_tokens is maximum allowed tokens in a batch and tokens_per_sample is max seq length in one instance. Current readme instructions are correct |
Ok thanks, I understand the difference now |
@ngoyal2707 I want to increase the max sequence length to be 2048 as you said. Can you give some hint as to how to do this? I see that the size of the positional embedding matrix is 1026 (rather than 1024) in the pretrained BART.
and similarly, the size is 2050 for the model I will be finetuning.
Would I copy the over the parameters from [2 : 1026] to the second half, [1026 : 2050]? |
I only tried once on this and I kept [1026:2050] random.
1026 was because bos + source (1024) + eos.
…On Wed, Feb 26, 2020 at 5:49 PM Logan Thien Lebanoff < ***@***.***> wrote:
copy 512 from pretrained bart to first 512 of your 2048 positional
embedding.
@ngoyal2707 <https://github.com/ngoyal2707> I want to increase the max
sequence length to be 2048 as you said. Can you give some hint as to how to
do this? I see that the size of the positional embedding matrix is 1026
(rather than 1024) in the pretrained BART.
state['model']['encoder.embed_positions.weight'].shape
Out[37]: torch.Size([1026, 1024])
state['model']['encoder.embed_positions.weight']
Out[38]:
tensor([[-0.0043, -0.0042, 0.0029, ..., 0.0149, 0.0098, 0.0102],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0497, -0.2086, -0.1076, ..., -0.1564, -0.0135, 0.0566],
...,
[ 0.0027, 0.0022, -0.0051, ..., 0.0007, 0.0089, -0.0124],
[ 0.0046, -0.0024, 0.0026, ..., -0.0050, -0.0112, -0.0063],
[-0.0056, -0.0084, 0.0082, ..., -0.0017, -0.0039, 0.0105]],
dtype=torch.float16)
and similarly, the size is 2050 for the model I will be finetuning.
self.get_model().state_dict()['encoder.embed_positions.weight'].shape
Out[46]: torch.Size([2050, 1024])
Would I copy the over the parameters from [2 : 1026] to the second half,
[1026 : 2050]?
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#1685>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AJQ6TRMYTMJAB3MOFFLR3X3RE4L3JANCNFSM4KR2REBQ>
.
--
Best Regards,
Yinhan Liu
Graduate Student at the University of Texas at Austin
|
Thanks. I saw that the positions always start at 2, so I copied [2 : 1026] to [1026 : 2050]. I compared it to random initialization of the second half, and I got better scores on my specific application when copying vs random. Thanks again! |
@ngoyal2707 I'd like to finetune BART on quite a different domain where the average sequence length of input documents is about 8000 tokens. Does BART support the lengths in this order? If not, is there a work-around to handle these cases? |
Are the With below modification, I can start the training but I'm not convinced if that makes sense: state['model']['encoder.embed_positions.weight'] = torch.cat([
state['model']['encoder.embed_positions.weight'][:1025].clone(),
state['model']['encoder.embed_positions.weight'][1:].clone()
], 0) Is this at all related to setting |
@ngoyal2707 Hi, would you please point me to where in your code for finetuning BART you copy |
@loganlebanoff Would you please share what exact changes you made to finetune this model on new dataset with longer sequence? appreciate that. |
After this line: https://github.com/pytorch/fairseq/blob/411531734df8c7294e82c68e9d42177382f362ef/fairseq/trainer.py#L202 I added the following code:
|
Thanks for the reply @loganlebanoff . And also changing |
Right, yes I changed max_source_positions to 2048. I still used it on CNN/DM, but a different setup than doing regular summarization. For my setup, I got slightly better performance by copying the positional embeddings to the last 1024 rather than randomizing the last 1024 (for both settings, I used max_source_positions=2048). I took a look at https://github.com/pytorch/fairseq/blob/7a6519f84fed06947bbf161c7b66c9099bc4ce53/fairseq/utils.py#L191
|
Can you please show, how can we increase it to take more than 1024 input tokens. |
❓ Questions and Help
Does BART support more than 1024 tokens in inference of summarization task?
For the long text like novel, does BART use all of the input to generate summary?
or just use first 1024 tokens and ignore others?
Before asking:
What is your question?
Code
What have you tried?
What's your environment?
pip
, source):The text was updated successfully, but these errors were encountered: