Skip to content

Commit

Permalink
Refactor handling of deme sizes
Browse files Browse the repository at this point in the history
* bump demes-forward-capi dependency
* add back end function to get sum of ancestral sizes at time 0
* replace rint with Decimal internally
  • Loading branch information
molpopgen committed Mar 14, 2023
1 parent 2ebe9e0 commit 0b576f3
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 41 deletions.
4 changes: 3 additions & 1 deletion cpp/demes/forward_graph.cc
Expand Up @@ -11,5 +11,7 @@ init_forward_graph(py::module &m)
.def(py::init<const std::string & /*yaml*/, std::uint32_t /*burnin*/,
bool /*round_epoch_sizes*/>(),
py::arg("yaml"), py::arg("burnin"), py::arg("round_epoch_sizes"),
py::kw_only());
py::kw_only())
.def("_sum_deme_sizes_at_time_zero",
&fwdpy11_core::ForwardDemesGraph::sum_deme_sizes_at_time_zero);
}
1 change: 1 addition & 0 deletions cpptests/test_forward_demes_graph.cc
Expand Up @@ -31,6 +31,7 @@ BOOST_FIXTURE_TEST_CASE(single_deme_model_with_burn_in, SingleDemeModel)
fwdpy11_core::ForwardDemesGraph g(yaml, 10);
BOOST_REQUIRE_EQUAL(g.number_of_demes(), 1);
fwdpy11::DiploidPopulation pop(100, 1.0);
BOOST_CHECK_EQUAL(pop.N, g.sum_deme_sizes_at_time_zero());
g.initialize_model(pop.generation);
auto end_time = g.model_end_time();
BOOST_REQUIRE_EQUAL(end_time, 11);
Expand Down
18 changes: 12 additions & 6 deletions fwdpy11/_functions/import_demes.py
@@ -1,16 +1,21 @@
import decimal
import math
from typing import Dict, Optional, Union

import attr
import demes
import numpy as np

from fwdpy11._types.demographic_model_citation import DemographicModelCitation
from fwdpy11._types.demographic_model_details import DemographicModelDetails
from fwdpy11._types.forward_demes_graph import ForwardDemesGraph


# TODO: need type hints for dg
def _round_via_decimal(value):
with decimal.localcontext() as ctx:
ctx.rounding = decimal.ROUND_HALF_UP
return int(decimal.Decimal(value).to_integral_value())


