Skip to content
This repository has been archived by the owner on Aug 13, 2021. It is now read-only.

Commit

Permalink
[326] NiH: impute base id (#337)
Browse files Browse the repository at this point in the history
* imputing base id

* added missing import

* Using starmap and cleared up logic if no ids found

* migrated to starmap

* moved to chain and starmap

* fixed test
  • Loading branch information
Joel Klinger committed Nov 24, 2020
1 parent a1bc3f4 commit aee5ab9
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 0 deletions.
1 change: 1 addition & 0 deletions nesta/core/orms/nih_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Projects(Base):
application_type = Column(INTEGER)
arra_funded = Column(VARCHAR(1))
award_notice_date = Column(DATETIME)
base_core_project_num = Column(VARCHAR(50), index=True)
budget_start = Column(DATETIME)
budget_end = Column(DATETIME)
cfda_code = Column(TEXT)
Expand Down
62 changes: 62 additions & 0 deletions nesta/core/routines/datasets/nih/nih_impute_base_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
'''
Impute Base ID
==============
What NiH don't tell you is that the `core_project_num` field
has a base project number, which is effectively the actual
core project number. This is essential for aggregating projects,
otherwise it will appear that many duplicates exist in the data.
This task imputes these values using a simple regex of the form:
{BASE_ID}-{an integer}-{an integer}-{an integer}
Any `core_project_num` failing this regex are ignored.
'''


import logging
import luigi
from datetime import datetime as dt
from multiprocessing.dummy import Pool as ThreadPool
from itertools import chain


from nesta.core.luigihacks.mysqldb import make_mysql_target
from nesta.packages.nih.impute_base_id import retrieve_id_ranges
from nesta.packages.nih.impute_base_id import impute_base_id_thread


class ImputeBaseIDTask(luigi.Task):
'''Impute the base ID using a regex of the form
{BASE_ID}-{an integer}-{an integer}-{an integer}
Args:
date (datetime): Date stamp.
test (bool): Running in test mode?
'''
date = luigi.DateParameter()
test = luigi.BoolParameter()

def output(self):
return make_mysql_target(self)

def run(self):
database = 'dev' if self.test else 'production'
id_ranges = retrieve_id_ranges(database)
# Threading required because it takes 2-3 days to impute
# all values on a single thread, or a few hours on 15 threads
pool = ThreadPool(15)
_id_ranges = map(lambda x: chain(x, [database]), id_ranges)
pool.starmap(impute_base_id_thread, _id_ranges)
pool.close()
pool.join()
self.output().touch()


class RootTask(luigi.WrapperTask):
date = luigi.DateParameter(default=dt.now())
production = luigi.BoolParameter(default=False)

def requires(self):
yield ImputeBaseIDTask(date=self.date, test=not self.production)
95 changes: 95 additions & 0 deletions nesta/packages/nih/impute_base_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
'''
Impute Base ID
==============
What NiH don't tell you is that the `core_project_num` field
has a base project number, which is effectively the actual
core project number. This is essential for aggregating projects,
otherwise it will appear that many duplicates exist in the data.
This code here is for imputing these values using a simple regex of the form:
{BASE_ID}-{an integer}-{an integer}-{an integer}
Any `core_project_num` failing this regex are ignored.
'''

import re
from sqlalchemy.orm import load_only

from nesta.core.orms.nih_orm import Projects
from nesta.core.orms.orm_utils import get_mysql_engine
from nesta.core.orms.orm_utils import db_session

REGEX = re.compile('^(.*)-(\d+)-(\d+)-(\d+)$')
def get_base_code(core_code):
"""Extract the base code from the core project number
if the pattern matches, otherwise return the
core project number."""
try:
core_code, _, _, _ = REGEX.findall(core_code)[0]
except IndexError:
pass
return core_code


def impute_base_id(session, from_id, to_id):
"""Impute the base ID values back into the database.
Args:
session (sqlalchemy.Session): Active context-managed db session
from_id (str): First NiH project PK to impute base ID for.
to_id (str): Last NiH project PK to impute base ID for.
"""
q = session.query(Projects)
# Don't load bloaty fields, we don't need them
q = q.options(load_only('application_id', 'core_project_num',
'base_core_project_num'))
# Don't update cached objects, since we're not using them
q = q.execution_options(synchronize_session=False)
# Retrieve the projects
q = q.filter(Projects.application_id.between(from_id, to_id))
for project in q.all():
# Ignore those which are null
if project.core_project_num is None:
continue
# Extract the base code
base_code = get_base_code(project.core_project_num)
app_id = project.application_id
# NB: the following triggers a SQL UPDATE when commit() is
# called when the session context manager goes out of scope
project.base_core_project_num = base_code


def retrieve_id_ranges(database, chunksize=1000):
"""Retrieve and calculate the input arguments,
over which "impute_base_id" can be mapped"""
engine = get_mysql_engine("MYSQLDB", "mysqldb", database)
# First get all offset values
with db_session(engine) as session:
q = session.query(Projects.application_id)
q = q.order_by(Projects.application_id)
try:
ids, = zip(*q.all())
except ValueError: # Forgiveness, if there are no IDs in the DB
return []

final_id = ids[-1]
ids = list(ids[0::chunksize]) # Every {chunksize}th id
# Pop the final ID back in, if it has been truncated
if ids[-1] != final_id:
ids.append(final_id)
# Zip together consecutive pairs of arguments, i.e.
# n-1 values of (from_id, to_id)
# where from_id[n] == to_id[n-1]
id_ranges = list(zip(ids, ids[1:]))
return id_ranges


def impute_base_id_thread(from_id, to_id, database):
"""Apply "impute_base_id" over this chunk of IDs"""
#from_id, to_id, database = args[0] # Unpack thread args
engine = get_mysql_engine("MYSQLDB", "mysqldb", database)
with db_session(engine) as session:
impute_base_id(session, from_id, to_id)
# Note: Commit happens now
71 changes: 71 additions & 0 deletions nesta/packages/nih/tests/test_impute_base_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest import mock

from nesta.packages.nih.impute_base_id import get_base_code
from nesta.packages.nih.impute_base_id import impute_base_id
from nesta.packages.nih.impute_base_id import retrieve_id_ranges
from nesta.packages.nih.impute_base_id import impute_base_id_thread

from nesta.packages.nih.impute_base_id import Projects

PATH = "nesta.packages.nih.impute_base_id.{}"

def test_get_base_code():
# Pass regex: return base code
assert get_base_code("helloworld-1-2-3") == "helloworld"
assert get_base_code("foobar-11-0-23") == "foobar"

# Fail regex: return input
assert get_base_code("foo-bar-2-3") == "foo-bar-2-3"
assert get_base_code("foo-bar-hello-world") == "foo-bar-hello-world"
assert get_base_code("foobar-11-0") == "foobar-11-0"
assert get_base_code("foobar123") == "foobar123"


def test_impute_base_id():
core_ids = ["helloworld-1-2-3", "foobar-11-0-23", "foo-bar-2-3",
"foo-bar-hello-world", "foobar123"]
projects = [Projects(application_id=i, core_project_num=core_id)
for i, core_id in enumerate(core_ids)]

session = mock.Mock()
q = session.query().options().execution_options().filter()
q.all.return_value = projects

# Check that the base_project_num has not been imputed yet
assert all(p.base_core_project_num is None
for p in projects)

# Impute the ids
impute_base_id(session, from_id=None, to_id=None)

# Check that the base_project_num has been imputed
imputed_values = [p.base_core_project_num for p in projects]
expect_values = ["helloworld", "foobar", # <-- Regex passes
# Regex fails:
"foo-bar-2-3", "foo-bar-hello-world", "foobar123"]
assert imputed_values == expect_values


@mock.patch(PATH.format("get_mysql_engine"))
@mock.patch(PATH.format("db_session"))
def test_retrieve_id_ranges(mocked_session_context, mocked_engine):
session = mocked_session_context().__enter__()
q = session.query().order_by()
q.all.return_value = [(0,), (1,), ("1",), (2,), (3,),
(5,), (8,), (13,), (21,)]
id_ranges = retrieve_id_ranges(database="db_name", chunksize=3)
assert id_ranges == [(0, 2), # 0 <= x <= 2
(2, 8), # 2 <= x <= 8
(8, 21)] # 8 <= x <= 21


@mock.patch(PATH.format("get_mysql_engine"))
@mock.patch(PATH.format("db_session"))
@mock.patch(PATH.format("impute_base_id"))
def test_impute_base_id_thread(mocked_impute_base_id,
mocked_session_context, mocked_engine):
session = mocked_session_context().__enter__()
impute_base_id_thread(0, 2, 'db_name')
call_args_list = mocked_impute_base_id.call_args_list
assert len(call_args_list) == 1
assert call_args_list[0] == [(session, 0, 2)]
1 change: 1 addition & 0 deletions nesta/packages/nih/tests/test_preprocess_nih.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_get_json_cols():

def test_get_long_text_cols():
assert get_long_text_cols(Projects) == {'cfda_code', 'core_project_num',
'base_core_project_num',
'ed_inst_type', 'foa_number',
'full_project_num',
'funding_mechanism', 'ic_name',
Expand Down

0 comments on commit aee5ab9

Please sign in to comment.