Skip to content

Commit

Permalink
Leave LayoutLMv2Processor tests unchanged
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Oct 21, 2021
1 parent e57df29 commit 7d2f83a
Showing 1 changed file with 2 additions and 106 deletions.
108 changes: 2 additions & 106 deletions tests/test_processor_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,9 @@

from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available
from transformers.models.layoutlmv2 import (
LayoutLMv2Tokenizer,
LayoutLMv2TokenizerFast,
LayoutXLMTokenizer,
LayoutXLMTokenizerFast,
)
from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast
from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES
from transformers.testing_utils import (
require_pytesseract,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
)
from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow


if is_pytesseract_available():
Expand All @@ -43,11 +32,7 @@
from transformers import LayoutLMv2FeatureExtractor, LayoutLMv2Processor


SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")


@require_pytesseract
@require_sentencepiece
@require_tokenizers
class LayoutLMv2ProcessorTest(unittest.TestCase):
tokenizer_class = LayoutLMv2Tokenizer
Expand Down Expand Up @@ -149,96 +134,7 @@ def test_save_load_pretrained_additional_features(self):
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)


@require_pytesseract
@require_sentencepiece
@require_tokenizers
class LayoutXLMProcessorTest(unittest.TestCase):
tokenizer_class = LayoutXLMTokenizer
rust_tokenizer_class = LayoutXLMTokenizerFast

def setUp(self):
feature_extractor_map = {
"do_resize": True,
"size": 224,
"apply_ocr": True,
}

self.tmpdirname = tempfile.mkdtemp()
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(feature_extractor_map) + "\n")

def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(SAMPLE_SP, **kwargs)

def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
return self.rust_tokenizer_class.from_pretrained(SAMPLE_SP, **kwargs)

def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]

def get_feature_extractor(self, **kwargs):
return LayoutLMv2FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)

def tearDown(self):
shutil.rmtree(self.tmpdirname)

def test_save_load_pretrained_default(self):
feature_extractor = self.get_feature_extractor()
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
processor = LayoutLMv2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

processor.save_pretrained(self.tmpdirname)
processor = LayoutLMv2Processor.from_pretrained(self.tmpdirname, use_xlm=True)

self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, (LayoutXLMTokenizer, LayoutXLMTokenizerFast))

self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)

def test_save_load_pretrained_additional_features(self):
processor = LayoutLMv2Processor(feature_extractor=self.get_feature_extractor(), tokenizer=self.get_tokenizer())
processor.save_pretrained(self.tmpdirname)

# slow tokenizer
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)

processor = LayoutLMv2Processor.from_pretrained(
self.tmpdirname,
use_fast=False,
use_xlm=True,
bos_token="(BOS)",
eos_token="(EOS)",
do_resize=False,
size=30,
)

self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, LayoutXLMTokenizer)

self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)

# fast tokenizer
tokenizer_add_kwargs = self.get_rust_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)

processor = LayoutLMv2Processor.from_pretrained(
self.tmpdirname, use_xlm=True, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
)

self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, LayoutXLMTokenizerFast)

self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)


# different use cases tests
@require_sentencepiece
@require_torch
@require_pytesseract
class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
Expand Down

0 comments on commit 7d2f83a

Please sign in to comment.