Skip to content

Commit

Permalink
Merge pull request #230 from dianakolusheva/load_tests
Browse files Browse the repository at this point in the history
Fix test loading
  • Loading branch information
bgyori committed Jun 16, 2021
2 parents 7679ba5 + 5475435 commit 7bd2fc3
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
export AWS_DEFAULT_REGION='us-east-1'
export NOSEATTR="!notravis"
export NOSEATTR=$(if [ "$GITHUB_EVENT_NAME" == "pull_request" ]; then echo $NOSEATTR,!nonpublic; else echo $NOSEATTR; fi)
export PYTHONPATH=$PYTHONPATH:`pwd`/covid-19:`pwd`/kappy:`pwd`/indra_world:`pwd`/automates/scripts
export PYTHONPATH=$PYTHONPATH:`pwd`/covid-19:`pwd`/kappy:`pwd`/indra_world:`pwd`/automates/scripts/gromet
export BNGPATH=`pwd`/BioNetGen-2.4.0
export INDRA_WM_CACHE="."
nosetests -v -a $NOSEATTR emmaa/tests/test_s3.py
Expand Down
2 changes: 1 addition & 1 deletion emmaa/analyze_tests_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_assembled_stmts_by_paper(self, id_type='TRID'):
logger.info('Mapping papers to statements')
stmts_by_papers = {}
for stmt in self.statements:
stmt_hash = stmt.get_hash()
stmt_hash = stmt.get_hash(refresh=True)
for evid in stmt.evidence:
paper_id = None
if id_type == 'pii':
Expand Down
2 changes: 1 addition & 1 deletion emmaa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def pysb_to_gromet(pysb_model, model_name, fname=None):
g : automates.script.gromet.gromet.Gromet
A GroMEt object built from PySB model.
"""
from gromet.gromet import Gromet, gromet_to_json, \
from gromet import Gromet, gromet_to_json, \
Junction, Wire, UidJunction, UidType, UidWire, Relation, \
UidBox, UidGromet, Literal, Val
from pysb import Parameter
Expand Down
2 changes: 1 addition & 1 deletion emmaa/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_papers():


def test_pysb_to_gromet():
from gromet.gromet import Gromet
from gromet import Gromet
emmaa_model = create_model()
pysb_model = emmaa_model.assemble_pysb()
gromet = pysb_to_gromet(pysb_model, 'test_model', 'gromet_test.json')
Expand Down
37 changes: 20 additions & 17 deletions emmaa_service/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,11 @@ def _get_all_tests(bucket=EMMAA_BUCKET_NAME):

def _load_tests_from_cache(test_corpus):
tests, file_key = tests_cache.get(test_corpus, (None, None))
latest_on_s3 = find_latest_s3_file(
EMMAA_BUCKET_NAME, f'tests/{test_corpus}', '.pkl')
try:
latest_on_s3 = find_latest_s3_file(
EMMAA_BUCKET_NAME, f'tests/{test_corpus}', '.pkl')
except ValueError:
latest_on_s3 = f'tests/{test_corpus}.pkl'
if file_key != latest_on_s3:
tests, file_key = load_tests_from_s3(test_corpus, EMMAA_BUCKET_NAME)
if isinstance(tests, dict):
Expand Down Expand Up @@ -777,7 +780,7 @@ def _count_curations(curations, stmts_by_hash):
def _get_stmt_row(stmt, source, model, cur_counts, date, test_corpus=None,
path_counts=None, cur_dict=None, with_evid=False,
paper_id=None, paper_id_type=None):
stmt_hash = str(stmt.get_hash())
stmt_hash = str(stmt.get_hash(refresh=True))
english = _format_stmt_text(stmt)
evid_count = len(stmt.evidence)
evid = []
Expand All @@ -800,7 +803,7 @@ def _get_stmt_row(stmt, source, model, cur_counts, date, test_corpus=None,
badges = _make_badges(evid_count, json_link, path_count,
cur_counts.get(stmt_hash))
stmt_row = [
(stmt.get_hash(), english, evid, evid_count, badges)]
(stmt.get_hash(refresh=True), english, evid, evid_count, badges)]
return stmt_row


Expand Down Expand Up @@ -1212,7 +1215,7 @@ def annotate_paper_statements(model):
model_stats = _load_model_stats_from_cache(model, date)
paper_hashes = model_stats['paper_summary']['stmts_by_paper'][trid]
paper_stmts = [stmt for stmt in all_stmts
if stmt.get_hash() in paper_hashes]
if stmt.get_hash(refresh=True) in paper_hashes]
for stmt in paper_stmts:
stmt.evidence = [ev for ev in stmt.evidence
if str(ev.text_refs.get('TRID')) == trid]
Expand Down Expand Up @@ -1269,7 +1272,7 @@ def get_paper_statements(model):
if paper_hashes:
all_stmts = _load_stmts_from_cache(model, date)
paper_stmts = [stmt for stmt in all_stmts
if stmt.get_hash() in paper_hashes]
if stmt.get_hash(refresh=True) in paper_hashes]
updated_stmts = [filter_evidence(stmt, trid, 'TRID')
for stmt in paper_stmts]
updated_stmts = sorted(updated_stmts, key=lambda x: len(x.evidence),
Expand All @@ -1282,7 +1285,7 @@ def get_paper_statements(model):
stmt_rows = []
stmts_by_hash = {}
for stmt in updated_stmts:
stmts_by_hash[str(stmt.get_hash())] = stmt
stmts_by_hash[str(stmt.get_hash(refresh=True))] = stmt
curations = get_curations(pa_hash=paper_hashes)
cur_dict = defaultdict(list)
for cur in curations:
Expand Down Expand Up @@ -1465,13 +1468,13 @@ def get_statement_evidence_page():
all_stmts = _load_stmts_from_cache(model, date)
for stmt in all_stmts:
for stmt_hash in stmt_hashes:
if str(stmt.get_hash()) == str(stmt_hash):
if str(stmt.get_hash(refresh=True)) == str(stmt_hash):
stmts.append(stmt)
elif source == 'paper':
all_stmts = _load_stmts_from_cache(model, date)
for stmt in all_stmts:
for stmt_hash in stmt_hashes:
if str(stmt.get_hash()) == str(stmt_hash):
if str(stmt.get_hash(refresh=True)) == str(stmt_hash):
stmts.append(filter_evidence(stmt, paper_id, paper_id_type))
elif source == 'test':
if not test_corpus:
Expand All @@ -1480,7 +1483,7 @@ def get_statement_evidence_page():
stmt_counts_dict = None
for t in tests:
for stmt_hash in stmt_hashes:
if str(t.stmt.get_hash()) == str(stmt_hash):
if str(t.stmt.get_hash(refresh=True)) == str(stmt_hash):
stmts.append(t.stmt)
else:
abort(Response(f'Source should be model_statement or test', 404))
Expand All @@ -1489,7 +1492,7 @@ def get_statement_evidence_page():
stmt_rows = []
stmts_by_hash = {}
for stmt in stmts:
stmts_by_hash[str(stmt.get_hash())] = stmt
stmts_by_hash[str(stmt.get_hash(refresh=True))] = stmt
curations = get_curations(pa_hash=stmt_hashes)
cur_dict = defaultdict(list)
for cur in curations:
Expand Down Expand Up @@ -1541,13 +1544,13 @@ def get_all_statements_page(model):
stmts = _load_stmts_from_cache(model, date)
stmts_by_hash = {}
for stmt in stmts:
stmts_by_hash[str(stmt.get_hash())] = stmt
stmts_by_hash[str(stmt.get_hash(refresh=True))] = stmt
msg = None
curations = get_curations()
cur_counts = _count_curations(curations, stmts_by_hash)
if filter_curated:
stmts = [stmt for stmt in stmts if str(stmt.get_hash()) not in
cur_counts]
stmts = [stmt for stmt in stmts if str(stmt.get_hash(refresh=True))
not in cur_counts]
# Add up paths per statement count across test corpora
stmt_counts_dict = Counter()
test_corpora = _get_test_corpora(model)
Expand Down Expand Up @@ -1927,7 +1930,7 @@ def get_statement_by_hash_model(model, date, hash_val):
cur_dict[(cur['pa_hash'], cur['source_hash'])].append(
{'error_type': cur['tag']})
for st in stmts:
if str(st.get_hash()) == str(hash_val):
if str(st.get_hash(refresh=True)) == str(hash_val):
st_json = st.to_json()
ev_list = _format_evidence_text(
st, cur_dict, ['correct', 'act_vs_amt', 'hypothesis'])
Expand All @@ -1946,7 +1949,7 @@ def get_tests_by_hash(test_corpus, hash_val):
{'error_type': cur['tag']})
st_json = {}
for test in tests:
if str(test.stmt.get_hash()) == str(hash_val):
if str(test.stmt.get_hash(refresh=True)) == str(hash_val):
st_json = test.stmt.to_json()
ev_list = _format_evidence_text(
test.stmt, cur_dict, ['correct', 'act_vs_amt', 'hypothesis'])
Expand All @@ -1966,7 +1969,7 @@ def get_statement_by_paper(model, paper_id, paper_id_type, date, hash_val):
cur_dict[(cur['pa_hash'], cur['source_hash'])].append(
{'error_type': cur['tag']})
for st in stmts:
if str(st.get_hash()) == str(hash_val):
if str(st.get_hash(refresh=True)) == str(hash_val):
stmt = filter_evidence(st, paper_id, paper_id_type)
st_json = stmt.to_json()
ev_list = _format_evidence_text(
Expand Down

0 comments on commit 7bd2fc3

Please sign in to comment.