Permalink
Browse files

Merge pull request #224 from nextstrain/kde-frequency-node-filters

Enable filtering nodes by attributes in frequency estimation
  • Loading branch information...
trvrb committed Oct 28, 2018
2 parents f41137e + 722dac8 commit dde76589fc3ec5092ed3e5a4dd0ccb98376ded51
Showing with 61 additions and 4 deletions.
  1. +34 −4 base/frequencies.py
  2. +27 −0 tests/test_frequencies.py
View
@@ -579,7 +579,7 @@ class KdeFrequencies(object):
"""
def __init__(self, sigma_narrow=1 / 12.0, sigma_wide=3 / 12.0, proportion_wide=0.2,
pivot_frequency=1 / 12.0, start_date=None, end_date=None, weights=None, weights_attribute=None,
max_date=None, include_internal_nodes=False):
node_filters=None, max_date=None, include_internal_nodes=False):
"""Define parameters for KDE-based frequency estimation.
Args:
@@ -591,6 +591,7 @@ def __init__(self, sigma_narrow=1 / 12.0, sigma_wide=3 / 12.0, proportion_wide=0
end_date (float): end of the pivots interval
weights (dict): Numerical weights indexed by attribute values and applied to individual tips
weights_attribute (str): Attribute annotated on tips of a tree to use for weighting
node_filters (dict): Mapping of node attribute names (keys) to a list of valid values to keep
max_date (float): Maximum year beyond which tips are excluded from frequency estimation and are assigned
frequencies of zero
include_internal_nodes (bool): Whether internal (non-tip) nodes should have their frequencies estimated
@@ -606,6 +607,7 @@ def __init__(self, sigma_narrow=1 / 12.0, sigma_wide=3 / 12.0, proportion_wide=0
self.end_date = end_date
self.weights = weights
self.weights_attribute = weights_attribute
self.node_filters = node_filters
self.max_date = max_date
self.include_internal_nodes = include_internal_nodes
@@ -626,7 +628,8 @@ def get_params(self):
"weights": self.weights,
"weights_attribute": self.weights_attribute,
"max_date": self.max_date,
"include_internal_nodes": self.include_internal_nodes
"include_internal_nodes": self.include_internal_nodes,
"node_filters": self.node_filters
}
@classmethod
@@ -766,13 +769,24 @@ def estimate_frequencies(cls, tip_dates, pivots, normalize_to=1.0, **kwargs):
return normalized_freq_matrix
def tip_passes_filters(self, tip):
"""Returns a boolean indicating whether a given tip passes the node filters
defined for the current instance.
If no filters are defined, returns True.
"""
return (self.node_filters is None or
all([tip.attr[key] in values for key, values in self.node_filters.items()]))
def estimate_frequencies_for_tree(self, tree):
"""Estimate frequencies for all nodes in a tree across the given pivots.
"""
clade_frequencies = {}
# Collect dates for tips.
tips = [(tip.clade, tip.attr["num_date"]) for tip in tree.get_terminals()]
tips = [(tip.clade, tip.attr["num_date"])
for tip in tree.get_terminals()
if self.tip_passes_filters(tip)]
tips = np.array(sorted(tips, key=lambda row: row[1]))
clades = tips[:, 0].astype(int)
tip_dates = tips[:, 1].astype(float)
@@ -794,6 +808,11 @@ def estimate_frequencies_for_tree(self, tree):
for clade in clades:
clade_frequencies[clade] = normalized_freq_matrix[clade_to_index[clade]]
# Assign zero frequencies to any tips that were filtered out of the frequency estimation.
for tip in tree.get_terminals():
if not tip.clade in clade_frequencies:
clade_frequencies[tip.clade] = np.zeros_like(self.pivots)
if self.include_internal_nodes:
for node in tree.find_clades(order="postorder"):
if not node.is_terminal():
@@ -815,7 +834,13 @@ def estimate_weighted_frequencies_for_tree(self, tree):
# Find tips with the current weight attribute.
tips = [(tip.clade, tip.attr["num_date"])
for tip in tree.get_terminals()
if tip.attr[self.weights_attribute] == weight_key]
if tip.attr[self.weights_attribute] == weight_key and self.tip_passes_filters(tip)]
# If none of the tips pass the given node filters, do not try to
# normalize tip frequencies.
if len(tips) == 0:
continue
tips = np.array(sorted(tips, key=lambda row: row[1]))
clades = tips[:, 0].astype(int)
tip_dates = tips[:, 1].astype(float)
@@ -837,6 +862,11 @@ def estimate_weighted_frequencies_for_tree(self, tree):
for clade in clades:
clade_frequencies[clade] = normalized_freq_matrix[clade_to_index[clade]]
# Assign zero frequencies to any tips that were filtered out of the frequency estimation.
for tip in tree.get_terminals():
if not tip.clade in clade_frequencies:
clade_frequencies[tip.clade] = np.zeros_like(self.pivots)
if self.include_internal_nodes:
for node in tree.find_clades(order="postorder"):
if not node.is_terminal():
View
@@ -176,6 +176,32 @@ def test_censored_frequencies(self, tree):
for tip in tree.get_terminals()
if tip.attr["num_date"] <= max_date])
def test_node_filter(self, tree):
"""Test frequency estimation with specific nodes omitted by setting their
frequencies to zero at all pivots.
"""
# Filter nodes by region.
regions = ["china"]
kde_frequencies = KdeFrequencies(
node_filters={"region": regions}
)
frequencies = kde_frequencies.estimate(tree)
# Verify that all tips have frequency estimates regardless of node
# filter.
assert all([tip.clade in frequencies
for tip in tree.get_terminals()])
# Verify that all tips from the requested region have non-zero frequencies.
assert all([frequencies[tip.clade].sum() > 0
for tip in tree.get_terminals()
if tip.attr["region"] in regions])
# Verify that all tips not from the requested region have zero frequencies.
assert all([frequencies[tip.clade].sum() == 0
for tip in tree.get_terminals()
if tip.attr["region"] not in regions])
def test_export_with_frequencies(self, tree):
"""Test frequencies export to JSON when frequencies have been estimated.
"""
@@ -199,6 +225,7 @@ def test_export_without_frequencies(self):
assert "params" in frequencies_json
assert kde_frequencies.pivot_frequency == frequencies_json["params"]["pivot_frequency"]
assert "node_filters" in frequencies_json["params"]
assert "data" not in frequencies_json
def test_import(self, tree, tmpdir):

0 comments on commit dde7658

Please sign in to comment.