Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Test parlai/core/script.py #3117

Merged
merged 3 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion parlai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.core.script import superscript_main as main
from parlai.core.script import superscript_main


def main():
# indirect call so that console_entry (see setup.py) doesn't use the return
# value of the final command as the return code. This lets us call
# superscript_main as a function in other places (test_script.py).
superscript_main()


if __name__ == '__main__':
main()
5 changes: 2 additions & 3 deletions parlai/core/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ class _SubcommandParser(ParlaiParser):
def __init__(self, **kwargs):
kwargs['add_parlai_args'] = False
kwargs['add_model_args'] = False
if 'description' not in kwargs:
kwargs['description'] = None
assert 'description' in kwargs, 'Must supply description'
return super().__init__(**kwargs)

def parse_known_args(self, args=None, namespace=None, nohelp=False):
Expand Down Expand Up @@ -269,4 +268,4 @@ def superscript_main(args=None):
elif cmd == 'help' or cmd is None:
parser.print_help()
elif cmd is not None:
SCRIPT_REGISTRY[cmd].klass._run_from_parser_and_opt(opt, parser)
return SCRIPT_REGISTRY[cmd].klass._run_from_parser_and_opt(opt, parser)
99 changes: 99 additions & 0 deletions tests/test_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
import unittest
from unittest.mock import patch
from parlai.core.params import ParlaiParser
import parlai.core.script as script
import parlai.utils.testing as testing_utils


@script.register_script("test_script")
class _TestScript(script.ParlaiScript):
@classmethod
def setup_args(cls):
parser = ParlaiParser(True, False, description='My Description')
parser.add_argument('--foo', default='defaultvalue')
parser.add_argument('--bar', default='sneaky', hidden=True)
return parser

def run(self):
return self.opt


@script.register_script("hidden_script", hidden=True)
class _HiddenScript(_TestScript):
pass


@script.register_script("no_setup_args")
class _NoSetupArgsScript(script.ParlaiScript):
pass


class TestScriptRegistry(unittest.TestCase):
def test_setup_script(self):
script.setup_script_registry()
assert 'train_model' in script.SCRIPT_REGISTRY

def test_register_script(self):
assert 'test_script' in script.SCRIPT_REGISTRY

def test_main_kwargs(self):
opt = _TestScript.main(foo='test')
assert opt.get('foo') == 'test'
assert opt.get('bar') == 'sneaky'

def test_main_args(self):
opt = _TestScript.main('--foo', 'test')
assert opt.get('foo') == 'test'
assert opt.get('bar') == 'sneaky'

def test_main_noargs(self):
with patch.object(sys, 'argv', ['test_script.py']): # argv[0] doesn't matter
opt = _TestScript.main()
assert opt.get('foo') == 'defaultvalue'
assert opt.get('bar') == 'sneaky'

def test_help(self):
helptext = _TestScript.help()
assert 'My Description' in helptext
assert '--foo' in helptext
assert '--bar' not in helptext

with testing_utils.capture_output() as output:
with self.assertRaises(SystemExit):
_TestScript.main('--help')
assert '--foo' in output.getvalue()
assert '--bar' not in output.getvalue()

with testing_utils.capture_output() as output:
with self.assertRaises(SystemExit):
_TestScript.main('--helpall')
assert '--foo' in output.getvalue()
assert '--bar' in output.getvalue()


class TestSuperCommand(unittest.TestCase):
def test_supercommand(self):
opt = script.superscript_main(args=['test_script', '--foo', 'test'])
assert opt.get('foo') == 'test'

def test_no_setup_args(self):
with self.assertRaises(NotImplementedError):
script.superscript_main(args=['no_setup_args'])

def test_help(self):
with testing_utils.capture_output() as output:
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
script.superscript_main(args=['help'])
assert 'test_script' in output.getvalue()
assert 'hidden_script' not in output.getvalue()

with testing_utils.capture_output() as output:
script.superscript_main(args=['helpall'])
assert 'test_script' in output.getvalue()
assert 'hidden_script' in output.getvalue()