Skip to content

Commit

Permalink
Merge pull request #160 from pagreene/belief-dev
Browse files Browse the repository at this point in the history
Integrate Belief sorting at the REST API level.
  • Loading branch information
pagreene committed Jan 27, 2021
2 parents 285bf8c + 2f697f9 commit ca2fce9
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 16 deletions.
38 changes: 31 additions & 7 deletions rest_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,20 @@
from rest_api.util import sec_since, get_s3_client, gilda_ground, \
_make_english_from_meta, get_html_source_info


# =========================
# A lot of config and setup
# =========================

# Get a logger, and assert the logging level.
logger = logging.getLogger("db rest api")
logger.setLevel(logging.INFO)

set_log_service_name(f"db-rest-api-{DEPLOYMENT}")
# Set the name of this service for the usage logs.
set_log_service_name(f"db-rest-api-{DEPLOYMENT if DEPLOYMENT else 'stable'}")


# Define a custom flask class to handle the deployment name prefix.
class MyFlask(Flask):
def route(self, url, *args, **kwargs):
if DEPLOYMENT is not None:
Expand All @@ -44,45 +52,53 @@ def route(self, url, *args, **kwargs):
return flask_dec


# Propagate the deployment name to the static path and auth URLs.
static_url_path = None
if DEPLOYMENT is not None:
static_url_path = f'/{DEPLOYMENT}/static'
auth.url_prefix = f'/{DEPLOYMENT}'


# Initialize the flask application (with modified static path).
app = MyFlask(__name__, static_url_path=static_url_path)


# Register the auth application, and config it if we are not testing.
app.register_blueprint(auth)
app.config['DEBUG'] = True
if not TESTING['status']:
SC, jwt = config_auth(app)
else:
logger.warning("TESTING: No auth will be enabled.")

# Apply wrappers to the app that will compress responses and enable CORS.
Compress(app)
CORS(app)

# The directory path to this location (works in any file system).
HERE = path.abspath(path.dirname(__file__))

# Instantiate a jinja2 env.
env = Environment(loader=ChoiceLoader([app.jinja_loader, auth.jinja_loader,
indra_loader]))


# Overwrite url_for function in jinja to handle DEPLOYMENT prefix gracefully.
def url_for(*args, **kwargs):
"""Generate a url for a given endpoint, applying the DEPLOYMENT prefix."""
res = base_url_for(*args, **kwargs)
if DEPLOYMENT is not None:
if not res.startswith(f'/{DEPLOYMENT}'):
res = f'/{DEPLOYMENT}' + res
return res


# Here we can add functions to the jinja2 env.
env.globals.update(url_for=url_for)


# Define a useful helper function.
def render_my_template(template, title, **kwargs):
"""Render a Jinja2 template wrapping in identity and other details."""
kwargs['title'] = TITLE + ': ' + title
if not TESTING['status']:
kwargs['identity'] = get_jwt_identity()
Expand All @@ -102,10 +118,15 @@ def render_my_template(template, title, **kwargs):


@app.route('/', methods=['GET'])
def iamalive():
def root():
return redirect(url_for('search'), code=302)


@app.route('/healthcheck', methods=['GET'])
def i_am_alive():
return jsonify({'status': 'testing' if TESTING['status'] else 'healthy'})


@app.route('/ground', methods=['GET'])
def ground():
ag = request.args['agent']
Expand Down Expand Up @@ -286,22 +307,25 @@ def expand_meta_row():
has_medscan = api_key is not None
logger.info(f'Auths for medscan: {has_medscan}')

# Get the sorting parameter.
sort_by = request.args.get('sort_by', 'ev_count')

# Get the more detailed results.
q = AgentJsonExpander(agent_json, stmt_type=stmt_type, hashes=hashes)
result = q.expand(sort_by='ev_count')
result = q.expand(sort_by=sort_by)

# Filter out any medscan content, and construct english.
entry_hash_lookup = defaultdict(list)
for key, entry in result.results.copy().items():
# Filter medscan...
if not has_medscan:
result.evidence_totals[key] -= \
result.evidence_counts[key] -= \
entry['source_counts'].pop('medscan', 0)
entry['total_count'] = result.evidence_totals[key]
entry['total_count'] = result.evidence_counts[key]
if not entry['source_counts']:
logger.warning("Censored content present. Removing it.")
result.results.pop(key)
result.evidence_totals.pop(key)
result.evidence_counts.pop(key)
continue

# Add english...
Expand Down
31 changes: 24 additions & 7 deletions rest_api/call_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,31 @@ def __init__(self, env):
self._env = env

self.web_query = request.args.copy()

# Get the offset and limit
self.offs = self._pop('offset', type_cast=int)
self.best_first = self._pop('best_first', True, bool)
if 'limit' in self.web_query:
self.limit = min(self._pop('limit', MAX_STMTS, int), MAX_STMTS)
else:
self.limit = min(self._pop('max_stmts', MAX_STMTS, int), MAX_STMTS)

