Skip to content

Commit

Permalink
Merge pull request #163 from pagreene/pa-dev
Browse files Browse the repository at this point in the history
Fix Preassembly
  • Loading branch information
pagreene committed Feb 18, 2021
2 parents fad53b8 + f45cff2 commit 6695d21
Show file tree
Hide file tree
Showing 9 changed files with 923 additions and 264 deletions.
1 change: 1 addition & 0 deletions indra_db/client/principal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .raw_statements import *
from .pa_statements import *
from .curation import *
from .content import *
172 changes: 172 additions & 0 deletions indra_db/client/principal/pa_statements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
__all__ = ["get_pa_stmt_jsons"]

import json
from collections import defaultdict

from sqlalchemy import func, cast, String, null
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.orm import aliased

from indra_db.util.constructors import get_db
from indra_db.client.principal.raw_statements import _fix_evidence


def get_pa_stmt_jsons(clauses=None, with_evidence=True, db=None, limit=1000):
"""Load preassembled Statements from the principal database."""
if db is None:
db = get_db('primary')

if clauses is None:
clauses = []

# Construct the core query.
if with_evidence:
text_ref_cols = [db.Reading.id, db.TextContent.id, db.TextRef.pmid,
db.TextRef.pmcid, db.TextRef.doi, db.TextRef.url,
db.TextRef.pii]
text_ref_types = tuple([str if isinstance(col.type, String) else int
for col in text_ref_cols])
text_ref_cols = tuple([cast(col, String)
if not isinstance(col.type, String) else col
for col in text_ref_cols])
text_ref_labels = ('rid', 'tcid', 'pmid', 'pmcid', 'doi', 'url', 'pii')
core_q = db.session.query(
db.PAStatements.mk_hash.label('mk_hash'),
db.PAStatements.json.label('json'),
func.array_agg(db.RawStatements.json).label("raw_jsons"),
func.array_agg(array(text_ref_cols)).label("text_refs")
).outerjoin(
db.RawUniqueLinks,
db.RawUniqueLinks.pa_stmt_mk_hash == db.PAStatements.mk_hash
).join(
db.RawStatements,
db.RawStatements.id == db.RawUniqueLinks.raw_stmt_id
).outerjoin(
db.Reading,
db.Reading.id == db.RawStatements.reading_id
).outerjoin(
db.TextContent,
db.TextContent.id == db.Reading.text_content_id
).outerjoin(
db.TextRef,
db.TextRef.id == db.TextContent.text_ref_id
)
else:
text_ref_types = None
text_ref_labels = None
core_q = db.session.query(
db.PAStatements.mk_hash.label('mk_hash'),
db.PAStatements.json.label('json'),
null().label('raw_jsons'),
null().label('text_refs')
)
core_q = core_q.filter(
*clauses
).group_by(
db.PAStatements.mk_hash,
db.PAStatements.json
)
if limit:
core_q = core_q.limit(limit)
core_sq = core_q.subquery().alias('core')

# Construct the layer of the query that gathers agent info.
agent_tuple = (cast(db.PAAgents.ag_num, String),
db.PAAgents.db_name,
db.PAAgents.db_id)
at_sq = db.session.query(
core_sq.c.mk_hash,
core_sq.c.json,
core_sq.c.raw_jsons,
core_sq.c.text_refs,
func.array_agg(array(agent_tuple)).label('db_refs')
).filter(
db.PAAgents.stmt_mk_hash == core_sq.c.mk_hash
).group_by(
core_sq.c.mk_hash,
core_sq.c.json,
core_sq.c.raw_jsons,
core_sq.c.text_refs
).subquery().alias('agent_tuples')

# Construct the layer of the query that gathers supports/supported by.
sup_from = aliased(db.PASupportLinks, name='sup_from')
sup_to = aliased(db.PASupportLinks, name='sup_to')
q = db.session.query(
at_sq.c.mk_hash,
at_sq.c.json,
at_sq.c.raw_jsons,
at_sq.c.text_refs,
at_sq.c.db_refs,
func.array_agg(sup_from.supporting_mk_hash).label('supporting_hashes'),
func.array_agg(sup_to.supported_mk_hash).label('supported_hashes')
).outerjoin(
sup_from,
sup_from.supported_mk_hash == at_sq.c.mk_hash
).outerjoin(
sup_to,
sup_to.supporting_mk_hash == at_sq.c.mk_hash
).group_by(
at_sq.c.mk_hash,
at_sq.c.json,
at_sq.c.raw_jsons,
at_sq.c.text_refs,
at_sq.c.db_refs
)

# Run and parse the query.
stmt_jsons = {}
stmts_by_hash = {}
for h, sj, rjs, text_refs, db_refs, supping, supped in q.all():
# Gather the agent refs.
db_ref_dicts = defaultdict(lambda: defaultdict(list))
for ag_num, db_name, db_id in db_refs:
db_ref_dicts[int(ag_num)][db_name].append(db_id)
db_ref_dicts = {k: dict(v) for k, v in db_ref_dicts.items()}

