Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit the maximum number of candidate pairs #605

Merged
merged 6 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions backend/entityservice/cache/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,45 @@ def get_total_number_of_comparisons(project_id):
r.expire(key, 60*60)
return total_comparisons

def save_current_progress(comparisons, run_id, config=None):

def save_current_progress(comparisons, candidate_pairs, run_id, config=None):
"""
Record progress for a run in the redis cache, adding the number of candidates identified
from the number of comparisons carried out.

This is safe to call from concurrent processes.

:param int comparisons: The number of pairwise comparisons that have been computed for this update.
:param int candidate_pairs: Number of candidates, edge count where the similarity was above
the configured threshold.
"""
if config is None:
config = globalconfig
logger.debug(f"Updating progress. Compared {comparisons} CLKS", run_id=run_id)
if comparisons > 0:
r = connect_to_redis()
key = _get_run_hash_key(run_id)
r.hincrby(key, 'progress', comparisons)
r.hincrby(key, 'comparisons', comparisons)
r.hincrby(key, 'candidates', candidate_pairs)
r.expire(key, config.CACHE_EXPIRY)


def get_progress(run_id):
def get_comparison_count_for_run(run_id):
r = connect_to_redis(read_only=True)
key = _get_run_hash_key(run_id)
res = r.hget(key, 'comparisons')
return _convert_redis_result_to_int(res)


def get_candidate_count_for_run(run_id):
r = connect_to_redis(read_only=True)
key = _get_run_hash_key(run_id)
res = r.hget(key, 'progress')
res = r.hget(key, 'candidates')
return _convert_redis_result_to_int(res)


def clear_progress(run_id):
r = connect_to_redis()
key = _get_run_hash_key(run_id)
r.hdel(key, 'progress')
r.hdel(key, 'comparisons')
r.hdel(key, 'candidates')
34 changes: 20 additions & 14 deletions backend/entityservice/integrationtests/redistests/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import redis

from entityservice.settings import Config as config
from entityservice.cache import connect_to_redis, clear_progress, save_current_progress, get_progress
from entityservice.cache import connect_to_redis, clear_progress, get_candidate_count_for_run, save_current_progress, \
get_comparison_count_for_run


class TestProgress:
Expand All @@ -19,35 +20,40 @@ def test_clear_missing_progress(self):
def test_clear_progress(self):
config.CACHE_EXPIRY = datetime.timedelta(seconds=1)
runid = 'runtest_clear_progress'
save_current_progress(1, runid, config)
assert 1 == get_progress(runid)
save_current_progress(1, 1, runid, config)
assert 1 == get_comparison_count_for_run(runid)
assert 1 == get_candidate_count_for_run(runid)
clear_progress(runid)
assert get_progress(runid) is None
assert get_comparison_count_for_run(runid) is None
assert get_candidate_count_for_run(runid) is None

def test_storing_wrong_type(self):
config.CACHE_EXPIRY = datetime.timedelta(seconds=1)
runid = 'test_storing_wrong_type'
with pytest.raises(redis.exceptions.ResponseError):
save_current_progress(1.5, runid, config)
save_current_progress(1.5, 1, runid, config)
with pytest.raises(redis.exceptions.ResponseError):
save_current_progress(1, 1.5, runid, config)

def test_progress_expires(self):
# Uses the minimum expiry of 1 second
config.CACHE_EXPIRY = datetime.timedelta(seconds=1)
runid = 'test_progress_expires'
save_current_progress(42, runid, config)
cached_progress = get_progress(runid)
save_current_progress(42, 7, runid, config)
cached_progress = get_comparison_count_for_run(runid)
assert cached_progress == 42
time.sleep(1)
# After expiry the progress should be reset to None
assert get_progress(runid) is None
assert get_comparison_count_for_run(runid) is None
assert get_candidate_count_for_run(runid) is None

def test_progress_increments(self):
config.CACHE_EXPIRY = datetime.timedelta(seconds=1)
config.CACHE_EXPIRY = datetime.timedelta(seconds=5)
runid = 'test_progress_increments'
save_current_progress(1, runid, config)
cached_progress = get_progress(runid)
assert cached_progress == 1
save_current_progress(2, 1, runid, config)
assert get_comparison_count_for_run(runid) == 2
for i in range(99):
save_current_progress(1, runid, config)
save_current_progress(2, 1, runid, config)

assert 100 == get_progress(runid)
assert 200 == get_comparison_count_for_run(runid)
assert 100 == get_candidate_count_for_run(runid)
6 changes: 6 additions & 0 deletions backend/entityservice/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ class Config(object):
# If there are more than 1M CLKS, don't cache them in redis
MAX_CACHE_SIZE = int(os.getenv('MAX_CACHE_SIZE', '1000000'))

# Global limits on maximum number of candidate pairs considered.
# If a run exceeds these limits, the run is put into an error state and further processing
# is abandoned to protect the service from running out of memory.
SOLVER_MAX_CANDIDATE_PAIRS = int(os.getenv('SOLVER_MAX_CANDIDATE_PAIRS', '100_000_000'))
SIMILARITY_SCORES_MAX_CANDIDATE_PAIRS = int(os.getenv('SIMILARITY_SCORES_MAX_CANDIDATE_PAIRS', '500_000_000'))

