# Demo of `models.py`

In [1]:
from gcdyn import models, poisson, utils, mutators
import jax.numpy as np
from jax.config import config

config.update("jax_enable_x64", True)

## Simulate some trees to use later

In [2]:
seed = 10

# (Deliberately avoiding rate parameters == constant 1, because a lot of bugs hide in this scenario)
true_parameters = {
    "birth_response": poisson.SigmoidResponse(1.0, 0.0, 2.0, 0.0),
    "death_response": poisson.ConstantResponse(1.3),
    "mutation_response": poisson.ConstantResponse(1.2),
    "mutator": mutators.DiscreteMutator(
        state_space=(1, 2, 3),
        transition_matrix=utils.random_transition_matrix(length=3, seed=seed),
    ),
    "extant_sampling_probability": 1,
}

PRESENT_TIME = 2

trees = utils.sample_trees(n=2, t=PRESENT_TIME, init_x=1, **true_parameters, seed=seed)

Notice: obtained error 'number of survivors 0 is less than min_survivors=1' 4 times.
Success: average of 27.5 nodes per tree, over 2 trees.


## Computing likelihoods under different models

In [3]:
models.naive_log_likelihood(trees=trees, **true_parameters)

Array(-44.24312484, dtype=float64)

In [4]:
for tree in trees:
    tree._pruned = True
    # tree.prune() removes extinct lineages too, which is not part of the ρ=σ=1 setting
    # that makes these models equivalent.
    # Set this flag to satisfy the assertion in the following likelihood functions

In [5]:
models.stadler_appx_log_likelihood(
    trees=trees,
    **true_parameters,
    extinct_sampling_probability=1,
    present_time=PRESENT_TIME,
)

Array(-44.24312484, dtype=float64)

In [6]:
models.stadler_full_log_likelihood(
    trees=trees,
    **true_parameters,
    extinct_sampling_probability=1,
    present_time=PRESENT_TIME,
    dtmax=0.01,
)

Array(-44.24312484, dtype=float64, weak_type=True)

## Maximum likelihood estimation of rate parameters

In [7]:
# Initialize with the truth

model = models.BirthDeathModel(
    log_likelihood=models.stadler_appx_log_likelihood,
    trees=trees,
    optimized_parameters={
        "birth_response": true_parameters["birth_response"],
        "death_response": true_parameters["death_response"],
    },
    fixed_parameters={
        "mutation_response": true_parameters["mutation_response"],
        "mutator": true_parameters["mutator"],
        "extant_sampling_probability": true_parameters["extant_sampling_probability"],
        "extinct_sampling_probability": 1,
        "present_time": PRESENT_TIME,
    },
)

model.fit()

ScipyMinimizeInfo(fun_val=Array(42.70447212, dtype=float64, weak_type=True), success=True, status=0, iter_num=78)

In [8]:
model.parameters

{'birth_response': SigmoidResponse(xscale=9.631466180159949, xshift=3.2767300177925405, yscale=10.53124412056242, yshift=1.0868446373469547),
 'death_response': ConstantResponse(value=0.8489909154295577)}

In [9]:
model = models.BirthDeathModel(
    log_likelihood=models.stadler_full_log_likelihood,
    trees=trees,
    optimized_parameters={
        "birth_response": true_parameters["birth_response"],
        "death_response": true_parameters["death_response"],
    },
    fixed_parameters={
        "mutation_response": true_parameters["mutation_response"],
        "mutator": true_parameters["mutator"],
        "extant_sampling_probability": true_parameters["extant_sampling_probability"],
        "extinct_sampling_probability": 1,
        "present_time": PRESENT_TIME,
        "dtmax": 0.1,
    },
)

model.fit()

# This runs a little slow because it takes a while to jit the grad

ScipyMinimizeInfo(fun_val=Array(42.70444679, dtype=float64, weak_type=True), success=True, status=0, iter_num=73)

In [10]:
model.parameters

{'birth_response': SigmoidResponse(xscale=8.94261956733765, xshift=3.300719959473646, yscale=10.768037542642809, yshift=1.0868504527608402),
 'death_response': ConstantResponse(value=0.8489772170201427)}