diff --git a/glance/db/simple/api.py b/glance/db/simple/api.py index 83806e132d..6b5d137093 100644 --- a/glance/db/simple/api.py +++ b/glance/db/simple/api.py @@ -31,6 +31,7 @@ 'tags': {}, 'locations': [], 'tasks': {}, + 'task_info': {} } @@ -57,6 +58,7 @@ def reset(): 'tags': {}, 'locations': [], 'tasks': {}, + 'task_info': {} } @@ -119,16 +121,32 @@ def _image_member_format(image_id, tenant_id, can_share, status='pending'): } +def _pop_task_info_values(values): + task_info_values = {} + for k, v in values.items(): + if k in ['input', 'result', 'message']: + values.pop(k) + task_info_values[k] = v + + return task_info_values + + +def _format_task_from_db(task_ref, task_info_ref): + task = copy.deepcopy(task_ref) + if task_info_ref: + task_info = copy.deepcopy(task_info_ref) + task_info_values = _pop_task_info_values(task_info) + task.update(task_info_values) + return task + + def _task_format(task_id, **values): dt = timeutils.utcnow() task = { 'id': task_id, 'type': 'import', 'status': 'pending', - 'input': None, - 'result': None, 'owner': None, - 'message': None, 'expires_at': None, 'created_at': dt, 'updated_at': dt, @@ -139,6 +157,17 @@ def _task_format(task_id, **values): return task +def _task_info_format(task_id, **values): + task_info = { + 'task_id': task_id, + 'input': None, + 'result': None, + 'message': None, + } + task_info.update(values) + return task_info + + def _image_format(image_id, **values): dt = timeutils.utcnow() image = { @@ -697,9 +726,11 @@ def user_get_storage_usage(context, owner_id, image_id=None, session=None): @log_call -def task_create(context, task_values): +def task_create(context, values): """Create a task object""" global DATA + + task_values = copy.deepcopy(values) task_id = task_values.get('id', uuidutils.generate_uuid()) required_attributes = ['type', 'status', 'input'] allowed_attributes = ['id', 'type', 'status', 'input', 'result', 'owner', @@ -718,16 +749,20 @@ def task_create(context, task_values): raise exception.Invalid( 'The keys %s are not valid' % str(incorrect_keys)) + task_info_values = _pop_task_info_values(task_values) task = _task_format(task_id, **task_values) DATA['tasks'][task_id] = task + task_info = _task_info_create(task['id'], task_info_values) - return copy.deepcopy(task) + return _format_task_from_db(task, task_info) @log_call -def task_update(context, task_id, values, purge_props=False): +def task_update(context, task_id, values): """Update a task object""" global DATA + task_values = copy.deepcopy(values) + task_info_values = _pop_task_info_values(task_values) try: task = DATA['tasks'][task_id] except KeyError: @@ -735,16 +770,18 @@ def task_update(context, task_id, values, purge_props=False): LOG.debug(msg) raise exception.TaskNotFound(task_id=task_id) - task.update(values) + task.update(task_values) task['updated_at'] = timeutils.utcnow() DATA['tasks'][task_id] = task - return task + task_info = _task_info_update(task['id'], task_info_values) + + return _format_task_from_db(task, task_info) @log_call def task_get(context, task_id, force_show_deleted=False): - task = _task_get(context, task_id, force_show_deleted) - return copy.deepcopy(task) + task, task_info = _task_get(context, task_id, force_show_deleted) + return _format_task_from_db(task, task_info) def _task_get(context, task_id, force_show_deleted=False): @@ -765,7 +802,9 @@ def _task_get(context, task_id, force_show_deleted=False): LOG.debug(msg) raise exception.Forbidden(msg) - return task + task_info = _task_info_get(task_id) + + return task, task_info @log_call @@ -802,7 +841,12 @@ def task_get_all(context, filters=None, marker=None, limit=None, tasks = _paginate_tasks(context, tasks, marker, limit, filters.get('deleted')) - return tasks + filtered_tasks = [] + for task in tasks: + task_info = DATA['task_info'][task['id']] + filtered_tasks.append(_format_task_from_db(task, task_info)) + + return filtered_tasks def _is_task_visible(context, task): @@ -878,3 +922,41 @@ def _paginate_tasks(context, tasks, marker, limit, show_deleted): end = start + limit if limit is not None else None return tasks[start:end] + + +def _task_info_create(task_id, values): + """Create a Task Info for Task with given task ID""" + global DATA + task_info = _task_info_format(task_id, **values) + DATA['task_info'][task_id] = task_info + + return task_info + + +def _task_info_update(task_id, values): + """Update Task Info for Task with given task ID and updated values""" + global DATA + try: + task_info = DATA['task_info'][task_id] + except KeyError: + msg = (_("No task info found with task id %s") % task_id) + LOG.debug(msg) + raise exception.TaskNotFound(task_id=task_id) + + task_info.update(values) + DATA['task_info'][task_id] = task_info + + return task_info + + +def _task_info_get(task_id): + """Get Task Info for Task with given task ID""" + global DATA + try: + task_info = DATA['task_info'][task_id] + except KeyError: + msg = _('Could not find task info %s') % task_id + LOG.info(msg) + raise exception.TaskNotFound(task_id=task_id) + + return task_info diff --git a/glance/db/sqlalchemy/api.py b/glance/db/sqlalchemy/api.py index 6002a1a4ab..738df8c7f5 100644 --- a/glance/db/sqlalchemy/api.py +++ b/glance/db/sqlalchemy/api.py @@ -1177,42 +1177,120 @@ def user_get_storage_usage(context, owner_id, image_id=None, session=None): return total_size +def _task_info_format(task_info_ref): + """Format a task info ref for consumption outside of this module""" + if task_info_ref is None: + return {} + return { + 'task_id': task_info_ref['task_id'], + 'input': task_info_ref['input'], + 'result': task_info_ref['result'], + 'message': task_info_ref['message'], + } + + +def _task_info_create(context, task_id, values, session=None): + """Create an TaskInfo object""" + session = session or _get_session() + task_info_ref = models.TaskInfo() + task_info_ref.task_id = task_id + task_info_ref.update(values) + task_info_ref.save(session=session) + return _task_info_format(task_info_ref) + + +def _task_info_update(context, task_id, values, session=None): + """Update an TaskInfo object""" + session = session or _get_session() + task_info_ref = _task_info_get(context, task_id, session=session) + if task_info_ref: + task_info_ref.update(values) + task_info_ref.save(session=session) + return _task_info_format(task_info_ref) + + +def _task_info_get(context, task_id, session=None): + """Fetch an TaskInfo entity by task_id""" + session = session or _get_session() + query = session.query(models.TaskInfo) + query = query.filter_by(task_id=task_id) + try: + task_info_ref = query.one() + except sa_orm.exc.NoResultFound: + msg = (_("TaskInfo was not found for task with id %(task_id)s") % + {'task_id': task_id}) + LOG.debug(msg) + task_info_ref = None + + return task_info_ref + + def task_create(context, values, session=None): """Create a task object""" - task_ref = models.Task() - _task_update(context, task_ref, values, session=session) - return _task_format(task_ref) + + values = values.copy() + session = session or _get_session() + with session.begin(): + task_info_values = _pop_task_info_values(values) + + task_ref = models.Task() + _task_update(context, task_ref, values, session=session) + + _task_info_create(context, + task_ref.id, + task_info_values, + session=session) + + return task_get(context, task_ref.id, session) + + +def _pop_task_info_values(values): + task_info_values = {} + for k, v in values.items(): + if k in ['input', 'result', 'message']: + values.pop(k) + task_info_values[k] = v + + return task_info_values def task_update(context, task_id, values, session=None): """Update a task object""" + session = session or _get_session() - task_ref = _task_get(context, task_id, session) - _task_update(context, task_ref, values, session) - return _task_format(task_ref) + with session.begin(): + task_info_values = _pop_task_info_values(values) + + task_ref = _task_get(context, task_id, session) + _drop_protected_attrs(models.Task, values) + + values['updated_at'] = timeutils.utcnow() + + _task_update(context, task_ref, values, session) + + if task_info_values: + _task_info_update(context, + task_id, + task_info_values, + session) + + return task_get(context, task_id, session) -def task_get(context, task_id, session=None): + +def task_get(context, task_id, session=None, force_show_deleted=False): """Fetch a task entity by id""" - task_ref = _task_get(context, task_id, session=session) - return _task_format(task_ref) + task_ref = _task_get(context, task_id, session=session, + force_show_deleted=force_show_deleted) + return _task_format(task_ref, task_ref.info) def task_delete(context, task_id, session=None): """Delete a task""" session = session or _get_session() - query = session.query(models.Task)\ - .filter_by(id=task_id)\ - .filter_by(deleted=False) - try: - task_ref = query.one() - except sa_orm.exc.NoResultFound: - msg = (_("No task found with ID %s") % task_id) - LOG.debug(msg) - raise exception.TaskNotFound(task_id=task_id) - + task_ref = _task_get(context, task_id, session=session) task_ref.delete(session=session) - return _task_format(task_ref) + return _task_format(task_ref, task_ref.info) def task_get_all(context, filters=None, marker=None, limit=None, @@ -1233,7 +1311,8 @@ def task_get_all(context, filters=None, marker=None, limit=None, filters = filters or {} session = _get_session() - query = session.query(models.Task) + query = session.query(models.Task)\ + .options(sa_orm.joinedload(models.Task.info)) if not (context.is_admin or admin_as_user == True) and \ context.owner is not None: @@ -1266,7 +1345,17 @@ def task_get_all(context, filters=None, marker=None, limit=None, marker=marker_task, sort_dir=sort_dir) - return [_task_format(task) for task in query.all()] + task_refs = query.all() + + tasks = [] + for task_ref in task_refs: + # NOTE(venkatesh): call to task_ref.info does not make any + # seperate query call to fetch task info as it has been + # eagerly loaded using joinedload(models.Task.info) method above. + task_info_ref = task_ref.info + tasks.append(_task_format(task_ref, task_info_ref)) + + return tasks def _is_task_visible(context, task): @@ -1290,8 +1379,10 @@ def _is_task_visible(context, task): def _task_get(context, task_id, session=None, force_show_deleted=False): """Fetch a task entity by id""" session = session or _get_session() - query = session.query(models.Task) - query = query.filter_by(id=task_id) + query = session.query(models.Task).options( + sa_orm.joinedload(models.Task.info) + ).filter_by(id=task_id) + if not force_show_deleted and not _can_show_deleted(context): query = query.filter_by(deleted=False) try: @@ -1312,26 +1403,32 @@ def _task_get(context, task_id, session=None, force_show_deleted=False): def _task_update(context, task_ref, values, session=None): """Apply supplied dictionary of values to a task object.""" - _drop_protected_attrs(models.Task, values) values["deleted"] = False task_ref.update(values) task_ref.save(session=session) return task_ref -def _task_format(task_ref): +def _task_format(task_ref, task_info_ref=None): """Format a task ref for consumption outside of this module""" - return { + task_dict = { 'id': task_ref['id'], 'type': task_ref['type'], 'status': task_ref['status'], - 'input': task_ref['input'], - 'result': task_ref['result'], 'owner': task_ref['owner'], - 'message': task_ref['message'], 'expires_at': task_ref['expires_at'], 'created_at': task_ref['created_at'], 'updated_at': task_ref['updated_at'], 'deleted_at': task_ref['deleted_at'], 'deleted': task_ref['deleted'] } + + if task_info_ref: + task_info_dict = { + 'input': task_info_ref['input'], + 'result': task_info_ref['result'], + 'message': task_info_ref['message'], + } + task_dict.update(task_info_dict) + + return task_dict diff --git a/glance/db/sqlalchemy/migrate_repo/versions/032_add_task_info_table.py b/glance/db/sqlalchemy/migrate_repo/versions/032_add_task_info_table.py new file mode 100644 index 0000000000..860539803e --- /dev/null +++ b/glance/db/sqlalchemy/migrate_repo/versions/032_add_task_info_table.py @@ -0,0 +1,96 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Rackspace +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from sqlalchemy.schema import (Column, ForeignKey, MetaData, Table) + +from glance.db.sqlalchemy.migrate_repo.schema import (String, + Text, + create_tables, + drop_tables) + +TASKS_MIGRATE_COLUMNS = ['input', 'message', 'result'] + + +def define_task_info_table(meta): + Table('tasks', meta, autoload=True) + #NOTE(nikhil): input and result are stored as text in the DB. + # SQLAlchemy marshals the data to/from JSON using custom type + # JSONEncodedDict. It uses simplejson underneath. + task_info = Table('task_info', + meta, + Column('task_id', String(36), + ForeignKey('tasks.id'), + primary_key=True, + nullable=False), + Column('input', Text()), + Column('result', Text()), + Column('message', Text()), + mysql_engine='InnoDB') + + return task_info + + +def upgrade(migrate_engine): + meta = MetaData() + meta.bind = migrate_engine + + tables = [define_task_info_table(meta)] + create_tables(tables) + + tasks_table = Table('tasks', meta, autoload=True) + task_info_table = Table('task_info', meta, autoload=True) + + tasks = tasks_table.select().execute().fetchall() + for task in tasks: + values = { + 'task_id': task.id, + 'input': task.input, + 'result': task.result, + 'message': task.message, + } + task_info_table.insert(values=values).execute() + + for col_name in TASKS_MIGRATE_COLUMNS: + tasks_table.columns[col_name].drop() + + +def downgrade(migrate_engine): + meta = MetaData() + meta.bind = migrate_engine + + tasks_table = Table('tasks', meta, autoload=True) + task_info_table = Table('task_info', meta, autoload=True) + + for col_name in TASKS_MIGRATE_COLUMNS: + column = Column(col_name, Text()) + column.create(tasks_table) + + task_info_records = task_info_table.select().execute().fetchall() + + for task_info in task_info_records: + values = { + 'input': task_info.input, + 'result': task_info.result, + 'message': task_info.message + } + + tasks_table\ + .update(values=values)\ + .where(tasks_table.c.id == task_info.task_id)\ + .execute() + + drop_tables([task_info_table]) diff --git a/glance/db/sqlalchemy/models.py b/glance/db/sqlalchemy/models.py index 20a2913835..c0d1b463d8 100644 --- a/glance/db/sqlalchemy/models.py +++ b/glance/db/sqlalchemy/models.py @@ -219,11 +219,27 @@ class Task(BASE, GlanceBase): id = Column(String(36), primary_key=True, default=uuidutils.generate_uuid) type = Column(String(30)) status = Column(String(30)) + owner = Column(String(255)) + expires_at = Column(DateTime, nullable=True) + + +class TaskInfo(BASE, models.ModelBase): + """Represents task info in the datastore""" + __tablename__ = 'task_info' + + task_id = Column(String(36), + ForeignKey('tasks.id'), + primary_key=True, + nullable=False) + + task = relationship(Task, backref=backref('info', uselist=False)) + + #NOTE(nikhil): input and result are stored as text in the DB. + # SQLAlchemy marshals the data to/from JSON using custom type + # JSONEncodedDict. It uses simplejson underneath. input = Column(JSONEncodedDict()) result = Column(JSONEncodedDict()) - owner = Column(String(255)) message = Column(Text) - expires_at = Column(DateTime, nullable=True) def register_models(engine): diff --git a/glance/tests/functional/db/base.py b/glance/tests/functional/db/base.py index 4e0d9e75b1..1de5871e00 100644 --- a/glance/tests/functional/db/base.py +++ b/glance/tests/functional/db/base.py @@ -1210,10 +1210,10 @@ def test_storage_quota_multiple_locations(self): self.assertEqual(total, x) -class DriverTaskTests(test_utils.BaseTestCase): +class TaskTests(test_utils.BaseTestCase): def setUp(self): - super(DriverTaskTests, self).setUp() + super(TaskTests, self).setUp() self.owner_id1 = uuidutils.generate_uuid() self.adm_context = context.RequestContext(is_admin=True, auth_tok='user:user:admin') @@ -1261,8 +1261,8 @@ def build_task_fixtures(self): def test_task_get_all_with_filter(self): for fixture in self.fixtures: - task = self.db_api.task_create(self.context, - build_task_fixture(**fixture)) + self.db_api.task_create(self.context, + build_task_fixture(**fixture)) import_tasks = self.db_api.task_get_all(self.context, filters={'type': 'import'}) @@ -1295,8 +1295,8 @@ def test_task_get_all_marker(self): def test_task_get_all_limit(self): for fixture in self.fixtures: - task = self.db_api.task_create(self.context, - build_task_fixture(**fixture)) + self.db_api.task_create(self.context, + build_task_fixture(**fixture)) tasks = self.db_api.task_get_all(self.context, limit=2) self.assertEqual(2, len(tasks)) @@ -1336,11 +1336,14 @@ def test_task_get_all_owned(self): def test_task_get(self): expires_at = timeutils.utcnow() + image_id = uuidutils.generate_uuid() fixture = { 'owner': self.context.owner, 'type': 'import', 'status': 'pending', 'input': '{"loc": "fake"}', + 'result': "{'image_id': %s}" % image_id, + 'message': 'blah', 'expires_at': expires_at } @@ -1357,8 +1360,67 @@ def test_task_get(self): self.assertEqual(task['owner'], self.context.owner) self.assertEqual(task['type'], 'import') self.assertEqual(task['status'], 'pending') + self.assertEqual(task['input'], fixture['input']) + self.assertEqual(task['result'], fixture['result']) + self.assertEqual(task['message'], fixture['message']) self.assertEqual(task['expires_at'], expires_at) + def test_task_get_all(self): + now = timeutils.utcnow() + image_id = uuidutils.generate_uuid() + fixture1 = { + 'owner': self.context.owner, + 'type': 'import', + 'status': 'pending', + 'input': '{"loc": "fake_1"}', + 'result': "{'image_id': %s}" % image_id, + 'message': 'blah_1', + 'expires_at': now, + 'created_at': now, + 'updated_at': now + } + + fixture2 = { + 'owner': self.context.owner, + 'type': 'import', + 'status': 'pending', + 'input': '{"loc": "fake_2"}', + 'result': "{'image_id': %s}" % image_id, + 'message': 'blah_2', + 'expires_at': now, + 'created_at': now, + 'updated_at': now + } + + task1 = self.db_api.task_create(self.context, fixture1) + task2 = self.db_api.task_create(self.context, fixture2) + + self.assertIsNotNone(task1) + self.assertIsNotNone(task2) + + task1_id = task1['id'] + task2_id = task2['id'] + task_fixtures = {task1_id: fixture1, task2_id: fixture2} + tasks = self.db_api.task_get_all(self.context) + + self.assertEqual(len(tasks), 2) + self.assertEqual(set((tasks[0]['id'], tasks[1]['id'])), + set((task1_id, task2_id))) + for task in tasks: + fixture = task_fixtures[task['id']] + + self.assertEqual(task['owner'], self.context.owner) + self.assertEqual(task['type'], fixture['type']) + self.assertEqual(task['status'], fixture['status']) + self.assertEqual(task['expires_at'], fixture['expires_at']) + self.assertFalse(task['deleted']) + self.assertIsNone(task['deleted_at']) + self.assertEqual(task['created_at'], fixture['created_at']) + self.assertEqual(task['updated_at'], fixture['updated_at']) + self.assertEqual(task['input'], fixture['input']) + self.assertEqual(task['result'], fixture['result']) + self.assertEqual(task['message'], fixture['message']) + def test_task_create(self): task_id = uuidutils.generate_uuid() self.context.tenant = uuidutils.generate_uuid() @@ -1375,10 +1437,64 @@ def test_task_create(self): self.assertEqual(task['owner'], self.context.owner) self.assertEqual(task['type'], 'export') self.assertEqual(task['status'], 'pending') + self.assertEqual(task['input'], {'ping': 'pong'}) + + def test_task_create_with_all_task_info_null(self): + task_id = uuidutils.generate_uuid() + self.context.tenant = uuidutils.generate_uuid() + values = { + 'id': task_id, + 'owner': self.context.owner, + 'type': 'export', + 'status': 'pending', + 'input': None, + 'result': None, + 'message': None, + } + task_values = build_task_fixture(**values) + task = self.db_api.task_create(self.context, task_values) + self.assertIsNotNone(task) + self.assertEqual(task['id'], task_id) + self.assertEqual(task['owner'], self.context.owner) + self.assertEqual(task['type'], 'export') + self.assertEqual(task['status'], 'pending') + self.assertEqual(task['input'], None) + self.assertEqual(task['result'], None) + self.assertEqual(task['message'], None) def test_task_update(self): self.context.tenant = uuidutils.generate_uuid() - task_values = build_task_fixture(owner=self.context.owner) + result = {'foo': 'bar'} + task_values = build_task_fixture(owner=self.context.owner, + result=result) + task = self.db_api.task_create(self.context, task_values) + + task_id = task['id'] + fixture = { + 'status': 'processing', + 'message': 'This is a error string', + } + task = self.db_api.task_update(self.context, task_id, fixture) + + self.assertEqual(task['id'], task_id) + self.assertEqual(task['owner'], self.context.owner) + self.assertEqual(task['type'], 'import') + self.assertEqual(task['status'], 'processing') + self.assertEqual(task['input'], {'ping': 'pong'}) + self.assertEqual(task['result'], result) + self.assertEqual(task['message'], 'This is a error string') + self.assertEqual(task['deleted'], False) + self.assertIsNone(task['deleted_at']) + self.assertIsNone(task['expires_at']) + self.assertEqual(task['created_at'], task_values['created_at']) + self.assertTrue(task['updated_at'] > task['created_at']) + + def test_task_update_with_all_task_info_null(self): + self.context.tenant = uuidutils.generate_uuid() + task_values = build_task_fixture(owner=self.context.owner, + input=None, + result=None, + message=None) task = self.db_api.task_create(self.context, task_values) task_id = task['id'] @@ -1389,9 +1505,17 @@ def test_task_update(self): self.assertEqual(task['owner'], self.context.owner) self.assertEqual(task['type'], 'import') self.assertEqual(task['status'], 'processing') + self.assertEqual(task['input'], None) + self.assertEqual(task['result'], None) + self.assertEqual(task['message'], None) + self.assertEqual(task['deleted'], False) + self.assertIsNone(task['deleted_at']) + self.assertIsNone(task['expires_at']) + self.assertEqual(task['created_at'], task_values['created_at']) + self.assertTrue(task['updated_at'] > task['created_at']) def test_task_delete(self): - task_values = build_task_fixture() + task_values = build_task_fixture(owner=self.context.owner) task = self.db_api.task_create(self.context, task_values) self.assertIsNotNone(task) @@ -1403,6 +1527,24 @@ def test_task_delete(self): self.assertRaises(exception.TaskNotFound, self.db_api.task_get, self.context, task_id) + def test_task_delete_as_admin(self): + task_values = build_task_fixture(owner=self.context.owner) + task = self.db_api.task_create(self.context, task_values) + + self.assertIsNotNone(task) + self.assertEqual(task['deleted'], False) + self.assertIsNone(task['deleted_at']) + + task_id = task['id'] + self.db_api.task_delete(self.context, task_id) + del_task = self.db_api.task_get(self.adm_context, + task_id, + force_show_deleted=True) + self.assertIsNotNone(del_task) + self.assertEqual(task_id, del_task['id']) + self.assertEqual(True, del_task['deleted']) + self.assertIsNotNone(del_task['deleted_at']) + class TestVisibility(test_utils.BaseTestCase): def setUp(self): diff --git a/glance/tests/functional/db/test_simple.py b/glance/tests/functional/db/test_simple.py index 965f4632e9..8ac301a2e7 100644 --- a/glance/tests/functional/db/test_simple.py +++ b/glance/tests/functional/db/test_simple.py @@ -63,7 +63,7 @@ def setUp(self): self.addCleanup(db_tests.reset) -class TestSimpleTask(base.DriverTaskTests): +class TestSimpleTask(base.TaskTests): def setUp(self): db_tests.load(get_db, reset_db) diff --git a/glance/tests/functional/db/test_sqlalchemy.py b/glance/tests/functional/db/test_sqlalchemy.py index 01db4e8674..2be31ad6a0 100644 --- a/glance/tests/functional/db/test_sqlalchemy.py +++ b/glance/tests/functional/db/test_sqlalchemy.py @@ -100,7 +100,7 @@ def fake_paginate_query(query, model, limit, sort_key='name') -class TestSqlAlchemyTask(base.DriverTaskTests): +class TestSqlAlchemyTask(base.TaskTests): def setUp(self): db_tests.load(get_db, reset_db) diff --git a/glance/tests/unit/test_migrations.py b/glance/tests/unit/test_migrations.py index 9b44682d75..1217ab6306 100644 --- a/glance/tests/unit/test_migrations.py +++ b/glance/tests/unit/test_migrations.py @@ -1093,3 +1093,74 @@ def _check_031(self, engine, image_id): ('file://ab1', '{"a": "that one, please"}'), ]) self.assertFalse(actual_locations.symmetric_difference(locations)) + + def _pre_upgrade_032(self, engine): + self.assertRaises(sqlalchemy.exc.NoSuchTableError, + get_table, engine, 'task_info') + + tasks = get_table(engine, 'tasks') + now = datetime.datetime.now() + base_values = { + 'deleted': False, + 'created_at': now, + 'updated_at': now, + 'status': 'active', + 'owner': 'TENANT', + 'type': 'import', + } + data = [ + { + 'id': 'task-1', + 'input': 'some input', + 'message': None, + 'result': 'successful' + }, + { + 'id': 'task-2', + 'input': None, + 'message': None, + 'result': None + }, + ] + map(lambda task: task.update(base_values), data) + for task in data: + tasks.insert().values(task).execute() + return data + + def _check_032(self, engine, data): + task_info_table = get_table(engine, 'task_info') + + task_info_refs = task_info_table.select().execute().fetchall() + + self.assertEquals(len(task_info_refs), 2) + + for x in range(len(task_info_refs)): + self.assertEqual(task_info_refs[x].task_id, data[x]['id']) + self.assertEqual(task_info_refs[x].input, data[x]['input']) + self.assertEqual(task_info_refs[x].result, data[x]['result']) + self.assertIsNone(task_info_refs[x].message) + + tasks_table = get_table(engine, 'tasks') + self.assertNotIn('input', tasks_table.c) + self.assertNotIn('result', tasks_table.c) + self.assertNotIn('message', tasks_table.c) + + def _post_downgrade_032(self, engine): + self.assertRaises(sqlalchemy.exc.NoSuchTableError, + get_table, engine, 'task_info') + + tasks_table = get_table(engine, 'tasks') + records = tasks_table.select().execute().fetchall() + self.assertEquals(len(records), 2) + + tasks = dict([(t.id, t) for t in records]) + + task_1 = tasks.get('task-1') + self.assertEqual(task_1.input, 'some input') + self.assertEqual(task_1.result, 'successful') + self.assertIsNone(task_1.message) + + task_2 = tasks.get('task-2') + self.assertIsNone(task_2.input) + self.assertIsNone(task_2.result) + self.assertIsNone(task_2.message) diff --git a/glance/tests/unit/utils.py b/glance/tests/unit/utils.py index af4b2a70dc..3f8fb70c6c 100644 --- a/glance/tests/unit/utils.py +++ b/glance/tests/unit/utils.py @@ -86,7 +86,8 @@ def reset(): 'members': [], 'tags': {}, 'locations': [], - 'tasks': {} + 'tasks': {}, + 'task_info': {} } def __getattr__(self, key):