Skip to content

Commit

Permalink
Merge pull request #7346 from jmchilton/item_attrs_refactor_2
Browse files Browse the repository at this point in the history
Refactor item_attrs for reuse without mixin inheritance.
  • Loading branch information
nsoranzo committed Feb 13, 2019
2 parents 110d955 + 834e9f2 commit 0020c7d
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 101 deletions.
158 changes: 86 additions & 72 deletions lib/galaxy/model/item_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
log = logging.getLogger(__name__)


class RuntimeException(Exception):
pass


class UsesItemRatings(object):
"""
Mixin for getting and setting item ratings.
Expand All @@ -28,7 +24,7 @@ def get_ave_item_rating_data(self, db_session, item, webapp_model=None):
webapp_model = galaxy.model
item_rating_assoc_class = self._get_item_rating_assoc_class(item, webapp_model=webapp_model)
if not item_rating_assoc_class:
raise RuntimeException("Item does not have ratings: %s" % item.__class__.__name__)
raise Exception("Item does not have ratings: %s" % item.__class__.__name__)
item_id_filter = self._get_item_id_filter_str(item, item_rating_assoc_class)
ave_rating = db_session.query(func.avg(item_rating_assoc_class.rating)).filter(item_id_filter).scalar()
# Convert ave_rating to float; note: if there are no item ratings, ave rating is None.
Expand Down Expand Up @@ -65,7 +61,7 @@ def get_user_item_rating(self, db_session, user, item, webapp_model=None):
webapp_model = galaxy.model
item_rating_assoc_class = self._get_item_rating_assoc_class(item, webapp_model=webapp_model)
if not item_rating_assoc_class:
raise RuntimeException("Item does not have ratings: %s" % item.__class__.__name__)
raise Exception("Item does not have ratings: %s" % item.__class__.__name__)

# Query rating table by user and item id.
item_id_filter = self._get_item_id_filter_str(item, item_rating_assoc_class)
Expand All @@ -82,15 +78,7 @@ def _get_item_id_filter_str(self, item, item_rating_assoc_class, webapp_model=No
# Get foreign key in item-rating association table that references item table.
if webapp_model is None:
webapp_model = galaxy.model
item_fk = None
for fk in item_rating_assoc_class.table.foreign_keys:
if fk.references(item.table):
item_fk = fk
break

if not item_fk:
raise RuntimeException("Cannot find item id column in item-rating association table: %s, %s" % item_rating_assoc_class.__name__, item_rating_assoc_class.table.name)

item_fk = get_foreign_key(item_rating_assoc_class, item)
# TODO: can we provide a better filter than a raw string?
return "%s=%i" % (item_fk.parent.name, item.id)

Expand All @@ -99,64 +87,16 @@ class UsesAnnotations(object):
""" Mixin for getting and setting item annotations. """

def get_item_annotation_str(self, db_session, user, item):
""" Returns a user's annotation string for an item. """
if hasattr(item, 'annotations'):
# If we already have an annotations object we use it.
annotation_obj = None
for annotation in item.annotations:
if annotation.user == user:
annotation_obj = annotation
break
else:
annotation_obj = self.get_item_annotation_obj(db_session, user, item)
if annotation_obj:
return galaxy.util.unicodify(annotation_obj.annotation)
return None
return get_item_annotation_str(db_session, user, item)

def get_item_annotation_obj(self, db_session, user, item):
""" Returns a user's annotation object for an item. """
# Get annotation association class.
annotation_assoc_class = self._get_annotation_assoc_class(item)
if not annotation_assoc_class:
return None

# Get annotation association object.
annotation_assoc = db_session.query(annotation_assoc_class).filter_by(user=user)

# TODO: use filtering like that in _get_item_id_filter_str()
if item.__class__ == galaxy.model.History:
annotation_assoc = annotation_assoc.filter_by(history=item)
elif item.__class__ == galaxy.model.HistoryDatasetAssociation:
annotation_assoc = annotation_assoc.filter_by(hda=item)
elif item.__class__ == galaxy.model.HistoryDatasetCollectionAssociation:
annotation_assoc = annotation_assoc.filter_by(history_dataset_collection=item)
elif item.__class__ == galaxy.model.StoredWorkflow:
annotation_assoc = annotation_assoc.filter_by(stored_workflow=item)
elif item.__class__ == galaxy.model.WorkflowStep:
annotation_assoc = annotation_assoc.filter_by(workflow_step=item)
elif item.__class__ == galaxy.model.Page:
annotation_assoc = annotation_assoc.filter_by(page=item)
elif item.__class__ == galaxy.model.Visualization:
annotation_assoc = annotation_assoc.filter_by(visualization=item)
return annotation_assoc.first()
return get_item_annotation_obj(db_session, user, item)

