Skip to content

Commit

Permalink
Merge pull request #106 from pagreene/xdd-integration
Browse files Browse the repository at this point in the history
Integrate Xdd with its own manager
  • Loading branch information
pagreene committed May 19, 2020
2 parents 34753f7 + 26635eb commit 4199051
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 135 deletions.
185 changes: 185 additions & 0 deletions indra_db/managers/xdd_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import json
import boto3
import logging
from collections import defaultdict

from indra.statements import Statement
from indra_db.reading.read_db import DatabaseStatementData, generate_reading_id
from indra_db.util import S3Path, get_db, insert_raw_agents

logger = logging.getLogger(__name__)


class XddManager:
bucket = S3Path(bucket='hms-uw-collaboration')
reader_versions = {'REACH': '1.3.3-61059a-biores-e9ee36',
'SPARSER': 'February2020-linux'}
indra_version = '1.16.0-c439fdbc936f4eac00cafd559927d7ee06c492e8'

def __init__(self):
self.groups = None
self.statements = None
self.text_content = None

def load_groups(self, db):
logger.info("Finding groups that have not been handled yet.")
s3 = boto3.client('s3')
groups = self.bucket.list_prefixes(s3)
previous_groups = {s for s, in db.select_all(db.XddUpdates.day_str)}

self.groups = [group for group in groups
if group.key[:-1] not in previous_groups]
return

def load_statements(self, db):
logger.info("Loading statements.")
s3 = boto3.client('s3')
self.statements = defaultdict(lambda: defaultdict(list))
self.text_content = {}
for group in self.groups:
logger.info(f"Processing {group.key}")
file_pair_dict = _get_file_pairs_from_group(s3, group)
for run_id, (bibs, stmts) in file_pair_dict.items():
logger.info(f"Loading {run_id}")
doi_lookup = {bib['_xddid']: bib['identifier'][0]['id'].upper()
for bib in bibs if 'identifier' in bib}
pub_lookup = {bib['_xddid']: bib['publisher'] for bib in bibs}
dois = {doi for doi in doi_lookup.values()}
trids = _get_trids_from_dois(db, dois)

for sj in stmts:
ev = sj['evidence'][0]
xddid = ev['text_refs']['CONTENT_ID']
ev.pop('pmid', None)
if xddid not in doi_lookup:
logger.warning("Skipping statement because bib "
"lacked a DOI.")
continue
ev['text_refs']['DOI'] = doi_lookup[xddid]

trid = trids[doi_lookup[xddid]]
ev['text_refs']['TRID'] = trid
ev['text_refs']['XDD_RUN_ID'] = run_id
ev['text_refs']['XDD_GROUP_ID'] = group.key

self.statements[trid][ev['text_refs']['READER']].append(sj)
if trid not in self.text_content:
self.text_content[trid] = \
(trid, 'xdd', 'xdd', 'fulltext',
pub_lookup[xddid] == 'bioRxiv')
return

def dump_statements(self, db):
tc_rows = set(self.text_content.values())
tc_cols = ('text_ref_id', 'source', 'format', 'text_type', 'preprint')
logger.info(f"Dumping {len(tc_rows)} text content.")
db.copy_lazy('text_content', tc_rows, tc_cols)

# Look up tcids for newly entered content.
tcids = db.select_all(
[db.TextContent.text_ref_id, db.TextContent.id],
db.TextContent.text_ref_id.in_(self.statements.keys()),
db.TextContent.source == 'xdd'
)
tcid_lookup = {trid: tcid for trid, tcid in tcids}

# Compile reading and statements into rows.
r_rows = set()
r_cols = ('id', 'text_content_id', 'reader', 'reader_version',
'format', 'batch_id')
s_rows = set()
rd_batch_id = db.make_copy_batch_id()
stmt_batch_id = db.make_copy_batch_id()
stmts = []
for trid, trid_set in self.statements.items():
for reader, stmt_list in trid_set.items():
tcid = tcid_lookup[trid]
reader_version = self.reader_versions[reader.upper()]
reading_id = generate_reading_id(tcid, reader, reader_version)
r_rows.add((reading_id, tcid, reader.upper(), reader_version,
'xdd', rd_batch_id))
for sj in stmt_list:
stmt = Statement._from_json(sj)
stmts.append(stmt)
sd = DatabaseStatementData(
stmt,
reading_id,
indra_version=self.indra_version
)
s_rows.add(sd.make_tuple(stmt_batch_id))

