Skip to content

Commit

Permalink
Fix marian tokenizer save pretrained (#5043)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer committed Jun 16, 2020
1 parent d5477ba commit 3d495c6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
12 changes: 7 additions & 5 deletions src/transformers/tokenization_marian.py
Expand Up @@ -40,9 +40,9 @@ class MarianTokenizer(PreTrainedTokenizer):

def __init__(
self,
vocab=None,
source_spm=None,
target_spm=None,
vocab,
source_spm,
target_spm,
source_lang=None,
target_lang=None,
unk_token="<unk>",
Expand All @@ -59,6 +59,7 @@ def __init__(
pad_token=pad_token,
**kwargs,
)
assert Path(source_spm).exists(), f"cannot find spm source {source_spm}"
self.encoder = load_json(vocab)
if self.unk_token not in self.encoder:
raise KeyError("<unk> token must be in vocab")
Expand Down Expand Up @@ -179,10 +180,11 @@ def save_vocabulary(self, save_directory: str) -> Tuple[str]:
assert save_dir.is_dir(), f"{save_directory} should be a directory"
save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])

for f in self.spm_files:
for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
dest_path = save_dir / Path(f).name
if not dest_path.exists():
copyfile(f, save_dir / Path(f).name)
copyfile(f, save_dir / orig)

return tuple(save_dir / f for f in self.vocab_files_names)

def get_vocab(self) -> Dict:
Expand Down
9 changes: 7 additions & 2 deletions tests/test_tokenization_marian.py
Expand Up @@ -15,6 +15,7 @@


import os
import tempfile
import unittest
from pathlib import Path
from shutil import copyfile
Expand All @@ -23,7 +24,6 @@
from transformers.tokenization_utils import BatchEncoding

from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow


SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
Expand Down Expand Up @@ -60,10 +60,15 @@ def get_input_output_texts(self, tokenizer):
"This is a test",
)

@slow
def test_tokenizer_equivalence_en_de(self):
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
self.assertIsInstance(batch, BatchEncoding)
expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0])

save_dir = tempfile.mkdtemp()
en_de_tokenizer.save_pretrained(save_dir)
contents = [x.name for x in Path(save_dir).glob("*")]
self.assertIn("source.spm", contents)
MarianTokenizer.from_pretrained(save_dir)

0 comments on commit 3d495c6

Please sign in to comment.