Skip to content

Commit

Permalink
[test] add test for --config_overrides (#14466)
Browse files Browse the repository at this point in the history
* add test for --config_overrides

* remove unneeded parts of the test
  • Loading branch information
stas00 committed Nov 22, 2021
1 parent e0e2da1 commit 11f65d4
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions examples/pytorch/language-modeling/run_clm.py
Expand Up @@ -324,6 +324,7 @@ def main():
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")

tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/language-modeling/run_mlm.py
Expand Up @@ -326,6 +326,7 @@ def main():
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")

tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/language-modeling/run_plm.py
Expand Up @@ -318,6 +318,7 @@ def main():
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")

tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
Expand Down
27 changes: 26 additions & 1 deletion examples/pytorch/test_examples.py
Expand Up @@ -25,7 +25,7 @@

from transformers import Wav2Vec2ForPreTraining
from transformers.file_utils import is_apex_available
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device


SRC_DIRS = [
Expand Down Expand Up @@ -157,6 +157,31 @@ def test_run_clm(self):
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 100)

def test_run_clm_config_overrides(self):
# test that config_overrides works, despite the misleading dumps of default un-updated
# config via tokenizer

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm.py
--model_type gpt2
--tokenizer_name gpt2
--train_file ./tests/fixtures/sample_text.txt
--output_dir {tmp_dir}
--config_overrides n_embd=10,n_head=2
""".split()

if torch_device != "cuda":
testargs.append("--no_cuda")

logger = run_clm.logger
with patch.object(sys, "argv", testargs):
with CaptureLogger(logger) as cl:
run_clm.main()

self.assertIn('"n_embd": 10', cl.out)
self.assertIn('"n_head": 2', cl.out)

def test_run_mlm(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
Expand Down

0 comments on commit 11f65d4

Please sign in to comment.