logger.info(f"Dumping {len(r_rows)} readings.")
db.copy_lazy('reading', r_rows, r_cols, commit=False)

logger.info(f"Dumping {len(s_rows)} raw statements.")
db.copy_lazy('raw_statements', s_rows,
DatabaseStatementData.get_cols(), commit=False)
if len(stmts):
insert_raw_agents(db, stmt_batch_id, stmts, verbose=False,
commit=False)

update_rows = [(json.dumps(self.reader_versions), self.indra_version,
group.key[:-1])
for group in self.groups]
db.copy('xdd_updates', update_rows,
('reader_versions', 'indra_version', 'day_str'))
return

def run(self, db):
self.load_groups(db)
self.load_statements(db)
self.dump_statements(db)


def _get_file_pairs_from_group(s3, group: S3Path):
files = group.list_objects(s3)
file_pairs = defaultdict(dict)
for file_path in files:
run_id, file_suffix = file_path.key.split('_')
file_type = file_suffix.split('.')[0]
try:
file_obj = s3.get_object(**file_path.kw())
file_json = json.loads(file_obj['Body'].read())
file_pairs[run_id][file_type] = file_json
except Exception as e:
logger.error(f"Failed to load {file_path}")
logger.exception(e)
if run_id in file_pairs:
del file_pairs[run_id]

ret = {}
for run_id, files in file_pairs.items():
if len(files) != 2 or 'bib' not in files or 'stmts' not in files:
logger.warning(f"Run {run_id} does not have both 'bib' and "
f"'stmts' in files: {files.keys()}. Skipping.")
continue
ret[run_id] = (files['bib'], files['stmts'])
return ret


def _get_trids_from_dois(db, dois):
# Get current relevant text refs (if any)
tr_list = db.select_all(db.TextRef, db.TextRef.doi_in(dois))

# Add new dois (if any)
new_dois = set(dois) - {tr.doi.upper() for tr in tr_list}
if new_dois:
new_trs = [db.TextRef.new(doi=doi) for doi in new_dois]
logger.info(f"Adding {len(new_trs)} new text refs.")
db.session.add_all(new_trs)
db.session.commit()
tr_list += new_trs

# Make the full mapping table.
return {tr.doi.upper(): tr.id for tr in tr_list}


def main():
db = get_db('primary')

m = XddManager()
m.run(db)


