Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/datmo/datmo
Browse files Browse the repository at this point in the history
  • Loading branch information
nmwalsh committed Apr 24, 2018
2 parents d602340 + 556f570 commit 8ef6025
Show file tree
Hide file tree
Showing 15 changed files with 163 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.idea/*
**/.DS_Store

# Created by .ignore support plugin (hsz.mobi)
### Python template
Expand Down
2 changes: 0 additions & 2 deletions datmo/cli/command/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from datmo.core.util.i18n import get as __
from datmo.cli.command.project import ProjectCommand
from datmo.core.controller.snapshot import SnapshotController
from datmo.core.util.exceptions import ProjectNotInitializedException


class SnapshotCommand(ProjectCommand):
Expand Down Expand Up @@ -52,7 +51,6 @@ def __init__(self, home, cli_helper):

ls = subcommand_parsers.add_parser("ls", help="List snapshots")
ls.add_argument("--session-id", dest="session_id", default=None, help="Session ID to filter")
ls.add_argument("--session-name", dest="session_name", default=None, help="Session name to filter")
ls.add_argument("--all", "-a", dest="details", action="store_true",
help="Show detailed snapshot information")

Expand Down
9 changes: 4 additions & 5 deletions datmo/core/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,17 @@ class BaseController(object):
Return the config dictionary based on key
get_config_defaults()
Return the configuration defaults
"""

def __init__(self, home):
self.home = home
if not os.path.isdir(self.home):
raise InvalidProjectPathException(__("error",
"controller.base.__init__",
home))
self.config = JSONStore(os.path.join(self.home,
".datmo",
".config"))
if not os.path.isdir(self.home):
raise InvalidProjectPathException(__("error",
"controller.base.__init__",
home))
# property caches and initial values
self._dal = None
self._model = None
Expand Down
4 changes: 2 additions & 2 deletions datmo/core/controller/code/driver/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(self, filepath, execpath, remote_url=None):
# Check if filepath exists
if not os.path.exists(self.filepath):
raise DoesNotExistException(__("error",
"controller.code.driver.git.__init__.dne",
filepath))
"controller.code.driver.git.__init__.dne",
filepath))
self.execpath = execpath
# Check the execpath and the version
try:
Expand Down
Binary file removed datmo/core/controller/environment/.DS_Store
Binary file not shown.
4 changes: 2 additions & 2 deletions datmo/core/controller/file/driver/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(self, filepath):
# Check if filepath exists
if not os.path.exists(self.filepath):
raise DoesNotExistException(__("error",
"controller.file.driver.local.__init__",
filepath))
"controller.file.driver.local.__init__",
filepath))
self._is_initialized = self.is_initialized
self.type = "local"

Expand Down
8 changes: 7 additions & 1 deletion datmo/core/controller/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datmo.core.util.i18n import get as __
from datmo.core.util.json_store import JSONStore
from datmo.core.util.exceptions import FileIOException, RequiredArgumentMissing, \
ProjectNotInitializedException
ProjectNotInitializedException, SessionDoesNotExistException, EntityNotFound


class SnapshotController(BaseController):
Expand Down Expand Up @@ -303,6 +303,12 @@ def checkout(self, id):
def list(self, session_id=None):
query = {}
if session_id:
try:
self.dal.session.get_by_id(session_id)
except EntityNotFound:
raise SessionDoesNotExistException(__("error",
"controller.snapshot.list",
session_id))
query['session_id'] = session_id
return self.dal.snapshot.query(query)

Expand Down
12 changes: 10 additions & 2 deletions datmo/core/controller/test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from datmo.core.controller.base import BaseController
from datmo.core.controller.code.driver.git import GitCodeDriver
from datmo.core.util.exceptions import \
DatmoModelNotInitializedException
DatmoModelNotInitializedException, InvalidProjectPathException


class TestBaseController():
Expand All @@ -21,11 +21,19 @@ def setup_method(self):
test_datmo_dir = os.environ.get('TEST_DATMO_DIR',
tempfile.gettempdir())
self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)
self.base = BaseController(self.temp_dir)
self.base = BaseController(home=self.temp_dir)

def teardown_method(self):
shutil.rmtree(self.temp_dir)

def test_failed_controller_instantiation(self):
failed = False
try:
BaseController(home=os.path.join("does", "not", "exist"))
except InvalidProjectPathException:
failed = True
assert failed

def test_instantiation(self):
assert self.base != None

Expand Down
11 changes: 10 additions & 1 deletion datmo/core/controller/test/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from datmo.core.controller.project import ProjectController
from datmo.core.controller.snapshot import SnapshotController
from datmo.core.util.exceptions import EntityNotFound, \
DoesNotExistException, GitCommitDoesNotExist
DoesNotExistException, GitCommitDoesNotExist, \
SessionDoesNotExistException


class TestSnapshotController():
Expand Down Expand Up @@ -209,6 +210,14 @@ def test_checkout(self):
os.path.isdir(snapshot_obj_1_path)

def test_list(self):
# Check for error if incorrect session given
failed = False
try:
self.snapshot.list(session_id="does_not_exist")
except SessionDoesNotExistException:
failed = True
assert failed

# Create files to add
self.snapshot.file_driver.create("dirpath1", dir=True)
self.snapshot.file_driver.create("dirpath2", dir=True)
Expand Down
Binary file removed datmo/core/util/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions datmo/core/util/lang/en.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
"controller.snapshot.create.arg": "Required argument missing to create snapshot: %s",
"controller.snapshot.create.file_config": "Config file does not exist",
"controller.snapshot.create.file_stat": "Stats file does not exist",
"controller.snapshot.list": "Session does not exist for id: %s",
"controller.snapshot.delete.arg": "Delete argument %s not present in input",
"controller.task.__init__": "Project has not been initialized",
"controller.task._run_helper.env_dne": "Environment specified does not exist: %s",
Expand Down
66 changes: 61 additions & 5 deletions datmo/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

def create(message, label=None, commit_id=None, environment_id=None, filepaths=None,
config=None, stats=None, home=None):
"""
Create a snapshot within a project
"""Create a snapshot within a project
The project must be created before this is implemented. You can do that by using
the following command::
Expand Down Expand Up @@ -43,7 +42,7 @@ def create(message, label=None, commit_id=None, environment_id=None, filepaths=N
Returns
-------
Snapshot
returns a snapshot entity for reference
returns a snapshot entity (as defined in datmo.core.entity.snapshot)
Examples
--------
Expand All @@ -53,7 +52,6 @@ def create(message, label=None, commit_id=None, environment_id=None, filepaths=N
>>> import datmo
>>> datmo.snapshot.create(message="my first snapshot", filepaths=["/path/to/a/large/file"], config={"test": 0.4, "test2": "string"}, stats={"accuracy": 0.94})
Snapshot()
"""
if not home:
home = os.getcwd()
Expand All @@ -79,4 +77,62 @@ def create(message, label=None, commit_id=None, environment_id=None, filepaths=N
if label:
snapshot_create_dict['label'] = label

return snapshot_controller.create(snapshot_create_dict)
return snapshot_controller.create(snapshot_create_dict)

def ls(session_id=None, filter=None, home=None):
"""List snapshots within a project
The project must be created before this is implemented. You can do that by using
the following command::
$ datmo init
Parameters
----------
session_id : str, optional
a description of the snapshot for later reference
(default is None, which means no session filter is given)
filter : str, optional
a string to use to filter from message and label
(default is to give all snapshots, unless provided a specific string. eg: best)
home : str, optional
absolute home path of the project
(default is None, which will use the CWD as the project path)
Returns
-------
list
returns a list of Snapshot entities (as defined in datmo.core.entity.snapshot)
Examples
--------
You can use this function within a project repository to list snapshots.
>>> import datmo
>>> snapshots = datmo.snapshot.ls()
"""
if not home:
home = os.getcwd()
snapshot_controller = SnapshotController(home=home)

# add arguments if they are not None
if not session_id:
session_id = snapshot_controller.current_session.id

snapshot_objs = snapshot_controller.list(session_id)

# Filtering Snapshots
# TODO: move to list function in SnapshotController
# Add in preliminary snapshots if no filter
filtered_snapshot_objs = [snapshot_obj for snapshot_obj in snapshot_objs
if snapshot_obj.visible and not filter]
# If filter is present then use it and only add those that pass filter
for snapshot_obj in snapshot_objs:
if snapshot_obj.visible:
if filter and (filter in snapshot_obj.message
or filter in snapshot_obj.label):
filtered_snapshot_objs.append(snapshot_obj)

# Return Snapshot entities
return filtered_snapshot_objs
69 changes: 65 additions & 4 deletions datmo/test/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import shutil
import tempfile

from datmo.snapshot import create
from datmo.snapshot import create, ls
from datmo.core.controller.project import ProjectController
from datmo.core.util.exceptions import GitCommitDoesNotExist, \
DoesNotExistException
InvalidProjectPathException, SessionDoesNotExistException


class TestSnapshotModule():
Expand All @@ -25,6 +25,15 @@ def teardown_method(self):
shutil.rmtree(self.temp_dir)

def test_create(self):
# check project is not initialized if wrong home
failed = False
try:
create(message="test",
home=os.path.join("does","not", "exist"))
except InvalidProjectPathException:
failed = True
assert failed

# Create a snapshot with default params
# (fails w/ no commit)
failed = False
Expand All @@ -35,7 +44,6 @@ def test_create(self):
assert failed

# Create a snapshot with default params and files to commit
# (fails w/ no environment)
test_filepath = os.path.join(self.temp_dir, "script.py")
with open(test_filepath, "w") as f:
f.write("import numpy\n")
Expand Down Expand Up @@ -64,4 +72,57 @@ def test_create(self):
assert snapshot_obj_2.file_collection_id
assert snapshot_obj_2.config == {}
assert snapshot_obj_2.stats == {}
assert snapshot_obj_2 != snapshot_obj_1
assert snapshot_obj_2 != snapshot_obj_1

def test_ls(self):
# check project is not initialized if wrong home
failed = False
try:
ls(home=os.path.join("does","not", "exist"))
except InvalidProjectPathException:
failed = True
assert failed

# check session does not exist if wrong session
failed = False
try:
ls(session_id="does_not_exist", home=self.temp_dir)
except SessionDoesNotExistException:
failed = True
assert failed

# create with default params and files to commit
test_filepath = os.path.join(self.temp_dir, "script.py")
with open(test_filepath, "w") as f:
f.write("import numpy\n")
f.write("import sklean\n")

create(message="test1", home=self.temp_dir)

# list all snapshots with no filters
snapshot_list_1 = ls(home=self.temp_dir)

assert snapshot_list_1
assert len(list(snapshot_list_1)) == 1

# Create a snapshot with default params, files, and environment
test_filepath = os.path.join(self.temp_dir, "Dockerfile")
with open(test_filepath, "w") as f:
f.write("FROM datmo/xgboost:cpu")
create(message="test2", home=self.temp_dir)

# list all snapshots with no filters (works when more than 1 snapshot)
snapshot_list_2 = ls(home=self.temp_dir)

assert snapshot_list_2
assert len(list(snapshot_list_2)) == 2

# list snapshots with specific filter
snapshot_list_3 = ls(filter='test2', home=self.temp_dir)

assert len(list(snapshot_list_3)) == 1

# list snapshots with filter of none
snapshot_list_4 = ls(filter='test3', home=self.temp_dir)

assert len(list(snapshot_list_4)) == 0
Binary file removed docs/.DS_Store
Binary file not shown.
Binary file removed docs/source/.DS_Store
Binary file not shown.

0 comments on commit 8ef6025

Please sign in to comment.