Skip to content

Commit

Permalink
Merge pull request #220 from datmo/hotfix
Browse files Browse the repository at this point in the history
hot fix for circle-ci test and adding initial as an option
  • Loading branch information
asampat3090 committed Jul 2, 2018
2 parents d0a1228 + e436963 commit f9033e6
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 30 deletions.
42 changes: 25 additions & 17 deletions datmo/cli/command/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import print_function

import os
import sys
import shlex
import platform
from datetime import datetime
Expand Down Expand Up @@ -167,8 +168,7 @@ def __get_core_task(self):
datmo.core.entity.task.Task
core task object fo the task
"""
task_controller = TaskController()
return task_controller.get(self.id)
return self._core_task

def __get_core_snapshot(self):
"""Returns the latest core snapshot object for id
Expand Down Expand Up @@ -275,7 +275,6 @@ def run(self, **kwargs):
self.cli_helper.echo(__("info", "cli.task.run"))
# Create input dictionaries
snapshot_dict = {}

# Environment
if kwargs.get("environment_id", None) or kwargs.get(
"environment_paths", None):
Expand All @@ -295,7 +294,13 @@ def run(self, **kwargs):
task_dict['command_list'] = kwargs['cmd']

# Run task and return Task object result
return self.task_run_helper(task_dict, snapshot_dict, "cli.task.run")
task_obj = self.task_run_helper(task_dict, snapshot_dict, "cli.task.run")
if not task_obj:
return False
# Creating the run object
run_obj = RunObject(task_obj)
return run_obj


@Helper.notify_no_project_found
def ls(self, **kwargs):
Expand Down Expand Up @@ -352,19 +357,25 @@ def rerun(self, **kwargs):
self.task_controller = TaskController()
# Get task id
task_id = kwargs.get("id", None)
initial = kwargs.get("initial", False)
self.cli_helper.echo(__("info", "cli.task.rerun", task_id))
# Create the task_obj
task_obj = self.task_controller.get(task_id)
# Create the run obj
run_obj = RunObject(task_obj)

environment_id = run_obj.environment_id
command = run_obj.command
snapshot_id = run_obj.core_snapshot_id
command = task_obj.command_list
snapshot_id = run_obj.core_snapshot_id if not initial else run_obj.before_snapshot_id

# Checkout to the core snapshot id before rerunning the task
self.snapshot_controller = SnapshotController()
checkout_success = self.snapshot_controller.checkout(snapshot_id)
try:
checkout_success = self.snapshot_controller.checkout(snapshot_id)
except Exception:
self.cli_helper.echo(__("error", "cli.snapshot.checkout.failure"))
sys.exit(1)

if checkout_success:
self.cli_helper.echo(
__("info", "cli.snapshot.checkout.success", snapshot_id))
Expand All @@ -373,19 +384,16 @@ def rerun(self, **kwargs):
# Create input dictionary for the new task
snapshot_dict = {}
snapshot_dict["environment_id"] = environment_id

task_dict = {
"ports": task_obj.ports,
"interactive": task_obj.interactive,
"mem_limit": task_obj.mem_limit,
"command_list": command
}
if not isinstance(command, list):
if platform.system() == "Windows":
task_dict['command'] = command
elif isinstance(command, basestring):
task_dict['command_list'] = shlex.split(command)
else:
task_dict['command_list'] = command

# Run task and return Task object result
return self.task_run_helper(task_dict, snapshot_dict, "cli.task.run")
new_task_obj = self.task_run_helper(task_dict, snapshot_dict, "cli.task.run")
if not new_task_obj:
return False
# Creating the run object
new_run_obj = RunObject(new_task_obj)
return new_run_obj
50 changes: 38 additions & 12 deletions datmo/cli/command/tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,18 @@ def test_run(self):
result = self.run_command.execute()
time.sleep(1)
assert result
assert isinstance(result, CoreTask)
assert isinstance(result, RunObject)
assert result.logs
assert "accuracy" in result.logs
assert result.results
assert result.results == {"accuracy": "0.45"}
assert result.status == "SUCCESS"
assert result.start_time
assert result.end_time
assert result.duration
assert result.core_snapshot_id
assert result.core_snapshot_id == result.after_snapshot_id
assert result.environment_id

