Permalink
Browse files

Merge pull request #207 from nextstrain/fix_fitness_intervals

Use frequency time intervals for fitness model
  • Loading branch information...
huddlej committed Aug 24, 2018
2 parents f61cb0b + 1705c48 commit 86a89697856c9f26d7950e951dd99871f292a960
Showing with 5 additions and 13 deletions.
  1. +2 −9 base/fitness_model.py
  2. +1 −1 base/frequencies.py
  3. +1 −2 base/process.py
  4. +1 −1 tests/test_frequencies.py
View
@@ -100,15 +100,14 @@ def sum_of_squared_errors(observed_freq, predicted_freq):
class fitness_model(object):
def __init__(self, tree, frequencies, time_interval, predictor_input, censor_frequencies=True,
def __init__(self, tree, frequencies, predictor_input, censor_frequencies=True,
pivot_spacing=1.0 / 12, verbose=0, enforce_positive_predictors=True, predictor_kwargs=None,
cost_function=sum_of_squared_errors, **kwargs):
"""
Args:
tree (Bio.Phylo): an annotated tree for which a fitness model is to be determined
frequencies (KdeFrequencies): a frequency estimator and its parameters
time_interval:
predictor_input: a list of predictors to fit or dict of predictors to coefficients / std deviations
censor_frequencies (bool): whether frequencies should censor future data or not
pivot_spacing:
@@ -136,13 +135,6 @@ def __init__(self, tree, frequencies, time_interval, predictor_input, censor_fre
self.time_window = kwargs.get("time_window", 6.0 / 12.0)
# Convert datetime date interval to floating point interval from
# earliest to latest.
self.time_interval = (
time_interval[1].year + (time_interval[1].month) / 12.0,
time_interval[0].year + (time_interval[0].month - 1) / 12.0
)
if isinstance(predictor_input, dict):
predictor_names = predictor_input.keys()
self.estimate_coefficients = False
@@ -164,6 +156,7 @@ def __init__(self, tree, frequencies, time_interval, predictor_input, censor_fre
self.pivots = self.frequencies.pivots
# final timepoint is end of interval and is only projected forward, not tested
self.time_interval = (self.frequencies.start_date, self.frequencies.end_date)
self.timepoint_step_size = 0.5 # amount of time between timepoints chosen for fitting
self.delta_time = 1.0 # amount of time projected forward to do fitting
self.timepoints = np.around(
View
@@ -699,7 +699,7 @@ def calculate_pivots(cls, pivot_frequency, tree=None, start_date=None, end_date=
pivots = np.arange(
pivot_start,
pivot_end,
pivot_end + 0.0001,
pivot_frequency
)
View
@@ -646,8 +646,7 @@ def annotate_fitness(self):
kwargs = {
"tree": self.tree.tree,
"frequencies": self.kde_frequencies,
"time_interval": self.info["time_interval"]
"frequencies": self.kde_frequencies
}
if "predictors" in self.config:
@@ -60,7 +60,7 @@ def test_calculate_pivots_from_start_and_end_date(self):
assert isinstance(pivots, np.ndarray)
assert pivots[1] - pivots[0] == pivot_frequency
assert pivots[0] == start_date
assert pivots[-1] != end_date
assert pivots[-1] == end_date
assert pivots[-1] >= end_date - pivot_frequency
def test_estimate(self, tree):

0 comments on commit 86a8969

Please sign in to comment.