Skip to content

Commit

Permalink
Improve spiral and cycle tests
Browse files Browse the repository at this point in the history
  • Loading branch information
guillett committed Feb 6, 2024
1 parent f1dc190 commit a20efed
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
27 changes: 11 additions & 16 deletions openfisca_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,25 +409,20 @@ def _check_for_cycle(self, variable: str, period):
raise errors.SpiralError(message, variable)

def invalidate_cache_entry(self, variable: str, period):
print((variable, period))
self.invalidated_caches.add(Cache(variable, period))

def invalidate_spiral_variables(self, variable: str):
print((variable))
print(self.tracer.stack)
initial_call_found = False
invalidate_entries = False
# 1. find the initial variable call
# 2. find the next variable call
# 3. invalidate all frame items from there
for frame in self.tracer.stack:
if initial_call_found:
if not invalidate_entries and frame["name"] == variable:
invalidate_entries = True
if invalidate_entries:
self.invalidate_cache_entry(str(frame["name"]), frame["period"])
elif frame["name"] == variable:
initial_call_found = True
# Visit the stack, from the bottom (most recent) up; we know that we'll find
# the variable implicated in the spiral (max_spiral_loops+1) times; we keep the
# intermediate values computed (to avoid impacting performance) but we mark them
# for deletion from the cache once the calculation ends.
count = 0
for frame in reversed(self.tracer.stack):
self.invalidate_cache_entry(str(frame["name"]), frame["period"])
if frame["name"] == variable:
count += 1
if count > self.max_spiral_loops:
break

# ----- Methods to access stored values ----- #

Expand Down
6 changes: 2 additions & 4 deletions tests/core/test_cycles.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,12 @@ def test_spirals_result_in_default_value(simulation, reference_period):

def test_spiral_heuristic(simulation, reference_period):
variable5 = simulation.calculate("variable5", period=reference_period)
tools.assert_near(variable5, [11])

variable6 = simulation.calculate("variable6", period=reference_period)
tools.assert_near(variable6, [11])

variable6_last_month = simulation.calculate(
"variable6", reference_period.last_month
)
tools.assert_near(variable5, [11])
tools.assert_near(variable6, [11])
tools.assert_near(variable6_last_month, [11])


Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def test_spiral_error(tracer):
tracer.record_calculation_start("a", period(2016))

with raises(SpiralError):
simulation._check_for_cycle("a", 2016)
simulation._check_for_cycle("a", period(2016))

assert len(simulation.invalidated_cache_items) == 3
assert len(tracer.stack) == 5
assert len(tracer.stack) == 3


def test_full_tracer_one_calculation(tracer):
Expand Down

0 comments on commit a20efed

Please sign in to comment.