-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'feature/plot-metrics' into develop
- Loading branch information
Showing
40 changed files
with
1,906 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,7 @@ | ||
"""Sacred(board) Data Access Layer.""" | ||
from .datastorage import Cursor, DataStorage | ||
from sacredboard.app.data.errors import NotFoundError, DataSourceError | ||
from .metricsdao import MetricsDAO | ||
|
||
__all__ = ["Cursor", "DataStorage", "MetricsDAO", "NotFoundError", | ||
"DataSourceError"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Errors that might occur during data access.""" | ||
|
||
|
||
class NotFoundError(Exception): | ||
"""Record not found exception.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
"""Record not found exception.""" | ||
Exception.__init__(self, *args, **kwargs) | ||
|
||
|
||
class DataSourceError(Exception): | ||
"""Error when accessing the data source.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
"""Error when accessing the data source.""" | ||
Exception.__init__(self, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
""" | ||
Interface for accessing Sacred metrics. | ||
Issue: https://github.com/chovanecm/sacredboard/issues/60 | ||
""" | ||
|
||
|
||
class MetricsDAO: | ||
""" | ||
Interface for accessing Sacred metrics. | ||
Issue: https://github.com/chovanecm/sacredboard/issues/58 | ||
""" | ||
|
||
def get_metric(self, run_id, metric_id): | ||
""" | ||
Read a metric of the given id and run. | ||
The returned object has the following format (timestamps are datetime | ||
objects). | ||
.. code:: | ||
{"steps": [0,1,20,40,...], | ||
"timestamps": [timestamp1,timestamp2,timestamp3,...], | ||
"values": [0,1 2,3,4,5,6,...], | ||
"name": "name of the metric", | ||
"metric_id": "metric_id", | ||
"run_id": "run_id"} | ||
:param run_id: ID of the Run that the metric belongs to. | ||
:param metric_id: The ID fo the metric. | ||
:return: The whole metric as specified. | ||
:raise NotFoundError | ||
""" | ||
raise NotImplementedError("The MetricsDAO method is abstract.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Module responsible for accessing the MongoDB database.""" | ||
from .metricsdao import MongoMetricsDAO | ||
from .genericdao import GenericDAO | ||
|
||
__all__ = ["MongoMetricsDAO", "GenericDAO"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
Generic DAO object for safe access to the MongoDB. | ||
Issue: https://github.com/chovanecm/sacredboard/issues/61 | ||
""" | ||
import pymongo | ||
from pymongo.errors import InvalidName | ||
|
||
from sacredboard.app.data import DataSourceError | ||
from .mongocursor import MongoDbCursor | ||
|
||
|
||
class GenericDAO: | ||
""" | ||
Generic DAO object for safe access to the MongoDB. | ||
Issue: https://github.com/chovanecm/sacredboard/issues/61 | ||
""" | ||
|
||
def __init__(self, pymongo_client, database_name): | ||
""" | ||
Create a new GenericDAO object that will work on the given database. | ||
:param pymongo_client: PyMongo client that is connected to MongoDB. | ||
:param database_name: Name of the database this GenericDAO works with. | ||
:raise DataSourceError | ||
""" | ||
self._client = pymongo_client | ||
self._database = self._get_database(database_name) | ||
|
||
def find_record(self, collection_name, query): | ||
""" | ||
Return the first record mathing the given Mongo query. | ||
:param collection_name: Name of the collection to search in. | ||
:param query: MongoDB Query, e.g. {_id: 123} | ||
:return: A single MongoDB record or None if not found. | ||
:raise DataSourceError | ||
""" | ||
cursor = self._get_collection(collection_name).find(query) | ||
for record in cursor: | ||
# Return the first record found. | ||
return record | ||
# Return None if nothing found. | ||
return None | ||
|
||
def find_records(self, collection_name, query={}, sort_by=None, | ||
sort_direction=None, start=0, limit=None): | ||
""" | ||
Return a cursor of records from the given MongoDB collection. | ||
:param collection_name: Name of the MongoDB collection to query. | ||
:param query: Standard MongoDB query. By default no restriction. | ||
:param sort_by: Name of a single field to sort by. | ||
:param sort_direction: The direction to sort, "asc" or "desc". | ||
:param start: Skip first n results. | ||
:param limit: The maximum number of results to return. | ||
:return: Cursor -- An iterable with results. | ||
:raise DataSourceError | ||
""" | ||
cursor = self._get_collection(collection_name).find(query) | ||
if sort_by is not None: | ||
cursor = self._apply_sort(cursor, sort_by, sort_direction) | ||
cursor = cursor.skip(start) | ||
if limit is not None: | ||
cursor = cursor.limit(limit) | ||
return MongoDbCursor(cursor) | ||
|
||
def _get_database(self, database_name): | ||
""" | ||
Get PyMongo client pointing to the current database. | ||
:return: MongoDB client of the current database. | ||
:raise DataSourceError | ||
""" | ||
try: | ||
return self._client[database_name] | ||
except InvalidName as ex: | ||
raise DataSourceError("Cannot connect to database %s!" | ||
% self._database) from ex | ||
|
||
def _get_collection(self, collection_name): | ||
""" | ||
Get PyMongo client pointing to the current DB and the given collection. | ||
:return: MongoDB client of the current database and given collection. | ||
:raise DataSourceError | ||
""" | ||
try: | ||
return self._database[collection_name] | ||
except InvalidName as ex: | ||
raise DataSourceError("Cannot access MongoDB collection %s!" | ||
% collection_name) from ex | ||
except Exception as ex: | ||
raise DataSourceError("Unexpected error when accessing MongoDB" | ||
"collection %s!" | ||
% collection_name) from ex | ||
|
||
def _apply_sort(self, cursor, sort_by, sort_direction): | ||
""" | ||
Apply sort to a cursor. | ||
:param cursor: The cursor to apply sort on. | ||
:param sort_by: The field name to sort by. | ||
:param sort_direction: The direction to sort, "asc" or "desc". | ||
:return: | ||
""" | ||
if sort_direction is not None and sort_direction.lower() == "desc": | ||
sort = pymongo.DESCENDING | ||
else: | ||
sort = pymongo.ASCENDING | ||
return cursor.sort(sort_by, sort) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
Module responsible for accessing the Metrics data in MongoDB. | ||
Issue: https://github.com/chovanecm/sacredboard/issues/60 | ||
""" | ||
|
||
from bson import ObjectId | ||
from bson.errors import InvalidId | ||
|
||
from sacredboard.app.data import NotFoundError | ||
from .genericdao import GenericDAO | ||
from ..metricsdao import MetricsDAO | ||
|
||
|
||
class MongoMetricsDAO(MetricsDAO): | ||
"""Implementation of MetricsDAO for MongoDB.""" | ||
|
||
def __init__(self, generic_dao: GenericDAO): | ||
""" | ||
Create new metrics accessor for MongoDB. | ||
:param generic_dao: A configured generic MongoDB data access object | ||
pointing to an appropriate database. | ||
""" | ||
self.generic_dao = generic_dao | ||
self.metrics_collection_name = "metrics" | ||
"""Name of the MongoDB collection with metrics.""" | ||
|
||
def get_metric(self, run_id, metric_id): | ||
""" | ||
Read a metric of the given id and run. | ||
The returned object has the following format (timestamps are datetime | ||
objects). | ||
.. code:: | ||
{"steps": [0,1,20,40,...], | ||
"timestamps": [timestamp1,timestamp2,timestamp3,...], | ||
"values": [0,1 2,3,4,5,6,...], | ||
"name": "name of the metric", | ||
"metric_id": "metric_id", | ||
"run_id": "run_id"} | ||
:param run_id: ID of the Run that the metric belongs to. | ||
:param metric_id: The ID fo the metric. | ||
:return: The whole metric as specified. | ||
:raise NotFoundError | ||
""" | ||
query = self._build_query(run_id, metric_id) | ||
row = self._read_metric_from_db(metric_id, run_id, query) | ||
metric = self._to_intermediary_object(row) | ||
return metric | ||
|
||
def _read_metric_from_db(self, metric_id, run_id, query): | ||
row = self.generic_dao.find_record(self.metrics_collection_name, query) | ||
if row is None: | ||
raise NotFoundError("Metric %s for run %s not found." | ||
% (metric_id, run_id)) | ||
return row | ||
|
||
def _build_query(self, run_id, metric_id): | ||
# Metrics in MongoDB is always an ObjectId | ||
try: | ||
id = ObjectId(metric_id) | ||
return {"run_id": run_id, "_id": id} | ||
except InvalidId as ex: | ||
raise NotFoundError("Metric Id %s is invalid " | ||
"ObjectId in MongoDB" % metric_id) from ex | ||
|
||
def _to_intermediary_object(self, row): | ||
return { | ||
"metric_id": str(row["_id"]), | ||
"run_id": row["run_id"], | ||
"name": row["name"], | ||
"steps": row["steps"], | ||
"timestamps": row["timestamps"], | ||
"values": row["values"], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
Implementation of cursor for iterating over results. | ||
Backed by pymongo cursor. | ||
""" | ||
from sacredboard.app.data.datastorage import Cursor | ||
|
||
|
||
class MongoDbCursor(Cursor): | ||
"""Implements Cursor for mongodb.""" | ||
|
||
def __init__(self, mongodb_cursor): | ||
"""Initialize a MongoDB cursor.""" | ||
self.mongodb_cursor = mongodb_cursor | ||
|
||
def count(self): | ||
"""Return the number of items in this cursor.""" | ||
return self.mongodb_cursor.count() | ||
|
||
def __iter__(self): | ||
"""Iterate over runs.""" | ||
return self.mongodb_cursor |
Oops, something went wrong.