Skip to content

Commit

Permalink
Merge 56576d5 into 56e5e93
Browse files Browse the repository at this point in the history
  • Loading branch information
shabazpatel committed Jun 27, 2018
2 parents 56e5e93 + 56576d5 commit 258fed3
Show file tree
Hide file tree
Showing 32 changed files with 528 additions and 505 deletions.
2 changes: 1 addition & 1 deletion datmo/VERSION
@@ -1 +1 @@
0.0.14-dev
0.0.15-dev
45 changes: 45 additions & 0 deletions datmo/cli/command/base.py
Expand Up @@ -4,6 +4,7 @@
from datmo.core.util.i18n import get as __
from datmo.core.util.exceptions import ClassMethodNotFound
from datmo.cli.parser import get_datmo_parser
from datmo.core.controller.task import TaskController
from datmo.core.util.logger import DatmoLogger
from datmo.core.util.misc_functions import parameterized

Expand Down Expand Up @@ -98,6 +99,50 @@ def execute(self):
method_result = method(**command_args)
return method_result

def task_run_helper(self, task_dict, snapshot_dict, error_identifier):
"""
Run task with given parameters and provide error identifier
Parameters
----------
task_dict : dict
input task dictionary for task run controller
snapshot_dict : dict
input snapshot dictionary for task run controller
error_identifier : str
identifier to print error
Returns
-------
Task or False
the Task object which completed its run with updated parameters.
returns False if an error occurs
"""
self.task_controller = TaskController()
task_obj = self.task_controller.create()

