Skip to content

Commit

Permalink
support for recording samples
Browse files Browse the repository at this point in the history
  • Loading branch information
molpopgen committed Nov 2, 2021
1 parent ed484cf commit 2da92e2
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
8 changes: 8 additions & 0 deletions fwdpy11/conditional_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# You should have received a copy of the GNU General Public License
# along with fwdpy11. If not, see <http://www.gnu.org/licenses/>.
#
from enum import Enum

import attr
import fwdpy11
from fwdpy11.class_decorators import attr_class_to_from_dict
Expand All @@ -25,6 +27,12 @@ class AddMutationFailure(Exception):
pass


class AncientSamplePolicy(Enum):
NEVER = 0
DURATION = 1
COMPLETION = 2


class GlobalFixation(object):
def __call__(self, pop: fwdpy11.DiploidPopulation, index: int, key: tuple) -> bool:
if pop.mutations[index].key != key:
Expand Down
4 changes: 3 additions & 1 deletion fwdpy11/conditional_models/_selective_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import typing

import fwdpy11
from fwdpy11.conditional_models import EvolveOptions
from fwdpy11.conditional_models import AncientSamplePolicy, EvolveOptions

from ._track_variant import _track_variant

Expand All @@ -40,6 +40,7 @@ def _selective_sweep(
*,
when: typing.Optional[int] = None,
max_attempts: typing.Optional[int] = None,
sampling_policy: typing.Optional[AncientSamplePolicy] = None,
evolvets_options: typing.Optional[EvolveOptions] = None,
) -> typing.Tuple[fwdpy11._types.DiploidPopulation, int, int]:

Expand All @@ -53,4 +54,5 @@ def _selective_sweep(
duration=None,
max_attempts=max_attempts,
evolvets_options=evolvets_options,
sampling_policy=sampling_policy,
)
36 changes: 32 additions & 4 deletions fwdpy11/conditional_models/_track_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
#
import copy
import typing
from dataclasses import dataclass

import attr
import fwdpy11
import numpy as np
from fwdpy11.class_decorators import attr_class_to_from_dict
from fwdpy11.conditional_models import AddMutationFailure, EvolveOptions
from fwdpy11.conditional_models import (AddMutationFailure,
AncientSamplePolicy, EvolveOptions)


@attr_class_to_from_dict
Expand All @@ -45,16 +46,33 @@ def __call__(self, pop: fwdpy11.DiploidPopulation, _) -> bool:

@attr.s(auto_attribs=True)
class _Recorder:
when: typing.Optional[int]
duration: typing.Optional[int]
criterion: typing.Callable[
[fwdpy11.DiploidPopulation, int, typing.Tuple[float, float, int]], bool
]
monitor: _ProgressMonitor
sampling_policy: AncientSamplePolicy

def __call__(self, pop, _) -> None:
def __attrs_post_init__(self):
if self.sampling_policy != AncientSamplePolicy.NEVER:
if self.when is None or self.when < 0:
raise ValueError(
f"sampling policy is {self.sampling_policy} and when is {self.when}"
)

def __call__(self, pop, sampler) -> None:
self.monitor.finished = self.criterion(
pop, self.monitor.index, self.monitor.key
)

if self.sampling_policy == AncientSamplePolicy.DURATION or (
self.sampling_policy == AncientSamplePolicy.COMPLETION
and self.monitor.finished is True
):
# Record all alive individuals
sampler.assign(np.array([i for i in range(pop.N)], dtype=np.uint32))


def _copy_pop_and_add_mutation(
rng,
Expand Down Expand Up @@ -129,6 +147,7 @@ def _track_variant(
when: typing.Optional[int] = None,
duration: typing.Optional[int] = None,
max_attempts: typing.Optional[int] = None,
sampling_policy: typing.Optional[AncientSamplePolicy] = None,
evolvets_options: typing.Optional[EvolveOptions] = None,
) -> typing.Tuple[fwdpy11._types.DiploidPopulation, int, int]:

Expand Down Expand Up @@ -158,8 +177,17 @@ def _track_variant(
if idx is None:
raise AddMutationFailure()

if sampling_policy is None:
_sampling_policy = AncientSamplePolicy.NEVER
else:
_sampling_policy = sampling_policy

recorder = _Recorder(
stopping_condition, _ProgressMonitor(idx, pcopy.mutations[idx].key, False)
when,
duration,
stopping_condition,
_ProgressMonitor(idx, pcopy.mutations[idx].key, False),
_sampling_policy,
)

internal_options = _InternalSweepEvolveOptions()
Expand Down

0 comments on commit 2da92e2

Please sign in to comment.