Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: more loading refactoring #53

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 54 additions & 17 deletions invenio_migrator/cli.py
Expand Up @@ -47,7 +47,7 @@ def _loadrecord(record_dump, source_type, eager=False):
:param record_dump: Record dump.
:type record_dump: dict
:param source_type: 'json' or 'marcxml'
:param eager: If True execute the task synchronously.
:param eager: If ``True`` execute the task synchronously.
"""
if eager:
import_record.s(record_dump, source_type=source_type).apply(throw=True)
Expand All @@ -71,7 +71,7 @@ def _loadrecord(record_dump, source_type, eager=False):
def loadrecords(sources, source_type, recid):
"""Load records migration dump."""
# Load all record dumps up-front and find the specific JSON
if recid:
if recid is not None:
for source in sources:
records = json.load(source)
for item in records:
Expand Down Expand Up @@ -144,37 +144,62 @@ def inspectrecords(sources, recid, entity=None):
click.echo(revision)


def loadcommon(sources, load_task, asynchronous=True, task_args=None,
task_kwargs=None):
def loadcommon(sources, load_task, asynchronous=True, predicate=None,
task_args=None, task_kwargs=None):
"""Common helper function for load simple objects.

Note: Extra `args` and `kwargs` are passed to the `load_task` function.
.. note::

Keyword arguments ``task_args`` and ``task_kwargs`` are passed to the
``load_task`` function as ``*task_args`` and ``**task_kwargs``.

.. note::

The `predicate` argument is used as a predicate function to load only
a *single* item from across all dumps (this CLI function will return right
after loading the item). This is primarily used for debugging of
the *dirty* data within the dump. The `predicate` should be a function
with a signature ``f(dict) -> bool``, i.e. taking a single parameter
(an item from the dump) and return ``True`` if the item
should be loaded. See the ``loaddeposit`` for a concrete example.

:param sources: JSON source files with dumps
:type sources: list of str (filepaths)
:param load_task: Shared task which loads the dump.
:type load_task: function
:param asynchronous: Flag for serial or asynchronous execution of the task.
:type asynchronous: bool
:param task_args: positional arguments passed to the task (default: None).
:type task_args: tuple or None
:param task_kwargs: named arguments passed to the task (default: None).
:type task_kwargs: dict or None
:param predicate: Predicate for selecting only a single item from the dump.
:type predicate: function
:param task_args: positional arguments passed to the task.
:type task_args: tuple
:param task_kwargs: named arguments passed to the task.
:type task_kwargs: dict
"""
# resolve the defaults for task_args and task_kwargs
task_args = tuple() if task_args is None else task_args
task_kwargs = dict() if task_kwargs is None else task_kwargs
click.echo('Loading dumps started.')
for idx, source in enumerate(sources, 1):
click.echo('Loading dump {0} of {1} ({2})'.format(idx, len(sources),
source.name))
click.echo('Opening dump file {0} of {1} ({2})'.format(
idx, len(sources), source.name))
data = json.load(source)
with click.progressbar(data) as data_bar:
for d in data_bar:
if asynchronous:
load_task.s(d, *task_args, **task_kwargs).apply_async()
# Load a single item from the dump
if predicate is not None:
if predicate(d):
load_task.s(d, *task_args, **task_kwargs).apply(
throw=True)
click.echo("Loaded a single record.")
return
# Load dumps normally
else:
load_task.s(d, *task_args, **task_kwargs).apply(throw=True)
if asynchronous:
load_task.s(d, *task_args, **task_kwargs).apply_async()
else:
load_task.s(d, *task_args, **task_kwargs).apply(
throw=True)


@dumps.command()
Expand Down Expand Up @@ -209,11 +234,23 @@ def loadusers(sources):

@dumps.command()
@click.argument('sources', type=click.File('r'), nargs=-1)
@click.option('--depid', '-d', type=int,
help='Deposit ID to load (Note: will load only one deposit!).',
default=None)
@with_appcontext
def loaddeposit(sources):
"""Load deposit."""
def loaddeposit(sources, depid):
"""Load deposit.

