Skip to content

Commit

Permalink
Ignore NaNs in FeatureFeatureCorrelation (#1718)
Browse files Browse the repository at this point in the history
* NaNs encoded as -1 are replaced back with NaN

* NaNs dropped at FeatureFeatureCorrelation
  • Loading branch information
TheSolY committed Jul 5, 2022
1 parent a2d8231 commit 54ca182
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from typing import List, Union

import numpy as np
import pandas as pd
import plotly.express as px

Expand Down Expand Up @@ -81,6 +82,8 @@ def run_logic(self, context: Context, dataset_kind) -> CheckResult:
num_features = [f for f in dataset.numerical_features if f in df.columns]
cat_features = [f for f in dataset.cat_features if f in df.columns]
encoded_cat_data = df.loc[:, cat_features].apply(lambda x: pd.factorize(x)[0])
# NaNs are encoded as -1, replace back to NaN
encoded_cat_data.replace(-1, np.NaN, inplace=True)

all_features = num_features + cat_features
full_df = pd.DataFrame(index=all_features, columns=all_features)
Expand Down
3 changes: 2 additions & 1 deletion deepchecks/utils/correlation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def correlation_ratio(categorical_data: Union[List, np.ndarray, pd.Series],
if ignore_mask:
numerical_data = numerical_data[~np.asarray(ignore_mask)]
categorical_data = categorical_data[~np.asarray(ignore_mask)]
cat_num = np.max(categorical_data) + 1

cat_num = int(np.max(categorical_data) + 1)
y_avg_array = np.zeros(cat_num)
n_array = np.zeros(cat_num)
for i in range(cat_num):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def test_feature_feature_correlation_pass_condition(adult_no_split):
def test_feature_feature_correlation_fail_condition(adult_no_split):
threshold = 0.5
num_pairs = 3
high_pairs = [('age', 'marital-status'), ('education-num', 'occupation'), ('edu'
''
'cation', 'education-num'),
high_pairs = [('age', 'marital-status'), ('education-num', 'occupation'), ('education', 'education-num'),
('marital-status', 'relationship')]
check = FeatureFeatureCorrelation()
result = check.add_condition_max_number_of_pairs_above_threshold(threshold, num_pairs).run(adult_no_split)
Expand Down

0 comments on commit 54ca182

Please sign in to comment.