_CACHE_EXPIRY_SECONDS = int(os.getenv('CACHE_EXPIRY_SECONDS', datetime.timedelta(days=10).total_seconds()))
CACHE_EXPIRY = datetime.timedelta(seconds=_CACHE_EXPIRY_SECONDS)

Expand Down
20 changes: 14 additions & 6 deletions backend/entityservice/tasks/comparing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from entityservice.async_worker import celery, logger
from entityservice.cache.encodings import remove_from_cache
from entityservice.cache.progress import save_current_progress
from entityservice.cache.progress import get_candidate_count_for_run, save_current_progress
from entityservice.encoding_storage import get_encoding_chunk, get_encoding_chunks
from entityservice.errors import InactiveRun
from entityservice.database import (
Expand Down Expand Up @@ -285,7 +285,6 @@ def new_child_span(name, parent_scope=None):
log.debug("Checking that the resource exists (in case of run being canceled/deleted)")
assert_valid_run(project_id, run_id, log)

#chunk_info_dp1, chunk_info_dp2 = chunk_info
def reindex_using_encoding_ids(recordarray, encoding_id_list):
# Map results from "index in chunk" to encoding id.
return array.array('I', [encoding_id_list[i] for i in recordarray])
Expand All @@ -294,7 +293,6 @@ def reindex_using_encoding_ids(recordarray, encoding_id_list):
num_comparisons = 0
sim_results = []


with DBConn() as conn:
if len(package) > 1: # multiple full blocks in one package
with new_child_span(f'fetching-encodings of package of size {len(package)}'):
Expand Down Expand Up @@ -336,16 +334,26 @@ def reindex_using_encoding_ids(recordarray, encoding_id_list):
sim_results.append((sims, (rec_is0, rec_is1), chunk_dp1['datasetIndex'], chunk_dp2['datasetIndex']))
log.debug(f'comparison is done. {num_comparisons} comparisons got {num_results} pairs above the threshold')

##### progess reporting
# progress reporting
log.debug('Encoding similarities calculated')

with new_child_span('update-comparison-progress') as scope:
# Update the number of comparisons completed
save_current_progress(num_comparisons, run_id)
save_current_progress(num_comparisons, num_results, run_id)
scope.span.log_kv({'comparisons': num_comparisons, 'num_similar': num_results})
log.debug("Comparisons: {}, Links above threshold: {}".format(num_comparisons, num_results))

###### results into file into minio
with new_child_span('check-within-candidate-limits') as scope:
global_candidates_for_run = get_candidate_count_for_run(run_id)
scope.span.log_kv({'global candidate count for run': global_candidates_for_run})

if global_candidates_for_run is not None and global_candidates_for_run > Config.SIMILARITY_SCORES_MAX_CANDIDATE_PAIRS:
log.warning(f"This run has created more than the global limit of candidate pairs. Setting state to 'error'")
with DBConn() as conn:
update_run_mark_failure(conn, run_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the user will never know why the run failed?

return

# Save results file into minio
with new_child_span('save-comparison-results-to-minio'):

file_iters = []
Expand Down
6 changes: 6 additions & 0 deletions backend/entityservice/tasks/solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anonlink
from anonlink.candidate_generation import _merge_similarities

from entityservice.database import DBConn, update_run_mark_failure
from entityservice.object_store import connect_to_object_store
from entityservice.async_worker import celery, logger
from entityservice.settings import Config as config
Expand All @@ -25,6 +26,11 @@ def solver_task(similarity_scores_filename, project_id, run_id, dataset_sizes, p
# https://github.com/data61/anonlink/issues/271
candidate_pairs = _merge_similarities([zip(similarity_scores, dset_is0, dset_is1, rec_is0, rec_is1)], k=None)
log.info(f"Number of candidate pairs after deduplication: {len(candidate_pairs[0])}")
if len(candidate_pairs[0]) > config.SOLVER_MAX_CANDIDATE_PAIRS:
log.warning(f"Attempting to solve with more than the global limit of candidate pairs.")
with DBConn() as conn:
update_run_mark_failure(conn, run_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here. Has the run table the ability to store an error message, and could that be passed on to the user?

return

log.info("Calculating the optimal mapping from similarity matrix")
groups = anonlink.solving.greedy_solve(candidate_pairs)
Expand Down
2 changes: 1 addition & 1 deletion backend/entityservice/views/run/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get(project_id, run_id):
max_val = db.get_project_column(conn, project_id, 'parties')
elif stage == 2:
# Computing similarity
abs_val = progress_cache.get_progress(run_id)
abs_val = progress_cache.get_comparison_count_for_run(run_id)
if abs_val is not None:
max_val = progress_cache.get_total_number_of_comparisons(project_id)
logger.debug(f"total comparisons: {max_val}")
Expand Down