Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
__*__
.idea
datajoint.egg-info/
*.pyc
44 changes: 38 additions & 6 deletions datajoint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,46 @@
import logging
import os

__author__ = "Dimitri Yatsenko, Edgar Walker, and Fabian Sinz at Baylor College of Medicine"
__version__ = "0.2"
__all__ = ['__author__', '__version__',
'Connection', 'Heading', 'Base', 'Not',
'AutoPopulate', 'TaskQueue', 'conn', 'DataJointError', 'blob']

# ------------ define datajoint error before the import hierarchy is flattened ------------
class DataJointError(Exception):
"""
Base class for errors specific to DataJoint internal
operation.
"""
pass



# ----------- loads local configuration from file ----------------
from .settings import Config, logger
config = Config()
local_config_file = os.environ.get(config['config.varname'], None)
if local_config_file is None:
local_config_file = os.path.expanduser(config['config.file'])
else:
local_config_file = os.path.expanduser(local_config_file)
config['config.file'] = local_config_file
try:
logger.log(logging.INFO, "Loading local settings from {0:s}".format(local_config_file))
config.load(local_config_file)
except:
logger.warn("Local config file {0:s} does not exist! Creating it.".format(local_config_file))
config.save(local_config_file)


# ------------- flatten import hierarchy -------------------------
from .connection import conn, Connection
from .core import DataJointError
from .base import Base
from .task import TaskQueue
from .autopopulate import AutoPopulate
from . import blob
from .relational import Not

__author__ = "Dimitri Yatsenko and Edgar Walker at Baylor College of Medicine"
__version__ = "0.2"

__all__ = ['__author__', '__version__',
'Connection', 'Heading', 'Base', 'Not',
'AutoPopulate', 'TaskQueue', 'conn', 'DataJointError', 'blob']

3 changes: 2 additions & 1 deletion datajoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from types import ModuleType
import numpy as np
from enum import Enum
from .core import DataJointError, from_camel_case
from .utils import from_camel_case
from . import DataJointError
from .relational import _Relational
from .heading import Heading
import logging
Expand Down
2 changes: 1 addition & 1 deletion datajoint/blob.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import zlib
import collections
import numpy as np
from .core import DataJointError
from . import DataJointError


mxClassID = collections.OrderedDict(
Expand Down
3 changes: 2 additions & 1 deletion datajoint/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pymysql
import re
from .core import DataJointError, to_camel_case
from .utils import to_camel_case
from . import DataJointError
import os
from .heading import Heading
from .base import prefix_to_role
Expand Down
3 changes: 2 additions & 1 deletion datajoint/erd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import matplotlib.pyplot as plt
from matplotlib import transforms

from .core import DataJointError, to_camel_case
from .utils import to_camel_case
from . import DataJointError


logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion datajoint/heading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
from collections import OrderedDict, namedtuple
import numpy as np
from .core import DataJointError
from datajoint import DataJointError


class Heading:
Expand Down
2 changes: 1 addition & 1 deletion datajoint/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import abc
from copy import copy
from .core import DataJointError
from datajoint import DataJointError
from .fetch import Fetch

class _Relational(metaclass=abc.ABCMeta):
Expand Down
82 changes: 79 additions & 3 deletions datajoint/settings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,88 @@
"""
Settings for DataJoint.
"""
from . import DataJointError
import json
import pprint

__author__ = 'eywalker'
import logging
import collections


validators = collections.defaultdict(lambda: lambda value: True)

# Settings dictionary. Don't manipulate this directly
default = {
'database.host': 'localhost',
'database.password': 'datajoint',
'database.user': 'datajoint',
#
'config.file': 'dj_local_conf.json',
'config.varname': 'DJ_LOCAL_CONF'
}

class Config(object):
class Config(collections.MutableMapping):
"""
Configuration object
Stores datajoint settings. Behaves like a dictionary, but applies validator functions
when certain keys are set.

The default parameters are stored in datajoint.settings.default . If a local config file
exists, the settings specified in this file override the default settings.

"""

def __init__(self, *args, **kwargs):
self._conf = dict(default)
self.update(dict(*args, **kwargs)) # use the free update to set keys

def __getitem__(self, key):
return self._conf[key]

def __setitem__(self, key, value):
if validators[key](value):
self._conf[key] = value
else:
raise DataJointError(u'Validator for {0:s} did not pass'.format(key, ))

def __delitem__(self, key):
del self._conf[key]

def __iter__(self):
return iter(self._conf)

def __len__(self):
return len(self._conf)

def __str__(self):
return pprint.pformat(self._conf, indent=4)

def __repr__(self):
return self.__str__()

def save(self, filename=None):
"""
Saves the settings in JSON format to the given file path.
:param filename: filename of the local JSON settings file. If None, the local config file is used.
"""
if filename is None:
import datajoint as dj
filename = dj.config['config.file']
with open(filename, 'w') as fid:
json.dump(self._conf, fid)

def load(self, filename):
"""
Updates the setting from config file in JSON format.

:param filename=None: filename of the local JSON settings file. If None, the local config file is used.
"""
if filename is None:
import datajoint as dj
filename = dj.config['config.file']
with open(filename, 'r') as fid:
self.update(json.load(fid))


#############################################################################
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable
13 changes: 1 addition & 12 deletions datajoint/core.py → datajoint/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
import re
import logging
# package-wide settings that control execution

# setup root logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) #set package wide logger level TODO:make this respond to environmental variable
from . import DataJointError


class Settings:
pass
# verbose = True

class DataJointError(Exception):
"""
Base class for errors specific to DataJoint internal
operation.
"""
pass


def to_camel_case(s):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
description='An object-relational mapping and relational algebra to facilitate data definition and data manipulation in MySQL databases.',
url='https://github.com/datajoint/datajoint-python',
packages=['datajoint'],
requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock']
requires=['numpy', 'pymysql', 'networkx', 'matplotlib', 'sphinx_rtd_theme', 'mock', 'json']
# todo: add license (check license of used packages)
)
2 changes: 1 addition & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from . import BASE_CONN, CONN_INFO, PREFIX, cleanup
from datajoint.connection import Connection
from nose.tools import assert_raises, assert_equal, assert_regexp_matches, assert_false, assert_true
from datajoint.core import DataJointError
from datajoint import DataJointError


def setup():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup)
from nose.tools import assert_true, assert_raises, assert_equal
import datajoint as dj
from datajoint.core import DataJointError
from datajoint.utils import DataJointError


def setup():
Expand Down
3 changes: 2 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
__author__ = 'eywalker'
from . import (CONN_INFO, PREFIX, BASE_CONN, cleanup)
from nose.tools import assert_true, assert_raises, assert_equal
from datajoint.core import to_camel_case, from_camel_case, DataJointError
from datajoint.utils import to_camel_case, from_camel_case
from datajoint import DataJointError


def setup():
Expand Down
32 changes: 32 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

__author__ = 'Fabian Sinz'

from nose.tools import assert_true, assert_raises, assert_equal
import datajoint as dj

def nested_dict_compare(d1, d2):
for k, v in d1.items():
if k not in d2:
return False
else:
if isinstance(v, dict):
tmp = nested_dict_compare(v, d2[k])
if not tmp: return False
else:
if not v == d2[k]: return False
else:
return True

def test_load_save():
old = dj.config['config.file']
dj.config['config.file'] = 'tmp.json'
dj.config.save()
conf = dj.Config()
conf.load('tmp.json')
assert_true(nested_dict_compare(conf, dj.config), 'Two config files do not match.')
dj.config['config.file'] = old
os.remove('tmp.json')