Skip to content

Commit

Permalink
Merge pull request #21 from datmo/arron
Browse files Browse the repository at this point in the history
Tests and refactoring
  • Loading branch information
asampat3090 committed Apr 26, 2018
2 parents e26b1ee + a914562 commit b1b4b32
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 30 deletions.
56 changes: 26 additions & 30 deletions datmo/core/controller/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,19 @@ def create(self, incoming_dictionary):
}

# Code setup
self.__code_setup(incoming_dictionary, create_dict)
self._code_setup(incoming_dictionary, create_dict)

# Environment setup
self.__env_setup(incoming_dictionary, create_dict)
self._env_setup(incoming_dictionary, create_dict)

# File setup
self.__file_setup(incoming_dictionary, create_dict)
self._file_setup(incoming_dictionary, create_dict)

# Config setup
self.__config_setup(incoming_dictionary, create_dict)
self._config_setup(incoming_dictionary, create_dict)

# Stats setup
self.__stats_setup(incoming_dictionary, create_dict)
self._stats_setup(incoming_dictionary, create_dict)

# If snapshot object with required args already exists, return it
# DO NOT create a new snapshot with the same required arguments
Expand Down Expand Up @@ -223,8 +223,8 @@ def delete(self, snapshot_id):
"snapshot_id"))
return self.dal.snapshot.delete(snapshot_id)

def __code_setup(self, incoming_dictionary, create_dict):
"""Set the code_id by using:
def _code_setup(self, incoming_dictionary, create_dict):
""" Set the code_id by using:
1. code_id
2. commit_id string, which creates a new code_id
3. create a new code id
Expand All @@ -245,8 +245,8 @@ def __code_setup(self, incoming_dictionary, create_dict):
else:
create_dict['code_id'] = self.code.create().id

def __env_setup(self, incoming_dictionary, create_dict):
"""Create or add environment to create_dict for Snapshot entity
def _env_setup(self, incoming_dictionary, create_dict):
""" TODO:
Parameters
----------
Expand All @@ -271,8 +271,8 @@ def __env_setup(self, incoming_dictionary, create_dict):
create_dict['environment_id'] = self.environment.\
create({}).id

def __file_setup(self, incoming_dictionary, create_dict):
""" Creates file collections and adds it to the create dict
def _file_setup(self, incoming_dictionary, create_dict):
""" TODO:
Parameters
----------
Expand All @@ -293,12 +293,11 @@ def __file_setup(self, incoming_dictionary, create_dict):
create_dict['file_collection_id'] = self.file_collection.\
create([]).id

def __config_setup(self, incoming_dictionary, create_dict):
"""Fills in snapshot config by having one of the following:
def _config_setup(self, incoming_dictionary, create_dict):
""" Fills in snapshot config by having one of the following:
1. config = JSON object
2. config_filepath = some location where a json file exists
3. config_filename = just the file name and we'll try to find it
3. config_filename = just the file nam
Parameters
----------
incoming_dictionary : dict
Expand All @@ -319,17 +318,18 @@ def __config_setup(self, incoming_dictionary, create_dict):
# If path exists transform file to config dict
config_json_driver = JSONStore(incoming_dictionary['config_filepath'])
create_dict['config'] = config_json_driver.to_dict()
else:
elif "config_filename" in incoming_dictionary:
config_filename = incoming_dictionary['config_filename'] \
if "config_filename" in incoming_dictionary else "config.json"
create_dict['config'] = \
self.__find_in_file_collection(config_filename, create_dict['file_collection_id'])
create_dict['config'] = self._find_in_filecollection(config_filename, create_dict['file_collection_id'])
else:
create_dict['config'] = {}

def __stats_setup(self, incoming_dictionary, create_dict):
def _stats_setup(self, incoming_dictionary, create_dict):
"""Fills in snapshot stats by having one of the following:
1. stats = JSON object
2. stats_filepath = some location where a json file exists
3. stats_filename = just the file name and we'll try to find it
3. stats_filename = just the file name
Parameters
----------
Expand All @@ -352,19 +352,15 @@ def __stats_setup(self, incoming_dictionary, create_dict):
# If path exists transform file to config dict
stats_json_driver = JSONStore(incoming_dictionary['stats_filepath'])
create_dict['stats'] = stats_json_driver.to_dict()
else:
elif "stats_filename" in incoming_dictionary:
stats_filename = incoming_dictionary['stats_filename'] \
if "stats_filename" in incoming_dictionary else "stats.json"
create_dict['stats'] = \
self.__find_in_file_collection(stats_filename, create_dict['file_collection_id'])

def __find_in_file_collection(self, file_to_find, file_collection_id):
"""Attempts to find a JSON file within the file collection
create_dict['stats'] = self._find_in_filecollection(stats_filename, create_dict['file_collection_id'])
else:
create_dict['stats'] = {}

