# Unit tests comparing the BDMS and MTBD likelihoods

In [6]:
from gcdyn import bdms, mutators, model, responses

from numpy import log, inf, sqrt, exp, max
from scipy.stats import expon

In [2]:
BIRTH              = bdms.TreeNode._BIRTH_EVENT
DEATH              = bdms.TreeNode._DEATH_EVENT
MUTATION           = bdms.TreeNode._MUTATION_EVENT
SAMPLED_SURVIVOR   = bdms.TreeNode._SAMPLING_EVENT
UNSAMPLED_SURVIVOR = bdms.TreeNode._SURVIVAL_EVENT


def add_event(node, event, edge_length):
    child = bdms.TreeNode(
        t    = node.t + edge_length,
        x    = node.x,
        dist = edge_length
    )

    child.event = event
    node.add_child(child)

    return child


def log_dexp(x, rate):
    return expon.logpdf(x, scale = 1 / rate)


def log_hazard_exp(x, rate):
    return expon.logsf(x, scale = 1 / rate)

## Single edge to sample time

In [3]:
tree  = bdms.TreeNode(x = 10)
event = add_event(tree, SAMPLED_SURVIVOR, edge_length = 3)

print(tree)


-- /-0


In [17]:
# BDMS likelihood by code
λ = responses.SigmoidResponse()
μ = responses.ConstantResponse(1)
γ = responses.ConstantResponse(1)
ρ = 1
σ = 1

tree._sampled = True # We do this ourselves

responses.init_numpy(use_jax=True)
#model.register_pytree_node_class(responses.SigmoidResponse)
m = model.BdmsModel([tree],
    death_rate = μ,
    mutation_rate = γ,
    mutator = None,
    sampling_probability = ρ
)

m.log_likelihood(birth_rate = λ).item()

-11.999727249145508

In [13]:
# BDMS likelihood by hand.
# The likelihood is the probability that no BDM event happens along this branch (ie. one would've happened after sample time),
# times the probability of being sampled
Λ = lambda x: λ(x) + μ(x) + γ(x)

log_hazard_exp(event.dist, Λ(event)) + log(ρ)


-11.999726778529322

In [18]:
# MTBD likelihood by hand (eq'n 6)

Λ = lambda x: λ(x) + μ(x) + γ(x)

def log_f_N(event):
    c = sqrt(Λ(event)**2 - 4 * μ(event) * (1 - σ) * λ(event))
    x = (-Λ(event) - c) / 2
    y = (-Λ(event) + c) / 2

    helper = lambda t:  (y + λ(event) * (1 - ρ)) * exp(-c * t) - x - λ(event) * (1 - ρ)

    present_time = max([node.t for node in tree.get_leaves()])
    t_s = present_time - (event.t - event.dist)
    t_e = present_time - event.t

    return c * (t_e - t_s) + 2 * ( log(helper(t_e)) - log(helper(t_s)) )


log_f_N(event) + log(ρ)

-11.999727487564087

## Single edge dying and being sampled

In [19]:
tree  = bdms.TreeNode(x = 10)
event = add_event(tree, DEATH, edge_length = 3)

print(tree)


-- /-0


In [21]:
# BDMS likelihood by code
λ = responses.SigmoidResponse()
μ = responses.ConstantResponse(1)
γ = responses.ConstantResponse(1)
ρ = 1
σ = 1

tree._sampled = True # We do this ourselves

m = model.BdmsModel([tree],
    death_rate = μ,
    mutation_rate = γ,
    mutator = None,
    sampling_probability = ρ
)

m.log_likelihood(birth_rate = λ).item()

-11.999727249145508

In [22]:
# BDMS likelihood by hand.
# The likelihood is the probability that no BDM event happens along this branch (ie. one would've happened after sample time),
# times the probability of being sampled
Λ = lambda x: λ(x) + μ(x) + γ(x)

log_dexp(event.dist, Λ(event)) + (log(μ(event)) - log(Λ(event)))

-11.999726778529322

In [23]:
# MTBD likelihood by hand (eq'n 6)

Λ = lambda x: λ(x) + μ(x) + γ(x)

def log_f_N(event):
    c = sqrt(Λ(event)**2 - 4 * μ(event) * (1 - σ) * λ(event))
    x = (-Λ(event) - c) / 2
    y = (-Λ(event) + c) / 2

    helper = lambda t:  (y + λ(event) * (1 - ρ)) * exp(-c * t) - x - λ(event) * (1 - ρ)

    present_time = max([node.t for node in tree.get_leaves()])
    t_s = present_time - (event.t - event.dist)
    t_e = present_time - event.t

    return c * (t_e - t_s) + 2 * ( log(helper(t_e)) - log(helper(t_s)) )


log_f_N(event) + log(σ) + log(μ(event))

