Skip to content

Commit

Permalink
Changes to fwdpy11.conditional_models: (#1161)
Browse files Browse the repository at this point in the history
* SimulationStatus is now an enum.
* Simplify implementations of some of the condition monitors.
* Update vignette

Closes #934
  • Loading branch information
molpopgen committed Aug 16, 2023
1 parent e1676b2 commit f1afeb7
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 76 deletions.
9 changes: 8 additions & 1 deletion doc/misc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@ updates to latest `fwdpp` version, etc.

## Next release

# New features
Breaking changes

* Refactor {class}`fwdpy11.conditional_models.SimulationStatus` as an enum.
This change makes correct use much easier.
PR {pr}`1161`.
Issue {issue}`934`.

New features

* {func}`fwdpy11.DiploidPopulation.create_from_tskit` is now able to restore
individual metadata, populating {attr}`fwdpy11.DiploidPopulation.diploid_metadata`
Expand Down
8 changes: 4 additions & 4 deletions doc/short_vignettes/incomplete_sweep.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ class IncompleteSweep(object):
if pop.mutations[index].key != key:
# it is fixed or lost, neither of
# which we want
return fwdpy11.conditional_models.SimulationStatus(False, False)
return fwdpy11.conditional_models.SimulationStatus.Restart
if pop.mcounts[index] == 0:
return fwdpy11.conditional_models.SimulationStatus(True, False)
return fwdpy11.conditional_models.SimulationStatus.Restart
# Terminate the first time we see the
# variant get about a freq of 0.25
if pop.mcounts[index] / 2 / pop.N >= 0.25:
return fwdpy11.conditional_models.SimulationStatus(False, True)
return fwdpy11.conditional_models.SimulationStatus.Success
# make sure there's a valid return value
return fwdpy11.conditional_models.SimulationStatus(False, False)
return fwdpy11.conditional_models.SimulationStatus.Continue
L = 10000.0
ttl_rec_rate = 1e-5*L
Expand Down
52 changes: 11 additions & 41 deletions fwdpy11/conditional_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,10 @@ class NewMutationParameters:
)


@attr.s(frozen=True)
class SimulationStatus:
"""
The return value of a stopping condition callable.
:param should_terminate: Set to `True` if the simulation
should be terminated.
:param condition_met: Set to `True` if the stopping
condition has been met.
For examples, see implementations of :class:`GlobalFixation`
and :class:`FocalDemeFixation`.
"""

should_terminate: bool = attr.ib(
validator=attr.validators.instance_of(bool))
condition_met: bool = attr.ib(validator=attr.validators.instance_of(bool))
class SimulationStatus(Enum):
Restart = 0
Continue = 1
Success = 2


@attr.s(auto_attribs=True, kw_only=True)
Expand All @@ -247,21 +234,11 @@ class GlobalFixation(object):
def __call__(
self, pop, index: int, key: tuple
) -> SimulationStatus:
if pop.mutations[index].key != key:
# The key has changed, meaning the mutation is
# flagged for recycling.
# First, check if it is in the fixations list
for m in pop.fixations:
if m.key == key:
# It is fixed, so we are done
return SimulationStatus(True, False)
# The mutation is gone from the simulation
return SimulationStatus(True, False)
if pop.mcounts[index] == 0:
return SimulationStatus(True, False)
return SimulationStatus.Restart
if pop.mcounts[index] == 2 * pop.N:
return SimulationStatus(False, True)
return SimulationStatus(False, False)
return SimulationStatus.Success
return SimulationStatus.Continue


@attr.s(auto_attribs=True, frozen=True)
Expand All @@ -276,15 +253,8 @@ class FocalDemeFixation:

def __call__(self, pop, index, key) -> SimulationStatus:
deme_sizes = pop.deme_sizes(as_dict=True)
if pop.mutations[index].key != key:
# check for a global fixation
for m in pop.fixations:
if m.key == key:
return SimulationStatus(True, False)
# The mutation is gone from the simulation
return SimulationStatus(True, False)
if self.deme not in deme_sizes:
return SimulationStatus(False, False)
return SimulationStatus.Restart
count = 0

# TODO: we can probably do better here
Expand All @@ -299,10 +269,10 @@ def __call__(self, pop, index, key) -> SimulationStatus:
count += 1

if count == 0:
return SimulationStatus(True, False)
return SimulationStatus.Restart
if count == 2 * deme_sizes[self.deme]:
return SimulationStatus(False, True)
return SimulationStatus(False, False)
return SimulationStatus.Success
return SimulationStatus.Continue


@attr_class_to_from_dict
Expand Down
60 changes: 30 additions & 30 deletions fwdpy11/conditional_models/_track_added_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@
import fwdpy11
import numpy as np
from fwdpy11.class_decorators import attr_class_to_from_dict
from fwdpy11.conditional_models import (AddMutationFailure, AlleleCount,
AlleleCountRange, AncientSamplePolicy,
ConditionalModelOutput, EvolveOptions,
FrequencyRange, NewMutationParameters,
OutOfAttempts, SimulationStatus,
_non_negative_value)
from fwdpy11.conditional_models import (
AddMutationFailure,
AlleleCount,
AlleleCountRange,
AncientSamplePolicy,
ConditionalModelOutput,
EvolveOptions,
FrequencyRange,
NewMutationParameters,
OutOfAttempts,
SimulationStatus,
_non_negative_value,
)


