Skip to content

Commit

Permalink
Merge pull request #52 from datmo/arron
Browse files Browse the repository at this point in the history
Added Sessions and refactoring
  • Loading branch information
pennyfx committed May 1, 2018
2 parents 6228130 + 318de4b commit af271e8
Show file tree
Hide file tree
Showing 16 changed files with 506 additions and 112 deletions.
70 changes: 70 additions & 0 deletions datmo/cli/command/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import print_function

import prettytable

from datmo.core.util.i18n import get as __
from datmo.core.controller.session import SessionController
from datmo.cli.command.project import ProjectCommand


class SessionCommand(ProjectCommand):
def __init__(self, home, cli_helper):
super(SessionCommand, self).__init__(home, cli_helper)
# dest="subcommand" argument will populate a "subcommand" property with the subparsers name
# example "subcommand"="create" or "subcommand"="ls"
snapshot_parser = self.subparsers.add_parser("session", help="Session module")
subcommand_parsers = snapshot_parser.add_subparsers(title="subcommands", dest="subcommand")

create = subcommand_parsers.add_parser("create", help="Create session")
create.add_argument("--name", "-m", dest="name", default="", help="Session name")
create.add_argument("--current", dest="current", action="store_false",
help="Boolean if you want to switch to this session")

delete = subcommand_parsers.add_parser("delete", help="Delete a snapshot by id")
delete.add_argument("--name", dest="name", help="Name of session to delete")

ls = subcommand_parsers.add_parser("ls", help="List sessions")

checkout = subcommand_parsers.add_parser("select", help="Select a session")
checkout.add_argument("--name", dest="name", help="Name of session to select")

self.session_controller = SessionController(home=home)

def create(self, **kwargs):
name = kwargs.get('name')
self.session_controller.create(kwargs)
self.cli_helper.echo(__("info","cli.session.create", name))
return True

def delete(self, **kwargs):
name = kwargs.get('name')
if self.session_controller.delete_by_name(name):
self.cli_helper.echo(__("info", "cli.session.delete", name))
return True


def select(self, **kwargs):
name = kwargs.get("name")
self.cli_helper.echo(__("info", "cli.session.select", name))
return self.session_controller.select(name)

def ls(self, **kwargs):
sessions = self.session_controller.list()
header_list = ["name", "selected", "tasks", "snapshots"]
t = prettytable.PrettyTable(header_list)
for sess in sessions:
snapshot_count = len(self.session_controller.dal.snapshot.query({"session_id": sess.id, "model_id": self.session_controller.model.id }))
task_count = len(self.session_controller.dal.task.query({"session_id": sess.id, "model_id": self.session_controller.model.id }))
t.add_row([sess.name, sess.current, task_count, snapshot_count])

self.cli_helper.echo(t)

return True








3 changes: 2 additions & 1 deletion datmo/cli/command/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datmo.core.util.i18n import get as __
from datmo.cli.command.project import ProjectCommand
from datmo.core.controller.task import TaskController
from datmo.core.util.exceptions import RequiredArgumentMissing


class TaskCommand(ProjectCommand):
Expand Down Expand Up @@ -62,7 +63,7 @@ def run(self, **kwargs):
"gpu": kwargs['gpu'],
"ports": kwargs['ports'],
"interactive": kwargs['interactive'],
"command": kwargs['cmd']
"command": cmd
}

# Create the task object
Expand Down
87 changes: 87 additions & 0 deletions datmo/cli/command/test/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Tests for SessopmCommand
"""
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import shutil
import tempfile
from io import open
try:
to_unicode = unicode
except NameError:
to_unicode = str

import os
from datmo.cli.driver.helper import Helper
from datmo.cli.command.session import SessionCommand
from datmo.cli.command.project import ProjectCommand


class TestSession():
def setup_class(self):
# provide mountable tmp directory for docker
tempfile.tempdir = '/tmp'
test_datmo_dir = os.environ.get('TEST_DATMO_DIR',
tempfile.gettempdir())
self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
self.cli_helper = Helper()
self.session_command = SessionCommand(self.temp_dir, self.cli_helper)

init = ProjectCommand(self.temp_dir, self.cli_helper)
init.parse([
"init",
"--name", "foobar",
"--description", "test model"])
init.execute()

def teardown_class(self):
shutil.rmtree(self.temp_dir)

def test_session_create(self):
self.session_command.parse([
"session",
"create",
"--name","pizza"
])
assert self.session_command.execute()

def test_session_select(self):
self.session_command.parse([
"session",
"select",
"--name","pizza"
])
assert self.session_command.execute()
current = 0
for s in self.session_command.session_controller.list():
print("%s - %s" % (s.name, s.current))
if s.current == True:
current = current + 1
assert current == 1

def test_session_ls(self):
self.session_command.parse([
"session",
"ls"
])
assert self.session_command.execute()

def test_session_delete(self):
self.session_command.parse([
"session",
"delete",
"--name","pizza"
])
assert self.session_command.execute()
session_removed = True
for s in self.session_command.session_controller.list():
if s.name == 'pizza':
session_removed = False
assert session_removed
self.session_command.parse([
"session",
"ls"
])
assert self.session_command.execute()
19 changes: 12 additions & 7 deletions datmo/cli/command/test/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def teardown_class(self):
pass

def __set_variables(self):
self.init = ProjectCommand(self.temp_dir, self.cli_helper)
self.init.parse([
init = ProjectCommand(self.temp_dir, self.cli_helper)
init.parse([
"init",
"--name", "foobar",
"--description", "test model"])
self.init.execute()
init.execute()

self.task = TaskCommand(self.temp_dir, self.cli_helper)

# Create environment_driver definition
Expand All @@ -64,9 +65,8 @@ def test_task_project_not_init(self):
failed = True
assert failed

def test_datmo_task_run(self):
def test_datmo_task_run_should_fail1(self):
self.__set_variables()

# Test failure case
self.task.parse([
"task",
Expand All @@ -79,6 +79,8 @@ def test_datmo_task_run(self):
failed = True
assert failed

def test_datmo_task_run_should_fail2(self):
self.__set_variables()
# Test failure case execute
test_command = ["yo", "yo"]
self.task.parse([
Expand All @@ -89,6 +91,9 @@ def test_datmo_task_run(self):
result = self.task.execute()
assert not result


def test_datmo_task_run(self):
self.__set_variables()
# Test success case
test_command = ["sh", "-c", "echo accuracy:0.45"]
test_gpu = True # TODO: implement in controller
Expand Down Expand Up @@ -195,12 +200,12 @@ def test_task_stop(self):
task_stop_command = self.task.execute()
assert task_stop_command == True

def test_task_stop_invalid_task_id(self):
# Passing wrong task id
test_task_id = "task_id"
self.task.parse([
"task",
"stop",
"--id", test_task_id
"--id", "invalid-task-id"
])

# test when wrong task id is passed to stop it
Expand Down
5 changes: 3 additions & 2 deletions datmo/core/controller/code/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def create(self, commit_id=None):
self.code_driver.create_ref()
# If code object with commit id exists, return it
results = self.dal.code.query({
"commit_id": create_dict[required_arg]
"commit_id": create_dict[required_arg],
"model_id": self.model.id
})
if results: return results[0];
else:
Expand All @@ -74,7 +75,7 @@ def create(self, commit_id=None):

def list(self):
# TODO: Add time filters
return self.dal.code.query({})
return self.dal.code.query({"model_id": self.model.id })

def delete(self, code_id):
"""Delete all traces of Code object
Expand Down

0 comments on commit af271e8

Please sign in to comment.