Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions datajoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ class DataJointError(Exception):


# ----------- loads local configuration from file ----------------
from .settings import Config, logger
from .settings import Config, logger, CONFIGVAR, LOCALCONFIG
config = Config()
local_config_file = os.environ.get(config['config.varname'], None)
local_config_file = os.environ.get(CONFIGVAR, None)
if local_config_file is None:
local_config_file = os.path.expanduser(config['config.file'])
local_config_file = os.path.expanduser(LOCALCONFIG)
else:
local_config_file = os.path.expanduser(local_config_file)
config['config.file'] = local_config_file

try:
logger.log(logging.INFO, "Loading local settings from {0:s}".format(local_config_file))
config.load(local_config_file)
Expand Down
3 changes: 3 additions & 0 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def conn_function(host=None, user=None, passwd=None, init_fun=None, reset=False)
host = host if host is not None else config['database.host']
user = user if user is not None else config['database.user']
passwd = passwd if passwd is not None else config['database.password']

if passwd is None: passwd = input("Please enter database password: ")

init_fun = init_fun if init_fun is not None else config['connection.init_function']
_connObj = Connection(host, user, passwd, init_fun)
return _connObj
Expand Down
12 changes: 6 additions & 6 deletions datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from . import DataJointError
import json
import pprint
from collections import OrderedDict

__author__ = 'eywalker'
import logging
import collections
from enum import Enum

LOCALCONFIG = 'dj_local_conf.json'
CONFIGVAR = 'DJ_LOCAL_CONF'

validators = collections.defaultdict(lambda: lambda value: True)

Expand All @@ -23,18 +26,15 @@
}
prefix_to_role = dict(zip(role_to_prefix.values(), role_to_prefix.keys()))


default = {
default = OrderedDict({
'database.host': 'localhost',
'database.password': 'datajoint',
'database.user': 'datajoint',
'database.port': 3306,
#
'connection.init_function': None,
#
'config.file': 'dj_local_conf.json',
'config.varname': 'DJ_LOCAL_CONF'
}
})


class Config(collections.MutableMapping):
Expand Down Expand Up @@ -102,4 +102,4 @@ def load(self, filename):

#############################################################################
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable
logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
description='An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases.',
url='https://github.com/datajoint/datajoint-python',
packages=['datajoint'],
requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json']
# todo: add license (check license of used packages)
requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json'],
license = "MIT",
)
5 changes: 1 addition & 4 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@


def test_load_save():
old = dj.config['config.file']
dj.config['config.file'] = 'tmp.json'
dj.config.save()
dj.config.save('tmp.json')
conf = dj.Config()
conf.load('tmp.json')
assert_true(conf == dj.config, 'Two config files do not match.')
dj.config['config.file'] = old
os.remove('tmp.json')

@raises(ValueError)
Expand Down