Skip to content

Commit

Permalink
abstract data interface
Browse files Browse the repository at this point in the history
Abstract the interface for data storage from anything mongodb specific,
thereby allowing for developers to more easily incorporate other
backends such as file storage and tinydb.

Basic implementation for file storage with corresponding unit tests.
  • Loading branch information
gideonite committed Jun 5, 2017
1 parent 5464584 commit 5ecdcf4
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 6 deletions.
20 changes: 20 additions & 0 deletions sacredboard/app/data/datastorage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class Cursor:
def __init__(self):
pass

def count(self):
raise NotImplemented()

def __iter__(self):
raise NotImplemented()

class DataStorage:
def __init__(self):
pass

def get_run(self, run_id):
raise NotImplemented()

def get_runs(self, sort_by=None, sort_direction=None,
start=0, limit=None, query={"type": "and", "filters": []}):
raise NotImplemented()
64 changes: 64 additions & 0 deletions sacredboard/app/data/filestorage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import datetime
import os
import json

from sacredboard.app.data.datastorage import Cursor, DataStorage

config_json = "config.json"
run_json = "run.json"

def _path_to_config(basepath, run_id):
return os.path.join(basepath, str(run_id), config_json)

def _path_to_run(basepath, run_id):
return os.path.join(basepath, str(run_id), run_json)

def _read_json(path_to_json):
with open(path_to_json) as f:
return json.load(f)

def _create_run(runjson, configjson):
runjson["config"] = configjson

# TODO probably want a smarter way of detecting which values have type "time."
for k in ["start_time", "stop_time", "heartbeat"]:
runjson[k] = datetime.datetime.strptime(runjson[k], '%Y-%m-%dT%H:%M:%S.%f')

return runjson

class FileStoreCursor(Cursor):
def __init__(self, count, iterable):
self.iterable = iterable
self._count = count

def count(self):
return self._count

def __iter__(self):
return iter(self.iterable)

class FileStorage(DataStorage):
def __init__(self, path_to_dir):
super().__init__()
self.path_to_dir = os.path.expanduser(path_to_dir)

def get_run(self, run_id):
config = _read_json(_path_to_config(self.path_to_dir, run_id))
run = _read_json(_path_to_run(self.path_to_dir, run_id))
return _create_run(run, config)

def get_runs(self, sort_by=None, sort_direction=None,
start=0, limit=None, query={"type": "and", "filters": []}):

all_run_ids = os.listdir(self.path_to_dir)

def run_iterator():
blacklist = set(["_sources"])
for id in all_run_ids:
if id in blacklist:
continue

yield self.get_run(id)

count = len(all_run_ids)
return FileStoreCursor(count, run_iterator())
17 changes: 15 additions & 2 deletions sacredboard/app/data/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,19 @@
import bson
import pymongo

from sacredboard.app.data.datastorage import Cursor, DataStorage

class PyMongoDataAccess:
class MongoDbCursor(Cursor):
def __init__(self, mongodb_cursor):
self.mongodb_cursor = mongodb_cursor

def count(self):
return self.mongodb_cursor.count()

def __iter__(self):
return self.mongodb_cursor

class PyMongoDataAccess(DataStorage):
"""Access records in MongoDB."""

