Skip to content

Commit

Permalink
Merge pull request #377 from dimitri-yatsenko/master
Browse files Browse the repository at this point in the history
Add external storage feature and multiple bug fixes
  • Loading branch information
eywalker committed Nov 15, 2017
2 parents 8ea75a6 + 7627569 commit 76e3d56
Show file tree
Hide file tree
Showing 27 changed files with 721 additions and 249 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ dj_local_conf.json
build/
.coverage
./tests/.coverage
./tests/dj-store/*
*.log
63 changes: 44 additions & 19 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""autopopulate containing the dj.AutoPopulate class. See `dj.AutoPopulate` for more info."""
import logging
import datetime
import traceback
import random
from tqdm import tqdm
from itertools import count
from pymysql import OperationalError
from .relational_operand import RelationalOperand, AndList
from .relational_operand import RelationalOperand, AndList, U
from . import DataJointError
from . import key as KEY
from .base_relation import FreeRelation
Expand All @@ -18,15 +21,15 @@ class AutoPopulate:
"""
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
Auto-populated relations must inherit from both Relation and AutoPopulate,
must define the property `key_source`, and must define the callback method _make_tuples.
must define the property `key_source`, and must define the callback method `make`.
"""
_key_source = None

@property
def key_source(self):
"""
:return: the relation whose primary key values are passed, sequentially, to the
`_make_tuples` method when populate() is called.The default value is the
``make`` method when populate() is called.The default value is the
join of the parent relations. Users may override to change the granularity
or the scope of populate() calls.
"""
Expand All @@ -40,13 +43,15 @@ def key_source(self):
self._key_source *= FreeRelation(self.connection, parents.pop(0)).proj()
return self._key_source

def _make_tuples(self, key):

def make(self, key):
"""
Derived classes must implement method _make_tuples that fetches data from tables that are
Derived classes must implement method `make` that fetches data from tables that are
above them in the dependency hierarchy, restricting by the given key, computes dependent
attributes, and inserts the new tuples into self.
"""
raise NotImplementedError('Subclasses of AutoPopulate must implement the method "_make_tuples"')
raise NotImplementedError('Subclasses of AutoPopulate must implement the method `make`')


@property
def target(self):
Expand All @@ -63,16 +68,19 @@ def _job_key(self, key):
"""
return key

def populate(self, *restrictions, suppress_errors=False, reserve_jobs=False, order="original", limit=None):
def populate(self, *restrictions, suppress_errors=False, reserve_jobs=False,
order="original", limit=None, max_calls=None, display_progress=False):
"""
rel.populate() calls rel._make_tuples(key) for every primary key in self.key_source
rel.populate() calls rel.make(key) for every primary key in self.key_source
for which there is not already a tuple in rel.
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
:param suppress_errors: suppresses error if true
:param reserve_jobs: if true, reserves job to populate in asynchronous fashion
:param order: "original"|"reverse"|"random" - the order of execution
:param limit: if not None, populates at max that many keys
:param display_progress: if True, report progress_bar
:param limit: if not None, checks at most that many keys
:param max_calls: if not None, populates at max that many keys
"""
if self.connection.in_transaction:
raise DataJointError('Populate cannot be called during a transaction.')
Expand All @@ -84,11 +92,20 @@ def populate(self, *restrictions, suppress_errors=False, reserve_jobs=False, ord
todo = self.key_source
if not isinstance(todo, RelationalOperand):
raise DataJointError('Invalid key_source value')
todo = todo.proj() & AndList(restrictions)
todo = (todo & AndList(restrictions)).proj()

error_list = [] if suppress_errors else None
# raise error if the populated target lacks any attributes from the primary key of key_source
try:
raise DataJointError(
'The populate target lacks attribute %s from the primary key of key_source' % next(
name for name in todo.heading if name not in self.target.heading))
except StopIteration:
pass

todo -= self.target

jobs = self.connection.jobs[self.target.database] if reserve_jobs else None
error_list = [] if suppress_errors else None
jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None

# define and setup signal handler for SIGTERM
if reserve_jobs:
Expand All @@ -97,15 +114,20 @@ def handler(signum, frame):
raise SystemExit('SIGTERM received')
old_handler = signal.signal(signal.SIGTERM, handler)

todo -= self.target
keys = todo.fetch(KEY, limit=limit)
if order == "reverse":
keys.reverse()
elif order == "random":
random.shuffle(keys)

call_count = count()
logger.info('Found %d keys to populate' % len(keys))
for key in keys:

make = self._make_tuples if hasattr(self, '_make_tuples') else self.make

for key in (tqdm(keys) if display_progress else keys):
if max_calls is not None and call_count >= max_calls:
break
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
self.connection.start_transaction()
if key in self.target: # already populated
Expand All @@ -114,8 +136,9 @@ def handler(signum, frame):
jobs.complete(self.target.table_name, self._job_key(key))
else:
logger.info('Populating: ' + str(key))
next(call_count)
try:
self._make_tuples(dict(key))
make(dict(key))
except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
Expand All @@ -124,7 +147,8 @@ def handler(signum, frame):
if reserve_jobs:
# show error name and error message (if any)
error_message = ': '.join([error.__class__.__name__, str(error)]).strip(': ')
jobs.error(self.target.table_name, self._job_key(key), error_message=error_message)
jobs.error(self.target.table_name, self._job_key(key),
error_message=error_message, error_stack=traceback.format_exc())

if not suppress_errors or isinstance(error, SystemExit):
raise
Expand All @@ -139,17 +163,18 @@ def handler(signum, frame):
# place back the original signal handler
if reserve_jobs:
signal.signal(signal.SIGTERM, old_handler)

return error_list

def progress(self, *restrictions, display=True):
"""
report progress of populating this table
:return: remaining, total -- tuples to be populated
"""
todo = self.key_source & AndList(restrictions)
todo = (self.key_source & AndList(restrictions)).proj()
if any(name not in self.target.heading for name in todo.heading):
raise DataJointError('The populated target must have all the attributes of the key source')
total = len(todo)
remaining = len(todo.proj() - self.target)
remaining = len(todo - self.target)
if display:
print('%-20s' % self.__class__.__name__,
'Completed %d of %d (%2.1f%%) %s' % (
Expand Down
133 changes: 60 additions & 73 deletions datajoint/base_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class BaseRelation(RelationalOperand):
_context = None
database = None
_log_ = None
_external_table = None

# -------------- required by RelationalOperand ----------------- #
@property
Expand All @@ -50,14 +51,21 @@ def context(self):

def declare(self):
"""
Loads the table heading. If the table is not declared, use self.definition to declare
Use self.definition to declare the table in the database
"""
try:
self.connection.query(
declare(self.full_table_name, self.definition, self._context))
sql, uses_external = declare(self.full_table_name, self.definition, self._context)
if uses_external:
# trigger the creation of the external hash lookup for the current schema
external_table = self.connection.schemas[self.database].external_table
sql = sql.format(external_table=external_table.full_table_name)
self.connection.query(sql)
except pymysql.OperationalError as error:
# skip if no create privilege
if error.args[0] == server_error_codes['command denied']:
logger.warning(error.args[1])
else:
raise
else:
self._log('Declared ' + self.full_table_name)

Expand Down Expand Up @@ -114,6 +122,12 @@ def _log(self):
self._log_ = Log(self.connection, database=self.database)
return self._log_

@property
def external_table(self):
if self._external_table is None:
self._external_table = self.connection.schemas[self.database].external_table
return self._external_table

def insert1(self, row, **kwargs):
"""
Insert one data record or one Mapping (like a dict).
Expand Down Expand Up @@ -142,60 +156,29 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
warnings.warn('Use of `ignore_errors` in `insert` and `insert1` is deprecated. Use try...except... '
'to explicitly handle any errors', stacklevel=2)

# handle query safely - if skip_duplicates=True, wraps the query with transaction and checks for warning
def safe_query(*args, **kwargs):
if skip_duplicates:
# check if there is already an open transaction
open_transaction = self.connection.in_transaction
if not open_transaction:
self.connection.start_transaction()
try:
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter('always')
self.connection.query(*args, suppress_warnings=False, **kwargs)
for w in ws:
if w.message.args[0] != server_error_codes['duplicate entry']:
raise InternalError(w.message.args)
except:
if not open_transaction:
try:
self.connection.cancel_transaction()
except OperationalError:
pass
raise
else:
if not open_transaction:
self.connection.commit_transaction()
else:
self.connection.query(*args, **kwargs)

heading = self.heading
if isinstance(rows, RelationalOperand):
# INSERT FROM SELECT - build alternate field-narrowing query (only) when needed
if ignore_extra_fields and not all(name in self.heading.names for name in rows.heading.names):
query = 'INSERT{ignore} INTO {table} ({fields}) SELECT {fields} FROM ({select}) as `__alias`'.format(
ignore=" IGNORE" if skip_duplicates else "",
table=self.full_table_name,
fields='`'+'`,`'.join(self.heading.names)+'`',
select=rows.make_sql())
else:
query = 'INSERT{ignore} INTO {table} ({fields}) {select}'.format(
ignore=" IGNORE" if skip_duplicates else "",
# insert from select
if not ignore_extra_fields:
try:
raise DataJointError(
"Attribute %s not found. To ignore extra attributes in insert, set ignore_extra_fields=True." %
next(name for name in rows.heading if name not in heading))
except StopIteration:
pass
fields = list(name for name in heading if name in rows.heading)

query = '{command} INTO {table} ({fields}) {select}{duplicate}'.format(
command='REPLACE' if replace else 'INSERT',
fields='`' + '`,`'.join(fields) + '`',
table=self.full_table_name,
fields='`'+'`,`'.join(rows.heading.names)+'`',
select=rows.make_sql())
try:
safe_query(query)
except (InternalError, IntegrityError) as err:
if err.args[0] == server_error_codes['unknown column']:
# args[1] -> Unknown column 'extra' in 'field list'
raise DataJointError('{} : To ignore extra fields, set ignore_extra_fields=True in insert.'.format(err.args[1]))
elif err.args[0] == server_error_codes['duplicate entry']:
raise DataJointError('{} : To ignore duplicate entries, set skip_duplicates=True in insert.'.format(err.args[1]))
else:
raise
select=rows.make_sql(select_fields=fields),
duplicate=(' ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`'.format(pk=self.primary_key[0])
if skip_duplicates else '')
)
self.connection.query(query)
return

heading = self.heading
if heading.attributes is None:
logger.warning('Could not access table {table}'.format(table=self.full_table_name))
return
Expand All @@ -218,16 +201,18 @@ def make_placeholder(name, value):
"""
if ignore_extra_fields and name not in heading:
return None
if heading[name].is_blob:
value = pack(value)
placeholder = '%s'
if heading[name].is_external:
placeholder, value = '%s', self.external_table.put(heading[name].type, value)
elif heading[name].is_blob:
if value is None:
placeholder, value = 'NULL', None
else:
placeholder, value = '%s', pack(value)
elif heading[name].numeric:
if value is None or value == '' or np.isnan(np.float(value)): # nans are turned into NULLs
placeholder = 'NULL'
value = None
placeholder, value = 'NULL', None
else:
placeholder = '%s'
value = str(int(value) if isinstance(value, bool) else value)
placeholder, value = '%s', (str(int(value) if isinstance(value, bool) else value))
else:
placeholder = '%s'
return name, placeholder, value
Expand Down Expand Up @@ -284,13 +269,15 @@ def check_fields(fields):
rows = list(make_row_to_insert(row) for row in rows)
if rows:
try:
safe_query(
"{command} INTO {destination}(`{fields}`) VALUES {placeholders}".format(
command='REPLACE' if replace else 'INSERT IGNORE' if skip_duplicates else 'INSERT',
destination=self.from_clause,
fields='`,`'.join(field_list),
placeholders=','.join('(' + ','.join(row['placeholders']) + ')' for row in rows)),
args=list(itertools.chain.from_iterable((v for v in r['values'] if v is not None) for r in rows)))
query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format(
command='REPLACE' if replace else 'INSERT',
destination=self.from_clause,
fields='`,`'.join(field_list),
placeholders=','.join('(' + ','.join(row['placeholders']) + ')' for row in rows),
duplicate=(' ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`'.format(pk=self.primary_key[0])
if skip_duplicates else ''))
self.connection.query(query, args=list(
itertools.chain.from_iterable((v for v in r['values'] if v is not None) for r in rows)))
except (OperationalError, InternalError, IntegrityError) as err:
if err.args[0] == server_error_codes['command denied']:
raise DataJointError('Command denied: %s' % err.args[1]) from None
Expand All @@ -304,7 +291,6 @@ def check_fields(fields):
else:
raise


def delete_quick(self):
"""
Deletes the table without cascading and without user prompt. If this table has any dependent
Expand Down Expand Up @@ -338,7 +324,7 @@ def delete(self):
# restrict by self
if self.restrictions:
restrict_by_me.add(self.full_table_name)
restrictions[self.full_table_name].append(self.restrictions.simplify()) # copy own restrictions
restrictions[self.full_table_name].append(self.restrictions) # copy own restrictions
# restrict by renamed nodes
restrict_by_me.update(table for table in delete_list if table.isdigit()) # restrict by all renamed nodes
# restrict by tables restricted by a non-primary semijoin
Expand Down Expand Up @@ -454,7 +440,7 @@ def describe(self):
if attr.name in fk_props['referencing_attributes']:
do_include = False
if attributes_thus_far.issuperset(fk_props['referencing_attributes']):
# simple foreign keys
# simple foreign key
parents.pop(parent_name)
if not parent_name.isdigit():
definition += '-> {class_name}\n'.format(
Expand All @@ -473,9 +459,10 @@ def describe(self):
attributes_declared.update(fk_props['referencing_attributes'])
if do_include:
attributes_declared.add(attr.name)
name = attr.name.lstrip('_') # for external
definition += '%-20s : %-28s # %s\n' % (
attr.name if attr.default is None else '%s=%s' % (attr.name, attr.default),
'%s%s' % (attr.type, 'auto_increment' if attr.autoincrement else ''), attr.comment)
name if attr.default is None else '%s=%s' % (name, attr.default),
'%s%s' % (attr.type, ' auto_increment' if attr.autoincrement else ''), attr.comment)
print(definition)
return definition

Expand Down Expand Up @@ -547,7 +534,7 @@ def lookup_class_name(name, context, depth=3):
except AttributeError:
pass # not a UserRelation -- cannot have part tables.
else:
for part in (getattr(member, p) for p in parts):
for part in (getattr(member, p) for p in parts if hasattr(member, p)):
if inspect.isclass(part) and issubclass(part, BaseRelation) and part.full_table_name == name:
return '.'.join([node['context_name'], member_name, part.__name__]).lstrip('.')
elif node['depth'] > 0 and inspect.ismodule(member) and member.__name__ != 'datajoint':
Expand Down

0 comments on commit 76e3d56

Please sign in to comment.