Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cascading Delete Implementation #130

Merged
merged 5 commits into from
Jul 17, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions datajoint/relation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Mapping
from collections import defaultdict
import numpy as np
import logging
import abc
Expand Down Expand Up @@ -82,10 +83,16 @@ def children(self):

@property
def references(self):
"""
:return: list of tables that this tables refers to
"""
return self.connection.erm.references[self.full_table_name]

@property
def referenced(self):
"""
:return: list of tables for which this table is referenced by
"""
return self.connection.erm.referenced[self.full_table_name]

@property
Expand Down Expand Up @@ -187,22 +194,35 @@ def delete(self):
User is prompted for confirmation if config['safemode']
"""
relations = self.descendants
if self.restrictions and len(relations)>1:
raise NotImplementedError('Restricted cascading deletes are not yet implemented')
do_delete = True
#if self.restrictions and len(relations)>1:
# raise NotImplementedError('Restricted cascading deletes are not yet implemented')
restrict_by_me = defaultdict(lambda: False)
rel_by_name = {r.full_table_name:r for r in relations}
for r in relations:
for ref in r.references:
restrict_by_me[ref] = True

if self.restrictions is not None:
restrict_by_me[self.full_table_name] = True
rel_by_name[self.full_table_name]._restrict(self.restrictions)

for r in relations:
for dep in (r.children + r.references):
rel_by_name[dep]._restrict(r.project() if restrict_by_me[r.full_table_name] else r.restrictions)

if config['safemode']:
do_delete = False
do_delete = False # indicate if there is anything to delete
print('The contents of the following tables are about to be deleted:')
for relation in relations:
count = len(relation)
if count:
do_delete = True
print(relation.full_table_name, '(%d tuples)' % count)
do_delete = do_delete and user_choice("Proceed?", default='no') == 'yes'
if do_delete:
with self.connection.transaction:
while relations:
relations.pop().delete_quick()
if not do_delete or user_choice("Proceed?", default='no') != 'yes':
return
with self.connection.transaction:
while relations:
relations.pop().delete_quick()

def drop_quick(self):
"""
Expand Down
15 changes: 10 additions & 5 deletions datajoint/relational_operand.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def _restrict(self, restriction):
if restriction is not None:
if self._restrictions is None:
self._restrictions = []
self._restrictions.append(restriction)
if isinstance(restriction, list):
self._restrictions.extend(restriction)
else:
self._restrictions.append(restriction)
return self

def __iand__(self, restriction):
Expand Down Expand Up @@ -414,12 +417,14 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes):
self._arg = Subquery(arg)
else:
self._group = None
if arg.heading.computed:
self._arg = Subquery(arg)
else:
# project without subquery
if arg.heading.computed or\
(isinstance(arg.restrictions, RelationalOperand) and \
all(attr in self._attributes for attr in arg.restrictions.heading.names)) :
# can simply the expression because all restrictions attrs are projected out anyway!
self._arg = arg
self._restrictions = self._arg.restrictions
else:
self._arg = Subquery(arg)

@property
def connection(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _make_tuples(self, key):
"""
from datetime import date, timedelta
users = User().fetch()['username']
random.seed('Amazing Seed')
for experiment_id in range(self.fake_experiments_per_subject):
self.insert1(
dict(key,
Expand All @@ -82,6 +83,7 @@ def _make_tuples(self, key):
"""
populate with random data (pretend reading from raw files)
"""
random.seed('Amazing Seed')
for trial_id in range(10):
self.insert1(
dict(key,
Expand All @@ -103,6 +105,7 @@ def _make_tuples(self, key):
"""
populate with random data
"""
random.seed('Amazing seed')
row = dict(key,
sampling_frequency=6000,
duration=np.minimum(2, random.expovariate(1)))
Expand All @@ -124,6 +127,7 @@ def fill(self, key, number_samples):
"""
populate random trace of specified length
"""
random.seed('Amazing seed')
for channel in range(2):
self.insert1(
dict(key,
Expand Down