updated_task_obj = task_obj
# Pass in the task
status = "NOT STARTED"
try:
updated_task_obj = self.task_controller.run(
task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
status = "SUCCESS"
except Exception as e:
status = "FAILED"
self.logger.error("%s %s" % (e, task_dict))
self.cli_helper.echo("%s" % e)
self.cli_helper.echo(__("error", error_identifier, task_obj.id))
return False
finally:
self.cli_helper.echo(__("info", "cli.task.run.stop"))
self.task_controller.stop(
task_id=updated_task_obj.id, status=status)
self.cli_helper.echo(
__("info", "cli.task.run.complete", updated_task_obj.id))

return updated_task_obj


@parameterized
def usage_docs(description):
Expand Down
22 changes: 2 additions & 20 deletions datmo/cli/command/run.py
Expand Up @@ -239,9 +239,6 @@ def __init__(self, cli_helper):
@Helper.notify_no_project_found
def run(self, **kwargs):
self.cli_helper.echo(__("info", "cli.task.run"))
# Create controllers
self.task_controller = TaskController()
self.snapshot_controller = SnapshotController()
# Create input dictionaries
snapshot_dict = {}

Expand All @@ -263,27 +260,12 @@ def run(self, **kwargs):
else:
task_dict['command_list'] = kwargs['cmd']

# Create the task object
task_obj = self.task_controller.create()
try:
# Pass in the task to run
updated_task_obj = self.task_controller.run(
task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
except Exception as e:
self.logger.error("%s %s" % (e, task_dict))
self.cli_helper.echo("%s" % e)
self.cli_helper.echo(__("error", "cli.task.run", task_obj.id))
return False

self.cli_helper.echo(
__("info", "cli.task.run.complete", updated_task_obj.id))
return updated_task_obj
# Run task and return Task object result
return self.task_run_helper(task_dict, snapshot_dict, "cli.task.run")

@Helper.notify_no_project_found
def ls(self, **kwargs):
# Create controllers
self.task_controller = TaskController()
self.snapshot_controller = SnapshotController()
session_id = kwargs.get('session_id',
self.task_controller.current_session.id)
print_format = kwargs.get('format', "table")
Expand Down
24 changes: 2 additions & 22 deletions datmo/cli/command/task.py
Expand Up @@ -29,7 +29,6 @@ def task(self):
@Helper.notify_environment_active(TaskController)
@Helper.notify_no_project_found
def run(self, **kwargs):
self.task_controller = TaskController()
self.cli_helper.echo(__("info", "cli.task.run"))
# Create input dictionaries
snapshot_dict = {}
Expand All @@ -52,27 +51,8 @@ def run(self, **kwargs):
else:
task_dict['command_list'] = kwargs['cmd']

# Create the task object
task_obj = self.task_controller.create()

updated_task_obj = task_obj
try:
# Pass in the task
updated_task_obj = self.task_controller.run(
task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
except Exception as e:
self.logger.error("%s %s" % (e, task_dict))
self.cli_helper.echo("%s" % e)
self.cli_helper.echo(__("error", "cli.task.run", task_obj.id))
return False
finally:
self.cli_helper.echo(
__("info", "cli.task.run.stop"))
self.task_controller.stop(updated_task_obj.id)
self.cli_helper.echo(
__("info", "cli.task.run.complete", updated_task_obj.id))

return updated_task_obj
# Run task and return Task object result
return self.task_run_helper(task_dict, snapshot_dict, "cli.task.run")

@Helper.notify_no_project_found
def ls(self, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion datmo/cli/command/tests/test_datmo.py
Expand Up @@ -13,7 +13,6 @@
# import builtins as __builtin__

import os
import shutil
import tempfile
import platform

Expand Down
14 changes: 7 additions & 7 deletions datmo/cli/command/tests/test_run.py
Expand Up @@ -323,7 +323,7 @@ def test_run_ls(self):
self.run_command.parse(["ls"])
run_objs = self.run_command.execute()
assert run_objs
assert run_objs[0].status == 'SUCCESS'
assert run_objs[0].status == "SUCCESS"

test_session_id = 'test_session_id'
self.run_command.parse(["ls", "--session-id", test_session_id])
Expand Down Expand Up @@ -351,13 +351,13 @@ def test_run_ls(self):
self.run_command.parse(["ls", "--format", "csv"])
run_objs = self.run_command.execute()
assert run_objs
assert run_objs[0].status == 'SUCCESS'
assert run_objs[0].status == "SUCCESS"

# Test success format csv, download default
self.run_command.parse(["ls", "--format", "csv", "--download"])
run_objs = self.run_command.execute()
assert run_objs
assert run_objs[0].status == 'SUCCESS'
assert run_objs[0].status == "SUCCESS"
test_wildcard = os.path.join(os.getcwd(), "run_ls_*")
paths = [n for n in glob.glob(test_wildcard) if os.path.isfile(n)]
assert paths
Expand All @@ -371,7 +371,7 @@ def test_run_ls(self):
])
run_objs = self.run_command.execute()
assert run_objs
assert run_objs[0].status == 'SUCCESS'
assert run_objs[0].status == "SUCCESS"
assert os.path.isfile(test_path)
assert open(test_path, "r").read()
os.remove(test_path)
Expand All @@ -380,13 +380,13 @@ def test_run_ls(self):
self.run_command.parse(["ls"])
run_objs = self.run_command.execute()
assert run_objs
assert run_objs[0].status == 'SUCCESS'
assert run_objs[0].status == "SUCCESS"

# Test success format table, download default
self.run_command.parse(["ls", "--download"])
run_objs = self.run_command.execute()
assert run_objs
assert run_objs[0].status == 'SUCCESS'
assert run_objs[0].status == "SUCCESS"
test_wildcard = os.path.join(os.getcwd(), "run_ls_*")
paths = [n for n in glob.glob(test_wildcard) if os.path.isfile(n)]
assert paths
Expand All @@ -399,7 +399,7 @@ def test_run_ls(self):
["ls", "--download", "--download-path", test_path])
run_objs = self.run_command.execute()
assert run_objs
assert run_objs[0].status == 'SUCCESS'
assert run_objs[0].status == "SUCCESS"
assert os.path.isfile(test_path)
assert open(test_path, "r").read()
os.remove(test_path)
Expand Down
22 changes: 11 additions & 11 deletions datmo/cli/command/tests/test_task.py
Expand Up @@ -138,9 +138,9 @@ def test_task_run(self):
assert result.status == "SUCCESS"

# teardown
self.task_command.parse(["task", "stop", "--all"])
# test when all is passed to stop all
task_stop_command = self.task_command.execute()
# self.task_command.parse(["task", "stop", "--all"])
# # test when all is passed to stop all
# task_stop_command = self.task_command.execute()

@pytest_docker_environment_failed_instantiation(test_datmo_dir)
def test_task_run_string_command(self):
Expand Down Expand Up @@ -172,10 +172,10 @@ def test_task_run_string_command(self):
assert result.results == {"accuracy": "0.45"}
assert result.status == "SUCCESS"

# teardown
self.task_command.parse(["task", "stop", "--all"])
# test when all is passed to stop all
task_stop_command = self.task_command.execute()
# # teardown
# self.task_command.parse(["task", "stop", "--all"])
# # test when all is passed to stop all
# task_stop_command = self.task_command.execute()

# def test_multiple_concurrent_task_run_command(self):
# test_dockerfile = os.path.join(self.temp_dir, "Dockerfile")
Expand Down Expand Up @@ -261,10 +261,10 @@ def test_task_run_notebook(self):
assert "Currently running servers" in result.logs
assert result.status == "SUCCESS"

# teardown
self.task_command.parse(["task", "stop", "--all"])
# test when all is passed to stop all
_ = self.task_command.execute()
# # teardown
# self.task_command.parse(["task", "stop", "--all"])
# # test when all is passed to stop all
# _ = self.task_command.execute()

def test_task_run_invalid_arg(self):
self.__set_variables()
Expand Down
50 changes: 6 additions & 44 deletions datmo/cli/command/workspace.py
Expand Up @@ -12,7 +12,6 @@ def __init__(self, cli_helper):
@Helper.notify_environment_active(TaskController)
@Helper.notify_no_project_found
def notebook(self, **kwargs):
self.task_controller = TaskController()
self.cli_helper.echo(__("info", "cli.workspace.notebook"))
# Creating input dictionaries
snapshot_dict = {}
Expand All @@ -29,32 +28,13 @@ def notebook(self, **kwargs):
"mem_limit": kwargs["mem_limit"]
}

