Skip to content

Commit

Permalink
Use rust back end for more of the work.
Browse files Browse the repository at this point in the history
Result is fewer passes through graphs in Python

* bump demes-forward-capi dependency
* add back end function to get sum of ancestral sizes at time 0
* replace rint with Decimal internally
* get model duration, parental sizes at time 0 from rust
* delete lots of Python helper code from old API
  that rust tools replace
  • Loading branch information
molpopgen committed Mar 15, 2023
1 parent dd687ed commit 9c6d064
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 219 deletions.
8 changes: 7 additions & 1 deletion cpp/demes/forward_graph.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cstdint>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <core/demes/forward_graph.hpp>

namespace py = pybind11;
Expand All @@ -11,5 +12,10 @@ init_forward_graph(py::module &m)
.def(py::init<const std::string & /*yaml*/, std::uint32_t /*burnin*/,
bool /*round_epoch_sizes*/>(),
py::arg("yaml"), py::arg("burnin"), py::arg("round_epoch_sizes"),
py::kw_only());
py::kw_only())
.def("_sum_deme_sizes_at_time_zero",
&fwdpy11_core::ForwardDemesGraph::sum_deme_sizes_at_time_zero)
.def("_model_end_time", &fwdpy11_core::ForwardDemesGraph::model_end_time)
.def("_parental_deme_sizes_at_time_zero",
&fwdpy11_core::ForwardDemesGraph::parental_deme_sizes_at_time_zero);
}
5 changes: 3 additions & 2 deletions cpptests/test_evolvets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,10 @@ BOOST_FIXTURE_TEST_CASE(test_basic_api_coherence_two_deme_perpetual_island_model
{
auto model = TwoDemePerpetualIslandModel();

// over-write the fixture so that the initial pop is okay
pop = fwdpy11::DiploidPopulation({100, 100}, 10.0);
fwdpy11_core::ForwardDemesGraph forward_demes_graph(model.yaml, 10);
// over-write the fixture so that the initial pop is okay
pop = fwdpy11::DiploidPopulation(forward_demes_graph.parental_deme_sizes_at_time_zero(),
10.0);

// TODO: if we put long run times in here, we get exceptions
// from the ForwardDemesGraph back end.
Expand Down
1 change: 1 addition & 0 deletions cpptests/test_forward_demes_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ BOOST_FIXTURE_TEST_CASE(single_deme_model_with_burn_in, SingleDemeModel)
fwdpy11_core::ForwardDemesGraph g(yaml, 10);
BOOST_REQUIRE_EQUAL(g.number_of_demes(), 1);
fwdpy11::DiploidPopulation pop(100, 1.0);
BOOST_CHECK_EQUAL(pop.N, g.sum_deme_sizes_at_time_zero());
g.initialize_model(pop.generation);
auto end_time = g.model_end_time();
BOOST_REQUIRE_EQUAL(end_time, 11);
Expand Down
186 changes: 5 additions & 181 deletions fwdpy11/_functions/import_demes.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import math
from typing import Dict, Optional, Union

import attr
import demes
import numpy as np

from fwdpy11._types.demographic_model_citation import DemographicModelCitation
from fwdpy11._types.demographic_model_details import DemographicModelDetails
from fwdpy11._types.forward_demes_graph import ForwardDemesGraph


# TODO: need type hints for dg
def demography_from_demes(
dg: Union[str, demes.Graph], burnin: int,
round_non_integer_sizes=Optional[bool],
Expand Down Expand Up @@ -51,34 +47,14 @@ def _build_from_foward_demes_graph(
The workhorse.
"""
idmap = _build_deme_id_to_int_map(fg.graph)
initial_sizes = _get_initial_deme_sizes(fg.graph, idmap)
Nref = _get_ancestral_population_size(fg.graph)

burnin_generation = int(np.rint(burnin * Nref))
model_times = _ModelTimes.from_demes_graph(fg.graph, burnin_generation)

# TODO: size_history now contains model_times, so passing
# the latter into functions is redundant.
# We should clean this up later.
# size_history = _DemeSizeHistory.from_demes_graph(
# dg, burnin, idmap, model_times)
# assert size_history.model_times is not None

# _set_initial_migration_matrix(dg, idmap, events, size_history)
# _process_all_epochs(dg, idmap, model_times, events, size_history)
# _process_migrations(dg, idmap, model_times, events, size_history)
# _process_pulses(dg, idmap, model_times, events, size_history)
# _process_admixtures(dg, dg_events, idmap,
# model_times, events, size_history)
# _process_mergers(dg, dg_events, idmap, model_times, events, size_history)
# _process_splits(dg, dg_events, idmap, model_times, events, size_history)
# _process_branches(dg, dg_events, idmap, model_times, events, size_history)

if fg.graph.doi != "None":
doi = fg.graph.doi
else:
doi = None

_initial_sizes = {i: j for i, j in enumerate(
fg._parental_deme_sizes_at_time_zero()) if j > 0}
return DemographicModelDetails(
model=fg,
name=fg.graph.description,
Expand All @@ -88,133 +64,13 @@ def _build_from_foward_demes_graph(
DOI=doi, full_citation=None, metadata=None),
metadata={
"deme_labels": {j: i for i, j in idmap.items()},
"initial_sizes": initial_sizes,
"burnin_time": burnin_generation,
"total_simulation_length": burnin_generation
+ model_times.model_duration
- 1,
"initial_sizes": _initial_sizes,
"burnin_time": fg.burnin_generation,
"total_simulation_length": fg._model_end_time() - 1
},
)


@attr.s(frozen=True, auto_attribs=True)
class _ModelTimes(object):
"""
These are in units of the deme graph
and increase backwards into the past.
"""

model_start_time: demes.demes.Time
model_end_time: demes.demes.Time
model_duration: int = attr.ib(validator=attr.validators.instance_of(int))
burnin_generation: int = attr.ib(
validator=attr.validators.instance_of(int))

@staticmethod
def from_demes_graph(dg: demes.Graph, burnin_generation: int) -> "_ModelTimes":
"""
In units of dg.time_units, obtain the following:
1. The time when the demographic model starts.
2. The time when it ends.
3. The total simulation length.
"""
# FIXME: this function isn't working well.
# For example, twodemes.yml and twodemes_one_goes_away.yml
# both break it.
oldest_deme_time = _get_most_ancient_deme_start_time(dg)
most_recent_deme_end = _get_most_recent_deme_end_time(dg)

model_start_time = oldest_deme_time
if oldest_deme_time == math.inf:
# We want to find the time of first event or
# the first demographic change, which is when
# burnin will end. To do this, get a list of
# first size change for all demes with inf
# start time, and the start time for all other
# demes, and take max of those.
ends_inf = [
d.epochs[0].end_time for d in dg.demes if d.start_time == math.inf
]
starts = [d.start_time for d in dg.demes if d.start_time != math.inf]
mig_starts = [
m.start_time for m in dg.migrations if m.start_time != math.inf
]
mig_ends = [
m.end_time for m in dg.migrations if m.start_time == math.inf]
pulse_times = [p.time for p in dg.pulses]
# The forward-time model with start with a generation 0,
# which is the earliest end point of a deme with start time
# of inf, minus 1. That definition is forwards in time, so we
# ADD one to the backwards-in-time demes info.
model_start_time = (
max(ends_inf + starts + mig_starts + mig_ends + pulse_times) + 1
)

if most_recent_deme_end != 0:
model_duration = model_start_time - most_recent_deme_end
else:
model_duration = model_start_time

return _ModelTimes(
model_start_time=model_start_time,
model_end_time=most_recent_deme_end,
model_duration=int(np.rint(model_duration)),
burnin_generation=burnin_generation,
)

def convert_time(self, demes_event_time: float) -> int:
"""
Backwards time -> forwards time
"""
if demes_event_time != math.inf:
return self.burnin_generation + int(
self.model_start_time - demes_event_time - 1
)

return 0


@attr.s(auto_attribs=True)
class _MigrationRateChange(object):
"""
Use to make registry of migration rate changes.
"""

when: int = attr.ib(
validator=[demes.demes.non_negative, attr.validators.instance_of(int)]
)
source: int = attr.ib(
validator=[demes.demes.non_negative, attr.validators.instance_of(int)]
)
destination: int = attr.ib(
validator=[demes.demes.non_negative, attr.validators.instance_of(int)]
)
rate_change: float = attr.ib(
validator=[attr.validators.instance_of(float)])
from_deme_graph: bool = attr.ib(
validator=attr.validators.instance_of(bool))


def _get_initial_deme_sizes(dg: demes.Graph, idmap: Dict) -> Dict:
"""
Build a map of a deme's integer label to its size
at the start of the simulation for all demes whose
start_time equals inf.
"""
otime = _get_most_ancient_deme_start_time(dg)
rv = dict()
for deme in dg.demes:
if deme.epochs[0].start_time == otime:
rv[idmap[deme.name]] = int(np.rint(deme.epochs[0].start_size))

if len(rv) == 0:
raise RuntimeError("could not determine initial deme sizes")

return rv


def _build_deme_id_to_int_map(dg: demes.Graph) -> Dict:
"""
Convert the string input ID to output integer values.
Expand All @@ -235,35 +91,3 @@ def _build_deme_id_to_int_map(dg: demes.Graph) -> Dict:
temp = sorted(temp, key=lambda x: -x[0])

return {j[1]: i for i, j in enumerate(temp)}


def _get_most_ancient_deme_start_time(dg: demes.Graph) -> demes.demes.Time:
return max([d.start_time for d in dg.demes])


def _get_most_recent_deme_end_time(dg: demes.Graph) -> demes.demes.Time:
return min([d.end_time for d in dg.demes])


def _get_ancestral_population_size(dg: demes.Graph) -> int:
"""
Need this for the burnin time.
If there are > 1 demes with the same most ancient start_time,
then the ancestral size is considered to be the size
of all those demes (size of ancestral metapopulation).
"""
oldest_deme_time = _get_most_ancient_deme_start_time(dg)

rv = sum(
[
int(np.rint(e.start_size))
for d in dg.demes
for e in d.epochs
if e.start_time == oldest_deme_time
]
)
if rv == 0:
raise ValueError(
"could not determinine ancestral metapopulation size")
return rv
72 changes: 38 additions & 34 deletions fwdpy11/_types/forward_demes_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import decimal
import typing
import warnings

import attr
import demes
Expand All @@ -13,6 +13,12 @@
import fwdpy11


def _round_via_decimal(value):
with decimal.localcontext() as ctx:
ctx.rounding = decimal.ROUND_HALF_UP
return int(decimal.Decimal(value).to_integral_value())


@attr.s(repr_ns="fwdpy11")
@attr_class_pickle_with_super
@attr_class_to_from_dict
Expand Down Expand Up @@ -44,49 +50,47 @@ class ForwardDemesGraph(fwdpy11._fwdpy11._ForwardDemesGraph):
burnin_is_exact: int = attr.ib()
round_non_integer_sizes: bool = attr.ib()
graph: demes.Graph
burnin_generation: int

def __attrs_post_init__(self):
from fwdpy11._functions.import_demes import _get_ancestral_population_size
self.graph = demes.loads(self.yaml)
if self.round_non_integer_sizes is False:
ForwardDemesGraph._reject_non_integer_sizes(self.graph)

ForwardDemesGraph._validate_pulses(self.graph)
Nref = _get_ancestral_population_size(self.graph)
Nref = self._get_ancestral_population_size(self.graph)
assert np.modf(Nref)[0] == 0.0
if self.burnin_is_exact is True:
burnin = self.burnin
else:
burnin = int(np.rint(self.burnin)*Nref)
burnin = self.burnin*Nref
self.burnin_generation = burnin
super(ForwardDemesGraph, self).__init__(
self.yaml, burnin, self.round_non_integer_sizes)
x = self._sum_deme_sizes_at_time_zero()
assert Nref == x, f"{Nref}, {x}, {self.yaml}"

def __get_most_ancient_deme_start_time(self, dg: demes.Graph) -> demes.demes.Time:
return max([d.start_time for d in dg.demes])

def _validate_pulses(graph: demes.Graph):
unique_pulse_times = set([np.rint(p.time) for p in graph.pulses])
for time in unique_pulse_times:
pulses = [p for p in graph.pulses if np.rint(p.time) == time]
dests = set()
for p in pulses:
if p.dest in dests:
warnings.warn(
f"multiple pulse events into deme {p.dest} at time {time}."
+ " The effect of these pulses will depend on the order "
+ "in which they are applied."
+ "To avoid unexpected behavior, "
+ "the graph can instead be structured to"
+ " introduce a new deme at this time with"
+ " the desired ancestry proportions or to specify"
+ " concurrent pulses with multiple sources.",
UserWarning)
dests.add(p.dest)

def _reject_non_integer_sizes(graph: demes.Graph):
for deme in graph.demes:
for i, epoch in enumerate(deme.epochs):
for size in [epoch.start_size, epoch.end_size]:
if np.isfinite(size) and np.modf(size)[0] != 0.0:
raise ValueError(
f"deme {deme.name} has non-integer size {size} in epoch {i}")
def _get_ancestral_population_size(self, dg: demes.Graph) -> int:
"""
Need this for the burnin time.
If there are > 1 demes with the same most ancient start_time,
then the ancestral size is considered to be the size
of all those demes (size of ancestral metapopulation).
"""
oldest_deme_time = self.__get_most_ancient_deme_start_time(dg)

rv = sum(
[
_round_via_decimal(e.start_size)
for d in dg.demes
for e in d.epochs
if e.start_time == oldest_deme_time
]
)
if rv == 0:
raise ValueError(
"could not determinine ancestral metapopulation size")
return rv

def number_of_demes(self) -> int:
return len(self.graph.demes)
Expand Down
6 changes: 6 additions & 0 deletions lib/core/demes/forward_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <memory>
#include <cstdint>
#include <string>
#include <vector>

namespace fwdpy11_core
{
Expand Down Expand Up @@ -51,5 +52,10 @@ namespace fwdpy11_core
ForwardDemesGraphDataIterator<double> offspring_cloning_rates() const;
ForwardDemesGraphDataIterator<double>
offspring_ancestry_proportions(std::size_t offspring_deme) const;

// TODO: this can become const once upstream
// gets a fn to handle this
std::uint32_t sum_deme_sizes_at_time_zero();
std::vector<std::uint32_t> parental_deme_sizes_at_time_zero() const;
};
}

0 comments on commit 9c6d064

Please sign in to comment.