# Clean supping and supped.
supping = [h for h in set(supping) if h is not None]
supped = [h for h in set(supped) if h is not None]

# Parse the JSON bytes into JSON.
stmt_json = json.loads(sj)
if 'supports' not in stmt_json:
stmt_json['supports'] = []
if 'supported_by' not in stmt_json:
stmt_json['supported_by'] = []

# Load the evidence.
if rjs is not None:
for rj, text_ref_values in zip(rjs, text_refs):
raw_json = json.loads(rj)
ev = raw_json['evidence'][0]
if any(v is not None for v in text_ref_values):
tr_dict = {lbl.upper(): None if val == "None" else typ(val)
for lbl, typ, val
in zip(text_ref_labels, text_ref_types,
text_ref_values)}
_fix_evidence(ev, tr_dict.pop('RID'), tr_dict.pop('TCID'),
tr_dict)
if 'evidence' not in stmt_json:
stmt_json['evidence'] = []
stmt_json['evidence'].append(ev)

# Resolve supports supported-by, as much as possible.
stmts_by_hash[h] = stmt_json
for supped_h in (h for h in supped if h in stmts_by_hash):
stmt_json['supports'].append(stmts_by_hash[supped_h]['id'])
stmts_by_hash[supped_h]['supported_by'].append(stmt_json['id'])
for supping_h in (h for h in supping if h in stmts_by_hash):
stmt_json['supported_by'].append(stmts_by_hash[supping_h]['id'])
stmts_by_hash[supping_h]['supports'].append(stmt_json['id'])

# Put it together in a dictionary.
result_dict = {
"mk_hash": h,
"stmt": stmt_json,
"db_refs": db_ref_dicts,
"supports_hashes": supping,
"supported_by_hashes": supped
}
stmt_jsons[h] = result_dict
return stmt_jsons
109 changes: 71 additions & 38 deletions indra_db/client/principal/raw_statements.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = ['get_direct_raw_stmt_jsons_from_agents',
'get_raw_stmt_jsons_from_papers']
__all__ = ['get_raw_stmt_jsons_from_agents', 'get_raw_stmt_jsons_from_papers',
'get_raw_stmt_jsons']

import json
from collections import defaultdict
Expand All @@ -17,7 +17,8 @@