Usage:
invenio dumps loaddeposit ~/data/deposit_dump_*.json
invenio dumps loaddeposit -d 12345 ~/data/deposit_dump_*.json
"""
from .tasks.deposit import load_deposit
loadcommon(sources, load_deposit)
if depid is not None:
pred = lambda dep: int(dep["_p"]["id"]) == depid
loadcommon(sources, load_deposit, predicate=pred, asynchronous=False)
else:
loadcommon(sources, load_deposit)


@dumps.command()
Expand Down
9 changes: 6 additions & 3 deletions invenio_migrator/tasks/deposit.py
Expand Up @@ -30,8 +30,7 @@
from celery.utils.log import get_task_logger

from .utils import empty_str_if_none
from .errors import DepositMultipleRecids, DepositRecidDoesNotExist, \
DepositSIPUserDoesNotExist
from .errors import DepositMultipleRecids, DepositRecidDoesNotExist

logger = get_task_logger(__name__)

Expand Down Expand Up @@ -133,6 +132,8 @@ def create_files_and_sip(deposit, dep_pid):
sip['agents'][0].get('email_address', "")),
)
user_id = sip['agents'][0]['user_id']
if user_id == 0:
user_id = None
content = sip['package']
sip_format = 'marcxml'
try:
Expand All @@ -144,7 +145,9 @@ def create_files_and_sip(deposit, dep_pid):
logger.exception('User ID {user_id} referred in deposit {depid} '
'does not exists.'.format(
user_id=user_id, depid=dep_pid.pid_value))
raise DepositSIPUserDoesNotExist(dep_pid.pid_value, user_id)
sip = SIP.create(sip_format,
content,
agent=agent)

# If recid was found, attach it to SIP
# TODO: This is always uses the first recid, as we quit if multiple
Expand Down
14 changes: 0 additions & 14 deletions invenio_migrator/tasks/errors.py
Expand Up @@ -72,20 +72,6 @@ def __str__(self):
", recids: {recids}".format(recids=self.recids)


class DepositSIPUserDoesNotExist(DepositError):
"""Deposit SIP user does not exist."""

def __init__(self, pid, user_id, *args, **kwargs):
"""Initialize exception."""
self.user_id = user_id
super(DepositSIPUserDoesNotExist, self).__init__(pid, *args, **kwargs)

def __str__(self):
"""String representation of the error."""
return super(DepositRecidDoesNotExist, self).__str__() + \
", user_id: {user_id}".format(user_id=self.user_id)


class UserEmailExistsError(Exception):
"""User email already exists in the database."""

Expand Down
17 changes: 10 additions & 7 deletions tests/unit/test_deposit_load.py
Expand Up @@ -35,35 +35,38 @@

from invenio_migrator.tasks.deposit import load_deposit
from invenio_migrator.tasks.errors import DepositMultipleRecids, \
DepositRecidDoesNotExist, DepositSIPUserDoesNotExist
DepositRecidDoesNotExist


def test_deposit_load(dummy_location, deposit_user, deposit_record_pid):
"""Test the deposit loading function."""
dep1 = dict(sips=[dict(metadata=dict(recid='10'),
agents=[dict(user_id=1), ],
package='Foobar'), ],
package='Content1'), ],
_p=dict(id='1'))
dep2 = dict(sips=[dict(metadata=dict(recid='50'),
agents=[dict(user_id=1), ],
package='Foobar'), ],
package='Content2'), ],
_p=dict(id='2'))
dep3 = dict(sips=[dict(metadata=dict(recid='10'),
agents=[dict(user_id=5), ],
package='Foobar'), ],
package='Content3'), ],
_p=dict(id='3'))
dep4 = dict(sips=[dict(metadata=dict(recid='10'),
agents=[dict(user_id=5), ],
package='Foobar'),
package='Content4'),
dict(metadata=dict(recid='11'),
agents=[dict(user_id=5), ],
package='Foobar'), ],
package='Content5'), ],
_p=dict(id='4'))
load_deposit(dep1)
pytest.raises(DepositRecidDoesNotExist, load_deposit, dep2)
pytest.raises(DepositSIPUserDoesNotExist, load_deposit, dep3)
pytest.raises(DepositMultipleRecids, load_deposit, dep4)

# Should set user to null because user_id does not exist
load_deposit(dep3)
assert SIP.query.filter_by(content="Content3").one().user_id is None


def test_deposit_load_task(dummy_location, deposit_dump, deposit_user,
deposit_record_pid):
Expand Down