diff --git a/datmo/cli/command/task.py b/datmo/cli/command/task.py index 72bd179c..23668593 100644 --- a/datmo/cli/command/task.py +++ b/datmo/cli/command/task.py @@ -25,12 +25,13 @@ def __init__(self, home, cli_helper): is the environment port available during a run. """) # run.add_argument("--data", nargs="*", dest="data", type=str, help="Path for data to be used during the Task") - run.add_argument("--env-def", dest="environment_definition_filepath", default="", + run.add_argument("--env-def", dest="environment_definition_filepath", default=None, nargs="?", type=str, help="Pass in the Dockerfile with which you want to build the environment") run.add_argument("--interactive", dest="interactive", action="store_true", help="Run the environment in interactive mode (keeps STDIN open)") - run.add_argument("cmd", nargs="?", default=None) + run.add_argument("--cmd", "-c", dest="cmd", default=None, nargs="?", + help="Pass in the command to be run inside container") # Task list arguments ls = subcommand_parsers.add_parser("ls", help="List tasks") @@ -47,15 +48,14 @@ def run(self, **kwargs): self.cli_helper.echo(__("info", "cli.task.run")) # Create input dictionaries - snapshot_dict = { - "environment_definition_filepath": + snapshot_dict = {} + if kwargs['environment_definition_filepath']: + snapshot_dict["environment_definition_filepath"] =\ kwargs['environment_definition_filepath'] - } - if not isinstance(kwargs['cmd'], list): if platform.system() == "Windows": kwargs['cmd'] = kwargs['cmd'] - else: + elif type(kwargs['cmd']) is str: kwargs['cmd'] = shlex.split(kwargs['cmd']) task_dict = { diff --git a/datmo/cli/command/test/test_task.py b/datmo/cli/command/test/test_task.py index 5e7ad208..03b7ad62 100644 --- a/datmo/cli/command/test/test_task.py +++ b/datmo/cli/command/test/test_task.py @@ -13,7 +13,6 @@ # import builtins as __builtin__ import os -import shutil import tempfile import platform from io import open @@ -86,7 +85,7 @@ def test_datmo_task_run_should_fail2(self): self.task.parse([ "task", "run", - test_command + "--cmd", test_command ]) result = self.task.execute() assert not result @@ -108,7 +107,7 @@ def test_datmo_task_run(self): "--ports", test_ports, "--env-def", test_dockerfile, "--interactive", - test_command + "--cmd", test_command ]) # test for desired side effects @@ -128,6 +127,27 @@ def test_datmo_task_run(self): assert result.results == {"accuracy": "0.45"} assert result.status == "SUCCESS" + def test_datmo_task_run_notebook(self): + self.__set_variables() + # Test success case + test_command = ["jupyter", "notebook", "list"] + test_ports = "8888:8888" + + self.task.parse([ + "task", + "run", + "--ports", test_ports, + "--cmd", test_command + ]) + + # test proper execution of task run command + result = self.task.execute() + assert result + assert isinstance(result, CoreTask) + assert result.logs + assert "Currently running servers" in result.logs + assert result.status == "SUCCESS" + def test_task_run_invalid_arg(self): self.__set_variables() exception_thrown = False @@ -182,7 +202,7 @@ def test_task_stop(self): "--ports", test_ports, "--env-def", test_dockerfile, "--interactive", - test_command + "--cmd", test_command ]) test_task_obj = self.task.execute() diff --git a/datmo/core/controller/environment/environment.py b/datmo/core/controller/environment/environment.py index 21ee745b..4b26e33e 100644 --- a/datmo/core/controller/environment/environment.py +++ b/datmo/core/controller/environment/environment.py @@ -70,11 +70,10 @@ def create(self, dictionary): create_dict = { "model_id": self.model.id, } - create_dict["driver_type"] = self.environment_driver.type create_dict["language"] = dictionary.get("language", None) - if "definition_filepath" in dictionary: + if "definition_filepath" in dictionary and dictionary['definition_filepath']: original_definition_filepath = dictionary['definition_filepath'] # Split up the given path and save definition filename definition_path, definition_filename = \ diff --git a/datmo/core/controller/snapshot.py b/datmo/core/controller/snapshot.py index f3ffe2d7..5fbdaad2 100644 --- a/datmo/core/controller/snapshot.py +++ b/datmo/core/controller/snapshot.py @@ -159,7 +159,6 @@ def create(self, incoming_dictionary): raise RequiredArgumentMissing(__("error", "controller.snapshot.create.arg", "message")) - # Code setup self._code_setup(incoming_dictionary, create_dict) @@ -306,7 +305,6 @@ def _env_setup(self, incoming_dictionary, create_dict): create_dict : dict dictionary for creating the Snapshot entity """ - language = incoming_dictionary.get("language", None) if "environment_id" in incoming_dictionary: create_dict['environment_id'] = incoming_dictionary['environment_id'] diff --git a/datmo/core/controller/task.py b/datmo/core/controller/task.py index d29412d2..50c4c065 100644 --- a/datmo/core/controller/task.py +++ b/datmo/core/controller/task.py @@ -227,7 +227,6 @@ def run(self, task_id, snapshot_dict=None, task_dict=None): raise TaskRunException(__("error", "controller.task.run", task_dirpath)) - # Create the before snapshot prior to execution before_snapshot_dict = snapshot_dict.copy() before_snapshot_dict['message'] = "autogenerated snapshot created before task %s is run" % task_obj.id