Skip to content

Commit

Permalink
Merge pull request #36 from ihmeuw/feature/simple_disease_example
Browse files Browse the repository at this point in the history
Feature/simple disease example
  • Loading branch information
collijk committed Jul 23, 2018
2 parents d7d324f + 04043b6 commit 4c879eb
Show file tree
Hide file tree
Showing 9 changed files with 420 additions and 6 deletions.
Empty file added vivarium/examples/__init__.py
Empty file.
Empty file.
151 changes: 151 additions & 0 deletions vivarium/examples/disease_model/disease.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import pandas as pd

from vivarium.framework.state_machine import State, Machine, Transition
from vivarium.framework.util import rate_to_probability
from vivarium.framework.values import list_combiner, joint_value_post_processor


class DiseaseTransition(Transition):

def __init__(self, name, rate, input_state, output_state, **kwargs):
super().__init__(input_state, output_state, probability_func=self._probability, **kwargs)
self.name = name
self.base_rate = lambda index: pd.Series(rate, index=index)

def setup(self, builder):
self.transition_rate = builder.value.register_rate_producer(f'{self.name}_rate',
source=self._risk_deleted_rate)
self.joint_population_attributable_fraction = builder.value.register_value_producer(
f'{self.name}_rate.population_attributable_fraction',
source=lambda index: [pd.Series(0, index=index)],
preferred_combiner=list_combiner,
preferred_post_processor=joint_value_post_processor)

def _probability(self, index):
effective_rate = self.transition_rate(index)
return rate_to_probability(effective_rate)

def _risk_deleted_rate(self, index):
return self.base_rate(index) * (1 - self.joint_population_attributable_fraction(index))


class DiseaseState(State):

def __init__(self, state_name, excess_mortality_rate=0, **kwargs):
super().__init__(state_name, **kwargs)
self._excess_mortality_rate = excess_mortality_rate

def setup(self, builder):
"""Performs this component's simulation setup.
Parameters
----------
builder : `engine.Builder`
Interface to several simulation tools.
"""
super().setup(builder)
self.clock = builder.time.clock()

self.excess_mortality_rate = builder.value.register_rate_producer(
f'{self.state_id}.excess_mortality_rate',
source=self.risk_deleted_excess_mortality_rate
)
self.excess_mortality_rate_paf = builder.value.register_value_producer(
f'{self.state_id}.excess_mortality_rate.population_attributable_fraction',
source=lambda index: [pd.Series(0, index=index)],
preferred_combiner=list_combiner,
preferred_post_processor=joint_value_post_processor
)

builder.value.register_value_modifier('mortality_rate', self.add_in_excess_mortality)
self.population_view = builder.population.get_view(
[self._model], query=f"alive == 'alive' and {self._model} == '{self.state_id}'")

def add_transition(self, transition_name, output, rate=1e6, **kwargs):
t = DiseaseTransition(transition_name, rate, self, output, **kwargs)
self.transition_set.append(t)
return t

def risk_deleted_excess_mortality_rate(self, index):
return pd.Series(self._excess_mortality_rate, index=index) * (1 - self.excess_mortality_rate_paf(index))

def add_in_excess_mortality(self, index, mortality_rates):
affected = self.population_view.get(index)
mortality_rates.loc[affected.index] += self.excess_mortality_rate(affected.index)

return mortality_rates


class DiseaseModel(Machine):

def __init__(self, disease, initial_state, cause_specific_mortality_rate=0., **kwargs):
super().__init__(disease, **kwargs)
self.initial_state = initial_state.state_id
self._cause_specific_mortality_rate = cause_specific_mortality_rate

def setup(self, builder):
super().setup(builder)
self.cause_specific_mortality_rate = builder.value.register_rate_producer(
f'{self.state_column}.cause_specific_mortality_rate',
source=lambda index: pd.Series(self._cause_specific_mortality_rate, index=index)
)
builder.value.register_value_modifier('mortality_rate', modifier=self.delete_cause_specific_mortality)
builder.value.register_value_modifier('metrics', modifier=self.metrics)