@attr_class_to_from_dict
Expand All @@ -51,25 +58,23 @@ class _ProgressMonitor:

def __call__(self, pop: fwdpy11.DiploidPopulation, _) -> bool:
if (
self.status.condition_met is True
self.status == SimulationStatus.Success
and self.return_when_stopping_condition_met is True
):
return True
return self.status.should_terminate
return self.status == SimulationStatus.Restart


@attr.s(auto_attribs=True)
class _MutationPresent:
when: int = attr.ib(
validator=[attr.validators.instance_of(int), _non_negative_value]
)
until: typing.Optional[int] = attr.ib(
attr.validators.optional(int)) # type: ignore
until: typing.Optional[int] = attr.ib(attr.validators.optional(int)) # type: ignore

def __attrs_post_init__(self):
if self.until is None:
raise ValueError(
"until cannot be None if stopping_condition is None")
raise ValueError("until cannot be None if stopping_condition is None")
if self.until is not None and self.until <= self.when:
raise ValueError("until must be > when")

Expand All @@ -79,17 +84,15 @@ def is_fixed(self, pop: fwdpy11.DiploidPopulation, key: tuple) -> bool:
return True
return False

def __call__(
self, pop, index: int, key: tuple
) -> SimulationStatus:
def __call__(self, pop, index: int, key: tuple) -> SimulationStatus:
if pop.generation == self.until:
if pop.mutations[index].key != key and self.is_fixed(pop, key) is False:
return SimulationStatus(True, False)
return SimulationStatus.Restart
if pop.mcounts[index] > 0:
return SimulationStatus(False, True)
return SimulationStatus.Success
if self.is_fixed(pop, key):
return SimulationStatus(False, True)
return SimulationStatus(False, False)
return SimulationStatus.Success
return SimulationStatus.Continue


@attr.s(auto_attribs=True)
Expand All @@ -111,7 +114,7 @@ def __attrs_post_init__(self):
)

def __call__(self, pop, sampler) -> None:
if self.monitor.status.condition_met is False:
if self.monitor.status != SimulationStatus.Success:
self.monitor.status = self.criterion(
pop, self.monitor.index, self.monitor.key
)
Expand All @@ -126,7 +129,7 @@ def __call__(self, pop, sampler) -> None:

y = (
self.sampling_policy == AncientSamplePolicy.COMPLETION
and self.monitor.status.condition_met is True
and self.monitor.status == SimulationStatus.Success
and (self.until is None or pop.generation == self.until)
)

Expand Down Expand Up @@ -206,8 +209,7 @@ def _get_allele_count_range(
)
return _integer_count_details(lo, hi, mutation_parameters.deme, pop)
else:
raise TypeError(
f"unsupported type {type(mutation_parameters.frequency)}")
raise TypeError(f"unsupported type {type(mutation_parameters.frequency)}")


def _copy_pop_and_add_mutation(
Expand Down Expand Up @@ -249,12 +251,10 @@ def _copy_pop_and_add_mutation(
raise ValueError(f"when must be >= 0, got {when}")

pcopy = copy.deepcopy(pop)
pre_sweep_pdict = {k: copy.deepcopy(v)
for k, v in params.asdict().items()}
pre_sweep_pdict = {k: copy.deepcopy(v) for k, v in params.asdict().items()}
pre_sweep_pdict["simlen"] = when
pre_sweep_params = fwdpy11.ModelParams(**pre_sweep_pdict)
fwdpy11.evolvets(rng, pcopy, pre_sweep_params,
**evolvets_options.asdict())
fwdpy11.evolvets(rng, pcopy, pre_sweep_params, **evolvets_options.asdict())

count_range = _get_allele_count_range(pcopy, mutation_parameters)
for c in count_range:
Expand Down Expand Up @@ -371,7 +371,7 @@ def _track_added_mutation(
idx,
pcopy.mutations[idx].key,
return_when_stopping_condition_met,
SimulationStatus(False, False),
SimulationStatus.Continue,
),
_sampling_policy,
)
Expand All @@ -390,7 +390,6 @@ def _track_added_mutation(
or max_attempts is not None
and attempt < max_attempts
):

# NOTE: deepcopy and not copy!
pcopy_loop = copy.deepcopy(pcopy)
local_params_copy = copy.deepcopy(local_params)
Expand All @@ -410,7 +409,7 @@ def _track_added_mutation(

# The sim ended, so
# check if condition was satisfied or not
if recorder.monitor.status.condition_met is True:
if recorder.monitor.status == SimulationStatus.Success:
pop_to_return = pcopy_loop
_evolvets_options = evolvets_options_copy
local_params = local_params_copy
Expand All @@ -422,6 +421,7 @@ def _track_added_mutation(
raise OutOfAttempts()

assert pop_to_return is not None

return ConditionalModelOutput(
pop=pop_to_return,
params=local_params,
Expand Down

0 comments on commit f1afeb7

Please sign in to comment.