Skip to content

Commit

Permalink
New example: bacteria in daycare centers (#250)
Browse files Browse the repository at this point in the history
* Implemented the daycare example

* Fix Lotka-Volterra to accept kwargs

* Address comments to PR
  • Loading branch information
vuolleko committed Jun 13, 2018
1 parent 043c87e commit 8cbbf26
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Changelog
dev
---
- Add support for kwargs to elfi.set_client
- Added new example: inference of transmission dynamics of bacteria in daycare centers

0.7.1 (2018-04-11)
------------------
Expand Down
305 changes: 305 additions & 0 deletions elfi/examples/daycare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
"""Example of inference of transmission dynamics of bacteria in day care centers.
Treatment roughly follows:
Numminen, E., Cheng, L., Gyllenberg, M. and Corander, J.: Estimating the transmission dynamics
of Streptococcus pneumoniae from strain prevalence data, Biometrics, 69, 748-757, 2013.
"""

import logging
from functools import partial

import numpy as np

import elfi


def daycare(t1, t2, t3, n_dcc=29, n_ind=53, n_strains=33, freq_strains_commun=None,
n_obs=36, time_end=10., batch_size=1, random_state=None):
r"""Generate cross-sectional data from a stochastic variant of the SIS-model.
This function simulates the transmission dynamics of bacterial infections in daycare centers
(DCC) as described in Nummelin et al. [2013]. The observation model is however simplified to
an equal number of sampled individuals among the daycare centers.
The model is defined as a continuous-time Markov process with transition probabilities:
Pr(I_{is}(t+dt)=1 | I_{is}(t)=0) = t1 * E_s(I(t)) + t2 * P_s, if \sum_{j=1}^N_s I_{ij}(t)=0
Pr(I_{is}(t+dt)=1 | I_{is}(t)=0) = t3 * (t1 * E_s(I(t)) + t2 * P_s), otherwise
Pr(I_{is}(t+dt)=0 | I_{is}(t)=1) = \gamma
where:
I_{is}(t) is the status of carriage of strain s for individual i.
E_s(I(t)) is the probability of sampling the strain s
t1 is the rate of transmission from other children at the DCC (\beta in paper).
t2 is the rate of transmission from the community outside the DCC (\Lambda in paper).
t3 scales the rate of an infected child being infected with another strain (\theta in paper).
\gamma is the relative probability of healing from a strain.
As in the paper, \gamma=1, and the other inferred parameters are relative to it.
The system is solved using the Direct method [Gillespie, 1977].
References
----------
Numminen, E., Cheng, L., Gyllenberg, M. and Corander, J. (2013) Estimating the transmission
dynamics of Streptococcus pneumoniae from strain prevalence data, Biometrics, 69, 748-757.
Gillespie, D. T. (1977) Exact stochastic simulation of coupled chemical reactions.
The Journal of Physical Chemistry 81 (25), 2340–2361.
Parameters
----------
t1 : float or np.array
Rate of transmission from other individuals at the DCC.
t2 : float or np.array
Rate of transmission from the community outside the DCC.
t3 : float or np.array
Scaling of co-infection for individuals infected with another strain.
n_dcc : int, optional
Number of daycare centers.
n_ind : int, optional
Number of individuals in a DCC (same for all).
n_strains : int, optional
Number of bacterial strains considered.
freq_strains_commun : np.array of shape (n_strains,), optional
Prevalence of each strain in the community outside the DCC. Defaults to 0.1.
n_obs : int, optional
Number of individuals sampled from each DCC (same for all).
time_end : float, optional
The system is solved using the Direct method until all cases within the batch exceed this.
batch_size : int, optional
random_state : np.random.RandomState, optional
Returns
-------
state_obs : np.array
Observations in shape (batch_size, n_dcc, n_obs, n_strains).
"""
random_state = random_state or np.random

t1 = np.asanyarray(t1).reshape((-1, 1, 1, 1))
t2 = np.asanyarray(t2).reshape((-1, 1, 1, 1))
t3 = np.asanyarray(t3).reshape((-1, 1, 1, 1))

if freq_strains_commun is None:
freq_strains_commun = np.full(n_strains, 0.1)

prob_commun = t2 * freq_strains_commun

# the state (infection status) is a 4D tensor for computational performance
state = np.zeros((batch_size, n_dcc, n_ind, n_strains), dtype=np.bool)

# time for each DCC in the batch
time = np.zeros((batch_size, n_dcc))

n_factor = 1. / (n_ind - 1)
gamma = 1. # relative, see paper
ind_b_dcc = [np.repeat(np.arange(batch_size), n_dcc), np.tile(np.arange(n_dcc), batch_size)]

while np.any(time < time_end):
with np.errstate(divide='ignore', invalid='ignore'):
# probability of sampling a strain; in paper: E_s(I(t))
prob_strain = np.sum(state / np.sum(state, axis=3, keepdims=True),
axis=2, keepdims=True) * n_factor
prob_strain = np.where(np.isfinite(prob_strain), prob_strain, 0)

# init prob to get infected, same for all
hazards = t1 * prob_strain + prob_commun # shape (batch_size, n_dcc, 1, n_strains)

# co-infection, depends on the individual's state
hazards = np.tile(hazards, (1, 1, n_ind, 1))
any_infection = np.any(state, axis=3, keepdims=True)
hazards = np.where(any_infection, t3 * hazards, hazards)

# (relative) probability to be cured
hazards[state] = gamma

# normalize to probabilities
inv_sum_hazards = 1. / np.sum(hazards, axis=(2, 3), keepdims=True)
probs = hazards * inv_sum_hazards

# times until next transition (for each DCC in the batch)
delta_t = random_state.exponential(inv_sum_hazards[:, :, 0, 0])
time = time + delta_t

# choose transition
probs = probs.reshape((batch_size, n_dcc, -1))
cumprobs = np.cumsum(probs[:, :, :-1], axis=2)
x = random_state.uniform(size=(batch_size, n_dcc, 1))
ind_transit = np.sum(x >= cumprobs, axis=2)

# update state, need to find the correct indices first
ind_transit = ind_b_dcc + list(np.unravel_index(ind_transit.ravel(), (n_ind, n_strains)))
state[ind_transit] = np.logical_not(state[ind_transit])

# observation model: simply take the first n_obs individuals
state_obs = state[:, :, :n_obs, :]

return state_obs


def get_model(true_params=None, seed_obs=None, **kwargs):
"""Return a complete ELFI graph ready for inference.
Selection of true values, priors etc. follows the approach in
Numminen, E., Cheng, L., Gyllenberg, M. and Corander, J.: Estimating the transmission dynamics
of Streptococcus pneumoniae from strain prevalence data, Biometrics, 69, 748-757, 2013.
and
Gutmann M U, Corander J (2016). Bayesian Optimization for Likelihood-Free Inference
of Simulator-Based Statistical Models. JMLR 17(125):1−47, 2016.
Parameters
----------
true_params : list, optional
Parameters with which the observed data is generated.
seed_obs : int, optional
Seed for the observed data generation.
Returns
-------
m : elfi.ElfiModel
"""
logger = logging.getLogger()
if true_params is None:
true_params = [3.6, 0.6, 0.1]

m = elfi.ElfiModel()
y_obs = daycare(*true_params, random_state=np.random.RandomState(seed_obs), **kwargs)
sim_fn = partial(daycare, **kwargs)
priors = []
sumstats = []

priors.append(elfi.Prior('uniform', 0, 11, model=m, name='t1'))
priors.append(elfi.Prior('uniform', 0, 2, model=m, name='t2'))
priors.append(elfi.Prior('uniform', 0, 1, model=m, name='t3'))

elfi.Simulator(sim_fn, *priors, observed=y_obs, name='DCC')

sumstats.append(elfi.Summary(ss_shannon, m['DCC'], name='Shannon'))
sumstats.append(elfi.Summary(ss_strains, m['DCC'], name='n_strains'))
sumstats.append(elfi.Summary(ss_prevalence, m['DCC'], name='prevalence'))
sumstats.append(elfi.Summary(ss_prevalence_multi, m['DCC'], name='multi'))

elfi.Discrepancy(distance, *sumstats, name='d')

logger.info("Generated observations with true parameters "
"t1: %.1f, t2: %.3f, t3: %.1f, ", *true_params)

return m


def ss_shannon(data):
r"""Calculate the Shannon index of diversity of the distribution of observed strains.
H = -\sum p \log(p)
https://en.wikipedia.org/wiki/Diversity_index#Shannon_index
Parameters
----------
data : np.array of shape (batch_size, n_dcc, n_obs, n_strains)
Returns
-------
np.array of shape (batch_size, n_dcc)
"""
proportions = np.sum(data, axis=2) / data.shape[2]
shannon = -np.sum(proportions * np.log(proportions + 1e-9), axis=2) # axis 3 is now 2

return shannon


def ss_strains(data):
"""Calculate the number of different strains observed.
Parameters
----------
data : np.array of shape (batch_size, n_dcc, n_obs, n_strains)
Returns
-------
np.array of shape (batch_size, n_dcc)
"""
strain_active = np.any(data, axis=2)
n_strain_obs = np.sum(strain_active, axis=2) # axis 3 is now 2

return n_strain_obs


def ss_prevalence(data):
"""Calculate the prevalence of carriage among the observed individuals.
Parameters
----------
data : np.array of shape (batch_size, n_dcc, n_obs, n_strains)
Returns
-------
np.array of shape (batch_size, n_dcc)
"""
any_infection = np.any(data, axis=3)
n_infected = np.sum(any_infection, axis=2)

return n_infected / data.shape[2]


def ss_prevalence_multi(data):
"""Calculate the prevalence of multiple infections among the observed individuals.
Parameters
----------
data : np.array of shape (batch_size, n_dcc, n_obs, n_strains)
Returns
-------
np.array of shape (batch_size, n_dcc)
"""
n_infections = np.sum(data, axis=3)
n_multi_infections = np.sum(n_infections > 1, axis=2)

return n_multi_infections / data.shape[2]


def distance(*summaries, observed):
"""Calculate an L1-based distance between the simulated and observed summaries.
Follows the simplified single-distance approach in:
Gutmann M U, Corander J (2016). Bayesian Optimization for Likelihood-Free Inference
of Simulator-Based Statistical Models. JMLR 17(125):1−47, 2016.
Parameters
----------
*summaries : k np.arrays of shape (m, n)
observed : list of k np.arrays of shape (1, n)
Returns
-------
np.array of shape (m,)
"""
summaries = np.stack(summaries)
observed = np.stack(observed)
n_ss, _, n_dcc = summaries.shape

# scale summaries with max observed
obs_max = np.max(observed, axis=2, keepdims=True)
obs_max = np.where(obs_max == 0, 1, obs_max)
y = observed / obs_max
x = summaries / obs_max

# sort to make comparison more robust
y = np.sort(y, axis=2)
x = np.sort(x, axis=2)

# L1 norm divided by the dimension
dist = np.sum(np.abs(x - y), axis=(0, 2)) / (n_ss * n_dcc)

return dist
11 changes: 6 additions & 5 deletions elfi/examples/lotka_volterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def lotka_volterra(r1, r2, r3, prey_init=50, predator_init=100, sigma=0., n_obs=
return stock_out


def get_model(n_obs=50, true_params=None, seed_obs=None, stochastic=True):
def get_model(n_obs=50, true_params=None, seed_obs=None, **kwargs):
"""Return a complete Lotka-Volterra model in inference task.
Parameters
Expand All @@ -159,13 +159,14 @@ def get_model(n_obs=50, true_params=None, seed_obs=None, stochastic=True):
"""
logger = logging.getLogger()
simulator = partial(lotka_volterra, n_obs=n_obs)
if true_params is None:
true_params = [1.0, 0.005, 0.6, 50, 100, 10.]
true_params = [1.0, 0.005, 0.6, 50, 100, 10.]

kwargs['n_obs'] = n_obs
y_obs = lotka_volterra(*true_params, random_state=np.random.RandomState(seed_obs), **kwargs)

m = elfi.ElfiModel()
y_obs = simulator(*true_params, n_obs=n_obs, random_state=np.random.RandomState(seed_obs))
sim_fn = partial(simulator, n_obs=n_obs)
sim_fn = partial(lotka_volterra, **kwargs)
priors = []
sumstats = []

Expand Down
11 changes: 8 additions & 3 deletions tests/unit/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import elfi
from elfi.examples import bdm, bignk, gauss, gnk, lotka_volterra, ricker
from elfi.examples import bdm, bignk, gauss, gnk, lotka_volterra, ricker, daycare


def test_bdm():
Expand Down Expand Up @@ -85,6 +85,11 @@ def test_bignk(stats_summary=['ss_octile']):


def test_Lotka_Volterra():
m = lotka_volterra.get_model()
m = lotka_volterra.get_model(time_end=0.05)
rej = elfi.Rejection(m, m['d'], batch_size=10)
rej.sample(20)
rej.sample(10, quantile=0.5)

def test_daycare():
m = daycare.get_model(time_end=0.05)
rej = elfi.Rejection(m['d'], batch_size=10)
rej.sample(10, quantile=0.5)

0 comments on commit 8cbbf26

Please sign in to comment.