# Create the task object
task_obj = self.task_controller.create()

updated_task_obj = task_obj
# Pass in the task
try:
updated_task_obj = self.task_controller.run(
task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
except Exception as e:
self.logger.error("%s %s" % (e, task_dict))
self.cli_helper.echo(
__("error", "cli.workspace.notebook", task_obj.id))
return False
finally:
self.cli_helper.echo(
__("info", "cli.task.run.stop"))
self.task_controller.stop(updated_task_obj.id)
self.cli_helper.echo(
__("info", "cli.task.run.complete", updated_task_obj.id))

return updated_task_obj
# Run task and return Task object result
return self.task_run_helper(task_dict, snapshot_dict,
"cli.workspace.notebook")

@Helper.notify_environment_active(TaskController)
@Helper.notify_no_project_found
def rstudio(self, **kwargs):
self.task_controller = TaskController()
self.cli_helper.echo(__("info", "cli.workspace.rstudio"))
# Creating input dictionaries
snapshot_dict = {}
Expand All @@ -75,24 +55,6 @@ def rstudio(self, **kwargs):
kwargs["mem_limit"]
}

# Create the task object
task_obj = self.task_controller.create()

updated_task_obj = task_obj
# Pass in the task
try:
updated_task_obj = self.task_controller.run(
task_obj.id, snapshot_dict=snapshot_dict, task_dict=task_dict)
except Exception as e:
self.logger.error("%s %s" % (e, task_dict))
self.cli_helper.echo(
__("error", "cli.workspace.rstudio", task_obj.id))
return False
finally:
self.cli_helper.echo(
__("info", "cli.task.run.stop"))
self.task_controller.stop(updated_task_obj.id)
self.cli_helper.echo(
__("info", "cli.task.run.complete", updated_task_obj.id))

return updated_task_obj
# Run task and return Task object result
return self.task_run_helper(task_dict, snapshot_dict,
"cli.workspace.rstudio")
2 changes: 1 addition & 1 deletion datmo/cli/driver/helper.py
Expand Up @@ -209,4 +209,4 @@ def wrapper(self, *args, **kwargs):

return wrapper

return real_decorator
return real_decorator
2 changes: 1 addition & 1 deletion datmo/core/controller/code/code.py
Expand Up @@ -53,7 +53,7 @@ def create(self, commit_id=None):
create_dict = {
"model_id": self.model.id,
}
## Required args for Code entity
# Required args for Code entity
required_args = ["driver_type", "commit_id"]
for required_arg in required_args:
# Handle Id if provided or not
Expand Down
7 changes: 1 addition & 6 deletions datmo/core/controller/code/driver/file.py
Expand Up @@ -13,7 +13,7 @@
from datmo.core.util.i18n import get as __
from datmo.core.util.exceptions import (PathDoesNotExist, FileIOError,
UnstagedChanges, CodeNotInitialized,
CommitDoesNotExist, CommitFailed)
CommitDoesNotExist)
from datmo.core.controller.code.driver import CodeDriver


Expand Down Expand Up @@ -198,11 +198,6 @@ def create_ref(self, commit_id=None):
return commit_id
# Find all tracked files (_get_tracked_files)
tracked_filepaths = self._get_tracked_files()
# If no tracked filepaths, then commit fails
if not tracked_filepaths:
raise CommitFailed(
__("error",
"controller.code.driver.file.create_ref.cannot_commit"))
# Create the hash of the files (_calculate_commit_hash)
commit_hash = self._calculate_commit_hash(tracked_filepaths)
# Check if the hash already exists with exists_ref
Expand Down
2 changes: 1 addition & 1 deletion datmo/core/controller/code/driver/git.py
Expand Up @@ -492,7 +492,7 @@ def check_unstaged_changes(self):
__("error", "controller.code.driver.git.status",
str(stderr)))
stdout = stdout.decode().strip()
if "working tree clean" not in stdout:
if "clean" not in stdout:
raise UnstagedChanges()
except subprocess.CalledProcessError as e:
raise GitExecutionError(
Expand Down
7 changes: 0 additions & 7 deletions datmo/core/controller/code/driver/tests/test_file.py
Expand Up @@ -191,13 +191,6 @@ def test_create_ref(self):
except CommitDoesNotExist:
failed = True
assert failed
# Test failure, no files to track
failed = False
try:
self.file_code_manager.create_ref()
except CommitFailed:
failed = True
assert failed
# Test successful creation of ref
self.__setup()
result = self.file_code_manager.create_ref()
Expand Down
8 changes: 2 additions & 6 deletions datmo/core/controller/code/tests/test_code.py
Expand Up @@ -48,12 +48,8 @@ def test_create(self):
self.project.init("test3", "test description")

# Test failing for nothing to commit, no id
failed = False
try:
self.code.create()
except CommitFailed:
failed = True
assert failed
result = self.code.create()
assert result

# Test failing for non-existant commit_id
failed = False
Expand Down

0 comments on commit 258fed3

Please sign in to comment.