creates_columns = [self.state_column]
builder.population.initializes_simulants(self.on_initialize_simulants, creates_columns=creates_columns)
self.population_view = builder.population.get_view(['age', 'sex', self.state_column])

builder.event.register_listener('time_step', self.on_time_step)

def on_initialize_simulants(self, pop_data):
condition_column = pd.Series(self.initial_state, index=pop_data.index, name=self.state_column)
self.population_view.update(condition_column)

def on_time_step(self, event):
self.transition(event.index, event.time)

def delete_cause_specific_mortality(self, index, rates):
return rates - self.cause_specific_mortality_rate(index)

def metrics(self, index, metrics):
pop = self.population_view.get(index, query="alive == 'alive'")
metrics[self.state_column + '_prevalent_cases'] = len(pop[pop[self.state_column] != self.initial_state])
return metrics


class SIS_DiseaseModel:

configuration_defaults = {
'disease': {
'incidence': 0.005,
'remission': 0.05,
'excess_mortality': 0.01,
}
}

def __init__(self, disease_name):
self.name = disease_name
self.configuration_defaults = {disease_name: SIS_DiseaseModel.configuration_defaults['disease']}

def setup(self, builder):
config = builder.configuration[self.name]

susceptible_state = DiseaseState(f'susceptible_to_{self.name}')
infected_state = DiseaseState(f'infected_with_{self.name}',
excess_mortality_rate=config.excess_mortality)

susceptible_state.allow_self_transitions()
susceptible_state.add_transition(f'{infected_state.state_id}.incidence', infected_state, rate=config.incidence)
infected_state.allow_self_transitions()
infected_state.add_transition(f'{infected_state.state_id}.remission', susceptible_state, rate=config.remission)

# Reasonable approximation for short duration diseases.
case_fatality_rate = config.excess_mortality / (config.excess_mortality + config.remission)
cause_specific_mortality = config.incidence * case_fatality_rate

model = DiseaseModel(self.name,
initial_state=susceptible_state,
cause_specific_mortality_rate=cause_specific_mortality,
states=[susceptible_state, infected_state])
builder.components.add_components([model])
47 changes: 47 additions & 0 deletions vivarium/examples/disease_model/disease_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
components:
vivarium.examples.disease_model:
population:
- BasePopulation()
- Mortality()
- Observer()
disease:
- SIS_DiseaseModel('diarrhea')
risk:
- Risk('child_growth_failure')
- DirectEffect('child_growth_failure', 'infected_with_diarrhea.incidence_rate')
- DirectEffect('child_growth_failure', 'infected_with_diarrhea.excess_mortality_rate')
intervention:
- MagicWandIntervention('breastfeeding_promotion', 'child_growth_failure.proportion_exposed')

configuration:
randomness:
key_columns: ['entrance_time', 'age']
time:
start:
year: 2005
month: 7
day: 1
end:
year: 2006
month: 7
day: 1
step_size: 3 # Days
population:
population_size: 10_000
age_start: 0
age_end: 30
mortality:
mortality_rate: 0.05
life_expectancy: 80
diarrhea:
incidence: 2.5 # Approximately 2.5 cases per person per year.
remission: 42 # Approximately 6 day median recovery time
excess_mortality: 12 # Approximately 22 % of cases result in death
child_growth_failure:
proportion_exposed: 0.5
effect_of_child_growth_failure_on_infected_with_diarrhea.incidence_rate:
relative_risk: 5
effect_of_child_growth_failure_on_infected_with_diarrhea.excess_mortality_rate:
relative_risk: 5
breastfeeding_promotion:
effect_size: 0.5
24 changes: 24 additions & 0 deletions vivarium/examples/disease_model/intervention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pandas as pd

class MagicWandIntervention:

configuration_defaults = {
'intervention': {
'effect_size': 0.5,
}
}

def __init__(self, name, affected_value):
self.name = name
self.affected_value = affected_value
self.configuration_defaults = {name: MagicWandIntervention.configuration_defaults['intervention']}

