Skip to content

Commit

Permalink
Make sure X and y are sorted equally in select_features (#715)
Browse files Browse the repository at this point in the history
* Make sure X and y are sorted equally in select_features and add a degregation test against it

* fix style

* Changelog

* Move the index sorting one step further down
  • Loading branch information
nils-braun committed Jun 20, 2020
1 parent f1fcdd2 commit 7656f51
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Unreleased
- Reworked the notebooks (#701, #704)
- Speed up the result pivoting (#705)
- Bugfixes:
- Fixed a bug in the selection, that caused all regression tasks with un-ordered index to be wrong (#715)
- Fixed readthedocs (#695, #696)
- Fix spark and dask after #705 and for non-id named id columns (#712)

Expand Down
13 changes: 13 additions & 0 deletions tests/integrations/examples/test_driftbif_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd

from tsfresh.examples.driftbif_simulation import velocity, load_driftbif, sample_tau
from tsfresh import extract_relevant_features


class DriftBifSimlationTestCase(unittest.TestCase):
Expand Down Expand Up @@ -63,6 +64,18 @@ def test_dimensionality(self):
self.assertEqual(v.shape, (Nt, 3),
'The returned vector should reflect the dimension of the initial condition.')

def test_relevant_feature_extraction(self):
df, y = load_driftbif(100, 10, classification=False, seed=42)

df['id'] = df['id'].astype('str')
y.index = y.index.astype('str')

X = extract_relevant_features(df, y,
column_id="id", column_sort="time",
column_kind="dimension", column_value="value")

self.assertGreater(len(X.columns), 10)


class SampleTauTestCase(unittest.TestCase):
def test_range(self):
Expand Down
10 changes: 8 additions & 2 deletions tests/units/feature_selection/test_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,14 @@ def test_target_real_calls_correct_tests(self, significance_test_feature_binary_

assert 0.5 == relevance_table.loc['feature_binary'].p_value
assert 0.7 == relevance_table.loc['feature_real'].p_value
significance_test_feature_binary_mock.assert_called_once_with(X['feature_binary'], y=y_real)
significance_test_feature_real_mock.assert_called_once_with(X['feature_real'], y=y_real)

assert significance_test_feature_binary_mock.call_count == 1
pd.testing.assert_series_equal(significance_test_feature_binary_mock.call_args[0][0], X["feature_binary"])
pd.testing.assert_series_equal(significance_test_feature_binary_mock.call_args[1]["y"], y_real)

assert significance_test_feature_real_mock.call_count == 1
pd.testing.assert_series_equal(significance_test_feature_real_mock.call_args[0][0], X["feature_real"])
pd.testing.assert_series_equal(significance_test_feature_real_mock.call_args[1]["y"], y_real)

@mock.patch('tsfresh.feature_selection.relevance.target_real_feature_real_test')
@mock.patch('tsfresh.feature_selection.relevance.target_real_feature_binary_test')
Expand Down
7 changes: 7 additions & 0 deletions tsfresh/feature_selection/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def calculate_relevance_table(X, y, ml_task='auto', n_jobs=defaults.N_PROCESSES,
not relevant] for this feature)
:rtype: pandas.DataFrame
"""

# Make sure X and y both have the exact same indices
y = y.sort_index()
X = X.sort_index()

assert list(y.index) == list(X.index), "The index of X and y need to be the same"

if ml_task not in ['auto', 'classification', 'regression']:
raise ValueError('ml_task must be one of: \'auto\', \'classification\', \'regression\'')
elif ml_task == 'auto':
Expand Down
2 changes: 2 additions & 0 deletions tsfresh/feature_selection/significance_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def __check_if_pandas_series(x, y):
raise TypeError("x should be a pandas Series")
if not isinstance(y, pd.Series):
raise TypeError("y should be a pandas Series")
if not list(y.index) == list(x.index):
raise ValueError("X and y need to have the same index!")


def __check_for_binary_target(y):
Expand Down

0 comments on commit 7656f51

Please sign in to comment.