-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #36 from ihmeuw/feature/simple_disease_example
Feature/simple disease example
- Loading branch information
Showing
9 changed files
with
420 additions
and
6 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import pandas as pd | ||
|
||
from vivarium.framework.state_machine import State, Machine, Transition | ||
from vivarium.framework.util import rate_to_probability | ||
from vivarium.framework.values import list_combiner, joint_value_post_processor | ||
|
||
|
||
class DiseaseTransition(Transition): | ||
|
||
def __init__(self, name, rate, input_state, output_state, **kwargs): | ||
super().__init__(input_state, output_state, probability_func=self._probability, **kwargs) | ||
self.name = name | ||
self.base_rate = lambda index: pd.Series(rate, index=index) | ||
|
||
def setup(self, builder): | ||
self.transition_rate = builder.value.register_rate_producer(f'{self.name}_rate', | ||
source=self._risk_deleted_rate) | ||
self.joint_population_attributable_fraction = builder.value.register_value_producer( | ||
f'{self.name}_rate.population_attributable_fraction', | ||
source=lambda index: [pd.Series(0, index=index)], | ||
preferred_combiner=list_combiner, | ||
preferred_post_processor=joint_value_post_processor) | ||
|
||
def _probability(self, index): | ||
effective_rate = self.transition_rate(index) | ||
return rate_to_probability(effective_rate) | ||
|
||
def _risk_deleted_rate(self, index): | ||
return self.base_rate(index) * (1 - self.joint_population_attributable_fraction(index)) | ||
|
||
|
||
class DiseaseState(State): | ||
|
||
def __init__(self, state_name, excess_mortality_rate=0, **kwargs): | ||
super().__init__(state_name, **kwargs) | ||
self._excess_mortality_rate = excess_mortality_rate | ||
|
||
def setup(self, builder): | ||
"""Performs this component's simulation setup. | ||
Parameters | ||
---------- | ||
builder : `engine.Builder` | ||
Interface to several simulation tools. | ||
""" | ||
super().setup(builder) | ||
self.clock = builder.time.clock() | ||
|
||
self.excess_mortality_rate = builder.value.register_rate_producer( | ||
f'{self.state_id}.excess_mortality_rate', | ||
source=self.risk_deleted_excess_mortality_rate | ||
) | ||
self.excess_mortality_rate_paf = builder.value.register_value_producer( | ||
f'{self.state_id}.excess_mortality_rate.population_attributable_fraction', | ||
source=lambda index: [pd.Series(0, index=index)], | ||
preferred_combiner=list_combiner, | ||
preferred_post_processor=joint_value_post_processor | ||
) | ||
|
||
builder.value.register_value_modifier('mortality_rate', self.add_in_excess_mortality) | ||
self.population_view = builder.population.get_view( | ||
[self._model], query=f"alive == 'alive' and {self._model} == '{self.state_id}'") | ||
|
||
def add_transition(self, transition_name, output, rate=1e6, **kwargs): | ||
t = DiseaseTransition(transition_name, rate, self, output, **kwargs) | ||
self.transition_set.append(t) | ||
return t | ||
|
||
def risk_deleted_excess_mortality_rate(self, index): | ||
return pd.Series(self._excess_mortality_rate, index=index) * (1 - self.excess_mortality_rate_paf(index)) | ||
|
||
def add_in_excess_mortality(self, index, mortality_rates): | ||
affected = self.population_view.get(index) | ||
mortality_rates.loc[affected.index] += self.excess_mortality_rate(affected.index) | ||
|
||
return mortality_rates | ||
|
||
|
||
class DiseaseModel(Machine): | ||
|
||
def __init__(self, disease, initial_state, cause_specific_mortality_rate=0., **kwargs): | ||
super().__init__(disease, **kwargs) | ||
self.initial_state = initial_state.state_id | ||
self._cause_specific_mortality_rate = cause_specific_mortality_rate | ||
|
||
def setup(self, builder): | ||
super().setup(builder) | ||
self.cause_specific_mortality_rate = builder.value.register_rate_producer( | ||
f'{self.state_column}.cause_specific_mortality_rate', | ||
source=lambda index: pd.Series(self._cause_specific_mortality_rate, index=index) | ||
) | ||
builder.value.register_value_modifier('mortality_rate', modifier=self.delete_cause_specific_mortality) | ||
builder.value.register_value_modifier('metrics', modifier=self.metrics) | ||
|
||
creates_columns = [self.state_column] | ||
builder.population.initializes_simulants(self.on_initialize_simulants, creates_columns=creates_columns) | ||
self.population_view = builder.population.get_view(['age', 'sex', self.state_column]) | ||
|
||
builder.event.register_listener('time_step', self.on_time_step) | ||
|
||
def on_initialize_simulants(self, pop_data): | ||
condition_column = pd.Series(self.initial_state, index=pop_data.index, name=self.state_column) | ||
self.population_view.update(condition_column) | ||
|
||
def on_time_step(self, event): | ||
self.transition(event.index, event.time) | ||
|
||
def delete_cause_specific_mortality(self, index, rates): | ||
return rates - self.cause_specific_mortality_rate(index) | ||
|
||
def metrics(self, index, metrics): | ||
pop = self.population_view.get(index, query="alive == 'alive'") | ||
metrics[self.state_column + '_prevalent_cases'] = len(pop[pop[self.state_column] != self.initial_state]) | ||
return metrics | ||
|
||
|
||
class SIS_DiseaseModel: | ||
|
||
configuration_defaults = { | ||
'disease': { | ||
'incidence': 0.005, | ||
'remission': 0.05, | ||
'excess_mortality': 0.01, | ||
} | ||
} | ||
|
||
def __init__(self, disease_name): | ||
self.name = disease_name | ||
self.configuration_defaults = {disease_name: SIS_DiseaseModel.configuration_defaults['disease']} | ||
|
||
def setup(self, builder): | ||
config = builder.configuration[self.name] | ||
|
||
susceptible_state = DiseaseState(f'susceptible_to_{self.name}') | ||
infected_state = DiseaseState(f'infected_with_{self.name}', | ||
excess_mortality_rate=config.excess_mortality) | ||
|
||
susceptible_state.allow_self_transitions() | ||
susceptible_state.add_transition(f'{infected_state.state_id}.incidence', infected_state, rate=config.incidence) | ||
infected_state.allow_self_transitions() | ||
infected_state.add_transition(f'{infected_state.state_id}.remission', susceptible_state, rate=config.remission) | ||
|
||
# Reasonable approximation for short duration diseases. | ||
case_fatality_rate = config.excess_mortality / (config.excess_mortality + config.remission) | ||
cause_specific_mortality = config.incidence * case_fatality_rate | ||
|
||
model = DiseaseModel(self.name, | ||
initial_state=susceptible_state, | ||
cause_specific_mortality_rate=cause_specific_mortality, | ||
states=[susceptible_state, infected_state]) | ||
builder.components.add_components([model]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
components: | ||
vivarium.examples.disease_model: | ||
population: | ||
- BasePopulation() | ||
- Mortality() | ||
- Observer() | ||
disease: | ||
- SIS_DiseaseModel('diarrhea') | ||
risk: | ||
- Risk('child_growth_failure') | ||
- DirectEffect('child_growth_failure', 'infected_with_diarrhea.incidence_rate') | ||
- DirectEffect('child_growth_failure', 'infected_with_diarrhea.excess_mortality_rate') | ||
intervention: | ||
- MagicWandIntervention('breastfeeding_promotion', 'child_growth_failure.proportion_exposed') | ||
|
||
configuration: | ||
randomness: | ||
key_columns: ['entrance_time', 'age'] | ||
time: | ||
start: | ||
year: 2005 | ||
month: 7 | ||
day: 1 | ||
end: | ||
year: 2006 | ||
month: 7 | ||
day: 1 | ||
step_size: 3 # Days | ||
population: | ||
population_size: 10_000 | ||
age_start: 0 | ||
age_end: 30 | ||
mortality: | ||
mortality_rate: 0.05 | ||
life_expectancy: 80 | ||
diarrhea: | ||
incidence: 2.5 # Approximately 2.5 cases per person per year. | ||
remission: 42 # Approximately 6 day median recovery time | ||
excess_mortality: 12 # Approximately 22 % of cases result in death | ||
child_growth_failure: | ||
proportion_exposed: 0.5 | ||
effect_of_child_growth_failure_on_infected_with_diarrhea.incidence_rate: | ||
relative_risk: 5 | ||
effect_of_child_growth_failure_on_infected_with_diarrhea.excess_mortality_rate: | ||
relative_risk: 5 | ||
breastfeeding_promotion: | ||
effect_size: 0.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pandas as pd | ||
|
||
class MagicWandIntervention: | ||
|
||
configuration_defaults = { | ||
'intervention': { | ||
'effect_size': 0.5, | ||
} | ||
} | ||
|
||
def __init__(self, name, affected_value): | ||
self.name = name | ||
self.affected_value = affected_value | ||
self.configuration_defaults = {name: MagicWandIntervention.configuration_defaults['intervention']} | ||
|
||
def setup(self, builder): | ||
effect_size = builder.configuration[self.name].effect_size | ||
builder.value.register_value_modifier(self.affected_value, modifier=self.intervention_effect) | ||
self.effect_size = builder.value.register_value_producer( | ||
f'{self.name}.effect_size', source=lambda index: pd.Series(effect_size, index=index) | ||
) | ||
|
||
def intervention_effect(self, index, value): | ||
return value * (1 - self.effect_size(index)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
class BasePopulation: | ||
configuration_defaults = { | ||
'population': { | ||
'age_start': 0, | ||
'age_end': 100, | ||
}, | ||
} | ||
|
||
def setup(self, builder): | ||
self.config = builder.configuration | ||
|
||
self.with_common_random_numbers = bool(self.config.randomness.key_columns) | ||
self.randomness = builder.randomness.get_stream('age_initialization', | ||
for_initialization=self.with_common_random_numbers) | ||
self.register = builder.randomness.register_simulants | ||
|
||
columns_created = ['age', 'sex', 'alive', 'entrance_time'] | ||
builder.population.initializes_simulants(self.generate_population, creates_columns=columns_created) | ||
self.population_view = builder.population.get_view(columns_created) | ||
builder.event.register_listener('time_step', self.age_simulants) | ||
|
||
def generate_population(self, pop_data): | ||
age_start = pop_data.user_data.get('age_start', self.config.population.age_start) | ||
age_end = pop_data.user_data.get('age_end', self.config.population.age_end) | ||
|
||
age_draw = self.randomness.get_draw(pop_data.index) | ||
age_window = pop_data.creation_window / pd.Timedelta(days=365) if age_start == age_end else age_end - age_start | ||
age = age_start + age_draw * age_window | ||
|
||
if self.with_common_random_numbers: | ||
population = pd.DataFrame({'entrance_time': pop_data.creation_time, | ||
'age': age.values}, index=pop_data.index) | ||
self.register(population) | ||
population['sex'] = self.randomness.choice(pop_data.index, ['Male', 'Female'], additional_key='sex_choice') | ||
population['alive'] = 'alive' | ||
else: | ||
population = pd.DataFrame( | ||
{'age': age.values, | ||
'sex': self.randomness.choice(pop_data.index, ['Male', 'Female'], additional_key='sex_choice'), | ||
'alive': pd.Series('alive', index=pop_data.index), | ||
'entrance_time': pop_data.creation_time}, | ||
index=pop_data.index) | ||
|
||
self.population_view.update(population) | ||
|
||
def age_simulants(self, event): | ||
population = self.population_view.get(event.index, query="alive == 'alive'") | ||
population['age'] += event.step_size / pd.Timedelta(days=365) | ||
self.population_view.update(population) | ||
|
||
|
||
class Mortality: | ||
|
||
configuration_defaults = { | ||
'mortality': { | ||
'mortality_rate': 0.01, | ||
'life_expectancy': 80, | ||
} | ||
} | ||
|
||
def setup(self, builder): | ||
self.config = builder.configuration.mortality | ||
self.population_view = builder.population.get_view(['alive'], query="alive == 'alive'") | ||
self.randomness = builder.randomness.get_stream('mortality') | ||
|
||
self.mortality_rate = builder.value.register_rate_producer('mortality_rate', source=self.base_mortality_rate) | ||
|
||
builder.event.register_listener('time_step', self.determine_deaths) | ||
|
||
def base_mortality_rate(self, index): | ||
return pd.Series(self.config.mortality_rate, index=index) | ||
|
||
def determine_deaths(self, event): | ||
effective_rate = self.mortality_rate(event.index) | ||
effective_probability = 1 - np.exp(-effective_rate) | ||
draw = self.randomness.get_draw(event.index) | ||
affected_simulants = draw < effective_probability | ||
self.population_view.update(pd.Series('dead', index=event.index[affected_simulants])) | ||
|
||
|
||
class Observer: | ||
|
||
def setup(self, builder): | ||
self.life_expectancy = builder.configuration.mortality.life_expectancy | ||
self.population_view = builder.population.get_view(['age', 'alive']) | ||
|
||
builder.value.register_value_modifier('metrics', self.metrics) | ||
|
||
def metrics(self, index, metrics): | ||
|
||
pop = self.population_view.get(index) | ||
metrics['total_population_alive'] = len(pop[pop.alive == 'alive']) | ||
metrics['total_population_dead'] = len(pop[pop.alive == 'dead']) | ||
|
||
metrics['years_of_life_lost'] = (self.life_expectancy - pop.age[pop.alive == 'dead']).sum() | ||
|
||
return metrics |
Oops, something went wrong.