def add_item_annotation(self, db_session, user, item, annotation):
""" Add or update an item's annotation; a user can only have a single annotation for an item. """
# Get/create annotation association object.
annotation_assoc = self.get_item_annotation_obj(db_session, user, item)
if not annotation_assoc:
annotation_assoc_class = self._get_annotation_assoc_class(item)
if not annotation_assoc_class:
return None
annotation_assoc = annotation_assoc_class()
item.annotations.append(annotation_assoc)
annotation_assoc.user = user
# Set annotation.
annotation_assoc.annotation = annotation
return annotation_assoc
return add_item_annotation(db_session, user, item, annotation)

def delete_item_annotation(self, db_session, user, item):
annotation_assoc = self.get_item_annotation_obj(db_session, user, item)
annotation_assoc = get_item_annotation_obj(db_session, user, item)
if annotation_assoc:
db_session.delete(annotation_assoc)
db_session.flush()
Expand All @@ -170,14 +110,88 @@ def copy_item_annotation(self, db_session, source_user, source_item, target_user
return annotation
return None

def _get_annotation_assoc_class(self, item):
""" Returns an item's item-annotation association class. """
class_name = '%sAnnotationAssociation' % item.__class__.__name__
return getattr(galaxy.model, class_name, None)

def get_item_annotation_obj(db_session, user, item):
"""Returns a user's annotation object for an item."""

# Get annotation association class.
annotation_assoc_class = _get_annotation_assoc_class(item)
if not annotation_assoc_class:
return None

# Get annotation association object.
annotation_assoc = db_session.query(annotation_assoc_class).filter_by(user=user)

# TODO: use filtering like that in _get_item_id_filter_str()
if item.__class__ == galaxy.model.History:
annotation_assoc = annotation_assoc.filter_by(history=item)
elif item.__class__ == galaxy.model.HistoryDatasetAssociation:
annotation_assoc = annotation_assoc.filter_by(hda=item)
elif item.__class__ == galaxy.model.HistoryDatasetCollectionAssociation:
annotation_assoc = annotation_assoc.filter_by(history_dataset_collection=item)
elif item.__class__ == galaxy.model.StoredWorkflow:
annotation_assoc = annotation_assoc.filter_by(stored_workflow=item)
elif item.__class__ == galaxy.model.WorkflowStep:
annotation_assoc = annotation_assoc.filter_by(workflow_step=item)
elif item.__class__ == galaxy.model.Page:
annotation_assoc = annotation_assoc.filter_by(page=item)
elif item.__class__ == galaxy.model.Visualization:
annotation_assoc = annotation_assoc.filter_by(visualization=item)
return annotation_assoc.first()


def get_item_annotation_str(db_session, user, item):
""" Returns a user's annotation string for an item. """
if hasattr(item, 'annotations'):
# If we already have an annotations object we use it.
annotation_obj = None
for annotation in item.annotations:
if annotation.user == user:
annotation_obj = annotation
break
else:
annotation_obj = get_item_annotation_obj(db_session, user, item)
if annotation_obj:
return galaxy.util.unicodify(annotation_obj.annotation)
return None


def add_item_annotation(db_session, user, item, annotation):
""" Add or update an item's annotation; a user can only have a single annotation for an item. """
# Get/create annotation association object.
annotation_assoc = get_item_annotation_obj(db_session, user, item)
if not annotation_assoc:
annotation_assoc_class = _get_annotation_assoc_class(item)
if not annotation_assoc_class:
return None
annotation_assoc = annotation_assoc_class()
item.annotations.append(annotation_assoc)
annotation_assoc.user = user
# Set annotation.
annotation_assoc.annotation = annotation
return annotation_assoc


def _get_annotation_assoc_class(item):
""" Returns an item's item-annotation association class. """
class_name = '%sAnnotationAssociation' % item.__class__.__name__
return getattr(galaxy.model, class_name, None)


def get_foreign_key(source_class, target_class):
""" Returns foreign key in source class that references target class. """
target_fk = None
for fk in source_class.table.foreign_keys:
if fk.references(target_class.table):
target_fk = fk
break
if not target_fk:
raise Exception("No foreign key found between objects: %s, %s" % source_class.table, target_class.table)
return target_fk


__all__ = (
'get_foreign_key',
'UsesAnnotations',
'UsesItemRatings',
'RuntimeException',
)
14 changes: 7 additions & 7 deletions lib/galaxy/tools/imp_exp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from galaxy import model
from galaxy.exceptions import MalformedContents
from galaxy.exceptions import ObjectNotFound
from galaxy.model.item_attrs import UsesAnnotations
from galaxy.model.item_attrs import add_item_annotation, get_item_annotation_str
from galaxy.web.framework.helpers import to_unicode

log = logging.getLogger(__name__)


class JobImportHistoryArchiveWrapper(UsesAnnotations):
class JobImportHistoryArchiveWrapper:
"""
Class provides support for performing jobs that import a history from
an archive.
Expand Down Expand Up @@ -84,7 +84,7 @@ def get_tag_str(tag, value):

# Add annotation, tags.
if user:
self.add_item_annotation(self.sa_session, user, new_history, history_attrs['annotation'])
add_item_annotation(self.sa_session, user, new_history, history_attrs['annotation'])
"""
TODO: figure out to how add tags to item.
for tag, value in history_attrs[ 'tags' ].items():
Expand Down Expand Up @@ -161,7 +161,7 @@ def get_tag_str(tag, value):

# Set tags, annotations.
if user:
self.add_item_annotation(self.sa_session, user, hda, dataset_attrs['annotation'])
add_item_annotation(self.sa_session, user, hda, dataset_attrs['annotation'])
# TODO: Set tags.
"""
for tag, value in dataset_attrs[ 'tags' ].items():
Expand Down Expand Up @@ -269,7 +269,7 @@ def default(self, obj):
raise


class JobExportHistoryArchiveWrapper(UsesAnnotations):
class JobExportHistoryArchiveWrapper:
"""
Class provides support for performing jobs that export a history to an
archive.
Expand Down Expand Up @@ -379,7 +379,7 @@ def default(self, obj):
"name": to_unicode(history.name),
"hid_counter": history.hid_counter,
"genome_build": history.genome_build,
"annotation": to_unicode(self.get_item_annotation_str(trans.sa_session, history.user, history)),
"annotation": to_unicode(get_item_annotation_str(trans.sa_session, history.user, history)),
"tags": get_item_tag_dict(history),
"includes_hidden_datasets": include_hidden,
"includes_deleted_datasets": include_deleted
Expand All @@ -396,7 +396,7 @@ def default(self, obj):
datasets_attrs = []
provenance_attrs = []
for dataset in datasets:
dataset.annotation = self.get_item_annotation_str(trans.sa_session, history.user, dataset)
dataset.annotation = get_item_annotation_str(trans.sa_session, history.user, dataset)
if (not dataset.visible and not include_hidden) or (dataset.deleted and not include_deleted):
provenance_attrs.append(dataset)
else:
Expand Down
12 changes: 1 addition & 11 deletions lib/galaxy/web/framework/helpers/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from six import string_types, text_type
from sqlalchemy.sql.expression import and_, false, func, null, or_, true

from galaxy.model.item_attrs import RuntimeException, UsesAnnotations, UsesItemRatings
from galaxy.model.item_attrs import get_foreign_key, UsesAnnotations, UsesItemRatings
from galaxy.util import restore_text, sanitize_text, unicodify
from galaxy.util.odict import odict
from galaxy.web.framework import decorators, url_for
Expand Down Expand Up @@ -584,16 +584,6 @@ def get_value(self, trans, grid, item):
item_id=trans.security.encode_id(item.id))

