Skip to content

Commit

Permalink
Merge pull request #317 from dimitri-yatsenko/master
Browse files Browse the repository at this point in the history
Fix #288 (add connection id to jobs table) and modify `key_source` handling
  • Loading branch information
eywalker committed Jun 9, 2017
2 parents 9cb16a8 + 79d77f9 commit 17acda7
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ MANIFEST
.vagrant/
dj_local_conf.json
build/
.coverage
./tests/.coverage
2 changes: 1 addition & 1 deletion datajoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .version import __version__

__author__ = "Dimitri Yatsenko, Edgar Y. Walker, and Fabian Sinz at Baylor College of Medicine"
__date__ = "March 8, 2017"
__date__ = "June 1, 2017"
__all__ = ['__author__', '__version__',
'config', 'conn', 'kill', 'BaseRelation',
'Connection', 'Heading', 'FreeRelation', 'Not', 'schema',
Expand Down
7 changes: 3 additions & 4 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,20 @@ def populate(self, *restrictions, suppress_errors=False, reserve_jobs=False, ord
todo = self.key_source
if not isinstance(todo, RelationalOperand):
raise DataJointError('Invalid key_source value')
todo = todo & AndList(restrictions)
todo = todo.proj() & AndList(restrictions)

error_list = [] if suppress_errors else None

jobs = self.connection.jobs[self.target.database] if reserve_jobs else None


# define and setup signal handler for SIGTERM
if reserve_jobs:
def handler(signum, frame):
logger.info('Populate terminated by SIGTERM')
raise SystemExit('SIGTERM received')
old_handler = signal.signal(signal.SIGTERM, handler)

todo -= self.target.proj()
todo -= self.target
keys = list(todo.fetch.keys())
if order == "reverse":
keys.reverse()
Expand Down Expand Up @@ -142,7 +141,7 @@ def progress(self, *restrictions, display=True):
"""
todo = self.key_source & AndList(restrictions)
total = len(todo)
remaining = len(todo - self.target.proj())
remaining = len(todo.proj() - self.target)
if display:
print('%-20s' % self.__class__.__name__,
'Completed %d of %d (%2.1f%%) %s' % (
Expand Down
4 changes: 4 additions & 0 deletions datajoint/base_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def insert(self, rows, replace=False, ignore_errors=False, skip_duplicates=False
return

heading = self.heading
if heading.attributes is None:
logger.warning('Could not access table {table}'.format(table=self.full_table_name))
return

field_list = None # ensures that all rows have the same attributes in the same order as the first row.

def make_row_to_insert(row):
Expand Down
5 changes: 5 additions & 0 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, host, user, password, init_fun=None):
self.connect()
if self.is_connected:
logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
self.connection_id = self.query('SELECT connection_id()').fetchone()[0]
else:
raise DataJointError('Connection failed.')
self._conn.autocommit(True)
Expand Down Expand Up @@ -129,6 +130,10 @@ def query(self, query, args=(), as_dict=False):
cur.execute(query, args)
else:
raise
except err.ProgrammingError as e:
print('Error in query:')
print(query)
raise
return cur

def get_user(self):
Expand Down
3 changes: 1 addition & 2 deletions datajoint/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import DataJointError
from . import key as PRIMARY_KEY


def update_dict(d1, d2):
return {k: (d2[k] if k in d2 else d1[k]) for k in d1}

Expand All @@ -29,7 +30,6 @@ def copy(self):
"""
return self.__class__(self)


def _initialize_behavior(self):
self.sql_behavior = {}
self.ext_behavior = dict(squeeze=False)
Expand Down Expand Up @@ -90,7 +90,6 @@ def order_by(self, *args):
self.sql_behavior['order_by'] = args
return self


@property
def as_dict(self):
"""
Expand Down
10 changes: 10 additions & 0 deletions datajoint/hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import hashlib
import base64


def filehash(filename):
s = hashlib.sha256()
with open(filename, 'rb') as f:
for block in iter(lambda: f.read(65536), b''):
s.update(block)
return base64.b64encode(s.digest(), b'-_')[0:43].decode()
1 change: 1 addition & 0 deletions datajoint/heading.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def init_from_database(self, conn, database, table_name):
if info is None:
if table_name == '~log':
logger.warning('Could not create the ~log table')
return
else:
raise DataJointError('The table `{database}`.`{table_name}` is not defined.'.format(
table_name=table_name, database=database))
Expand Down
5 changes: 4 additions & 1 deletion datajoint/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, arg, database=None):
user="" :varchar(255) # database user
host="" :varchar(255) # system hostname
pid=0 :int unsigned # system process id
connection_id = 0 : bigint unsigned # connection_id()
timestamp=CURRENT_TIMESTAMP :timestamp # automatic timestamp
""".format(database=database, error_message_length=ERROR_MESSAGE_LENGTH)
if not self.is_declared:
Expand Down Expand Up @@ -80,10 +81,11 @@ def reserve(self, table_name, key):
status='reserved',
host=os.uname().nodename,
pid=os.getpid(),
connection_id=self.connection.connection_id,
key=key,
user=self._user)
try:
self.insert1(job)
self.insert1(job, ignore_extra_fields=True)
except pymysql.err.IntegrityError:
return False
return True
Expand Down Expand Up @@ -113,6 +115,7 @@ def error(self, table_name, key, error_message):
status="error",
host=os.uname().nodename,
pid=os.getpid(),
connection_id=self.connection.connection_id,
user=self._user,
key=key,
error_message=error_message), replace=True, ignore_extra_fields=True)
Expand Down
2 changes: 1 addition & 1 deletion datajoint/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.6.0"
__version__ = "0.6.1"
7 changes: 5 additions & 2 deletions tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def test_reserve_job():
'failed to reserve new jobs')
# finish with error
for key in subjects.fetch.keys():
schema.schema.jobs.error(table_name, key, "error message")
schema.schema.jobs.error(table_name, key,
"error message")
# refuse jobs with errors
for key in subjects.fetch.keys():
assert_false(schema.schema.jobs.reserve(table_name, key),
Expand All @@ -43,6 +44,7 @@ def test_reserve_job():
assert_false(schema.schema.jobs,
'failed to clear error jobs')


def test_restrictions():
# clear out jobs table
jobs = schema.schema.jobs
Expand Down Expand Up @@ -73,6 +75,7 @@ def test_sigint():
assert_equals(error_message, 'KeyboardInterrupt')
schema.schema.jobs.delete()


def test_sigterm():
# clear out job table
schema.schema.jobs.delete()
Expand Down Expand Up @@ -113,4 +116,4 @@ def test_long_error_message():
error_message = schema.schema.jobs.fetch1['error_message']
assert_true(error_message == short_error_message, 'error messages do not agree')
assert_false(error_message.endswith(TRUNCATION_APPENDIX), 'error message should not be truncated')
schema.schema.jobs.delete()
schema.schema.jobs.delete()

0 comments on commit 17acda7

Please sign in to comment.