RUNNING_DEAD_RUN_CLAUSE = {
Expand All @@ -19,6 +30,7 @@ def __init__(self, uri, database_name, collection_name):
Better use the static methods build_data_access
or build_data_access_with_uri
"""
super().__init__()
self._uri = uri
self._db_name = database_name
self._client = None
Expand Down Expand Up @@ -72,7 +84,8 @@ def get_runs(self, sort_by=None, sort_direction=None,
cursor = cursor.skip(start)
if limit is not None:
cursor = cursor.limit(limit)
return cursor

return MongoDbCursor(cursor)

def get_run(self, run_id):
try:
Expand Down
18 changes: 14 additions & 4 deletions sacredboard/bootstrap.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gevent.pywsgi import WSGIServer

from sacredboard.app.config import jinja_filters
from sacredboard.app.data.filestorage import FileStorage
from sacredboard.app.data.mongodb import PyMongoDataAccess
from sacredboard.app.webapi import routes

Expand All @@ -35,13 +36,15 @@
"You might need it if you use a custom collection name "
"or Sacred v0.6 (which used default.runs). "
"Default: runs")
@click.option("-F", default="",
help="Path to directory containing experiments.")
@click.option("--no-browser", is_flag=True, default=False,
help="Do not open web browser automatically.")
@click.option("--debug", is_flag=True, default=False,
help="Run the application in Flask debug mode "
"(for development).")
@click.version_option()
def run(debug, no_browser, m, mu, mc):
def run(debug, no_browser, m, mu, mc, f):
"""
Sacredboard.
Expand Down Expand Up @@ -76,12 +79,20 @@ def run(debug, no_browser, m, mu, mc):
Note: MongoDB must be listening on localhost.
"""
add_mongo_config(app, m, mu, mc)

if m:
add_mongo_config(app, m, mu, mc)
app.config["data"].connect()
elif f:
app.config["data"] = FileStorage(f)
else:
print("Must specify either a mongodb instance or a path to a file storage.")

app.config['DEBUG'] = debug
app.debug = debug
jinja_filters.setup_filters(app)
routes.setup_routes(app)
app.config["data"].connect()

if debug:
app.run(host="0.0.0.0", debug=True)
else:
Expand All @@ -98,7 +109,6 @@ def run(debug, no_browser, m, mu, mc):
http_server.serve_forever()
break


def add_mongo_config(app, simple_connection_string,
mongo_uri, collection_name):
"""
Expand Down
81 changes: 81 additions & 0 deletions sacredboard/tests/data/test_filestorage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# coding=utf-8
import bson
import pytest
import json
import tempfile
import os

from sacredboard.app.data.filestorage import FileStorage


def create_tmp_datastore():
'''
Rather than mocking the file system, this actually creates some temporary files that emulate the file store system
in Sacred. Unfortunately, Sacred and Sacredboard are completely decoupled, which makes it impossible to ensure that
this standard is upheld throughout the sacred system.
'''
config = {"length": None, "n_input": 255, "batch_size": None,
"dataset_path": "./german-nouns.hdf5", "validation_ds": "validation",
"log_dir": "./log/rnn500_dropout0.5_lrate1e-4_minibatch_1000steps",
"seed": 144363069, "dropout_keep_probability": 0.5,
"max_character_ord": 255, "training_ds": "training", "num_classes": 3,
"training_steps": 1000, "learning_rate": 0.0001, "hidden_size": 500}

run = {"status": "COMPLETED",
"_id": "57f9efb2e4b8490d19d7c30e",
"info": {}, "resources": [],
"host": {"os": "Linux",
"os_info": "Linux-3.16.0-38-generic-x86_64-with-LinuxMint-17.2-rafaela",
"cpu": "Intel(R) Core(TM) i3 CPU M 370 @ 2.40GHz",
"python_version": "3.4.3",
"python_compiler": "GCC 4.8.4",
"cpu_count": 4,
"hostname": "ntbacer"},
"experiment": {"doc": None, "sources": [[
"/home/martin/mnt/noun-classification/train_model.py",
"86aaa9b81d6e32a181598ed78bb1d7a1"]],
"dependencies": [["h5py", "2.6.0"],
["numpy", "1.11.2"],
["sacred", "0.6.10"]],
"name": "German nouns"},
"result": 2403.52, "artifacts": [], "comment": "",
# N.B. time formatting is different between mongodb and file store.
"start_time": "2017-06-02T07:13:05.305845",
"stop_time": "2017-06-02T07:14:02.455460",
"heartbeat": "2017-06-02T07:14:02.452597",
"captured_out": "Output: \n"}

experiment_dir = tempfile.mkdtemp()
experiment42 = os.path.join(experiment_dir, "42") # experiment number 42
os.mkdir(experiment42)

with open(os.path.join(experiment42, "config.json"), 'w') as config_file:
json.dump(config, config_file)

with open(os.path.join(experiment42, "run.json"), 'w') as run_file:
json.dump(run, run_file)

return experiment_dir

@pytest.fixture
def tmpfilestore() -> FileStorage:
dir = create_tmp_datastore()
return FileStorage(dir)

def test_get_run(tmpfilestore : FileStorage):
run42 = tmpfilestore.get_run(42)

for key in ["info", "resources", "host", "experiment", "result", "artifacts", "comment", "start_time", "stop_time",
"heartbeat", "captured_out", "config"]:
assert key in run42

def test_get_runs(tmpfilestore : FileStorage):
runs = tmpfilestore.get_runs()
runs = list(runs)

assert 1 == len(runs)

run = runs[0]
for key in ["info", "resources", "host", "experiment", "result", "artifacts", "comment", "start_time", "stop_time",
"heartbeat", "captured_out", "config"]:
assert key in run

0 comments on commit 5ecdcf4

Please sign in to comment.