if __name__ == '__main__':
main()
11 changes: 9 additions & 2 deletions indra_db/reading/read_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,19 @@ class DatabaseStatementData(object):
reading_id : int or None
The id number of the entry in the `readings` table of the database.
None if no such id is available.
indra_version : str or None
Override the default indra version, which is the version of indra
currently installed.
"""
def __init__(self, statement, reading_id=None, db_info_id=None):
def __init__(self, statement, reading_id=None, db_info_id=None,
indra_version=None):
self.reading_id = reading_id
self.db_info_id = db_info_id
self.statement = statement
self.indra_version = get_indra_version()
if indra_version is None:
self.indra_version = get_indra_version()
else:
self.indra_version = indra_version
self.__text_patt = re.compile('[\W_]+')
return

Expand Down
35 changes: 33 additions & 2 deletions indra_db/schemas/principal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import Column, Integer, String, UniqueConstraint, ForeignKey, \
Boolean, DateTime, func, BigInteger, or_, tuple_
from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import BYTEA, INET
from sqlalchemy.dialects.postgresql import BYTEA, INET, JSONB

from indra_db.schemas.mixins import IndraDBTable
from indra_db.schemas.indexes import StringIndex, BtreeIndex
Expand Down Expand Up @@ -60,6 +60,26 @@ class TextRef(Base, IndraDBTable):
UniqueConstraint('pmcid', 'doi', name='pmcid-doi')
)

def __repr__(self):
terms = [f'id={self.id}']
for col in ['pmid', 'pmcid', 'doi', 'pii', 'url', 'manuscript_id']:
if getattr(self, col) is not None:
terms.append(f'{col}={getattr(self, col)}')
if len(terms) > 2:
break
return f'{self.__class__.__name__}({", ".join(terms)})'

@classmethod
def new(cls, pmid=None, pmcid=None, doi=None, pii=None, url=None,
manuscript_id=None):
pmid, pmid_num = cls.process_pmid(pmid)
pmcid, pmcid_num, pmcid_version = cls.process_pmcid(pmcid)
doi, doi_ns, doi_id = cls.process_doi(doi)
return cls(pmid=pmid, pmid_num=pmid_num, pmcid=pmcid,
pmcid_num=pmcid_num, pmcid_version=pmcid_version,
doi=doi, doi_ns=doi_ns, doi_id=doi_id, pii=pii, url=url,
manuscript_id=manuscript_id)

@staticmethod
def process_pmid(pmid):
if not pmid:
Expand Down Expand Up @@ -226,9 +246,10 @@ class TextContent(Base, IndraDBTable):
source = Column(String(250), nullable=False)
format = Column(String(250), nullable=False)
text_type = Column(String(250), nullable=False)
content = Column(BYTEA, nullable=False)
content = Column(BYTEA)
insert_date = Column(DateTime, default=func.now())
last_updated = Column(DateTime, onupdate=func.now())
preprint = Column(Boolean)
__table_args__ = (
UniqueConstraint('text_ref_id', 'source', 'format',
'text_type', name='content-uniqueness'),
Expand Down Expand Up @@ -269,6 +290,16 @@ class ReadingUpdates(Base, IndraDBTable):
latest_datetime = Column(DateTime, nullable=False)
table_dict[ReadingUpdates.__tablename__] = ReadingUpdates

class XddUpdates(Base, IndraDBTable):
__tablename__ = 'xdd_updates'
_always_disp = ['day_str']
id = Column(Integer, primary_key=True)
reader_versions = Column(JSONB)
indra_version = Column(String)
day_str = Column(String, nullable=False, unique=True)
processed_date = Column(DateTime, default=func.now())
table_dict[XddUpdates.__tablename__] = XddUpdates

class DBInfo(Base, IndraDBTable):
__tablename__ = 'db_info'
_always_disp = ['id', 'db_name', 'source_api']
Expand Down
43 changes: 43 additions & 0 deletions indra_db/tests/test_xdd_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
import boto3
import random

from indra_db.tests.util import get_temp_db
from indra_db.managers.xdd_manager import XddManager


def test_dump():
db = get_temp_db(clear=True)
m = XddManager()

# Enter "old" DOIs
s3 = boto3.client('s3')
res = s3.list_objects_v2(**m.bucket.kw())
dois = set()
for ref in res['Contents']:
key = ref['Key']
if 'bib' not in key:
continue
try:
obj = s3.get_object(Key=key, **m.bucket.kw())
except Exception:
print('ack')
continue
bibs = json.loads(obj['Body'].read())
dois |= {bib['identifier'][0]['id'] for bib in bibs
if 'identifier' in bib}
sample_dois = random.sample(dois, len(dois)//2)
new_trs = [db.TextRef.new(doi=doi) for doi in sample_dois]
print(f"Adding {len(new_trs)} 'old' text refs.")
db.session.add_all(new_trs)
db.session.commit()

# Run the update.
m.run(db)

# Check the result.
assert db.select_all(db.TextRef)
assert db.select_all(db.TextContent)
assert db.select_all(db.Reading)
assert db.select_all(db.RawStatements)
assert db.select_all(db.RawAgents)

0 comments on commit 4199051

Please sign in to comment.