Skip to content

Commit

Permalink
fix indexing error when creating a filtered error analysis tree view …
Browse files Browse the repository at this point in the history
…with a dataset that contains categoricals
  • Loading branch information
imatiach-msft committed Apr 5, 2023
1 parent dd7d6d5 commit ea4f305
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def get_surrogate_booster_local(filtered_df, analyzer, is_model_analyzer,
else:
string_indexed_data = analyzer.string_indexed_data
for idx, c_i in enumerate(analyzer.categorical_indexes):
input_data[:, c_i] = string_indexed_data[row_index, idx]
input_data[:, c_i] = string_indexed_data[:, idx]
dataset_sub_features = input_data[:, indexes]

categorical_info = get_categorical_info(analyzer,
Expand Down
14 changes: 13 additions & 1 deletion erroranalysis/tests/test_surrogate_error_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,24 @@ def test_surrogate_error_tree_iris(self, analyzer_type):
def test_surrogate_error_tree_int_categorical(self, analyzer_type):
X_train, X_test, y_train, y_test, categorical_features = \
create_adult_census_data()

model = create_kneighbors_classifier(X_train, y_train)

run_error_analyzer(model, X_test, y_test, list(X_train.columns),
analyzer_type, categorical_features)

@pytest.mark.parametrize('analyzer_type', [AnalyzerType.MODEL,
AnalyzerType.PREDICTIONS])
def test_surrogate_error_tree_categorical_filtered(self, analyzer_type):
X_train, X_test, y_train, y_test, categorical_features = \
create_adult_census_data()
model = create_kneighbors_classifier(X_train, y_train)
filters = [{'arg': [40],
'column': 'Age',
'method': 'less and equal'}]
run_error_analyzer(model, X_test, y_test, list(X_train.columns),
analyzer_type, categorical_features,
filters=filters)

def test_large_data_surrogate_error_tree(self):
# validate tree trains quickly for large data
X_train, y_train, X_test, y_test, _ = \
Expand Down

0 comments on commit ea4f305

Please sign in to comment.