Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

should mBART-large-en-ro have decoder_start_token_id by default? #6156

Closed
sshleifer opened this issue Jul 30, 2020 · 7 comments
Closed

should mBART-large-en-ro have decoder_start_token_id by default? #6156

sshleifer opened this issue Jul 30, 2020 · 7 comments
Assignees
Labels
Help wanted Extra attention is needed, help appreciated translation machine translation utilities and models

Comments

@sshleifer
Copy link
Contributor

Hypothesis: since the argument prepend_bos is set to "False" in fairseq/examples/README.md, mbart-large-en-ro does not need decoder_start_token_id.

TODO:

  • create branch that deletes decoder_start_token_id. Setting it to None in the config might not be enough.
  • verify that decoder_start_token_id is in fact not being used by setting a breakpoint in generate.
  • run_eval.py on wmt-en-ro/test and see if BLEU is >= 26.46, the score with decoder_start_token_id=250020.
@sshleifer sshleifer added translation machine translation utilities and models Help wanted Extra attention is needed, help appreciated labels Jul 30, 2020
@sshleifer sshleifer added this to To do in Examples/seq2seq via automation Jul 30, 2020
@sshleifer sshleifer self-assigned this Jul 30, 2020
@KMFODA
Copy link
Contributor

KMFODA commented Aug 10, 2020

Hi @sshleifer, I'd like to contribute and help out here if still needed. My thinking is to remove decoder_start_token_id from run_eval.py and generation_utils.py and change the following code:

# create empty decoder_input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
decoder_start_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)

to:

        input_ids = torch.full(
            (effective_batch_size * num_beams, 1),
            250020,
            dtype=torch.long,
            device=next(self.parameters()).device,
        )

@sshleifer
Copy link
Contributor Author

sshleifer commented Aug 11, 2020

I dont think that change will do anything since decoder_start_token_id = 250020.

What I would do is change the 250020 to a bos_token_id (0, I think) or a pad_token_id (1) and see what the BLEU score is.

@KMFODA
Copy link
Contributor

KMFODA commented Aug 13, 2020

Ah yes that makes sense. I tried those two and the eos_token_id and got the following results:

ID BLEU Score
eos_token_id (2) 28.22
decoder_start_token_id (250020) 28.06
pad_token_id (1) 26.79
bos_token_id (0) 26.01

@sshleifer
Copy link
Contributor Author

Super interesting, thanks for running that. It seems like I should change decoder_start_token_id in the mbart-large-en-ro config to 2. Do you have opinions on mbart-large-cc25?

@KMFODA
Copy link
Contributor

KMFODA commented Aug 19, 2020

No problem! Yes I think configuring decoder_start_token_id to 2 is a good idea. Unfortunately, I'm getting the same issues you're getting with mbart-large-cc25 (output's in English not Romanian and missing the first word when I use bos_token_id or 250020 and gibberish with eos/pad_token_id) and don't understand why that's the case. I'll investigate and post any useful findings.

@sshleifer
Copy link
Contributor Author

sshleifer commented Aug 21, 2020

I think I fixed this another way in #6526
on master

python run_eval.py facebook/mbart-large-en-ro $ENRO_DIR/test.source eos_baseline_enro_test_generations.txt \
--reference_path $ENRO_DIR/test.target \
--score_path baseline_test_bleu_eos.json --bs 32 --task translation --fp16

=> {'bleu': 26.81}

python run_eval.py facebook/mbart-large-en-ro $ENRO_DIR/test.source \
eos_baseline_enro_test_generations.txt --reference_path $ENRO_DIR/test.target \
--score_path baseline_test_bleu_eos.json --bs 32 --task translation --fp16  \
--decoder_start_token_id 2

{'bleu': 11.57} (and takes 40 mins!)

in the original fairseq I get 26.83.

@sshleifer
Copy link
Contributor Author

Gunna close this since the score is now basically the same as fairseq. Thanks for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Help wanted Extra attention is needed, help appreciated translation machine translation utilities and models
Projects
Development

No branches or pull requests

2 participants