Handle missing weight attributes in KDE frequency estimation

Throw an exception if the user has requested weighted KDE frequencies with
weights that do not match any of the tips in the given tree. This commit
explicitly checks for an empty dictionary of weights after filtering for
representation by tips and raise an exception with a meaningful error
message (instead of allowing the code to continue running and throwing a less
meaningful ValueError when no valid weights remain). This commit also adds a
unit test for this behavior.

Closes #425.
huddlej committed Dec 16, 2019
1 parent 1bb1456 commit 6e006b9246312ca2b5d6de46c8ba06bee91595a1
@@ -8,7 +8,7 @@
from Bio.Align import MultipleSeqAlignment

from .frequency_estimators import get_pivots, alignment_frequencies, tree_frequencies
from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies
from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies, TreeKdeFrequenciesError
from .utils import read_metadata, read_node_data, write_json, get_numerical_dates

@@ -162,7 +162,12 @@ def run(args):
frequencies = kde_frequencies.estimate(tree)

frequencies = kde_frequencies.estimate(tree)
except TreeKdeFrequenciesError as e:
print("ERROR: %s" % str(e), file=sys.stderr)
return 1

# Export frequencies in auspice-format by strain name.
frequency_dict = {"pivots": list(kde_frequencies.pivots)}
@@ -12,6 +12,12 @@
log_thres = 10.0

class TreeKdeFrequenciesError(Exception):
"""Represents an error estimating KDE frequencies for a tree.

def get_pivots(observations, pivot_interval, start_date=None, end_date=None):
"""Calculate pivots for a given list of floating point observation dates and
interval between pivots.
@@ -1143,6 +1149,13 @@ def estimate(self, tree):
for key, value in self.weights.items():
self.weights[key] = value / weight_total

# Confirm that one or more weights are represented by tips in the
# tree. If there are no more weights, raise an exception because
# this likely represents a data error (either in the tree
# annotations or the weight definitions).
if len(self.weights) == 0:
raise TreeKdeFrequenciesError("None of the provided weight attributes were represented by tips in the given tree. Doublecheck weight attribute definitions and their representations in the tree.")

# Estimate frequencies for all tips within each weight attribute
# group.
weight_keys, weight_values = zip(*sorted(self.weights.items()))
@@ -12,7 +12,7 @@
# we assume (and assert) that this script is running from the tests/ directory

from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies
from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies, TreeKdeFrequenciesError
from augur.utils import json_to_tree

# Define regions to use for testing weighted frequencies.
@@ -173,6 +173,19 @@ def test_weighted_estimate_with_unrepresented_weights(self, tree):
# Frequencies should sum to 1 at all pivots.
assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots))

# Estimate weighted frequencies such that all weighted attributes are
# missing. This should raise an exception because none of the tips will
# match any of the weights and the weighting of frequencies will be
# impossible.
weights = {"fake_region_1": 1.0, "fake_region_2": 2.0}
kde_frequencies = TreeKdeFrequencies(

with pytest.raises(TreeKdeFrequenciesError):
frequencies = kde_frequencies.estimate(tree)

def test_only_tip_estimates(self, tree):
"""Test frequency estimation for only tips in a given tree.

