Skip to content

Commit

Permalink
Merge pull request #426 from nextstrain/handle-missing-weight-attributes
Browse files Browse the repository at this point in the history
Handle missing weight attributes in KDE frequency estimation
  • Loading branch information
trvrb committed Dec 17, 2019
2 parents 1bb1456 + f7bff49 commit 71cd517
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
9 changes: 7 additions & 2 deletions augur/frequencies.py
Expand Up @@ -8,7 +8,7 @@
from Bio.Align import MultipleSeqAlignment from Bio.Align import MultipleSeqAlignment


from .frequency_estimators import get_pivots, alignment_frequencies, tree_frequencies from .frequency_estimators import get_pivots, alignment_frequencies, tree_frequencies
from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies, TreeKdeFrequenciesError
from .utils import read_metadata, read_node_data, write_json, get_numerical_dates from .utils import read_metadata, read_node_data, write_json, get_numerical_dates




Expand Down Expand Up @@ -162,7 +162,12 @@ def run(args):
include_internal_nodes=args.include_internal_nodes, include_internal_nodes=args.include_internal_nodes,
censored=args.censored censored=args.censored
) )
frequencies = kde_frequencies.estimate(tree)
try:
frequencies = kde_frequencies.estimate(tree)
except TreeKdeFrequenciesError as e:
print("ERROR: %s" % str(e), file=sys.stderr)
return 1


# Export frequencies in auspice-format by strain name. # Export frequencies in auspice-format by strain name.
frequency_dict = {"pivots": list(kde_frequencies.pivots)} frequency_dict = {"pivots": list(kde_frequencies.pivots)}
Expand Down
15 changes: 14 additions & 1 deletion augur/frequency_estimators.py
Expand Up @@ -12,6 +12,12 @@
log_thres = 10.0 log_thres = 10.0




class TreeKdeFrequenciesError(Exception):
"""Represents an error estimating KDE frequencies for a tree.
"""
pass


def get_pivots(observations, pivot_interval, start_date=None, end_date=None): def get_pivots(observations, pivot_interval, start_date=None, end_date=None):
"""Calculate pivots for a given list of floating point observation dates and """Calculate pivots for a given list of floating point observation dates and
interval between pivots. interval between pivots.
Expand Down Expand Up @@ -1143,6 +1149,13 @@ def estimate(self, tree):
for key, value in self.weights.items(): for key, value in self.weights.items():
self.weights[key] = value / weight_total self.weights[key] = value / weight_total


# Confirm that one or more weights are represented by tips in the
# tree. If there are no more weights, raise an exception because
# this likely represents a data error (either in the tree
# annotations or the weight definitions).
if len(self.weights) == 0:
raise TreeKdeFrequenciesError("None of the provided weight attributes were represented by tips in the given tree. Doublecheck weight attribute definitions and their representations in the tree.")

# Estimate frequencies for all tips within each weight attribute # Estimate frequencies for all tips within each weight attribute
# group. # group.
weight_keys, weight_values = zip(*sorted(self.weights.items())) weight_keys, weight_values = zip(*sorted(self.weights.items()))
Expand All @@ -1153,7 +1166,7 @@ def estimate(self, tree):
# Find tips with the current weight attribute. # Find tips with the current weight attribute.
tips = [(tip.name, tip.attr["num_date"]) tips = [(tip.name, tip.attr["num_date"])
for tip in tree.get_terminals() for tip in tree.get_terminals()
if tip.attr[self.weights_attribute].lower() == weight_key and self.tip_passes_filters(tip)] if tip.attr[self.weights_attribute] == weight_key and self.tip_passes_filters(tip)]
frequencies.update(self.estimate_tip_frequencies_to_proportion(tips, proportion)) frequencies.update(self.estimate_tip_frequencies_to_proportion(tips, proportion))
else: else:
tips = [(tip.name, tip.attr["num_date"]) tips = [(tip.name, tip.attr["num_date"])
Expand Down
15 changes: 14 additions & 1 deletion tests/python3/test_frequencies.py
Expand Up @@ -12,7 +12,7 @@
# we assume (and assert) that this script is running from the tests/ directory # we assume (and assert) that this script is running from the tests/ directory
sys.path.append(str(Path(__file__).parent.parent.parent)) sys.path.append(str(Path(__file__).parent.parent.parent))


from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies, TreeKdeFrequenciesError
from augur.utils import json_to_tree from augur.utils import json_to_tree


# Define regions to use for testing weighted frequencies. # Define regions to use for testing weighted frequencies.
Expand Down Expand Up @@ -173,6 +173,19 @@ def test_weighted_estimate_with_unrepresented_weights(self, tree):
# Frequencies should sum to 1 at all pivots. # Frequencies should sum to 1 at all pivots.
assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots)) assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots))


# Estimate weighted frequencies such that all weighted attributes are
# missing. This should raise an exception because none of the tips will
# match any of the weights and the weighting of frequencies will be
# impossible.
weights = {"fake_region_1": 1.0, "fake_region_2": 2.0}
kde_frequencies = TreeKdeFrequencies(
weights=weights,
weights_attribute="region"
)

with pytest.raises(TreeKdeFrequenciesError):
frequencies = kde_frequencies.estimate(tree)

def test_only_tip_estimates(self, tree): def test_only_tip_estimates(self, tree):
"""Test frequency estimation for only tips in a given tree. """Test frequency estimation for only tips in a given tree.
""" """
Expand Down

0 comments on commit 71cd517

Please sign in to comment.