def demography_from_demes(
dg: Union[str, demes.Graph], burnin: int,
round_non_integer_sizes=Optional[bool],
Expand Down Expand Up @@ -54,7 +59,7 @@ def _build_from_foward_demes_graph(
initial_sizes = _get_initial_deme_sizes(fg.graph, idmap)
Nref = _get_ancestral_population_size(fg.graph)

burnin_generation = int(np.rint(burnin * Nref))
burnin_generation = _round_via_decimal(burnin*Nref)
model_times = _ModelTimes.from_demes_graph(fg.graph, burnin_generation)

# TODO: size_history now contains model_times, so passing
Expand Down Expand Up @@ -160,7 +165,7 @@ def from_demes_graph(dg: demes.Graph, burnin_generation: int) -> "_ModelTimes":
return _ModelTimes(
model_start_time=model_start_time,
model_end_time=most_recent_deme_end,
model_duration=int(np.rint(model_duration)),
model_duration=_round_via_decimal(model_duration),
burnin_generation=burnin_generation,
)

Expand Down Expand Up @@ -207,7 +212,8 @@ def _get_initial_deme_sizes(dg: demes.Graph, idmap: Dict) -> Dict:
rv = dict()
for deme in dg.demes:
if deme.epochs[0].start_time == otime:
rv[idmap[deme.name]] = int(np.rint(deme.epochs[0].start_size))
rv[idmap[deme.name]] = _round_via_decimal(
deme.epochs[0].start_size)

if len(rv) == 0:
raise RuntimeError("could not determine initial deme sizes")
Expand Down Expand Up @@ -257,7 +263,7 @@ def _get_ancestral_population_size(dg: demes.Graph) -> int:

rv = sum(
[
int(np.rint(e.start_size))
_round_via_decimal(e.start_size)
for d in dg.demes
for e in d.epochs
if e.start_time == oldest_deme_time
Expand Down
36 changes: 3 additions & 33 deletions fwdpy11/_types/forward_demes_graph.py
@@ -1,5 +1,4 @@
import typing
import warnings

import attr
import demes
Expand Down Expand Up @@ -48,45 +47,16 @@ class ForwardDemesGraph(fwdpy11._fwdpy11._ForwardDemesGraph):
def __attrs_post_init__(self):
from fwdpy11._functions.import_demes import _get_ancestral_population_size
self.graph = demes.loads(self.yaml)
if self.round_non_integer_sizes is False:
ForwardDemesGraph._reject_non_integer_sizes(self.graph)

ForwardDemesGraph._validate_pulses(self.graph)
Nref = _get_ancestral_population_size(self.graph)
assert np.modf(Nref)[0] == 0.0
if self.burnin_is_exact is True:
burnin = self.burnin
else:
burnin = int(np.rint(self.burnin)*Nref)
burnin = self.burnin*Nref
super(ForwardDemesGraph, self).__init__(
self.yaml, burnin, self.round_non_integer_sizes)

def _validate_pulses(graph: demes.Graph):
unique_pulse_times = set([np.rint(p.time) for p in graph.pulses])
for time in unique_pulse_times:
pulses = [p for p in graph.pulses if np.rint(p.time) == time]
dests = set()
for p in pulses:
if p.dest in dests:
warnings.warn(
f"multiple pulse events into deme {p.dest} at time {time}."
+ " The effect of these pulses will depend on the order "
+ "in which they are applied."
+ "To avoid unexpected behavior, "
+ "the graph can instead be structured to"
+ " introduce a new deme at this time with"
+ " the desired ancestry proportions or to specify"
+ " concurrent pulses with multiple sources.",
UserWarning)
dests.add(p.dest)

def _reject_non_integer_sizes(graph: demes.Graph):
for deme in graph.demes:
for i, epoch in enumerate(deme.epochs):
for size in [epoch.start_size, epoch.end_size]:
if np.isfinite(size) and np.modf(size)[0] != 0.0:
raise ValueError(
f"deme {deme.name} has non-integer size {size} in epoch {i}")
x = self._sum_deme_sizes_at_time_zero()
assert Nref == x, f"{Nref}, {x}, {self.yaml}"

def number_of_demes(self) -> int:
return len(self.graph.demes)
Expand Down
4 changes: 4 additions & 0 deletions lib/core/demes/forward_graph.hpp
Expand Up @@ -51,5 +51,9 @@ namespace fwdpy11_core
ForwardDemesGraphDataIterator<double> offspring_cloning_rates() const;
ForwardDemesGraphDataIterator<double>
offspring_ancestry_proportions(std::size_t offspring_deme) const;

// TODO: this can become const once upstream
// gets a fn to handle this
std::uint32_t sum_deme_sizes_at_time_zero();
};
}
16 changes: 16 additions & 0 deletions lib/demes/forward_graph.cc
@@ -1,3 +1,4 @@
#include <limits>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
Expand Down Expand Up @@ -282,4 +283,19 @@ namespace fwdpy11_core
throw_if_null(begin, __FILE__, __LINE__);
return ForwardDemesGraphDataIterator<double>{begin, begin + number_of_demes()};
}

std::uint32_t
ForwardDemesGraph::sum_deme_sizes_at_time_zero()
{
std::int32_t status;
auto rv
= demes_forward_graph_sum_sizes_at_time_zero(&status, pimpl->graph.get());
pimpl->handle_error_code(status);
if (rv >= static_cast<double>(std::numeric_limits<std::uint32_t>::max()))
{
throw std::runtime_error(
"sum of sizes at time zero is too large for integer type");
}
return rv;
}
}
2 changes: 1 addition & 1 deletion rust/fp11rust/Cargo.toml
Expand Up @@ -16,5 +16,5 @@ panic = "abort"
strip = true

[dependencies]
demes-forward-capi = {version="0.4.0-alpha.0"}
demes-forward-capi = {version="0.4.0-alpha.1"}
libc = "~0.2"
31 changes: 31 additions & 0 deletions rust/fp11rust/src/demes_capi_bridge.rs
Expand Up @@ -147,3 +147,34 @@ pub unsafe extern "C" fn demes_forward_graph_model_end_time(
) -> f64 {
forward_graph_model_end_time(status, graph)
}

// Below are functions defined only for fwdpy11.
// These make use of demes_forward_capi.

#[no_mangle]
pub unsafe extern "C" fn demes_forward_graph_sum_sizes_at_time_zero(
status: *mut i32,
graph: *mut OpaqueForwardGraph,
) -> f64 {
if forward_graph_is_error_state(graph) {
*status = -1;
return f64::NAN;
}
forward_graph_update_state(0.0, graph);
if forward_graph_is_error_state(graph) {
*status = -1;
return f64::NAN;
}
let ptr = forward_graph_parental_deme_sizes(graph, status);
if *status < 0 {
return f64::NAN;
}
assert!(!ptr.is_null());
let num_demes = forward_graph_number_of_demes(graph);
if num_demes < 0 {
return f64::NAN;
}
let size_slice = std::slice::from_raw_parts(ptr, num_demes as usize);
assert_eq!(size_slice.len(), num_demes as usize);
size_slice.iter().sum()
}

0 comments on commit 0b576f3

Please sign in to comment.