-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #52 from datmo/arron
Added Sessions and refactoring
- Loading branch information
Showing
16 changed files
with
506 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from __future__ import print_function | ||
|
||
import prettytable | ||
|
||
from datmo.core.util.i18n import get as __ | ||
from datmo.core.controller.session import SessionController | ||
from datmo.cli.command.project import ProjectCommand | ||
|
||
|
||
class SessionCommand(ProjectCommand): | ||
def __init__(self, home, cli_helper): | ||
super(SessionCommand, self).__init__(home, cli_helper) | ||
# dest="subcommand" argument will populate a "subcommand" property with the subparsers name | ||
# example "subcommand"="create" or "subcommand"="ls" | ||
snapshot_parser = self.subparsers.add_parser("session", help="Session module") | ||
subcommand_parsers = snapshot_parser.add_subparsers(title="subcommands", dest="subcommand") | ||
|
||
create = subcommand_parsers.add_parser("create", help="Create session") | ||
create.add_argument("--name", "-m", dest="name", default="", help="Session name") | ||
create.add_argument("--current", dest="current", action="store_false", | ||
help="Boolean if you want to switch to this session") | ||
|
||
delete = subcommand_parsers.add_parser("delete", help="Delete a snapshot by id") | ||
delete.add_argument("--name", dest="name", help="Name of session to delete") | ||
|
||
ls = subcommand_parsers.add_parser("ls", help="List sessions") | ||
|
||
checkout = subcommand_parsers.add_parser("select", help="Select a session") | ||
checkout.add_argument("--name", dest="name", help="Name of session to select") | ||
|
||
self.session_controller = SessionController(home=home) | ||
|
||
def create(self, **kwargs): | ||
name = kwargs.get('name') | ||
self.session_controller.create(kwargs) | ||
self.cli_helper.echo(__("info","cli.session.create", name)) | ||
return True | ||
|
||
def delete(self, **kwargs): | ||
name = kwargs.get('name') | ||
if self.session_controller.delete_by_name(name): | ||
self.cli_helper.echo(__("info", "cli.session.delete", name)) | ||
return True | ||
|
||
|
||
def select(self, **kwargs): | ||
name = kwargs.get("name") | ||
self.cli_helper.echo(__("info", "cli.session.select", name)) | ||
return self.session_controller.select(name) | ||
|
||
def ls(self, **kwargs): | ||
sessions = self.session_controller.list() | ||
header_list = ["name", "selected", "tasks", "snapshots"] | ||
t = prettytable.PrettyTable(header_list) | ||
for sess in sessions: | ||
snapshot_count = len(self.session_controller.dal.snapshot.query({"session_id": sess.id, "model_id": self.session_controller.model.id })) | ||
task_count = len(self.session_controller.dal.task.query({"session_id": sess.id, "model_id": self.session_controller.model.id })) | ||
t.add_row([sess.name, sess.current, task_count, snapshot_count]) | ||
|
||
self.cli_helper.echo(t) | ||
|
||
return True | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
""" | ||
Tests for SessopmCommand | ||
""" | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import shutil | ||
import tempfile | ||
from io import open | ||
try: | ||
to_unicode = unicode | ||
except NameError: | ||
to_unicode = str | ||
|
||
import os | ||
from datmo.cli.driver.helper import Helper | ||
from datmo.cli.command.session import SessionCommand | ||
from datmo.cli.command.project import ProjectCommand | ||
|
||
|
||
class TestSession(): | ||
def setup_class(self): | ||
# provide mountable tmp directory for docker | ||
tempfile.tempdir = '/tmp' | ||
test_datmo_dir = os.environ.get('TEST_DATMO_DIR', | ||
tempfile.gettempdir()) | ||
self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir) | ||
self.cli_helper = Helper() | ||
self.session_command = SessionCommand(self.temp_dir, self.cli_helper) | ||
|
||
init = ProjectCommand(self.temp_dir, self.cli_helper) | ||
init.parse([ | ||
"init", | ||
"--name", "foobar", | ||
"--description", "test model"]) | ||
init.execute() | ||
|
||
def teardown_class(self): | ||
shutil.rmtree(self.temp_dir) | ||
|
||
def test_session_create(self): | ||
self.session_command.parse([ | ||
"session", | ||
"create", | ||
"--name","pizza" | ||
]) | ||
assert self.session_command.execute() | ||
|
||
def test_session_select(self): | ||
self.session_command.parse([ | ||
"session", | ||
"select", | ||
"--name","pizza" | ||
]) | ||
assert self.session_command.execute() | ||
current = 0 | ||
for s in self.session_command.session_controller.list(): | ||
print("%s - %s" % (s.name, s.current)) | ||
if s.current == True: | ||
current = current + 1 | ||
assert current == 1 | ||
|
||
def test_session_ls(self): | ||
self.session_command.parse([ | ||
"session", | ||
"ls" | ||
]) | ||
assert self.session_command.execute() | ||
|
||
def test_session_delete(self): | ||
self.session_command.parse([ | ||
"session", | ||
"delete", | ||
"--name","pizza" | ||
]) | ||
assert self.session_command.execute() | ||
session_removed = True | ||
for s in self.session_command.session_controller.list(): | ||
if s.name == 'pizza': | ||
session_removed = False | ||
assert session_removed | ||
self.session_command.parse([ | ||
"session", | ||
"ls" | ||
]) | ||
assert self.session_command.execute() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.