Parameters
----------
file_to_find : str
filename for file in collection
def _find_in_filecollection(self, file_to_find, file_collection_id):
""" Attempts to find a file within the file collection
Returns
-------
Expand Down
137 changes: 137 additions & 0 deletions datmo/core/controller/test/test_snapshot_private.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Tests for SnapshotController
"""
import os
import shutil
import tempfile

from datmo.core.controller.project import ProjectController
from datmo.core.controller.snapshot import SnapshotController


class TestSnapshotController():
def setup_method(self):
# provide mountable tmp directory for docker
tempfile.tempdir = '/tmp'
test_datmo_dir = os.environ.get('TEST_DATMO_DIR',
tempfile.gettempdir())
self.temp_dir = tempfile.mkdtemp(dir=test_datmo_dir)

self.project = ProjectController(self.temp_dir)
self.project.init("test", "test description")
self.snapshot = SnapshotController(self.temp_dir)

# Create environment_driver definition
self.env_def_path = os.path.join(self.temp_dir,
"Dockerfile")
with open(self.env_def_path, "w") as f:
f.write(str("FROM datmo/xgboost:cpu"))

# Create config
self.config_filepath = os.path.join(self.snapshot.home,
"config.json")
with open(self.config_filepath, "w") as f:
f.write(str('{"foo":1}'))

# Create stats
self.stats_filepath = os.path.join(self.snapshot.home,
"stats.json")
with open(self.stats_filepath, "w") as f:
f.write(str('{"bar":1}'))

# Create test file
self.filepath = os.path.join(self.snapshot.home,
"file.txt")
with open(self.filepath, "w") as f:
f.write(str("test"))

def teardown_method(self):
shutil.rmtree(self.temp_dir)

def test_code_setup_with_code_id(self):
val = 1
incoming_data = {"code_id": val}
final_data = {}
self.snapshot._code_setup(incoming_data, final_data)
assert final_data['code_id'] == val

def test_code_setup_with_commit_id(self):
val = "f38a8ace"
incoming_data = {"commit_id": val}
final_data = {}
self.snapshot._code_setup(incoming_data, final_data)
assert final_data['code_id']

def test_code_setup_with_none(self):
incoming_data = {}
final_data = {}
self.snapshot._code_setup(incoming_data, final_data)
assert final_data['code_id']

def test_env_setup_with_none(self):
incoming_data = {}
final_data = {}
self.snapshot._env_setup(incoming_data, final_data)
assert final_data['environment_id']

def test_file_setup_with_none(self):
incoming_data = {}
final_data = {}
self.snapshot._file_setup(incoming_data, final_data)
assert final_data['file_collection_id']

def test_file_setup_with_filepaths(self):
incoming_data = {"filepaths": [self.filepath]}
final_data = {}
self.snapshot._file_setup(incoming_data, final_data)
assert final_data['file_collection_id']

def test_config_setup_with_json(self):
incoming_data = {"config":{"foo":1}}
final_data = {}
self.snapshot._config_setup(incoming_data, final_data)
assert final_data['config']["foo"] == 1

def test_config_setup_with_filepath(self):
incoming_data = {"config_filepath": self.config_filepath }
final_data = {}
self.snapshot._config_setup(incoming_data, final_data)
assert final_data['config']["foo"] == 1

def test_config_setup_with_filename(self):
incoming_data = {"config_filename": "config.json" }
final_data = {}
self.snapshot._file_setup(incoming_data, final_data)
self.snapshot._config_setup(incoming_data, final_data)
assert final_data['config']["foo"] == 1

def test_config_setup_with_empty(self):
incoming_data = {}
final_data = {}
self.snapshot._config_setup(incoming_data, final_data)
assert final_data['config'] == {}

def test_stats_setup_with_json(self):
incoming_data = {"stats":{"bar":1}}
final_data = {}
self.snapshot._stats_setup(incoming_data, final_data)
assert final_data['stats']["bar"] == 1

def test_stats_setup_with_filepath(self):
incoming_data = {"stats_filepath": self.stats_filepath }
final_data = {}
self.snapshot._stats_setup(incoming_data, final_data)
assert final_data['stats']["bar"] == 1

def test_stats_setup_with_empty(self):
incoming_data = {}
final_data = {}
self.snapshot._stats_setup(incoming_data, final_data)
assert final_data['stats'] == {}

def test_stats_setup_with_filename(self):
incoming_data = {"stats_filename": "stats.json" }
final_data = {}
self.snapshot._file_setup(incoming_data, final_data)
self.snapshot._stats_setup(incoming_data, final_data)
assert final_data['stats']["bar"] == 1

0 comments on commit b1b4b32

Please sign in to comment.