def setup(self, builder):
effect_size = builder.configuration[self.name].effect_size
builder.value.register_value_modifier(self.affected_value, modifier=self.intervention_effect)
self.effect_size = builder.value.register_value_producer(
f'{self.name}.effect_size', source=lambda index: pd.Series(effect_size, index=index)
)

def intervention_effect(self, index, value):
return value * (1 - self.effect_size(index))
101 changes: 101 additions & 0 deletions vivarium/examples/disease_model/population.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import numpy as np
import pandas as pd


class BasePopulation:
configuration_defaults = {
'population': {
'age_start': 0,
'age_end': 100,
},
}

def setup(self, builder):
self.config = builder.configuration

self.with_common_random_numbers = bool(self.config.randomness.key_columns)
self.randomness = builder.randomness.get_stream('age_initialization',
for_initialization=self.with_common_random_numbers)
self.register = builder.randomness.register_simulants

columns_created = ['age', 'sex', 'alive', 'entrance_time']
builder.population.initializes_simulants(self.generate_population, creates_columns=columns_created)
self.population_view = builder.population.get_view(columns_created)
builder.event.register_listener('time_step', self.age_simulants)

def generate_population(self, pop_data):
age_start = pop_data.user_data.get('age_start', self.config.population.age_start)
age_end = pop_data.user_data.get('age_end', self.config.population.age_end)

age_draw = self.randomness.get_draw(pop_data.index)
age_window = pop_data.creation_window / pd.Timedelta(days=365) if age_start == age_end else age_end - age_start
age = age_start + age_draw * age_window

if self.with_common_random_numbers:
population = pd.DataFrame({'entrance_time': pop_data.creation_time,
'age': age.values}, index=pop_data.index)
self.register(population)
population['sex'] = self.randomness.choice(pop_data.index, ['Male', 'Female'], additional_key='sex_choice')
population['alive'] = 'alive'
else:
population = pd.DataFrame(
{'age': age.values,
'sex': self.randomness.choice(pop_data.index, ['Male', 'Female'], additional_key='sex_choice'),
'alive': pd.Series('alive', index=pop_data.index),
'entrance_time': pop_data.creation_time},
index=pop_data.index)

self.population_view.update(population)

def age_simulants(self, event):
population = self.population_view.get(event.index, query="alive == 'alive'")
population['age'] += event.step_size / pd.Timedelta(days=365)
self.population_view.update(population)


class Mortality:

configuration_defaults = {
'mortality': {
'mortality_rate': 0.01,
'life_expectancy': 80,
}
}

def setup(self, builder):
self.config = builder.configuration.mortality
self.population_view = builder.population.get_view(['alive'], query="alive == 'alive'")
self.randomness = builder.randomness.get_stream('mortality')

self.mortality_rate = builder.value.register_rate_producer('mortality_rate', source=self.base_mortality_rate)

builder.event.register_listener('time_step', self.determine_deaths)

def base_mortality_rate(self, index):
return pd.Series(self.config.mortality_rate, index=index)

def determine_deaths(self, event):
effective_rate = self.mortality_rate(event.index)
effective_probability = 1 - np.exp(-effective_rate)
draw = self.randomness.get_draw(event.index)
affected_simulants = draw < effective_probability
self.population_view.update(pd.Series('dead', index=event.index[affected_simulants]))


class Observer:

def setup(self, builder):
self.life_expectancy = builder.configuration.mortality.life_expectancy
self.population_view = builder.population.get_view(['age', 'alive'])

builder.value.register_value_modifier('metrics', self.metrics)

def metrics(self, index, metrics):

pop = self.population_view.get(index)
metrics['total_population_alive'] = len(pop[pop.alive == 'alive'])
metrics['total_population_dead'] = len(pop[pop.alive == 'dead'])

metrics['years_of_life_lost'] = (self.life_expectancy - pop.age[pop.alive == 'dead']).sum()

return metrics

0 comments on commit 4c879eb

Please sign in to comment.