# teardown
self.task_command.parse(["task", "stop", "--all"])
Expand Down Expand Up @@ -190,12 +196,19 @@ def test_run_string_command(self):
# test proper execution of run command
result = self.run_command.execute()
assert result
assert isinstance(result, CoreTask)
assert isinstance(result, RunObject)
assert result.logs
assert "accuracy" in result.logs
assert result.results
assert result.results == {"accuracy": "0.45"}
assert result.status == "SUCCESS"
assert result.start_time
assert result.end_time
assert result.duration
assert result.core_snapshot_id
assert result.core_snapshot_id == result.after_snapshot_id
assert result.environment_id
assert result

# teardown
self.task_command.parse(["task", "stop", "--all"])
Expand Down Expand Up @@ -282,7 +295,7 @@ def test_run_notebook(self):
# test proper execution of run command
result = self.run_command.execute()
assert result
assert isinstance(result, CoreTask)
assert isinstance(result, RunObject)
assert result.logs
assert "Currently running servers" in result.logs
assert result.status == "SUCCESS"
Expand Down Expand Up @@ -431,23 +444,36 @@ def test_rerun(self):
])

# test proper execution of run command
task_obj = self.run_command.execute()
run_id = task_obj.id
# Test success rerun
run_obj = self.run_command.execute()
run_id = run_obj.id
# 1. Test success rerun
self.run_command.parse(
["rerun", run_id])
result = self.run_command.execute()
assert result
assert isinstance(result, CoreTask)
assert result.command == task_obj.command
assert result.status == "SUCCESS"
assert result.logs
result_run_obj = self.run_command.execute()
assert result_run_obj
assert isinstance(result_run_obj, RunObject)
assert result_run_obj.command == run_obj.command
assert result_run_obj.status == "SUCCESS"
assert result_run_obj.logs
assert result_run_obj.before_snapshot_id == run_obj.after_snapshot_id

# 2. Test success rerun
self.run_command.parse(
["rerun", "--initial", run_id])
result_run_obj = self.run_command.execute()
assert result_run_obj
assert isinstance(result_run_obj, RunObject)
assert result_run_obj.command == run_obj.command
assert result_run_obj.status == "SUCCESS"
assert result_run_obj.logs
assert result_run_obj.before_snapshot_id == run_obj.before_snapshot_id

# teardown
self.task_command.parse(["task", "stop", "--all"])
# test when all is passed to stop all
task_stop_command = self.task_command.execute()

@pytest_docker_environment_failed_instantiation(test_datmo_dir)
def test_rerun_invalid_arg(self):
self.__set_variables()
exception_thrown = False
Expand Down
5 changes: 5 additions & 0 deletions datmo/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def get_datmo_parser():
# Rerun
rerun_parser = subparsers.add_parser("rerun", help="To rerun an experiment")
rerun_parser.add_argument("id", help="run id to be rerun")
rerun_parser.add_argument(
"--initial",
dest="initial",
action="store_true",
help="boolean if you want to rerun the experiment with the state at the beginning of the run")

# Session
session_parser = subparsers.add_parser("session", help="session module")
Expand Down
3 changes: 2 additions & 1 deletion datmo/core/entity/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def __str__(self):
if self.results:
table_data.append(["Results", "-> " + str(self.results)])
final_str = final_str + format_table(table_data)
final_str = final_str + "\n" + " " + self.command + "\n" + "\n"
if self.command:
final_str = final_str + "\n" + " " + self.command + "\n" + "\n"
return final_str

def __repr__(self):
Expand Down
2 changes: 2 additions & 0 deletions datmo/core/util/lang/en.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@
"Error while stopping all tasks",
"cli.snapshot.create.task.args":
"Error due to passing excluded args while creating snapshot from task: %s",
"cli.snapshot.checkout.failure":
"Error while checking out to a snapshot due to unstaged changes",
"util.misc_functions.get_filehash":
"Filepath does not point to a valid file: %s",
"util.misc_functions.mutually_exclusive":
Expand Down

0 comments on commit f9033e6

Please sign in to comment.