def sort(self, trans, query, ascending, column_name=None):
def get_foreign_key(source_class, target_class):
""" Returns foreign key in source class that references target class. """
target_fk = None
for fk in source_class.table.foreign_keys:
if fk.references(target_class.table):
target_fk = fk
break
if not target_fk:
raise RuntimeException("No foreign key found between objects: %s, %s" % source_class.table, target_class.table)
return target_fk
# Get the columns that connect item's table and item's rating association table.
item_rating_assoc_class = getattr(trans.model, '%sRatingAssociation' % self.model_class.__name__)
foreign_key = get_foreign_key(item_rating_assoc_class, self.model_class)
Expand Down
12 changes: 1 addition & 11 deletions lib/galaxy/webapps/reports/framework/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from six import string_types, text_type
from sqlalchemy.sql.expression import and_, false, func, null, or_, true

from galaxy.model.item_attrs import RuntimeException, UsesAnnotations, UsesItemRatings
from galaxy.model.item_attrs import get_foreign_key, UsesAnnotations, UsesItemRatings
from galaxy.util import sanitize_text, unicodify
from galaxy.util.odict import odict
from galaxy.web.framework import decorators, url_for
Expand Down Expand Up @@ -503,16 +503,6 @@ def get_value(self, trans, grid, item):
item_id=trans.security.encode_id(item.id))

def sort(self, trans, query, ascending, column_name=None):
def get_foreign_key(source_class, target_class):
""" Returns foreign key in source class that references target class. """
target_fk = None
for fk in source_class.table.foreign_keys:
if fk.references(target_class.table):
target_fk = fk
break
if not target_fk:
raise RuntimeException("No foreign key found between objects: %s, %s" % source_class.table, target_class.table)
return target_fk
# Get the columns that connect item's table and item's rating association table.
item_rating_assoc_class = getattr(trans.model, '%sRatingAssociation' % self.model_class.__name__)
foreign_key = get_foreign_key(item_rating_assoc_class, self.model_class)
Expand Down

0 comments on commit 0020c7d

Please sign in to comment.