diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 9e0142e2b..95b6a3dd1 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -16,14 +16,14 @@ class DataJointError(Exception): # ----------- loads local configuration from file ---------------- -from .settings import Config, logger +from .settings import Config, logger, CONFIGVAR, LOCALCONFIG config = Config() -local_config_file = os.environ.get(config['config.varname'], None) +local_config_file = os.environ.get(CONFIGVAR, None) if local_config_file is None: - local_config_file = os.path.expanduser(config['config.file']) + local_config_file = os.path.expanduser(LOCALCONFIG) else: local_config_file = os.path.expanduser(local_config_file) - config['config.file'] = local_config_file + try: logger.log(logging.INFO, "Loading local settings from {0:s}".format(local_config_file)) config.load(local_config_file) diff --git a/datajoint/connection.py b/datajoint/connection.py index ca6551cea..4655866f7 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -35,6 +35,9 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False) host = host if host is not None else config['database.host'] user = user if user is not None else config['database.user'] passwd = passwd if passwd is not None else config['database.password'] + + if passwd is None: passwd = input("Please enter database password: ") + init_fun = init_fun if init_fun is not None else config['connection.init_function'] _connObj = Connection(host, user, passwd, init_fun) return _connObj diff --git a/datajoint/settings.py b/datajoint/settings.py index 22afcadb6..d6ad20b72 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -4,12 +4,15 @@ from . import DataJointError import json import pprint +from collections import OrderedDict __author__ = 'eywalker' import logging import collections from enum import Enum +LOCALCONFIG = 'dj_local_conf.json' +CONFIGVAR = 'DJ_LOCAL_CONF' validators = collections.defaultdict(lambda: lambda value: True) @@ -23,8 +26,7 @@ } prefix_to_role = dict(zip(role_to_prefix.values(), role_to_prefix.keys())) - -default = { +default = OrderedDict({ 'database.host': 'localhost', 'database.password': 'datajoint', 'database.user': 'datajoint', @@ -32,9 +34,7 @@ # 'connection.init_function': None, # - 'config.file': 'dj_local_conf.json', - 'config.varname': 'DJ_LOCAL_CONF' -} +}) class Config(collections.MutableMapping): @@ -102,4 +102,4 @@ def load(self, filename): ############################################################################# logger = logging.getLogger() -logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable \ No newline at end of file +logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable diff --git a/setup.py b/setup.py index fa72134a9..5c56b483c 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,6 @@ description='An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases.', url='https://github.com/datajoint/datajoint-python', packages=['datajoint'], - requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json'] - # todo: add license (check license of used packages) + requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json'], + license = "MIT", ) diff --git a/tests/test_settings.py b/tests/test_settings.py index 9dadcb43a..aff05b9c9 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -7,13 +7,10 @@ def test_load_save(): - old = dj.config['config.file'] - dj.config['config.file'] = 'tmp.json' - dj.config.save() + dj.config.save('tmp.json') conf = dj.Config() conf.load('tmp.json') assert_true(conf == dj.config, 'Two config files do not match.') - dj.config['config.file'] = old os.remove('tmp.json') @raises(ValueError)