@clockit
def get_raw_stmt_jsons_from_papers(id_list, id_type='pmid', db=None):
def get_raw_stmt_jsons_from_papers(id_list, id_type='pmid', db=None,
max_stmts=None, offset=None):
"""Get raw statement jsons for a given list of papers.
Parameters
Expand Down Expand Up @@ -81,12 +82,15 @@ def get_raw_stmt_jsons_from_papers(id_list, id_type='pmid', db=None):


@clockit
def get_direct_raw_stmt_jsons_from_agents(agents=None, stmt_type=None, db=None,
max_stmts=None, offset=None):
def get_raw_stmt_jsons_from_agents(agents=None, stmt_type=None, db=None,
max_stmts=None, offset=None):
"""Get Raw statement jsons from a list of agent refs and Statement type."""
if db is None:
db = get_db('primary')

if agents is None:
agents = []

# Turn the agents parameters into an intersection of queries for stmt ids.
entity_queries = []
for role, ag_dbid, ns in agents:
Expand All @@ -98,9 +102,11 @@ def get_direct_raw_stmt_jsons_from_agents(agents=None, stmt_type=None, db=None,
ag_dbid = ag_dbid.replace(char, '\%s' % char)

# Generate the query
q = (db.session
.query(db.RawAgents.stmt_id.label('stmt_id'))
.filter(db.RawAgents.db_id.like(ag_dbid)))
q = db.session.query(
db.RawAgents.stmt_id.label('stmt_id')
).filter(
db.RawAgents.db_id.like(ag_dbid)
)

if ns is not None:
q = q.filter(db.RawAgents.db_name.like(ns))
Expand All @@ -110,46 +116,64 @@ def get_direct_raw_stmt_jsons_from_agents(agents=None, stmt_type=None, db=None,

entity_queries.append(q)

# Add a constraint for the statement type.
if stmt_type is not None:
q = db.session.query(
db.RawStatements.id.label('stmt_id')
).filter(
db.RawStatements.type == stmt_type
)
entity_queries.append(q)

# Generate the sub-query.
ag_query_al = intersect_all(*entity_queries).alias('intersection')
ag_query = db.session.query(ag_query_al).distinct().subquery('ag_stmt_ids')

# Create a query for the raw statement json
rid_c = db.RawStatements.reading_id.label('rid')
json_q = (db.session.query(db.RawStatements.json, rid_c, ag_query)
.filter(db.RawStatements.id == ag_query.c.stmt_id))
# Get the raw statement JSONs from the database.
res = get_raw_stmt_jsons([db.RawStatements.id == ag_query.c.stmt_id], db=db,
max_stmts=max_stmts, offset=offset)
return res

# Filter by type, if applicable.
if stmt_type is not None:
json_q = json_q.filter(db.RawStatements.type == stmt_type)

# Apply count limits and such.
def get_raw_stmt_jsons(clauses=None, db=None, max_stmts=None, offset=None):
"""Get Raw Statements from the principle database, given arbitrary clauses.
"""
if db is None:
db = get_db('primary')

if clauses is None:
clauses = []

q = db.session.query(
db.RawStatements.id,
db.RawStatements.json,
db.Reading.id,
db.TextContent.id,
db.TextRef
).filter(
*clauses
).outerjoin(
db.Reading,
db.Reading.id == db.RawStatements.reading_id
).outerjoin(
db.TextContent,
db.TextContent.id == db.Reading.text_content_id
).outerjoin(
db.TextRef,
db.TextRef.id == db.TextContent.text_ref_id
)

if max_stmts is not None:
json_q = json_q.limit(max_stmts)
q = q.limit(max_stmts)

if offset is not None:
json_q = json_q.offset(offset)

# Construct final query, that joins with text ref info on the database.
json_q = json_q.subquery('json_content')
ref_q = (db.session
.query(json_q, db.Reading.text_content_id.label('tcid'),
db.TextRef)
.outerjoin(db.Reading, db.Reading.id == json_q.c.rid)
.join(db.TextContent,
db.TextContent.id == db.Reading.text_content_id)
.join(db.TextRef, db.TextRef.id == db.TextContent.text_ref_id))

# Process the jsons, filling text ref info.
q = q.offset(offset)

raw_stmt_jsons = {}
for json_bytes, rid, sid, tcid, tr in ref_q.all():
for sid, json_bytes, rid, tcid, tr in q.all():
raw_j = json.loads(json_bytes)
ev = raw_j['evidence'][0]
ev['text_refs'] = tr.get_ref_dict()
ev['text_refs']['TCID'] = tcid
ev['text_refs']['READING_ID'] = rid
if tr.pmid:
ev['pmid'] = tr.pmid

if rid is not None:
_fix_evidence(raw_j['evidence'][0], rid, tcid, tr.get_ref_dict())
raw_stmt_jsons[sid] = raw_j

return raw_stmt_jsons
Expand All @@ -170,3 +194,12 @@ def _get_id_col(tr, id_type):
raise ValueError("Invalid id_type: %s" % id_type)
return id_attr


def _fix_evidence(ev, rid, tcid, tr_dict):
ev['text_refs'] = tr_dict
ev['text_refs']['TCID'] = tcid
ev['text_refs']['READING_ID'] = rid
if 'PMID' in tr_dict:
ev['pmid'] = tr_dict['PMID']
return

16 changes: 11 additions & 5 deletions indra_db/preassembly/preassemble_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _get_cache_path(self, file_name):

def _init_cache(self, continuing):
if self.s3_cache is None:
return
return datetime.utcnow()

import boto3
s3 = boto3.client('s3')
Expand Down Expand Up @@ -449,7 +449,9 @@ def create_corpus(self, db, continuing=False):
db.PAStatements.mk_hash.in_(hash_list[in_si:in_ei])
)
inner_batch = [_stmt_from_json(sj) for sj, in inner_sj_q.all()]
split_idx = len(inner_batch)
# NOTE: deliberately subtracting 1 because the INDRA
# implementation is weird.
split_idx = len(inner_batch) - 1
full_list = inner_batch + outer_batch
self._log(f'Getting support between outer batch {outer_idx}/'
f'{len(idx_batches)-1} and inner batch {inner_idx}/'
Expand Down Expand Up @@ -553,7 +555,7 @@ def _supplement_support(self, db, new_hashes, start_time, continuing=False):
db.PAStatements.json,
db.PAStatements.mk_hash.in_(new_hashes[out_s:out_e])
)
npa_batch = [_stmt_from_json(s_json) for s_json in npa_json_q.all()]
npa_batch = [_stmt_from_json(s_json) for s_json, in npa_json_q.all()]

# Compare internally
self._log(f"Getting support for new pa batch {outer_idx}/"
Expand All @@ -569,7 +571,9 @@ def _supplement_support(self, db, new_hashes, start_time, continuing=False):
)
other_npa_batch = [_stmt_from_json(sj)
for sj, in other_npa_q.all()]
split_idx = len(npa_batch)
# NOTE: deliberately subtracting 1 because the INDRA
# implementation is weird.
split_idx = len(npa_batch) - 1
full_list = npa_batch + other_npa_batch
self._log(f"Comparing outer batch {outer_idx}/"
f"{len(idx_batches)-1} to inner batch {in_idx}/"
Expand All @@ -589,7 +593,9 @@ def _supplement_support(self, db, new_hashes, start_time, continuing=False):
for opa_idx, opa_json_batch in opa_json_iter:
opa_batch = [_stmt_from_json(s_json)
for s_json, in opa_json_batch]
split_idx = len(npa_batch)
# NOTE: deliberately subtracting 1 because the INDRA
# implementation is weird.
split_idx = len(npa_batch) - 1
full_list = npa_batch + opa_batch
self._log(f"Comparing new batch {outer_idx}/"
f"{len(idx_batches)-1} to batch {opa_idx} of old pa "
Expand Down

0 comments on commit 6695d21

Please sign in to comment.