# Demo of `models.py`

In [1]:
from gcdyn import models, poisson, utils, mutators
import jax.numpy as np
from jax.config import config

config.update("jax_enable_x64", True)

## Simulate some trees to use later

In [2]:
# (Deliberately avoiding rate parameters == constant 1, because a lot of bugs hide in this scenario)
true_parameters = {
    "birth_rate": poisson.SigmoidResponse(1.0, 0.0, 2.0, 0.0),
    "death_rate": poisson.ConstantResponse(1.3),
    "mutation_rate": poisson.ConstantResponse(1.2),
    "mutator": mutators.DiscreteMutator(
        state_space=(1, 2, 3),
        transition_matrix=utils.random_transition_matrix(length=3),
    ),
    "extant_sampling_probability": 1,
}

PRESENT_TIME = 2

trees = utils.sample_trees(
    n=2,
    t=PRESENT_TIME,
    init_x=1,
    **true_parameters,
    seed=10,
)

Notice: obtained error 'minimum number of survivors 1 not attained' 6 times.
Success: average of 54.0 nodes per tree, over 2 trees.


## Computing likelihoods under different models

In [3]:
models.naive_log_likelihood(trees=trees, **true_parameters)

Array(-75.8863699, dtype=float64)

In [4]:
# TODO: enable once `prune` preserves mutation event nodes
# for tree in trees:
#    tree.prune()

for tree in trees:
    tree._pruned = True

In [5]:
models.stadler_appx_log_likelihood(
    trees=trees,
    **true_parameters,
    extinct_sampling_probability=1,
    present_time=PRESENT_TIME,
)

Array(-75.8863699, dtype=float64)

In [6]:
models.stadler_full_log_likelihood(
    trees=trees,
    **true_parameters,
    extinct_sampling_probability=1,
    present_time=PRESENT_TIME,
    dtmax=0.01,
)

Array(-75.8863699, dtype=float64, weak_type=True)

In [7]:
models.stadler_full_log_likelihood_scipy(
    trees=trees,
    **true_parameters,
    extinct_sampling_probability=1,
    present_time=PRESENT_TIME,
    max_step=0.01,
)

Array(-75.8863699, dtype=float64)

## Maximum likelihood estimation of rate parameters

In [8]:
# Initialize with the truth

model = models.BirthDeathModel(
    log_likelihood=models.stadler_appx_log_likelihood,
    trees=trees,
    optimized_parameters={
        "birth_rate": true_parameters["birth_rate"],
        "death_rate": true_parameters["death_rate"],
    },
    fixed_parameters={
        "mutation_rate": true_parameters["mutation_rate"],
        "mutator": true_parameters["mutator"],
        "extant_sampling_probability": true_parameters["extant_sampling_probability"],
        "extinct_sampling_probability": 1,
        "present_time": PRESENT_TIME,
    },
)

model.fit()

ScipyMinimizeInfo(fun_val=Array(75.32768632, dtype=float64, weak_type=True), success=True, status=0, iter_num=22)

In [9]:
model.parameters

{'birth_rate': SigmoidResponse(xscale=1.936334913485592, xshift=-2.007671650593814, yscale=1.1833484109309638, yshift=0.5757177848367284),
 'death_rate': ConstantResponse(value=1.0927569207867909)}

In [10]:
model = models.BirthDeathModel(
    log_likelihood=models.stadler_full_log_likelihood,
    trees=trees,
    optimized_parameters={
        "birth_rate": true_parameters["birth_rate"],
        "death_rate": true_parameters["death_rate"],
    },
    fixed_parameters={
        "mutation_rate": true_parameters["mutation_rate"],
        "mutator": true_parameters["mutator"],
        "extant_sampling_probability": true_parameters["extant_sampling_probability"],
        "extinct_sampling_probability": 1,
        "present_time": PRESENT_TIME,
        "dtmax": 0.1,
    },
)

model.fit()

# This runs a little slow because it takes a while to jit the grad

2023-04-21 16:08:28.601835: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_objective] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-04-21 16:11:31.456034: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 5m2.854811s

********************************
[Compiling module jit_objective] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


ScipyMinimizeInfo(fun_val=Array(75.32755671, dtype=float64, weak_type=True), success=True, status=0, iter_num=22)

In [11]:
model.parameters

{'birth_rate': SigmoidResponse(xscale=1.9335909694211946, xshift=-2.0019620615547145, yscale=1.1853365376056422, yshift=0.5737605752872461),
 'death_rate': ConstantResponse(value=1.09276954025032)}