diff --git a/src/sentry/utils/concurrent.py b/src/sentry/utils/concurrent.py index cb42d28f6612d8..129e26c385ca8f 100644 --- a/src/sentry/utils/concurrent.py +++ b/src/sentry/utils/concurrent.py @@ -2,8 +2,9 @@ import functools import logging import threading -from concurrent.futures import Future +from concurrent.futures import Future, InvalidStateError from concurrent.futures._base import FINISHED, RUNNING +from contextlib import contextmanager from queue import Full, PriorityQueue from time import time @@ -90,16 +91,22 @@ def cancel(self, *args, **kwargs): self.__timing[1] = time() return super().cancel(*args, **kwargs) + @contextmanager + def __set_finished_time_on_success(self): + prev_value = self.__timing[1] + self.__timing[1] = time() + try: + yield + except InvalidStateError: + self.__timing[1] = prev_value + raise + def set_result(self, *args, **kwargs): - with self._condition: - _time = time() - result = super().set_result(*args, **kwargs) - self.__timing[1] = _time - return result + with self._condition, self.__set_finished_time_on_success(): + return super().set_result(*args, **kwargs) def set_exception(self, *args, **kwargs): - with self._condition: - self.__timing[1] = time() + with self._condition, self.__set_finished_time_on_success(): return super().set_exception(*args, **kwargs) diff --git a/tests/sentry/utils/test_concurrent.py b/tests/sentry/utils/test_concurrent.py index eb66d418d29f39..498c008da0d76d 100644 --- a/tests/sentry/utils/test_concurrent.py +++ b/tests/sentry/utils/test_concurrent.py @@ -98,13 +98,26 @@ def test_timed_future_success(): future = TimedFuture() assert future.get_timing() == (None, None) - with timestamp(1.0): + expected_result = mock.sentinel.RESULT_VALUE + start_time, finish_time = expected_timing = (1.0, 2.0) + + callback_results = [] + callback = lambda future: callback_results.append((future.result(), future.get_timing())) + + future.add_done_callback(callback) + + with timestamp(start_time): future.set_running_or_notify_cancel() - assert future.get_timing() == (1.0, None) + assert future.get_timing() == (start_time, None) - with timestamp(2.0): - future.set_result(None) - assert future.get_timing() == (1.0, 2.0) + assert len(callback_results) == 0 + + with timestamp(finish_time): + future.set_result(expected_result) + assert future.get_timing() == expected_timing + + assert len(callback_results) == 1 + assert callback_results[0] == (expected_result, expected_timing) def test_time_is_not_overwritten_if_fail_to_set_result(): @@ -130,13 +143,25 @@ def test_timed_future_error(): future = TimedFuture() assert future.get_timing() == (None, None) - with timestamp(1.0): + start_time, finish_time = expected_timing = (1.0, 2.0) + + callback_timings = [] + callback = lambda future: callback_timings.append(future.get_timing()) + + future.add_done_callback(callback) + + with timestamp(start_time): future.set_running_or_notify_cancel() - assert future.get_timing() == (1.0, None) + assert future.get_timing() == (start_time, None) - with timestamp(2.0): + assert len(callback_timings) == 0 + + with timestamp(finish_time): future.set_exception(None) - assert future.get_timing() == (1.0, 2.0) + assert future.get_timing() == expected_timing + + assert len(callback_timings) == 1 + assert callback_timings[0] == expected_timing def test_timed_future_cancel():