perf: Optimize PredictionsModelWrapper lookup from O(n*m) to O(1)#185
Merged
imatiach-msft merged 2 commits intomainfrom Mar 12, 2026
Merged
perf: Optimize PredictionsModelWrapper lookup from O(n*m) to O(1)#185imatiach-msft merged 2 commits intomainfrom
imatiach-msft merged 2 commits intomainfrom
Conversation
Problem: The PredictionsModelWrapper's predict() and predict_proba() methods used row-by-row DataFrame filtering with O(n*m) complexity per row, where n is the dataset size and m is the number of features. This caused severe performance issues when used with RAI insights on large datasets. For example, with 93K training rows and 34 features, predict() processed only ~35-40 rows/sec, causing RAI dashboard generation to timeout after 2+ hours when the MimicExplainer needed to call predict() on all training data for surrogate model training. Solution: - Build a hash-based index at wrapper initialization (O(N) one-time cost) - Use hash lookup for each query row (O(1) average case) - Handle hash collisions by comparing actual values - Fallback to original filtering for edge cases where hash lookup fails The optimization applies to both predict() in the base class and predict_proba() in the classification wrapper. Expected speedup: ~10,000x for large datasets (from 35-40 rows/sec to ~400,000+ rows/sec) Testing: - All existing tests pass - Added performance tests verifying speedup - Added tests for hash collision handling - Added tests for NaN values, mixed dtypes, index mismatches Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Contributor
Author
|
close/reopen for CI |
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #185 +/- ##
==========================================
- Coverage 33.11% 32.05% -1.07%
==========================================
Files 26 26
Lines 1869 1931 +62
==========================================
Hits 619 619
- Misses 1250 1312 +62
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The
PredictionsModelWrapper.predict()andpredict_proba()methods use row-by-row DataFrame filtering with O(n×m) complexity per row, where n is the dataset size and m is the number of features.This causes severe performance issues when used with RAI insights on large datasets:
Root Cause Analysis
The bottleneck occurs in
_get_filtered_data()which iterates through all features and filters the dataset for each feature value. For each query row, this performs roughlyn × mcomparisons.When
MimicExplainer._model_distill()(in interpret-community) trains its surrogate model, it callsmodel.predict()on the entire training dataset (93K rows). With 93K rows × 34 features × 98K dataset size = ~300 billion comparisons.Solution
This PR introduces hash-based O(1) lookups:
Key Changes
_compute_row_hash()- Computes consistent hash for row values, handling NaN and various dtypes_build_row_hash_index()- Builds hash→index mapping at initialization_lookup_by_hash()- O(1) lookup with collision handling_rows_equal()- Value comparison handling NaN equalitypredict()in base class to use hash lookup with fallbackpredict_proba()in classification wrapper to use hash lookup with fallback__getstate__/__setstate__to rebuild hash index after pickle deserializationExpected Performance
Testing
test_large_dataset_performance- Verifies speedup with 10K rowstest_hash_collision_handling- Tests duplicate value scenariostest_nan_values_with_hash_lookup- NaN handlingtest_mixed_dtypes_with_hash_lookup- Mixed int/float/str/bool columnstest_reset_index_query_with_hash_lookup- Index mismatch scenariostest_duplicate_rows_with_hash_lookup- Duplicate row handlingtest_consistency_with_original_behavior- Verifies identical outputBackwards Compatibility