Skip to content

Commit

Permalink
Use rust back end for more of the work. (#1103)
Browse files Browse the repository at this point in the history
Result is fewer passes through graphs in Python

* bump demes-forward-capi dependency
* add back end function to get sum of ancestral sizes at time 0
* replace rint with Decimal internally
* get model duration, parental sizes at time 0 from rust
* delete lots of Python helper code from old API
  that rust tools replace
  • Loading branch information
molpopgen committed Mar 15, 2023
1 parent dd687ed commit 0af4b04
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 219 deletions.
8 changes: 7 additions & 1 deletion cpp/demes/forward_graph.cc
@@ -1,5 +1,6 @@
#include <cstdint>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <core/demes/forward_graph.hpp>

namespace py = pybind11;
Expand All @@ -11,5 +12,10 @@ init_forward_graph(py::module &m)
.def(py::init<const std::string & /*yaml*/, std::uint32_t /*burnin*/,
bool /*round_epoch_sizes*/>(),
py::arg("yaml"), py::arg("burnin"), py::arg("round_epoch_sizes"),
py::kw_only());
py::kw_only())
.def("_sum_deme_sizes_at_time_zero",
&fwdpy11_core::ForwardDemesGraph::sum_deme_sizes_at_time_zero)
.def("_model_end_time", &fwdpy11_core::ForwardDemesGraph::model_end_time)
.def("_parental_deme_sizes_at_time_zero",
&fwdpy11_core::ForwardDemesGraph::parental_deme_sizes_at_time_zero);
}
5 changes: 3 additions & 2 deletions cpptests/test_evolvets.cc
Expand Up @@ -285,9 +285,10 @@ BOOST_FIXTURE_TEST_CASE(test_basic_api_coherence_two_deme_perpetual_island_model
{
auto model = TwoDemePerpetualIslandModel();

// over-write the fixture so that the initial pop is okay
pop = fwdpy11::DiploidPopulation({100, 100}, 10.0);
fwdpy11_core::ForwardDemesGraph forward_demes_graph(model.yaml, 10);
// over-write the fixture so that the initial pop is okay
pop = fwdpy11::DiploidPopulation(forward_demes_graph.parental_deme_sizes_at_time_zero(),
10.0);

// TODO: if we put long run times in here, we get exceptions
// from the ForwardDemesGraph back end.
Expand Down
1 change: 1 addition & 0 deletions cpptests/test_forward_demes_graph.cc
Expand Up @@ -31,6 +31,7 @@ BOOST_FIXTURE_TEST_CASE(single_deme_model_with_burn_in, SingleDemeModel)
fwdpy11_core::ForwardDemesGraph g(yaml, 10);
BOOST_REQUIRE_EQUAL(g.number_of_demes(), 1);
fwdpy11::DiploidPopulation pop(100, 1.0);
BOOST_CHECK_EQUAL(pop.N, g.sum_deme_sizes_at_time_zero());
g.initialize_model(pop.generation);
auto end_time = g.model_end_time();
BOOST_REQUIRE_EQUAL(end_time, 11);
Expand Down
186 changes: 5 additions & 181 deletions fwdpy11/_functions/import_demes.py
@@ -1,16 +1,12 @@
import math
from typing import Dict, Optional, Union

import attr
import demes
import numpy as np

from fwdpy11._types.demographic_model_citation import DemographicModelCitation
from fwdpy11._types.demographic_model_details import DemographicModelDetails
from fwdpy11._types.forward_demes_graph import ForwardDemesGraph


# TODO: need type hints for dg
def demography_from_demes(
dg: Union[str, demes.Graph], burnin: int,
round_non_integer_sizes=Optional[bool],
Expand Down Expand Up @@ -51,34 +47,14 @@ def _build_from_foward_demes_graph(
The workhorse.
"""
idmap = _build_deme_id_to_int_map(fg.graph)
initial_sizes = _get_initial_deme_sizes(fg.graph, idmap)
Nref = _get_ancestral_population_size(fg.graph)

burnin_generation = int(np.rint(burnin * Nref))
model_times = _ModelTimes.from_demes_graph(fg.graph, burnin_generation)

# TODO: size_history now contains model_times, so passing
# the latter into functions is redundant.
# We should clean this up later.
# size_history = _DemeSizeHistory.from_demes_graph(
# dg, burnin, idmap, model_times)
# assert size_history.model_times is not None

# _set_initial_migration_matrix(dg, idmap, events, size_history)
# _process_all_epochs(dg, idmap, model_times, events, size_history)
# _process_migrations(dg, idmap, model_times, events, size_history)
# _process_pulses(dg, idmap, model_times, events, size_history)
# _process_admixtures(dg, dg_events, idmap,
# model_times, events, size_history)
# _process_mergers(dg, dg_events, idmap, model_times, events, size_history)
# _process_splits(dg, dg_events, idmap, model_times, events, size_history)
# _process_branches(dg, dg_events, idmap, model_times, events, size_history)

if fg.graph.doi != "None":
doi = fg.graph.doi
else:
doi = None

_initial_sizes = {i: j for i, j in enumerate(
fg._parental_deme_sizes_at_time_zero()) if j > 0}
return DemographicModelDetails(
model=fg,
name=fg.graph.description,
Expand All @@ -88,133 +64,13 @@ def _build_from_foward_demes_graph(
DOI=doi, full_citation=None, metadata=None),
metadata={
"deme_labels": {j: i for i, j in idmap.items()},
"initial_sizes": initial_sizes,
"burnin_time": burnin_generation,
"total_simulation_length": burnin_generation
+ model_times.model_duration
- 1,
"initial_sizes": _initial_sizes,
"burnin_time": fg.burnin_generation,
"total_simulation_length": fg._model_end_time() - 1
},
)


@attr.s(frozen=True, auto_attribs=True)
class _ModelTimes(object):
"""
These are in units of the deme graph
and increase backwards into the past.
"""

model_start_time: demes.demes.Time
model_end_time: demes.demes.Time
model_duration: int = attr.ib(validator=attr.validators.instance_of(int))
burnin_generation: int = attr.ib(
validator=attr.validators.instance_of(int))

@staticmethod
def from_demes_graph(dg: demes.Graph, burnin_generation: int) -> "_ModelTimes":
"""
In units of dg.time_units, obtain the following:
1. The time when the demographic model starts.
2. The time when it ends.
3. The total simulation length.
"""
# FIXME: this function isn't working well.
# For example, twodemes.yml and twodemes_one_goes_away.yml
# both break it.
oldest_deme_time = _get_most_ancient_deme_start_time(dg)
most_recent_deme_end = _get_most_recent_deme_end_time(dg)

model_start_time = oldest_deme_time
if oldest_deme_time == math.inf:
# We want to find the time of first event or
# the first demographic change, which is when
# burnin will end. To do this, get a list of
# first size change for all demes with inf
# start time, and the start time for all other
# demes, and take max of those.
ends_inf = [
d.epochs[0].end_time for d in dg.demes if d.start_time == math.inf
]
starts = [d.start_time for d in dg.demes if d.start_time != math.inf]
mig_starts = [
m.start_time for m in dg.migrations if m.start_time != math.inf
]
mig_ends = [
m.end_time for m in dg.migrations if m.start_time == math.inf]
pulse_times = [p.time for p in dg.pulses]
# The forward-time model with start with a generation 0,
# which is the earliest end point of a deme with start time
# of inf, minus 1. That definition is forwards in time, so we
# ADD one to the backwards-in-time demes info.
model_start_time = (
max(ends_inf + starts + mig_starts + mig_ends + pulse_times) + 1
)

if most_recent_deme_end != 0:
model_duration = model_start_time - most_recent_deme_end
else:
model_duration = model_start_time

return _ModelTimes(
model_start_time=model_start_time,
model_end_time=most_recent_deme_end,
model_duration=int(np.rint(model_duration)),
burnin_generation=burnin_generation,
)

def convert_time(self, demes_event_time: float) -> int:
"""
Backwards time -> forwards time
"""
if demes_event_time != math.inf:
return self.burnin_generation + int(
self.model_start_time - demes_event_time - 1
)

return 0


@attr.s(auto_attribs=True)
class _MigrationRateChange(object):
"""
Use to make registry of migration rate changes.
"""

when: int = attr.ib(
validator=[demes.demes.non_negative, attr.validators.instance_of(int)]
)
source: int = attr.ib(
validator=[demes.demes.non_negative, attr.validators.instance_of(int)]
)
destination: int = attr.ib(
validator=[demes.demes.non_negative, attr.validators.instance_of(int)]
)
rate_change: float = attr.ib(
validator=[attr.validators.instance_of(float)])
from_deme_graph: bool = attr.ib(
validator=attr.validators.instance_of(bool))


def _get_initial_deme_sizes(dg: demes.Graph, idmap: Dict) -> Dict:
"""
Build a map of a deme's integer label to its size
at the start of the simulation for all demes whose
start_time equals inf.
"""
otime = _get_most_ancient_deme_start_time(dg)
rv = dict()
for deme in dg.demes:
if deme.epochs[0].start_time == otime:
rv[idmap[deme.name]] = int(np.rint(deme.epochs[0].start_size))

if len(rv) == 0:
raise RuntimeError("could not determine initial deme sizes")

return rv


def _build_deme_id_to_int_map(dg: demes.Graph) -> Dict:
"""
Convert the string input ID to output integer values.
Expand All @@ -235,35 +91,3 @@ def _build_deme_id_to_int_map(dg: demes.Graph) -> Dict:
temp = sorted(temp, key=lambda x: -x[0])

return {j[1]: i for i, j in enumerate(temp)}


def _get_most_ancient_deme_start_time(dg: demes.Graph) -> demes.demes.Time:
return max([d.start_time for d in dg.demes])


def _get_most_recent_deme_end_time(dg: demes.Graph) -> demes.demes.Time:
return min([d.end_time for d in dg.demes])


def _get_ancestral_population_size(dg: demes.Graph) -> int:
"""
Need this for the burnin time.
If there are > 1 demes with the same most ancient start_time,
then the ancestral size is considered to be the size
of all those demes (size of ancestral metapopulation).
"""
oldest_deme_time = _get_most_ancient_deme_start_time(dg)

rv = sum(
[
int(np.rint(e.start_size))
for d in dg.demes
for e in d.epochs
if e.start_time == oldest_deme_time
]
)
if rv == 0:
raise ValueError(
"could not determinine ancestral metapopulation size")
return rv
72 changes: 38 additions & 34 deletions fwdpy11/_types/forward_demes_graph.py
@@ -1,5 +1,5 @@
import decimal
import typing
import warnings

import attr
import demes
Expand All @@ -13,6 +13,12 @@
import fwdpy11


def _round_via_decimal(value):
with decimal.localcontext() as ctx:
ctx.rounding = decimal.ROUND_HALF_UP
return int(decimal.Decimal(value).to_integral_value())


@attr.s(repr_ns="fwdpy11")
@attr_class_pickle_with_super
@attr_class_to_from_dict
Expand Down Expand Up @@ -44,49 +50,47 @@ class ForwardDemesGraph(fwdpy11._fwdpy11._ForwardDemesGraph):
burnin_is_exact: int = attr.ib()
round_non_integer_sizes: bool = attr.ib()
graph: demes.Graph
burnin_generation: int

def __attrs_post_init__(self):
from fwdpy11._functions.import_demes import _get_ancestral_population_size
self.graph = demes.loads(self.yaml)
if self.round_non_integer_sizes is False:
ForwardDemesGraph._reject_non_integer_sizes(self.graph)

ForwardDemesGraph._validate_pulses(self.graph)
Nref = _get_ancestral_population_size(self.graph)
Nref = self._get_ancestral_population_size(self.graph)
assert np.modf(Nref)[0] == 0.0
if self.burnin_is_exact is True:
burnin = self.burnin
else:
burnin = int(np.rint(self.burnin)*Nref)
burnin = self.burnin*Nref
self.burnin_generation = burnin
super(ForwardDemesGraph, self).__init__(
self.yaml, burnin, self.round_non_integer_sizes)
x = self._sum_deme_sizes_at_time_zero()
assert Nref == x, f"{Nref}, {x}, {self.yaml}"

def __get_most_ancient_deme_start_time(self, dg: demes.Graph) -> demes.demes.Time:
return max([d.start_time for d in dg.demes])

def _validate_pulses(graph: demes.Graph):
unique_pulse_times = set([np.rint(p.time) for p in graph.pulses])
for time in unique_pulse_times:
pulses = [p for p in graph.pulses if np.rint(p.time) == time]
dests = set()
for p in pulses:
if p.dest in dests:
warnings.warn(
f"multiple pulse events into deme {p.dest} at time {time}."
+ " The effect of these pulses will depend on the order "
+ "in which they are applied."
+ "To avoid unexpected behavior, "
+ "the graph can instead be structured to"
+ " introduce a new deme at this time with"
+ " the desired ancestry proportions or to specify"
+ " concurrent pulses with multiple sources.",
UserWarning)
dests.add(p.dest)

def _reject_non_integer_sizes(graph: demes.Graph):
for deme in graph.demes:
for i, epoch in enumerate(deme.epochs):
for size in [epoch.start_size, epoch.end_size]:
if np.isfinite(size) and np.modf(size)[0] != 0.0:
raise ValueError(
f"deme {deme.name} has non-integer size {size} in epoch {i}")
def _get_ancestral_population_size(self, dg: demes.Graph) -> int:
"""
Need this for the burnin time.
If there are > 1 demes with the same most ancient start_time,
then the ancestral size is considered to be the size
of all those demes (size of ancestral metapopulation).
"""
oldest_deme_time = self.__get_most_ancient_deme_start_time(dg)

rv = sum(
[
_round_via_decimal(e.start_size)
for d in dg.demes
for e in d.epochs
if e.start_time == oldest_deme_time
]
)
if rv == 0:
raise ValueError(
"could not determinine ancestral metapopulation size")
return rv

def number_of_demes(self) -> int:
return len(self.graph.demes)
Expand Down
6 changes: 6 additions & 0 deletions lib/core/demes/forward_graph.hpp
Expand Up @@ -3,6 +3,7 @@
#include <memory>
#include <cstdint>
#include <string>
#include <vector>

namespace fwdpy11_core
{
Expand Down Expand Up @@ -51,5 +52,10 @@ namespace fwdpy11_core
ForwardDemesGraphDataIterator<double> offspring_cloning_rates() const;
ForwardDemesGraphDataIterator<double>
offspring_ancestry_proportions(std::size_t offspring_deme) const;

// TODO: this can become const once upstream
// gets a fn to handle this
std::uint32_t sum_deme_sizes_at_time_zero();
std::vector<std::uint32_t> parental_deme_sizes_at_time_zero() const;
};
}

0 comments on commit 0af4b04

Please sign in to comment.