Skip to content

Commit

Permalink
Merge pull request #174 from pagreene/paper-dev
Browse files Browse the repository at this point in the history
Fix paper search feature
  • Loading branch information
pagreene committed Jun 4, 2021
2 parents 41353ab + 518c3a2 commit f2bf56d
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 72 deletions.
34 changes: 20 additions & 14 deletions indra_db/client/readonly/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def agg(self, ro, with_hashes=True, sort_by='ev_count'):
def run(self):
raise NotImplementedError

def print(self):
print(self.agg_q)
def __str__(self):
return str(self.agg_q.selectable.compile(
compile_kwargs={'literal_binds': True}
))


class InteractionSQL(AgentJsonSQL):
Expand Down Expand Up @@ -513,7 +515,7 @@ class must be provided.
# Put it all together.
selection = select(cols).select_from(stmts_q)
if self._print_only:
print(selection)
print(selection.compile(compile_kwargs={'literal_binds': True}))
return

logger.debug(f"Executing query (get_statements):\n{selection}")
Expand Down Expand Up @@ -663,7 +665,7 @@ def get_hashes(self, ro=None, limit=None, offset=None, sort_by='ev_count',
q = mk_hashes_q

if self._print_only:
print(q)
print(q.selectable.compile(compile_kwargs={'literal_binds': True}))
return

# Make the query, and package the results.
Expand Down Expand Up @@ -828,7 +830,7 @@ def _run_meta_sql(self, ms, ro, limit, offset, sort_by, with_hashes=None):
order_params = ms.agg(ro, **kwargs)
ms = self._apply_limits(ms, order_params, limit, offset)
if self._print_only:
ms.print()
print(ms)
return
return ms.run()

Expand Down Expand Up @@ -1801,27 +1803,31 @@ def _get_table(self, ro):

def _get_conditions(self, ro):
conditions = []
id_groups = defaultdict(set)
for id_type, paper_id in self.paper_list:
if paper_id is None:
logger.warning("Got paper with id None.")
continue

# TODO: upgrade this to use new id formatting. This will require
# updating the ReadingRefLink table in the readonly build.
if id_type in ['trid', 'tcid']:
id_groups[id_type].add(int(paper_id))
else:
id_groups[id_type].add(str(paper_id))

for id_type, id_list in id_groups.items():
tbl_attr = getattr(ro.ReadingRefLink, id_type)
if not self._inverted:
if id_type in ['trid', 'tcid']:
conditions.append(tbl_attr == int(paper_id))
conditions.append(tbl_attr.in_(id_list))
else:
constraint = ro.ReadingRefLink.has_ref(id_type,
[str(paper_id)])
constraint = ro.ReadingRefLink.has_ref(id_type, id_list)
conditions.append(constraint)
else:
if id_type in ['trid', 'tcid']:
conditions.append(tbl_attr != int(paper_id))
conditions.append(tbl_attr.notin_(id_list))
else:
# Note that this is a highly non-optimized approach.
conditions.append(tbl_attr.notlike(str(paper_id)))
constraint = ro.ReadingRefLink.not_has_ref(id_type, id_list)
conditions.append(constraint)
return conditions

def _get_hash_query(self, ro, inject_queries=None):
Expand Down Expand Up @@ -2305,7 +2311,7 @@ def _run_meta_sql(self, ms, ro, limit, offset, sort_by, with_hashes=None):
order_params = ms.agg(ro, **kwargs)
ms = self._apply_limits(ms, order_params, limit, offset)
if self._print_only:
ms.print()
print(ms)
return
return ms.run()

Expand Down
10 changes: 9 additions & 1 deletion indra_db/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,15 @@ def grab_session(self):
raise IndraDbException("Failed to grab session.")
if self.__protected:
def no_flush(*a, **k):
logger.error("Write not allowed!")
# If further errors occur as a result, please first think
# carefully whether you really want to write, and if you
# do, instantiate your database handle with
# "protected=False". Note that you should NOT be writing to
# readonly unless you are doing the initial load, or are
# testing something on a dev database. Do NOT write to a
# stable deployment.
logger.info("Session flush attempted. Write not allowed in "
"protected mode.")
self.session.flush = no_flush

def get_tables(self):
Expand Down
167 changes: 118 additions & 49 deletions indra_db/schemas/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from termcolor import colored
from psycopg2.errors import DuplicateTable
from sqlalchemy import inspect, Column, BigInteger, tuple_, and_, \
UniqueConstraint
UniqueConstraint, or_
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm.attributes import InstrumentedAttribute

Expand Down Expand Up @@ -260,6 +260,8 @@ class IndraDBRefTable:
doi_ns = NotImplemented
doi_id = NotImplemented

# PMID Processing methods.

@staticmethod
def process_pmid(pmid):
if not pmid:
Expand All @@ -270,6 +272,46 @@ def process_pmid(pmid):

return pmid, int(pmid)

@classmethod
def _get_pmid_lookups(cls, pmid_list, filter_ids=False):
# Process the ID list.
pmid_num_set = set()
for pmid in pmid_list:
_, pmid_num = cls.process_pmid(pmid)
if pmid_num is None:
if filter_ids:
logger.warning('"%s" is not a valid pmid. Skipping.'
% pmid)
continue
else:
ValueError('"%s" is not a valid pmid.' % pmid)
pmid_num_set.add(pmid_num)
return pmid_num_set

@classmethod
def pmid_in(cls, pmid_list, filter_ids=False):
"""Get sqlalchemy clauses for entries IN a list of pmids."""
pmid_num_set = cls._get_pmid_lookups(pmid_list, filter_ids)

# Return the constraint
if len(pmid_num_set) == 1:
return cls.pmid_num == pmid_num_set.pop()
else:
return cls.pmid_num.in_(pmid_num_set)

@classmethod
def pmid_notin(cls, pmid_list, filter_ids=False):
"""Get sqlalchemy clauses for entries NOT IN a list of pmids."""
pmid_num_set = cls._get_pmid_lookups(pmid_list, filter_ids)

# Return the constraint
if len(pmid_num_set) == 1:
return cls.pmid_num != pmid_num_set.pop()
else:
return cls.pmid_num.notin_(pmid_num_set)

# PMCID Processing methods

@staticmethod
def process_pmcid(pmcid):
if not pmcid:
Expand All @@ -292,6 +334,47 @@ def process_pmcid(pmcid):

return pmcid, int(pmcid[3:]), version_number

@classmethod
def _get_pmcid_lookups(cls, pmcid_list, filter_ids=False):
# Process the ID list.
pmcid_num_set = set()
for pmcid in pmcid_list:
_, pmcid_num, _ = cls.process_pmcid(pmcid)
if not pmcid_num:
if filter_ids:
logger.warning('"%s" does not look like a valid '
'pmcid. Skipping.' % pmcid)
continue
else:
raise ValueError('"%s" is not a valid pmcid.' % pmcid)
else:
pmcid_num_set.add(pmcid_num)
return pmcid_num_set

@classmethod
def pmcid_in(cls, pmcid_list, filter_ids=False):
"""Get the sqlalchemy clauses for entries IN a list of pmcids."""
pmcid_num_set = cls._get_pmcid_lookups(pmcid_list, filter_ids)

# Return the constraint
if len(pmcid_num_set) == 1:
return cls.pmcid_num == pmcid_num_set.pop()
else:
return cls.pmcid_num.in_(pmcid_num_set)

@classmethod
def pmcid_notin(cls, pmcid_list, filter_ids=False):
"""Get the sqlalchemy clause for entries NOT IN a list of pmcids."""
pmcid_num_set = cls._get_pmcid_lookups(pmcid_list, filter_ids)

# Return the constraint
if len(pmcid_num_set) == 1:
return cls.pmcid_num != pmcid_num_set.pop()
else:
return cls.pmcid_num.notin_(pmcid_num_set)

# DOI Processing methods

@staticmethod
def process_doi(doi):
# Check for invalid DOIs
Expand Down Expand Up @@ -321,53 +404,7 @@ def process_doi(doi):
return doi, namespace, group_id

@classmethod
def pmid_in(cls, pmid_list, filter_ids=False):
"""Get sqlalchemy clauses for a list of pmids."""
# Process the ID list.
pmid_num_set = set()
for pmid in pmid_list:
_, pmid_num = cls.process_pmid(pmid)
if pmid_num is None:
if filter_ids:
logger.warning('"%s" is not a valid pmid. Skipping.'
% pmid)
continue
else:
ValueError('"%s" is not a valid pmid.' % pmid)
pmid_num_set.add(pmid_num)

# Return the constraint
if len(pmid_num_set) == 1:
return cls.pmid_num == pmid_num_set.pop()
else:
return cls.pmid_num.in_(pmid_num_set)

@classmethod
def pmcid_in(cls, pmcid_list, filter_ids=False):
"""Get the sqlalchemy clauses for a list of pmcids."""
# Process the ID list.
pmcid_num_set = set()
for pmcid in pmcid_list:
_, pmcid_num, _ = cls.process_pmcid(pmcid)
if not pmcid_num:
if filter_ids:
logger.warning('"%s" does not look like a valid '
'pmcid. Skipping.' % pmcid)
continue
else:
raise ValueError('"%s" is not a valid pmcid.' % pmcid)
else:
pmcid_num_set.add(pmcid_num)

# Return the constraint
if len(pmcid_num_set) == 1:
return cls.pmcid_num == pmcid_num_set.pop()
else:
return cls.pmcid_num.in_(pmcid_num_set)

@classmethod
def doi_in(cls, doi_list, filter_ids=False):
"""Get clause for looking up a list of dois."""
def _get_doi_lookups(cls, doi_list, filter_ids=False):
# Parse the DOIs in the list.
doi_tuple_set = set()
for doi in doi_list:
Expand All @@ -381,6 +418,12 @@ def doi_in(cls, doi_list, filter_ids=False):
raise ValueError('"%s" is not a valid doi.' % doi)
else:
doi_tuple_set.add((doi_ns, doi_id))
return doi_tuple_set

@classmethod
def doi_in(cls, doi_list, filter_ids=False):
"""Get clause for looking up entities IN a list of dois."""
doi_tuple_set = cls._get_doi_lookups(doi_list, filter_ids)

# Return the constraint
if len(doi_tuple_set) == 1:
Expand All @@ -389,9 +432,21 @@ def doi_in(cls, doi_list, filter_ids=False):
else:
return tuple_(cls.doi_ns, cls.doi_id).in_(doi_tuple_set)

@classmethod
def doi_notin(cls, doi_list, filter_ids=False):
"""Get clause for looking up entities NOT IN a list of dois."""
doi_tuple_set = cls._get_doi_lookups(doi_list, filter_ids)

# Return the constraint
if len(doi_tuple_set) == 1:
doi_ns, doi_id = doi_tuple_set.pop()
return or_(cls.doi_ns != doi_ns, cls.doi_id != doi_id)
else:
return tuple_(cls.doi_ns, cls.doi_id).notin_(doi_tuple_set)

@classmethod
def has_ref(cls, id_type, id_list, filter_ids=False):
"""Get the appropriate constraint for the given ID list."""
"""Get clause for entries IN the given ID list."""
id_type = id_type.lower()
if id_type == 'pmid':
return cls.pmid_in(id_list, filter_ids)
Expand All @@ -402,7 +457,21 @@ def has_ref(cls, id_type, id_list, filter_ids=False):
else:
return getattr(cls, id_type).in_(id_list)

@classmethod
def not_has_ref(cls, id_type, id_list, filter_ids=False):
"""Get clause for entries NOT IN the given ID list"""
id_type = id_type.lower()
if id_type == 'pmid':
return cls.pmid_notin(id_list, filter_ids)
elif id_type == 'pmcid':
return cls.pmcid_notin(id_list, filter_ids)
elif id_type == 'doi':
return cls.doi_notin(id_list, filter_ids)
else:
return getattr(cls, id_type).notin_(id_list)

def get_ref_dict(self):
"""Return the refs as a dictionary keyed by type."""
ref_dict = {}
for ref in self._ref_cols:
val = getattr(self, ref, None)
Expand Down
16 changes: 8 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ def main():
print("Installing `indra_db` Packages:\n", '\n'.join(packages))
extras_require = {'test': ['nose', 'coverage', 'python-coveralls',
'nose-timer'],
'rest_api': ['flask', 'flask-jwt-extended', 'flask-cors',
'flask-compress']}
'service': ['flask', 'flask-jwt-extended', 'flask-cors',
'flask-compress'],
'cli': ['click', 'boto3'],
'copy': ['pgcopy']}
extras_require['all'] = list({dep for deps in extras_require.values()
for dep in deps})
setup(name='indra_db',
Expand All @@ -19,15 +21,13 @@ def main():
author_email='patrick_greene@hms.harvard.edu',
packages=packages,
include_package_data=True,
install_requires=['indra', 'boto3', 'sqlalchemy', 'psycopg2',
'pgcopy', 'matplotlib', 'nltk', 'reportlab',
'cachetools', 'termcolor', 'click'],
install_requires=['sqlalchemy', 'psycopg2'],
extras_require=extras_require,
entry_points="""
[console_scripts]
indra_db=indra_db.cli:main
indra_db_rest=rest_api.cli:main
indra_db_benchmarker=benchmarker.cli:main
indra-db=indra_db.cli:main
indra-db-service=rest_api.cli:main
indra-db-benchmarker=benchmarker.cli:main
""")


Expand Down

0 comments on commit f2bf56d

Please sign in to comment.