Skip to content

Commit

Permalink
[Marian Conversion] Fix eos_token_id conversion in conversion script (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored Nov 8, 2021
1 parent c016dbd commit b48faae
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/transformers/models/marian/convert_marian_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def check_marian_cfg_assumptions(marian_cfg):


class OpusState:
def __init__(self, source_dir):
def __init__(self, source_dir, eos_token_id=0):
npz_path = find_model_file(source_dir)
self.state_dict = np.load(npz_path)
cfg = load_config_from_state_dict(self.state_dict)
Expand Down Expand Up @@ -492,7 +492,8 @@ def __init__(self, source_dir):
d_model=cfg["dim-emb"],
activation_function=cfg["transformer-aan-activation"],
pad_token_id=self.pad_token_id,
eos_token_id=0,
eos_token_id=eos_token_id,
forced_eos_token_id=eos_token_id,
bos_token_id=0,
max_position_embeddings=cfg["dim-emb"],
scale_embedding=True,
Expand Down Expand Up @@ -595,7 +596,11 @@ def convert(source_dir: Path, dest_dir):
tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
tokenizer.save_pretrained(dest_dir)

opus_state = OpusState(source_dir)
# retrieve EOS token and set correctly
tokenizer_has_eos_token_id = hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None
eos_token_id = tokenizer.eos_token_id if tokenizer_has_eos_token_id else 0

opus_state = OpusState(source_dir, eos_token_id=eos_token_id)
if opus_state.cfg["vocab_size"] != len(tokenizer.encoder):
raise ValueError(
f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
Expand Down

1 comment on commit b48faae

@vackosar
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this fix!

Please sign in to comment.