Skip to content

Commit

Permalink
Merge pull request #287 from ihmeuw/collijk/feature/add-simulation-name
Browse files Browse the repository at this point in the history
Add unique simulation names
  • Loading branch information
collijk committed Feb 27, 2023
2 parents 1bb26bf + 4098d72 commit 9a077fd
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 15 deletions.
65 changes: 57 additions & 8 deletions src/vivarium/framework/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
tools to easily setup and run a simulation.
"""
import time
from pathlib import Path
from pprint import pformat
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Set, Union

import numpy as np
import pandas as pd
from loguru import logger

from vivarium.config_tree import ConfigTree
from vivarium.exceptions import VivariumError
from vivarium.framework.configuration import build_model_specification

from .artifact import ArtifactInterface
Expand All @@ -46,13 +46,65 @@


class SimulationContext:

_created_simulation_contexts: Set[str] = set()

@staticmethod
def _get_context_name(sim_name: Union[str, None]) -> str:
"""Get a unique name for a simulation context.
Parameters
----------
sim_name
The name of the simulation context. If None, a unique name will be generated.
Returns
-------
str
A unique name for the simulation context.
Note
----
This method mutates process global state (the class attribute
``_created_simulation_contexts``) in order to keep track contexts that have been
generated. This functionality makes generating simulation contexts in parallel
a non-threadsafe operation.
"""
if sim_name is None:
sim_number = len(SimulationContext._created_simulation_contexts) + 1
sim_name = f"simulation_{sim_number}"

if sim_name in SimulationContext._created_simulation_contexts:
msg = (
"Attempting to create two SimulationContexts "
f"with the same name {sim_name}"
)
raise VivariumError(msg)

SimulationContext._created_simulation_contexts.add(sim_name)
return sim_name

@staticmethod
def _clear_context_cache():
"""Clear the cache of simulation context names.
This is primarily useful for testing purposes.
"""
SimulationContext._created_simulation_contexts = set()

def __init__(
self,
model_specification: Union[str, Path, ConfigTree] = None,
components: Union[List, Dict, ConfigTree] = None,
configuration: Union[Dict, ConfigTree] = None,
plugin_configuration: Union[Dict, ConfigTree] = None,
sim_name: str = None,
):

self._name = self._get_context_name(sim_name)

# Bootstrap phase: Parse arguments, make private managers
component_configuration = (
components if isinstance(components, (dict, ConfigTree)) else None
Expand Down Expand Up @@ -138,7 +190,7 @@ def __init__(

@property
def name(self):
return "simulation_context"
return self._name

def setup(self):
self._lifecycle.set_state("setup")
Expand Down Expand Up @@ -230,10 +282,7 @@ def get_population(self, untracked: bool = True):
return self._population.get_population(untracked)

def __repr__(self):
return "SimulationContext()"

def __str__(self):
return str(self._lifecycle)
return f"SimulationContext({self.name})"


class Builder:
Expand All @@ -256,7 +305,7 @@ class Builder:
population: PopulationInterface
Provides access to simulant state table via the
:ref:`population<population_concept>` system.
resource: ResourceInterface
resources: ResourceInterface
Provides access to the :ref:`resource<resource_concept>` system,
which manages dependencies between components.
time: TimeInterface
Expand Down
50 changes: 43 additions & 7 deletions tests/framework/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
ComponentManager,
OrderedComponentSet,
)
from vivarium.framework.engine import Builder, SimulationContext
from vivarium.framework.engine import Builder
from vivarium.framework.engine import SimulationContext as SimulationContext_
from vivarium.framework.event import EventInterface, EventManager
from vivarium.framework.lifecycle import LifeCycleInterface, LifeCycleManager
from vivarium.framework.lookup import LookupTableInterface, LookupTableManager
Expand All @@ -25,6 +26,12 @@ def is_same_object_method(m1, m2):
return m1.__func__ is m2.__func__ and m1.__self__ is m2.__self__


@pytest.fixture()
def SimulationContext():
yield SimulationContext_
SimulationContext_._clear_context_cache()


@pytest.fixture
def components():
return [
Expand All @@ -39,7 +46,16 @@ def log(mocker):
return mocker.patch("vivarium.framework.engine.logger")


def test_SimulationContext_init_default(components):
def test_SimulationContext_get_sim_name(SimulationContext):
assert SimulationContext._created_simulation_contexts == set()

assert SimulationContext._get_context_name(None) == "simulation_1"
assert SimulationContext._get_context_name("foo") == "foo"

assert SimulationContext._created_simulation_contexts == {"simulation_1", "foo"}


def test_SimulationContext_init_default(SimulationContext, components):
sim = SimulationContext(components=components)

assert isinstance(sim._lifecycle, LifeCycleManager)
Expand Down Expand Up @@ -103,7 +119,27 @@ def test_SimulationContext_init_default(components):
assert isinstance(list(sim._component_manager._components)[-1], Metrics)


def test_SimulationContext_setup_default(base_config, components):
def test_SimulationContext_name_management(SimulationContext):
assert SimulationContext._created_simulation_contexts == set()

sim1 = SimulationContext()
assert sim1._name == "simulation_1"
assert SimulationContext._created_simulation_contexts == {"simulation_1"}

sim2 = SimulationContext(sim_name="foo")
assert sim2._name == "foo"
assert SimulationContext._created_simulation_contexts == {"simulation_1", "foo"}

sim3 = SimulationContext()
assert sim3._name == "simulation_3"
assert SimulationContext._created_simulation_contexts == {
"simulation_1",
"foo",
"simulation_3",
}


def test_SimulationContext_setup_default(SimulationContext, base_config, components):
sim = SimulationContext(base_config, components)
listener = [c for c in components if "listener" in c.args][0]
assert not listener.post_setup_called
Expand Down Expand Up @@ -140,7 +176,7 @@ def test_SimulationContext_setup_default(base_config, components):
assert listener.post_setup_called


def test_SimulationContext_initialize_simulants(base_config, components):
def test_SimulationContext_initialize_simulants(SimulationContext, base_config, components):
sim = SimulationContext(base_config, components)
sim.setup()
pop_size = sim.configuration.population.population_size
Expand All @@ -151,7 +187,7 @@ def test_SimulationContext_initialize_simulants(base_config, components):
assert sim._clock.time == current_time


def test_SimulationContext_step(log, base_config, components):
def test_SimulationContext_step(SimulationContext, log, base_config, components):
sim = SimulationContext(base_config, components)
sim.setup()
sim.initialize_simulants()
Expand All @@ -177,7 +213,7 @@ def test_SimulationContext_step(log, base_config, components):
assert sim._clock.time == current_time + step_size


def test_SimulationContext_finalize(base_config, components):
def test_SimulationContext_finalize(SimulationContext, base_config, components):
sim = SimulationContext(base_config, components)
listener = [c for c in components if "listener" in c.args][0]
sim.setup()
Expand All @@ -188,7 +224,7 @@ def test_SimulationContext_finalize(base_config, components):
assert listener.simulation_end_called


def test_SimulationContext_report(base_config, components):
def test_SimulationContext_report(SimulationContext, base_config, components):
sim = SimulationContext(base_config, components)
sim.setup()
sim.initialize_simulants()
Expand Down

0 comments on commit 9a077fd

Please sign in to comment.