diff --git a/.gitignore b/.gitignore index 2e1a2004f..ef4ea964a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __*__ .idea datajoint.egg-info/ +*.pyc diff --git a/datajoint/__init__.py b/datajoint/__init__.py index c95425020..c3c35b4cf 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -1,14 +1,46 @@ +import logging +import os + +__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine" +__version__ = "0.2" +__all__ = ['__author__', '__version__', + 'Connection', 'Heading', 'Base', 'Not', + 'AutoPopulate', 'TaskQueue', 'conn', 'DataJointError', 'blob'] + +# ------------ define datajoint error before the import hierarchy is flattened ------------ +class DataJointError(Exception): + """ + Base class for errors specific to DataJoint internal + operation. + """ + pass + + + +# ----------- loads local configuration from file ---------------- +from .settings import Config, logger +config = Config() +local_config_file = os.environ.get(config['config.varname'], None) +if local_config_file is None: + local_config_file = os.path.expanduser(config['config.file']) +else: + local_config_file = os.path.expanduser(local_config_file) + config['config.file'] = local_config_file +try: + logger.log(logging.INFO, "Loading local settings from {0:s}".format(local_config_file)) + config.load(local_config_file) +except: + logger.warn("Local config file {0:s} does not exist! Creating it.".format(local_config_file)) + config.save(local_config_file) + + +# ------------- flatten import hierarchy ------------------------- from .connection import conn, Connection -from .core import DataJointError from .base import Base from .task import TaskQueue from .autopopulate import AutoPopulate from . import blob from .relational import Not -__author__ = "Dimitri Yatsenko and Edgar Walker at Baylor College of Medicine" -__version__ = "0.2" -__all__ = ['__author__', '__version__', - 'Connection', 'Heading', 'Base', 'Not', - 'AutoPopulate', 'TaskQueue', 'conn', 'DataJointError', 'blob'] + diff --git a/datajoint/base.py b/datajoint/base.py index 75dc7db2a..c30717b04 100644 --- a/datajoint/base.py +++ b/datajoint/base.py @@ -3,7 +3,8 @@ from types import ModuleType import numpy as np from enum import Enum -from .core import DataJointError, from_camel_case +from .utils import from_camel_case +from . import DataJointError from .relational import _Relational from .heading import Heading import logging diff --git a/datajoint/blob.py b/datajoint/blob.py index 69bd115b3..82ac9e338 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -1,7 +1,7 @@ import zlib import collections import numpy as np -from .core import DataJointError +from . import DataJointError mxClassID = collections.OrderedDict( diff --git a/datajoint/connection.py b/datajoint/connection.py index 6226077e8..55d300f63 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -1,6 +1,7 @@ import pymysql import re -from .core import DataJointError, to_camel_case +from .utils import to_camel_case +from . import DataJointError import os from .heading import Heading from .base import prefix_to_role diff --git a/datajoint/erd.py b/datajoint/erd.py index ab3d4f269..a7e1bc024 100644 --- a/datajoint/erd.py +++ b/datajoint/erd.py @@ -8,7 +8,8 @@ import matplotlib.pyplot as plt from matplotlib import transforms -from .core import DataJointError, to_camel_case +from .utils import to_camel_case +from . import DataJointError logger = logging.getLogger(__name__) diff --git a/datajoint/heading.py b/datajoint/heading.py index 0bdaeb11d..ca7495db0 100644 --- a/datajoint/heading.py +++ b/datajoint/heading.py @@ -8,7 +8,7 @@ import re from collections import OrderedDict, namedtuple import numpy as np -from .core import DataJointError +from datajoint import DataJointError class Heading: diff --git a/datajoint/relational.py b/datajoint/relational.py index 9a7e343ef..f02edbc30 100644 --- a/datajoint/relational.py +++ b/datajoint/relational.py @@ -7,7 +7,7 @@ import numpy as np import abc from copy import copy -from .core import DataJointError +from datajoint import DataJointError from .fetch import Fetch class _Relational(metaclass=abc.ABCMeta): diff --git a/datajoint/settings.py b/datajoint/settings.py index 0c26510f2..7977053a3 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -1,12 +1,88 @@ """ Settings for DataJoint. """ +from . import DataJointError +import json +import pprint + __author__ = 'eywalker' +import logging +import collections + + +validators = collections.defaultdict(lambda: lambda value: True) -# Settings dictionary. Don't manipulate this directly +default = { + 'database.host': 'localhost', + 'database.password': 'datajoint', + 'database.user': 'datajoint', + # + 'config.file': 'dj_local_conf.json', + 'config.varname': 'DJ_LOCAL_CONF' +} -class Config(object): +class Config(collections.MutableMapping): """ - Configuration object + Stores datajoint settings. Behaves like a dictionary, but applies validator functions + when certain keys are set. + + The default parameters are stored in datajoint.settings.default . If a local config file + exists, the settings specified in this file override the default settings. + """ + def __init__(self, *args, **kwargs): + self._conf = dict(default) + self.update(dict(*args, **kwargs)) # use the free update to set keys + + def __getitem__(self, key): + return self._conf[key] + + def __setitem__(self, key, value): + if validators[key](value): + self._conf[key] = value + else: + raise DataJointError(u'Validator for {0:s} did not pass'.format(key, )) + + def __delitem__(self, key): + del self._conf[key] + + def __iter__(self): + return iter(self._conf) + + def __len__(self): + return len(self._conf) + + def __str__(self): + return pprint.pformat(self._conf, indent=4) + + def __repr__(self): + return self.__str__() + + def save(self, filename=None): + """ + Saves the settings in JSON format to the given file path. + :param filename: filename of the local JSON settings file. If None, the local config file is used. + """ + if filename is None: + import datajoint as dj + filename = dj.config['config.file'] + with open(filename, 'w') as fid: + json.dump(self._conf, fid) + + def load(self, filename): + """ + Updates the setting from config file in JSON format. + + :param filename=None: filename of the local JSON settings file. If None, the local config file is used. + """ + if filename is None: + import datajoint as dj + filename = dj.config['config.file'] + with open(filename, 'r') as fid: + self.update(json.load(fid)) + + +############################################################################# +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable diff --git a/datajoint/core.py b/datajoint/utils.py similarity index 75% rename from datajoint/core.py rename to datajoint/utils.py index c92eda23b..9d1bf85de 100644 --- a/datajoint/core.py +++ b/datajoint/utils.py @@ -1,22 +1,11 @@ import re -import logging # package-wide settings that control execution # setup root logger -logger = logging.getLogger() -logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable +from . import DataJointError -class Settings: - pass - # verbose = True -class DataJointError(Exception): - """ - Base class for errors specific to DataJoint internal - operation. - """ - pass def to_camel_case(s): diff --git a/setup.py b/setup.py index 95adc3dfc..fa72134a9 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,6 @@ description='An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases.', url='https://github.com/datajoint/datajoint-python', packages=['datajoint'], - requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock'] + requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json'] # todo: add license (check license of used packages) ) diff --git a/tests/test_base.py b/tests/test_base.py index 4c9f2c98f..42e11bacf 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -10,7 +10,7 @@ from . import BASE_CONN, CONN_INFO, PREFIX, cleanup from datajoint.connection import Connection from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true -from datajoint.core import DataJointError +from datajoint import DataJointError def setup(): diff --git a/tests/test_connection.py b/tests/test_connection.py index 9eff46937..b54ad82a8 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,7 +7,7 @@ from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) from nose.tools import assert_true, assert_raises, assert_equal import datajoint as dj -from datajoint.core import DataJointError +from datajoint.utils import DataJointError def setup(): diff --git a/tests/test_core.py b/tests/test_core.py index 3b259793a..bfbfb0dd2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,7 +5,8 @@ __author__ = 'eywalker' from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup) from nose.tools import assert_true, assert_raises, assert_equal -from datajoint.core import to_camel_case, from_camel_case, DataJointError +from datajoint.utils import to_camel_case, from_camel_case +from datajoint import DataJointError def setup(): diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 000000000..988c19b3f --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,32 @@ +import os + +__author__ = 'Fabian Sinz' + +from nose.tools import assert_true, assert_raises, assert_equal +import datajoint as dj + +def nested_dict_compare(d1, d2): + for k, v in d1.items(): + if k not in d2: + return False + else: + if isinstance(v, dict): + tmp = nested_dict_compare(v, d2[k]) + if not tmp: return False + else: + if not v == d2[k]: return False + else: + return True + +def test_load_save(): + old = dj.config['config.file'] + dj.config['config.file'] = 'tmp.json' + dj.config.save() + conf = dj.Config() + conf.load('tmp.json') + assert_true(nested_dict_compare(conf, dj.config), 'Two config files do not match.') + dj.config['config.file'] = old + os.remove('tmp.json') + + +