# Sort out the sorting.
sort_by = self._pop('sort_by', None)
best_first = self._pop('best_first', True, bool)
if sort_by is not None:
self.sort_by = sort_by
elif best_first:
self.sort_by = 'ev_count'
else:
self.sort_by = None

# Gather other miscillaneous options
self.fmt = self._pop('format', 'json')
self.w_english = self._pop('with_english', False, bool)
self.w_cur_counts = self._pop('with_cur_counts', False, bool)
self.strict = self._pop('strict', False, bool)

# Prime agent recorders.
self.agent_dict = None
self.agent_set = None

Expand Down Expand Up @@ -90,7 +105,7 @@ def run(self, result_type):

# Actually run the function
params = dict(offset=self.offs, limit=self.limit,
best_first=self.best_first)
sort_by=self.sort_by)
logger.info(f"Sending query with params: {params}")
if result_type == 'statements':
self.special['ev_limit'] = \
Expand Down Expand Up @@ -231,9 +246,9 @@ def process_entries(self, result):
if result.result_type == 'statements':
result.source_counts[key].pop('medscan', 0)
else:
result.evidence_totals[key] -= \
result.evidence_counts[key] -= \
entry['source_counts'].pop('medscan', 0)
entry['total_count'] = result.evidence_totals[key]
entry['total_count'] = result.evidence_counts[key]
if not entry['source_counts']:
logger.warning("Censored content present.")

Expand Down Expand Up @@ -326,13 +341,15 @@ def produce_response(self, result):
stmts_json = result.results
if self.fmt == 'html':
title = TITLE
ev_totals = res_json.pop('evidence_totals')
ev_counts = res_json.pop('evidence_counts')
beliefs = res_json.pop('belief_scores')
stmts = stmts_from_json(stmts_json.values())
db_rest_url = request.url_root[:-1] \
+ self._env.globals['url_for']('iamalive')[:-1]
+ self._env.globals['url_for']('root')[:-1]
html_assembler = \
HtmlAssembler(stmts, summary_metadata=res_json,
ev_counts=ev_totals, title=title,
ev_counts=ev_counts, beliefs=beliefs,
sort_by=self.sort_by, title=title,
source_counts=result.source_counts,
db_rest_url=db_rest_url)
idbr_template = \
Expand Down
32 changes: 30 additions & 2 deletions rest_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ def __check_stmts(self, json_stmts, check_support=False, check_stmts=False):
if s1.matches(s2)]))
return

def test_health_check(self):
"""Test that the health check works."""
resp = self.app.get('healthcheck')
assert resp.status_code == 200, resp.status_code
assert resp.json == {'status': 'testing'}, resp.json

def test_blank_response(self):
"""Test the response to an empty request."""
resp, dt, size = self.__time_get_query('statements/from_agents', '')
Expand Down Expand Up @@ -243,6 +249,28 @@ def test_query_with_other(self):
(0, 'HGNC', hgnc_client.get_hgnc_id('MAPK1'))])
return

def test_belief_sort_in_agent_search(self):
"""Test sorting by belief."""
resp = self.__check_good_statement_query(agent='MAPK1',
sort_by='belief')
assert len(resp.json['belief_scores']) \
== len(resp.json['evidence_counts'])
beliefs = list(resp.json['belief_scores'].values())
assert all(b1 >= b2 for b1, b2 in zip(beliefs[:-1], beliefs[1:]))
ev_counts = list(resp.json['evidence_counts'].values())
assert not all(c1 >= c2 for c1, c2 in zip(ev_counts[:-1], ev_counts[1:]))

def test_explicit_ev_count_sort_agent_search(self):
"""Test sorting by ev_count explicitly."""
resp = self.__check_good_statement_query(agent='MAPK1',
sort_by='ev_count')
assert len(resp.json['belief_scores']) \
== len(resp.json['evidence_counts'])
beliefs = list(resp.json['belief_scores'].values())
assert not all(b1 >= b2 for b1, b2 in zip(beliefs[:-1], beliefs[1:]))
ev_counts = list(resp.json['evidence_counts'].values())
assert all(c1 >= c2 for c1, c2 in zip(ev_counts[:-1], ev_counts[1:]))

def test_bad_camel(self):
"""Test that a type can be poorly formatted and resolve correctly."""
resp = self.__check_good_statement_query(agent='MAPK1',
Expand Down Expand Up @@ -277,7 +305,7 @@ def test_offset(self):
time_goal=20)
j1 = json.loads(resp1.data)
hashes1 = set(j1['statements'].keys())
ev_counts1 = j1['evidence_totals']
ev_counts1 = j1['evidence_counts']
resp2 = self.__check_good_statement_query(agent='NFkappaB@FPLX',
offset=MAX_STMTS,
check_stmts=False,
Expand All @@ -286,7 +314,7 @@ def test_offset(self):
hashes2 = set(j2['statements'].keys())
assert not hashes2 & hashes1

ev_counts2 = j2['evidence_totals']
ev_counts2 = j2['evidence_counts']
assert max(ev_counts2.values()) <= min(ev_counts1.values())

return
Expand Down

0 comments on commit ca2fce9

Please sign in to comment.