-11.999727487564087

## Edge with a type change, then eventually sampled

In [24]:
mutator = mutators.GaussianMutator(-1, 1)

tree    = bdms.TreeNode(x = 10)
m_event = add_event(tree, MUTATION, edge_length = 3)
mutator.mutate(m_event)
s_event = add_event(m_event, SAMPLED_SURVIVOR, edge_length = 4)

print(tree)


-- /- /-0


In [25]:
# BDMS likelihood by code
λ = responses.SigmoidResponse()
μ = responses.ConstantResponse(1)
γ = responses.ConstantResponse(1)
ρ = 1
σ = 1

tree._sampled = True

m = model.BdmsModel([tree],
    death_rate = μ,
    mutation_rate = γ,
    mutator = mutator,
    sampling_probability = ρ
)

m.log_likelihood(birth_rate = λ).item()

-29.020896911621094

In [26]:
# BDMS likelihood by hand.
# The likelihood is the probability of mutating after the given time
# (which is probability to any event, times probability the event is a mutation),
# times the probability of the specific mutation that occurred,
# times the likelihood as derived in "Single edge to sample time"
Λ = lambda x: λ(x) + μ(x) + γ(x)

(
    log_dexp(m_event.dist, Λ(m_event)) + (log(γ(m_event)) - log(Λ(m_event)))
    + mutator.logprob(tree, m_event)
    + log_hazard_exp(s_event.dist, Λ(s_event)) + log(ρ)
)

-29.020700177600865

In [28]:
# MTBD likelihood by hand (eq'n 6)

def log_f_N(event):
    c = sqrt(Λ(event)**2 - 4 * μ(event) * (1 - σ) * λ(event))
    x = (-Λ(event) - c) / 2
    y = (-Λ(event) + c) / 2

    helper = lambda t:  (y + λ(event) * (1 - ρ)) * exp(-c * t) - x - λ(event) * (1 - ρ)

    present_time = max([node.t for node in tree.get_leaves()])
    t_s = present_time - (event.t - event.dist)
    t_e = present_time - event.t

    return c * (t_e - t_s) + 2 * ( log(helper(t_e)) - log(helper(t_s)) )


(
    log_f_N(m_event) + log_f_N(s_event)
    + log(γ(m_event)) + mutator.logprob(tree, m_event)
    + log(ρ)
)

-29.020700134533033

## A bifurcation with both children sampled

In [29]:
tree    = bdms.TreeNode(x = 10)
b_event = add_event(tree, BIRTH, edge_length = 3)

s_events = []

for i in range(bdms.TreeNode._OFFSPRING_NUMBER):
    s_events.append(
        add_event(b_event, SAMPLED_SURVIVOR, edge_length = 3 + i)
    )

print(tree)


      /-0
-- /-|
      \-0


In [31]:
# BDMS likelihood by code
λ = responses.SigmoidResponse()
μ = responses.ConstantResponse(1)
γ = responses.ConstantResponse(1)
ρ = 1
σ = 1

tree._sampled = True

m = model.BdmsModel([tree],
    death_rate = μ,
    mutation_rate = γ,
    mutator = mutator,
    sampling_probability = ρ
)

m.log_likelihood(birth_rate = λ).item()

-39.30598831176758

In [32]:
# BDMS likelihood by hand.
# The likelihood is the probability of birthing after the given time
# (which is probability to any event, times probability the event is a birth),
# times the likelihood as derived in "Single edge to sample time", once for each child
Λ = lambda x: λ(x) + μ(x) + γ(x)

(
    log_dexp(b_event.dist, Λ(b_event)) + (log(λ(b_event)) - log(Λ(b_event)))
    + log_hazard_exp(s_events[0].dist, Λ(s_events[0])) + log(ρ)
    + log_hazard_exp(s_events[1].dist, Λ(s_events[1])) + log(ρ)
)

-39.30598749803913

In [33]:
# MTBD likelihood by hand (eq'n 6)

def log_f_N(event):
    c = sqrt(Λ(event)**2 - 4 * μ(event) * (1 - σ) * λ(event))
    x = (-Λ(event) - c) / 2
    y = (-Λ(event) + c) / 2

    helper = lambda t:  (y + λ(event) * (1 - ρ)) * exp(-c * t) - x - λ(event) * (1 - ρ)

    present_time = max([node.t for node in tree.get_leaves()])
    t_s = present_time - (event.t - event.dist)
    t_e = present_time - event.t

    return c * (t_e - t_s) + 2 * ( log(helper(t_e)) - log(helper(t_s)) )


(
    log_f_N(b_event) + log_f_N(s_events[0]) + log_f_N(s_events[1])
    + log(λ(b_event)) #+ log(b_event.t)
    + 2 * log(ρ)
)

-39.30598986148834